pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -121,20 +121,18 @@ class ModelConfig:
121
121
  latent_dim (int): Dimensionality of the latent space.
122
122
  dropout_rate (float): Dropout rate for regularization.
123
123
  num_hidden_layers (int): Number of hidden layers in the neural network.
124
- hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
124
+ activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
125
125
  layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
126
- layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
127
- gamma (float): Parameter for the focal loss function.
126
+ layer_schedule (Literal["pyramid", "linear"]): Schedule for scaling hidden layer sizes.
128
127
  """
129
128
 
130
129
  latent_init: Literal["random", "pca"] = "random"
131
130
  latent_dim: int = 2
132
131
  dropout_rate: float = 0.2
133
132
  num_hidden_layers: int = 2
134
- hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
133
+ activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
135
134
  layer_scaling_factor: float = 5.0
136
- layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
137
- gamma: float = 2.0
135
+ layer_schedule: Literal["pyramid", "linear"] = "pyramid"
138
136
 
139
137
 
140
138
  @dataclass
@@ -144,28 +142,39 @@ class TrainConfig:
144
142
  Attributes:
145
143
  batch_size (int): Number of samples per training batch.
146
144
  learning_rate (float): Learning rate for the optimizer.
147
- lr_input_factor (float): Factor to scale the learning rate for input layer.
148
145
  l1_penalty (float): L1 regularization penalty.
149
146
  early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
150
147
  min_epochs (int): Minimum number of epochs to train.
151
148
  max_epochs (int): Maximum number of epochs to train.
152
149
  validation_split (float): Proportion of data to use for validation.
153
- weights_beta (float): Smoothing factor for class weights.
154
- weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
150
+ weights_max_ratio (float | None): Maximum ratio for class weights to prevent extreme values.
151
+ gamma (float): Focusing parameter for focal loss.
155
152
  device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
156
153
  """
157
154
 
158
- batch_size: int = 32
155
+ batch_size: int = 64
159
156
  learning_rate: float = 1e-3
160
- lr_input_factor: float = 1.0
161
157
  l1_penalty: float = 0.0
162
- early_stop_gen: int = 20
158
+ early_stop_gen: int = 25
163
159
  min_epochs: int = 100
164
- max_epochs: int = 5000
160
+ max_epochs: int = 2000
165
161
  validation_split: float = 0.2
166
- weights_beta: float = 0.9999
167
- weights_max_ratio: float = 1.0
168
162
  device: Literal["gpu", "cpu", "mps"] = "cpu"
163
+ weights_max_ratio: Optional[float] = None
164
+ weights_power: float = 1.0
165
+ weights_normalize: bool = True
166
+ weights_inverse: bool = False
167
+ gamma: float = 0.0
168
+ gamma_schedule: bool = False
169
+
170
+
171
+ def _default_train_config() -> TrainConfig:
172
+ """Typed default factory for TrainConfig (helps some type checkers).
173
+
174
+ Using the class object directly (default_factory=TrainConfig) is valid at runtime but certain type checkers can fail to match dataclasses.field overloads.
175
+ """
176
+
177
+ return TrainConfig()
169
178
 
170
179
 
171
180
  @dataclass
@@ -174,19 +183,13 @@ class TuneConfig:
174
183
 
175
184
  Attributes:
176
185
  enabled (bool): If True, enables hyperparameter tuning.
177
- metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
186
+ metric (Literal["f1", "accuracy", "pr_macro", "average_precision", "roc_auc", "precision", "recall", "mcc", "jaccard"]): Metric to optimize during tuning.
178
187
  n_trials (int): Number of hyperparameter trials to run.
179
188
  resume (bool): If True, resumes tuning from a previous state.
180
189
  save_db (bool): If True, saves the tuning results to a database.
181
- fast (bool): If True, uses a faster but less thorough tuning approach.
182
- max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
183
- max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
184
190
  epochs (int): Number of epochs to train each trial.
185
191
  batch_size (int): Batch size for training during tuning.
186
- eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
187
- infer_epochs (int): Number of epochs for inference during tuning.
188
192
  patience (int): Number of evaluations with no improvement before stopping early.
189
- proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
190
193
  """
191
194
 
192
195
  enabled: bool = False
@@ -198,34 +201,15 @@ class TuneConfig:
198
201
  "roc_auc",
199
202
  "precision",
200
203
  "recall",
204
+ "mcc",
205
+ "jaccard",
201
206
  ] = "f1"
202
207
  n_trials: int = 100
203
208
  resume: bool = False
204
209
  save_db: bool = False
205
- fast: bool = True
206
- max_samples: int = 512
207
- max_loci: int = 0 # 0 = all
208
210
  epochs: int = 500
209
211
  batch_size: int = 64
210
- eval_interval: int = 20
211
- infer_epochs: int = 100
212
212
  patience: int = 10
213
- proxy_metric_batch: int = 0
214
-
215
-
216
- @dataclass
217
- class EvalConfig:
218
- """Evaluation configuration.
219
-
220
- Attributes:
221
- eval_latent_steps (int): Number of optimization steps for latent space evaluation.
222
- eval_latent_lr (float): Learning rate for latent space optimization.
223
- eval_latent_weight_decay (float): Weight decay for latent space optimization.
224
- """
225
-
226
- eval_latent_steps: int = 50
227
- eval_latent_lr: float = 1e-2
228
- eval_latent_weight_decay: float = 0.0
229
213
 
230
214
 
231
215
  @dataclass
@@ -244,15 +228,18 @@ class PlotConfig:
244
228
  dpi: int = 300
245
229
  fontsize: int = 18
246
230
  despine: bool = True
247
- show: bool = False
231
+ show: bool = True
248
232
 
249
233
 
250
234
  @dataclass
251
235
  class IOConfig:
252
236
  """I/O configuration.
253
237
 
238
+ Dataclass that includes configuration settings for file naming, logging verbosity, random seed, and parallelism.
239
+
254
240
  Attributes:
255
241
  prefix (str): Prefix for output files. Default is "pgsui".
242
+ ploidy (int): Ploidy level of the organism. Default is 2.
256
243
  verbose (bool): If True, enables verbose logging. Default is False.
257
244
  debug (bool): If True, enables debug mode. Default is False.
258
245
  seed (int | None): Random seed for reproducibility. Default is None.
@@ -261,6 +248,7 @@ class IOConfig:
261
248
  """
262
249
 
263
250
  prefix: str = "pgsui"
251
+ ploidy: int = 2
264
252
  verbose: bool = False
265
253
  debug: bool = False
266
254
  seed: int | None = None
@@ -287,37 +275,46 @@ class SimConfig:
287
275
  "nonrandom",
288
276
  "nonrandom_weighted",
289
277
  ] = "random"
290
- sim_prop: float = 0.10
278
+ sim_prop: float = 0.20
291
279
  sim_kwargs: dict | None = None
292
280
 
293
281
 
294
282
  @dataclass
295
- class NLPCAConfig:
296
- """Top-level configuration for ImputeNLPCA.
283
+ class AutoencoderConfig:
284
+ """Top-level configuration for ImputeAutoencoder.
285
+
286
+ This configuration class encapsulates all settings required for the
287
+ ImputeAutoencoder model, including I/O, model architecture, training,
288
+ hyperparameter tuning, plotting, and simulated-missing configuration.
297
289
 
298
290
  Attributes:
299
291
  io (IOConfig): I/O configuration.
300
292
  model (ModelConfig): Model architecture configuration.
301
293
  train (TrainConfig): Training procedure configuration.
302
294
  tune (TuneConfig): Hyperparameter tuning configuration.
303
- evaluate (EvalConfig): Evaluation configuration.
304
295
  plot (PlotConfig): Plotting configuration.
305
- sim (SimConfig): Simulation configuration.
296
+ sim (SimConfig): Simulated-missing configuration.
306
297
  """
307
298
 
308
299
  io: IOConfig = field(default_factory=IOConfig)
309
300
  model: ModelConfig = field(default_factory=ModelConfig)
310
- train: TrainConfig = field(default_factory=TrainConfig)
301
+ train: TrainConfig = field(default_factory=_default_train_config)
311
302
  tune: TuneConfig = field(default_factory=TuneConfig)
312
- evaluate: EvalConfig = field(default_factory=EvalConfig)
313
303
  plot: PlotConfig = field(default_factory=PlotConfig)
314
304
  sim: SimConfig = field(default_factory=SimConfig)
315
305
 
316
306
  @classmethod
317
307
  def from_preset(
318
308
  cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
319
- ) -> "NLPCAConfig":
320
- """Build a NLPCAConfig from a named preset."""
309
+ ) -> "AutoencoderConfig":
310
+ """Build a AutoencoderConfig from a named preset.
311
+
312
+ Args:
313
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
314
+
315
+ Returns:
316
+ AutoencoderConfig: Configuration instance corresponding to the preset.
317
+ """
321
318
  if preset not in {"fast", "balanced", "thorough"}:
322
319
  raise ValueError(f"Unknown preset: {preset}")
323
320
 
@@ -325,414 +322,87 @@ class NLPCAConfig:
325
322
 
326
323
  # Common baselines
327
324
  cfg.io.verbose = False
328
- cfg.train.validation_split = 0.20
329
- cfg.model.hidden_activation = "relu"
325
+ cfg.io.ploidy = 2
326
+ cfg.train.validation_split = 0.2
327
+ cfg.model.activation = "relu"
330
328
  cfg.model.layer_schedule = "pyramid"
331
- cfg.model.latent_init = "random"
332
- cfg.evaluate.eval_latent_lr = 1e-2
333
- cfg.evaluate.eval_latent_weight_decay = 0.0
334
- cfg.sim.simulate_missing = True
329
+ cfg.model.layer_scaling_factor = 2.0
335
330
  cfg.sim.sim_strategy = "random"
336
331
  cfg.sim.sim_prop = 0.2
332
+ cfg.plot.show = True
333
+
334
+ # Train settings
335
+ cfg.train.weights_max_ratio = None
336
+ cfg.train.weights_power = 1.0
337
+ cfg.train.weights_normalize = True
338
+ cfg.train.weights_inverse = False
339
+ cfg.train.gamma = 0.0
340
+ cfg.train.gamma_schedule = False
341
+ cfg.train.min_epochs = 100
342
+
343
+ # Tune
344
+ cfg.tune.enabled = False
345
+ cfg.tune.n_trials = 100
337
346
 
338
347
  if preset == "fast":
339
348
  # Model
340
349
  cfg.model.latent_dim = 4
341
350
  cfg.model.num_hidden_layers = 1
342
- cfg.model.layer_scaling_factor = 2.0
343
351
  cfg.model.dropout_rate = 0.10
344
- cfg.model.gamma = 1.5
345
- # Train
346
- cfg.train.batch_size = 128
347
- cfg.train.learning_rate = 1e-3
348
- cfg.train.early_stop_gen = 5
349
- cfg.train.min_epochs = 10
350
- cfg.train.max_epochs = 120
351
- cfg.train.weights_beta = 0.9999
352
- cfg.train.weights_max_ratio = 2.0
353
- # Tuning
354
- cfg.tune.enabled = True
355
- cfg.tune.fast = True
356
- cfg.tune.n_trials = 25
357
- cfg.tune.epochs = 120
358
- cfg.tune.batch_size = 128
359
- cfg.tune.max_samples = 512
360
- cfg.tune.max_loci = 0
361
- cfg.tune.eval_interval = 20
362
- cfg.tune.infer_epochs = 20
363
- cfg.tune.patience = 5
364
- cfg.tune.proxy_metric_batch = 0
365
- # Eval
366
- cfg.evaluate.eval_latent_steps = 20
367
352
 
368
- elif preset == "balanced":
369
- # Model
370
- cfg.model.latent_dim = 8
371
- cfg.model.num_hidden_layers = 2
372
- cfg.model.layer_scaling_factor = 3.0
373
- cfg.model.dropout_rate = 0.20
374
- cfg.model.gamma = 2.0
375
353
  # Train
376
354
  cfg.train.batch_size = 128
377
- cfg.train.learning_rate = 8e-4
355
+ cfg.train.learning_rate = 2e-3
378
356
  cfg.train.early_stop_gen = 15
379
- cfg.train.min_epochs = 50
380
- cfg.train.max_epochs = 600
381
- cfg.train.weights_beta = 0.9999
382
- cfg.train.weights_max_ratio = 2.0
383
- # Tuning
384
- cfg.tune.enabled = True
385
- cfg.tune.fast = True
386
- cfg.tune.n_trials = 75
387
- cfg.tune.epochs = 300
388
- cfg.tune.batch_size = 128
389
- cfg.tune.max_samples = 2048
390
- cfg.tune.max_loci = 0
391
- cfg.tune.eval_interval = 20
392
- cfg.tune.infer_epochs = 40
393
- cfg.tune.patience = 10
394
- cfg.tune.proxy_metric_batch = 0
395
- # Eval
396
- cfg.evaluate.eval_latent_steps = 30
357
+ cfg.train.max_epochs = 200
358
+ cfg.train.weights_max_ratio = None
397
359
 
398
- else: # thorough
399
- # Model
400
- cfg.model.latent_dim = 16
401
- cfg.model.num_hidden_layers = 3
402
- cfg.model.layer_scaling_factor = 5.0
403
- cfg.model.dropout_rate = 0.30
404
- cfg.model.gamma = 2.5
405
- # Train
406
- cfg.train.batch_size = 64
407
- cfg.train.learning_rate = 6e-4
408
- cfg.train.early_stop_gen = 20 # Reduced from 30
409
- cfg.train.min_epochs = 100
410
- cfg.train.max_epochs = 800 # Reduced from 1200
411
- cfg.train.weights_beta = 0.9999
412
- cfg.train.weights_max_ratio = 2.0
413
- # Tuning
414
- cfg.tune.enabled = True
415
- cfg.tune.fast = False
416
- cfg.tune.n_trials = 150
417
- cfg.tune.epochs = 600
418
- cfg.tune.batch_size = 64
419
- cfg.tune.max_samples = 5000 # Capped from 0
420
- cfg.tune.max_loci = 0
421
- cfg.tune.eval_interval = 10
422
- cfg.tune.infer_epochs = 80
423
- cfg.tune.patience = 15 # Reduced from 20
424
- cfg.tune.proxy_metric_batch = 0
425
- # Eval
426
- cfg.evaluate.eval_latent_steps = 50
427
-
428
- return cfg
429
-
430
- def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
431
- """Apply flat dot-key overrides."""
432
- if not overrides:
433
- return self
434
- for k, v in overrides.items():
435
- node = self
436
- parts = k.split(".")
437
- for p in parts[:-1]:
438
- node = getattr(node, p)
439
- last = parts[-1]
440
- if hasattr(node, last):
441
- setattr(node, last, v)
442
- else:
443
- raise KeyError(f"Unknown config key: {k}")
444
- return self
445
-
446
- def to_dict(self) -> Dict[str, Any]:
447
- return asdict(self)
448
-
449
-
450
- @dataclass
451
- class UBPConfig:
452
- """Top-level configuration for ImputeUBP.
453
-
454
- Attributes:
455
- io (IOConfig): I/O configuration.
456
- model (ModelConfig): Model architecture configuration.
457
- train (TrainConfig): Training procedure configuration.
458
- tune (TuneConfig): Hyperparameter tuning configuration.
459
- evaluate (EvalConfig): Evaluation configuration.
460
- plot (PlotConfig): Plotting configuration.
461
- sim (SimConfig): Simulated-missing configuration.
462
- """
463
-
464
- io: IOConfig = field(default_factory=IOConfig)
465
- model: ModelConfig = field(default_factory=ModelConfig)
466
- train: TrainConfig = field(default_factory=TrainConfig)
467
- tune: TuneConfig = field(default_factory=TuneConfig)
468
- evaluate: EvalConfig = field(default_factory=EvalConfig)
469
- plot: PlotConfig = field(default_factory=PlotConfig)
470
- sim: SimConfig = field(default_factory=SimConfig)
471
-
472
- @classmethod
473
- def from_preset(
474
- cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
475
- ) -> "UBPConfig":
476
- """Build a UBPConfig from a named preset."""
477
- if preset not in {"fast", "balanced", "thorough"}:
478
- raise ValueError(f"Unknown preset: {preset}")
479
-
480
- cfg = cls()
481
-
482
- # Common baselines
483
- cfg.io.verbose = False
484
- cfg.model.hidden_activation = "relu"
485
- cfg.model.layer_schedule = "pyramid"
486
- cfg.model.latent_init = "random"
487
- cfg.sim.simulate_missing = True
488
- cfg.sim.sim_strategy = "random"
489
- cfg.sim.sim_prop = 0.2
490
-
491
- if preset == "fast":
492
- # Model
493
- cfg.model.latent_dim = 4
494
- cfg.model.num_hidden_layers = 1
495
- cfg.model.layer_scaling_factor = 2.0
496
- cfg.model.dropout_rate = 0.10
497
- cfg.model.gamma = 1.5
498
- # Train
499
- cfg.train.batch_size = 128
500
- cfg.train.learning_rate = 1e-3
501
- cfg.train.early_stop_gen = 5
502
- cfg.train.min_epochs = 10
503
- cfg.train.max_epochs = 120
504
- cfg.train.weights_beta = 0.9999
505
- cfg.train.weights_max_ratio = 2.0
506
- # Tuning
507
- cfg.tune.enabled = True
508
- cfg.tune.fast = True
509
- cfg.tune.n_trials = 25
510
- cfg.tune.epochs = 120
511
- cfg.tune.batch_size = 128
512
- cfg.tune.max_samples = 512
513
- cfg.tune.max_loci = 0
514
- cfg.tune.eval_interval = 20
515
- cfg.tune.infer_epochs = 20
516
- cfg.tune.patience = 5
517
- cfg.tune.proxy_metric_batch = 0
518
- # Eval
519
- cfg.evaluate.eval_latent_steps = 20
520
- cfg.evaluate.eval_latent_lr = 1e-2
521
- cfg.evaluate.eval_latent_weight_decay = 0.0
360
+ # Tune
361
+ cfg.tune.patience = 15
522
362
 
523
363
  elif preset == "balanced":
524
364
  # Model
525
365
  cfg.model.latent_dim = 8
526
366
  cfg.model.num_hidden_layers = 2
527
- cfg.model.layer_scaling_factor = 3.0
528
367
  cfg.model.dropout_rate = 0.20
529
- cfg.model.gamma = 2.0
530
- # Train
531
- cfg.train.batch_size = 128
532
- cfg.train.learning_rate = 8e-4
533
- cfg.train.early_stop_gen = 15
534
- cfg.train.min_epochs = 50
535
- cfg.train.max_epochs = 600
536
- cfg.train.weights_beta = 0.9999
537
- cfg.train.weights_max_ratio = 2.0
538
- # Tuning
539
- cfg.tune.enabled = True
540
- cfg.tune.fast = True
541
- cfg.tune.n_trials = 75
542
- cfg.tune.epochs = 300
543
- cfg.tune.batch_size = 128
544
- cfg.tune.max_samples = 2048
545
- cfg.tune.max_loci = 0
546
- cfg.tune.eval_interval = 20
547
- cfg.tune.infer_epochs = 40
548
- cfg.tune.patience = 10
549
- cfg.tune.proxy_metric_batch = 0
550
- # Eval
551
- cfg.evaluate.eval_latent_steps = 30
552
- cfg.evaluate.eval_latent_lr = 1e-2
553
- cfg.evaluate.eval_latent_weight_decay = 0.0
554
368
 
555
- else: # thorough
556
- # Model
557
- cfg.model.latent_dim = 16
558
- cfg.model.num_hidden_layers = 3
559
- cfg.model.layer_scaling_factor = 5.0
560
- cfg.model.dropout_rate = 0.30
561
- cfg.model.gamma = 2.5
562
369
  # Train
563
370
  cfg.train.batch_size = 64
564
- cfg.train.learning_rate = 6e-4
565
- cfg.train.early_stop_gen = 20 # Reduced from 30
566
- cfg.train.min_epochs = 100
567
- cfg.train.max_epochs = 800 # Reduced from 1200
568
- cfg.train.weights_beta = 0.9999
569
- cfg.train.weights_max_ratio = 2.0
570
- # Tuning
571
- cfg.tune.enabled = True
572
- cfg.tune.fast = False
573
- cfg.tune.n_trials = 150
574
- cfg.tune.epochs = 600
575
- cfg.tune.batch_size = 64
576
- cfg.tune.max_samples = 5000 # Capped from 0
577
- cfg.tune.max_loci = 0
578
- cfg.tune.eval_interval = 10
579
- cfg.tune.infer_epochs = 80
580
- cfg.tune.patience = 15 # Reduced from 20
581
- cfg.tune.proxy_metric_batch = 0
582
- # Eval
583
- cfg.evaluate.eval_latent_steps = 50
584
- cfg.evaluate.eval_latent_lr = 1e-2
585
- cfg.evaluate.eval_latent_weight_decay = 0.0
586
-
587
- return cfg
588
-
589
- def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
590
- """Apply flat dot-key overrides."""
591
- if not overrides:
592
- return self
593
-
594
- for k, v in overrides.items():
595
- node = self
596
- parts = k.split(".")
597
- for p in parts[:-1]:
598
- node = getattr(node, p)
599
- last = parts[-1]
600
- if hasattr(node, last):
601
- setattr(node, last, v)
602
- else:
603
- raise KeyError(f"Unknown config key: {k}")
604
- return self
605
-
606
- def to_dict(self) -> Dict[str, Any]:
607
- return asdict(self)
608
-
609
-
610
- @dataclass
611
- class AutoencoderConfig:
612
- """Top-level configuration for ImputeAutoencoder.
613
-
614
- Attributes:
615
- io (IOConfig): I/O configuration.
616
- model (ModelConfig): Model architecture configuration.
617
- train (TrainConfig): Training procedure configuration.
618
- tune (TuneConfig): Hyperparameter tuning configuration.
619
- evaluate (EvalConfig): Evaluation configuration.
620
- plot (PlotConfig): Plotting configuration.
621
- sim (SimConfig): Simulated-missing configuration.
622
- """
623
-
624
- io: IOConfig = field(default_factory=IOConfig)
625
- model: ModelConfig = field(default_factory=ModelConfig)
626
- train: TrainConfig = field(default_factory=TrainConfig)
627
- tune: TuneConfig = field(default_factory=TuneConfig)
628
- evaluate: EvalConfig = field(default_factory=EvalConfig)
629
- plot: PlotConfig = field(default_factory=PlotConfig)
630
- sim: SimConfig = field(default_factory=SimConfig)
631
-
632
- @classmethod
633
- def from_preset(
634
- cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
635
- ) -> "AutoencoderConfig":
636
- """Build a AutoencoderConfig from a named preset."""
637
- if preset not in {"fast", "balanced", "thorough"}:
638
- raise ValueError(f"Unknown preset: {preset}")
639
-
640
- cfg = cls()
641
-
642
- # Common baselines (no latent refinement at eval)
643
- cfg.io.verbose = False
644
- cfg.train.validation_split = 0.20
645
- cfg.model.hidden_activation = "relu"
646
- cfg.model.layer_schedule = "pyramid"
647
- cfg.evaluate.eval_latent_steps = 0
648
- cfg.evaluate.eval_latent_lr = 0.0
649
- cfg.evaluate.eval_latent_weight_decay = 0.0
650
- cfg.sim.simulate_missing = True
651
- cfg.sim.sim_strategy = "random"
652
- cfg.sim.sim_prop = 0.2
653
-
654
- if preset == "fast":
655
- cfg.model.latent_dim = 4
656
- cfg.model.num_hidden_layers = 1
657
- cfg.model.layer_scaling_factor = 2.0
658
- cfg.model.dropout_rate = 0.10
659
- cfg.model.gamma = 1.5
660
- cfg.train.batch_size = 128
661
371
  cfg.train.learning_rate = 1e-3
662
- cfg.train.early_stop_gen = 5
663
- cfg.train.min_epochs = 10
664
- cfg.train.max_epochs = 120
665
- cfg.train.weights_beta = 0.9999
666
- cfg.train.weights_max_ratio = 2.0
667
- cfg.tune.enabled = True
668
- cfg.tune.fast = True
669
- cfg.tune.n_trials = 25
670
- cfg.tune.epochs = 120
671
- cfg.tune.batch_size = 128
672
- cfg.tune.max_samples = 512
673
- cfg.tune.max_loci = 0
674
- cfg.tune.eval_interval = 20
675
- cfg.tune.patience = 5
676
- cfg.tune.proxy_metric_batch = 0
677
- if hasattr(cfg.tune, "infer_epochs"):
678
- cfg.tune.infer_epochs = 0
372
+ cfg.train.early_stop_gen = 25
373
+ cfg.train.max_epochs = 500
374
+ cfg.train.weights_max_ratio = None
679
375
 
680
- elif preset == "balanced":
681
- cfg.model.latent_dim = 8
682
- cfg.model.num_hidden_layers = 2
683
- cfg.model.layer_scaling_factor = 3.0
684
- cfg.model.dropout_rate = 0.20
685
- cfg.model.gamma = 2.0
686
- cfg.train.batch_size = 128
687
- cfg.train.learning_rate = 8e-4
688
- cfg.train.early_stop_gen = 15
689
- cfg.train.min_epochs = 50
690
- cfg.train.max_epochs = 600
691
- cfg.train.weights_beta = 0.9999
692
- cfg.train.weights_max_ratio = 2.0
693
- cfg.tune.enabled = True
694
- cfg.tune.fast = True
695
- cfg.tune.n_trials = 75
696
- cfg.tune.epochs = 300
697
- cfg.tune.batch_size = 128
698
- cfg.tune.max_samples = 2048
699
- cfg.tune.max_loci = 0
700
- cfg.tune.eval_interval = 20
701
- cfg.tune.patience = 10
702
- cfg.tune.proxy_metric_batch = 0
703
- if hasattr(cfg.tune, "infer_epochs"):
704
- cfg.tune.infer_epochs = 0
376
+ # Tune
377
+ cfg.tune.patience = 25
705
378
 
706
379
  else: # thorough
380
+ # Model
707
381
  cfg.model.latent_dim = 16
708
382
  cfg.model.num_hidden_layers = 3
709
- cfg.model.layer_scaling_factor = 5.0
710
383
  cfg.model.dropout_rate = 0.30
711
- cfg.model.gamma = 2.5
384
+
385
+ # Train
712
386
  cfg.train.batch_size = 64
713
- cfg.train.learning_rate = 6e-4
714
- cfg.train.early_stop_gen = 20 # Reduced from 30
715
- cfg.train.min_epochs = 100
716
- cfg.train.max_epochs = 800 # Reduced from 1200
717
- cfg.train.weights_beta = 0.9999
718
- cfg.train.weights_max_ratio = 2.0
719
- cfg.tune.enabled = True
720
- cfg.tune.fast = False
721
- cfg.tune.n_trials = 150
722
- cfg.tune.epochs = 600
723
- cfg.tune.batch_size = 64
724
- cfg.tune.max_samples = 5000 # Capped from 0
725
- cfg.tune.max_loci = 0
726
- cfg.tune.eval_interval = 10
727
- cfg.tune.patience = 15 # Reduced from 20
728
- cfg.tune.proxy_metric_batch = 0
729
- if hasattr(cfg.tune, "infer_epochs"):
730
- cfg.tune.infer_epochs = 0
387
+ cfg.train.learning_rate = 5e-4
388
+ cfg.train.early_stop_gen = 50
389
+ cfg.train.max_epochs = 1000
390
+ cfg.train.weights_max_ratio = None
391
+
392
+ # Tune
393
+ cfg.tune.patience = 50
731
394
 
732
395
  return cfg
733
396
 
734
397
  def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
735
- """Apply flat dot-key overrides."""
398
+ """Apply flat dot-key overrides.
399
+
400
+ Args:
401
+ overrides (Dict[str, Any] | None): Dictionary of overrides with dot-separated keys.
402
+
403
+ Returns:
404
+ AutoencoderConfig: New configuration instance with overrides applied.
405
+ """
736
406
  if not overrides:
737
407
  return self
738
408
  for k, v in overrides.items():
@@ -753,29 +423,22 @@ class AutoencoderConfig:
753
423
 
754
424
  @dataclass
755
425
  class VAEExtraConfig:
756
- """VAE-specific knobs.
757
-
758
- Attributes:
759
- kl_beta (float): Final β for KL divergence term.
760
- kl_warmup (int): Number of epochs with β=0 (warm-up period).
761
- kl_ramp (int): Number of epochs for linear ramp to final β.
762
- """
763
-
764
426
  kl_beta: float = 1.0
765
- kl_warmup: int = 50
766
- kl_ramp: int = 200
427
+ kl_beta_schedule: bool = False
767
428
 
768
429
 
769
430
  @dataclass
770
431
  class VAEConfig:
771
432
  """Top-level configuration for ImputeVAE (AE-parity + VAE extras).
772
433
 
434
+ Mirrors AutoencoderConfig sections and adds a ``vae`` block with KL-beta
435
+ controls for the VAE loss.
436
+
773
437
  Attributes:
774
438
  io (IOConfig): I/O configuration.
775
439
  model (ModelConfig): Model architecture configuration.
776
440
  train (TrainConfig): Training procedure configuration.
777
441
  tune (TuneConfig): Hyperparameter tuning configuration.
778
- evaluate (EvalConfig): Evaluation configuration.
779
442
  plot (PlotConfig): Plotting configuration.
780
443
  vae (VAEExtraConfig): VAE-specific configuration.
781
444
  sim (SimConfig): Simulated-missing configuration.
@@ -783,9 +446,8 @@ class VAEConfig:
783
446
 
784
447
  io: IOConfig = field(default_factory=IOConfig)
785
448
  model: ModelConfig = field(default_factory=ModelConfig)
786
- train: TrainConfig = field(default_factory=TrainConfig)
449
+ train: TrainConfig = field(default_factory=_default_train_config)
787
450
  tune: TuneConfig = field(default_factory=TuneConfig)
788
- evaluate: EvalConfig = field(default_factory=EvalConfig)
789
451
  plot: PlotConfig = field(default_factory=PlotConfig)
790
452
  vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
791
453
  sim: SimConfig = field(default_factory=SimConfig)
@@ -794,107 +456,92 @@ class VAEConfig:
794
456
  def from_preset(
795
457
  cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
796
458
  ) -> "VAEConfig":
797
- """Build a VAEConfig from a named preset."""
459
+ """Build a VAEConfig from a named preset.
460
+
461
+ Args:
462
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
463
+
464
+ Returns:
465
+ VAEConfig: Configuration instance corresponding to the preset.
466
+ """
798
467
  if preset not in {"fast", "balanced", "thorough"}:
799
468
  raise ValueError(f"Unknown preset: {preset}")
800
469
 
801
470
  cfg = cls()
802
471
 
803
- # Common baselines (match AE; no latent refinement at eval)
472
+ # General settings
804
473
  cfg.io.verbose = False
805
- cfg.train.validation_split = 0.20
806
- cfg.model.hidden_activation = "relu"
474
+ cfg.io.ploidy = 2
475
+ cfg.train.validation_split = 0.2
476
+ cfg.model.activation = "relu"
807
477
  cfg.model.layer_schedule = "pyramid"
808
- cfg.evaluate.eval_latent_steps = 0
809
- cfg.evaluate.eval_latent_lr = 0.0
810
- cfg.evaluate.eval_latent_weight_decay = 0.0
478
+ cfg.model.layer_scaling_factor = 2.0
811
479
  cfg.sim.simulate_missing = True
812
480
  cfg.sim.sim_strategy = "random"
813
481
  cfg.sim.sim_prop = 0.2
814
-
815
- # VAE KL schedules, shortened for speed
482
+ cfg.plot.show = True
483
+
484
+ # Train settings
485
+ cfg.train.weights_max_ratio = None
486
+ cfg.train.weights_power = 1.0
487
+ cfg.train.weights_normalize = True
488
+ cfg.train.weights_inverse = False
489
+ cfg.train.gamma = 0.0
490
+ cfg.train.gamma_schedule = False
491
+ cfg.train.min_epochs = 100
492
+
493
+ # VAE-specific
816
494
  cfg.vae.kl_beta = 1.0
817
- cfg.vae.kl_warmup = 25
818
- cfg.vae.kl_ramp = 100
495
+ cfg.vae.kl_beta_schedule = False
496
+
497
+ # Tune
498
+ cfg.tune.enabled = False
499
+ cfg.tune.n_trials = 100
819
500
 
820
501
  if preset == "fast":
502
+ # Model
821
503
  cfg.model.latent_dim = 4
822
- cfg.model.num_hidden_layers = 1
823
- cfg.model.layer_scaling_factor = 2.0
504
+ cfg.model.num_hidden_layers = 2
824
505
  cfg.model.dropout_rate = 0.10
825
- cfg.model.gamma = 1.5
826
- cfg.vae.kl_beta = 0.5 # Lower beta for fast training
506
+
507
+ # Train
827
508
  cfg.train.batch_size = 128
828
- cfg.train.learning_rate = 1e-3
829
- cfg.train.early_stop_gen = 5
830
- cfg.train.min_epochs = 10
831
- cfg.train.max_epochs = 120
832
- cfg.train.weights_beta = 0.9999
833
- cfg.train.weights_max_ratio = 2.0
834
- cfg.tune.enabled = True
835
- cfg.tune.fast = True
836
- cfg.tune.n_trials = 25
837
- cfg.tune.epochs = 120
838
- cfg.tune.batch_size = 128
839
- cfg.tune.max_samples = 512
840
- cfg.tune.max_loci = 0
841
- cfg.tune.eval_interval = 20
842
- cfg.tune.patience = 5
843
- cfg.tune.proxy_metric_batch = 0
844
- if hasattr(cfg.tune, "infer_epochs"):
845
- cfg.tune.infer_epochs = 0
509
+ cfg.train.learning_rate = 2e-3
510
+ cfg.train.early_stop_gen = 15
511
+ cfg.train.max_epochs = 200
512
+
513
+ # Tune
514
+ cfg.tune.patience = 15
846
515
 
847
516
  elif preset == "balanced":
517
+ # Model
848
518
  cfg.model.latent_dim = 8
849
- cfg.model.num_hidden_layers = 2
850
- cfg.model.layer_scaling_factor = 3.0
519
+ cfg.model.num_hidden_layers = 4
851
520
  cfg.model.dropout_rate = 0.20
852
- cfg.model.gamma = 2.0
853
- cfg.train.batch_size = 128
854
- cfg.train.learning_rate = 8e-4
855
- cfg.train.early_stop_gen = 15
856
- cfg.train.min_epochs = 50
857
- cfg.train.max_epochs = 600
858
- cfg.train.weights_beta = 0.9999
859
- cfg.train.weights_max_ratio = 2.0
860
- cfg.tune.enabled = True
861
- cfg.tune.fast = True
862
- cfg.tune.n_trials = 75
863
- cfg.tune.epochs = 300
864
- cfg.tune.batch_size = 128
865
- cfg.tune.max_samples = 2048
866
- cfg.tune.max_loci = 0
867
- cfg.tune.eval_interval = 20
868
- cfg.tune.patience = 10
869
- cfg.tune.proxy_metric_batch = 0
870
- if hasattr(cfg.tune, "infer_epochs"):
871
- cfg.tune.infer_epochs = 0
521
+
522
+ # Train
523
+ cfg.train.batch_size = 64
524
+ cfg.train.learning_rate = 1e-3
525
+ cfg.train.early_stop_gen = 25
526
+ cfg.train.max_epochs = 500
527
+
528
+ # Tune
529
+ cfg.tune.patience = 25
872
530
 
873
531
  else: # thorough
532
+ # Model
874
533
  cfg.model.latent_dim = 16
875
- cfg.model.num_hidden_layers = 3
876
- cfg.model.layer_scaling_factor = 5.0
534
+ cfg.model.num_hidden_layers = 8
877
535
  cfg.model.dropout_rate = 0.30
878
- cfg.model.gamma = 2.5
536
+
537
+ # Train
879
538
  cfg.train.batch_size = 64
880
- cfg.train.learning_rate = 6e-4
881
- cfg.train.early_stop_gen = 20 # Reduced from 30
882
- cfg.train.min_epochs = 100
883
- cfg.train.max_epochs = 800 # Reduced from 1200
884
- cfg.train.weights_beta = 0.9999
885
- cfg.train.weights_max_ratio = 2.0
886
- cfg.tune.enabled = True
887
- cfg.tune.fast = False
888
- cfg.tune.n_trials = 150
889
- cfg.tune.epochs = 600
890
- cfg.tune.batch_size = 64
891
- cfg.tune.max_samples = 5000 # Capped from 0
892
- cfg.tune.max_loci = 0
893
- cfg.tune.eval_interval = 10
894
- cfg.tune.patience = 15 # Reduced from 20
895
- cfg.tune.proxy_metric_batch = 0
896
- if hasattr(cfg.tune, "infer_epochs"):
897
- cfg.tune.infer_epochs = 0
539
+ cfg.train.learning_rate = 5e-4
540
+ cfg.train.early_stop_gen = 50
541
+ cfg.train.max_epochs = 1000
542
+
543
+ # Tune
544
+ cfg.tune.patience = 50
898
545
 
899
546
  return cfg
900
547
 
@@ -923,9 +570,9 @@ class MostFrequentAlgoConfig:
923
570
  """Algorithmic knobs for ImputeMostFrequent.
924
571
 
925
572
  Attributes:
926
- by_populations (bool): Whether to compute per-population modes.
927
- default (int): Fallback mode if no valid entries in a locus.
928
- missing (int): Code for missing genotypes in 0/1/2.
573
+ by_populations (bool): Whether to compute per-population modes. Default is False.
574
+ default (int): Fallback mode if no valid entries in a locus. Default is 0.
575
+ missing (int): Code for missing genotypes in 0/1/2. Default is -1.
929
576
  """
930
577
 
931
578
  by_populations: bool = False
@@ -938,8 +585,8 @@ class DeterministicSplitConfig:
938
585
  """Evaluation split configuration shared by deterministic imputers.
939
586
 
940
587
  Attributes:
941
- test_size (float): Proportion of data to use as the test set.
942
- test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
588
+ test_size (float): Proportion of data to use as the test set. Default is 0.2.
589
+ test_indices (Optional[Sequence[int]]): Specific indices to use as the test set. Default is None.
943
590
  """
944
591
 
945
592
  test_size: float = 0.2
@@ -950,6 +597,10 @@ class DeterministicSplitConfig:
950
597
  class MostFrequentConfig:
951
598
  """Top-level configuration for ImputeMostFrequent.
952
599
 
600
+ Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
601
+ and ``sim``. The ``train`` and ``tune`` sections are retained for schema
602
+ parity with NN models but are not currently used by ImputeMostFrequent.
603
+
953
604
  Attributes:
954
605
  io (IOConfig): I/O configuration.
955
606
  plot (PlotConfig): Plotting configuration.
@@ -966,19 +617,27 @@ class MostFrequentConfig:
966
617
  algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
967
618
  sim: SimConfig = field(default_factory=SimConfig)
968
619
  tune: TuneConfig = field(default_factory=TuneConfig)
969
- train: TrainConfig = field(default_factory=TrainConfig)
620
+ train: TrainConfig = field(default_factory=_default_train_config)
970
621
 
971
622
  @classmethod
972
623
  def from_preset(
973
624
  cls,
974
625
  preset: Literal["fast", "balanced", "thorough"] = "balanced",
975
626
  ) -> "MostFrequentConfig":
976
- """Construct a preset configuration."""
627
+ """Construct a preset configuration.
628
+
629
+ Args:
630
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
631
+
632
+ Returns:
633
+ MostFrequentConfig: Configuration instance corresponding to the preset.
634
+ """
977
635
  if preset not in {"fast", "balanced", "thorough"}:
978
636
  raise ValueError(f"Unknown preset: {preset}")
979
637
 
980
638
  cfg = cls()
981
639
  cfg.io.verbose = False
640
+ cfg.io.ploidy = 2
982
641
  cfg.split.test_size = 0.2
983
642
  cfg.sim.simulate_missing = True
984
643
  cfg.sim.sim_strategy = "random"
@@ -1021,6 +680,10 @@ class RefAlleleAlgoConfig:
1021
680
  class RefAlleleConfig:
1022
681
  """Top-level configuration for ImputeRefAllele.
1023
682
 
683
+ Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
684
+ and ``sim``. The ``train`` and ``tune`` sections are retained for schema
685
+ parity with NN models but are not currently used by ImputeRefAllele.
686
+
1024
687
  Attributes:
1025
688
  io (IOConfig): I/O configuration.
1026
689
  plot (PlotConfig): Plotting configuration.
@@ -1037,18 +700,26 @@ class RefAlleleConfig:
1037
700
  algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
1038
701
  sim: SimConfig = field(default_factory=SimConfig)
1039
702
  tune: TuneConfig = field(default_factory=TuneConfig)
1040
- train: TrainConfig = field(default_factory=TrainConfig)
703
+ train: TrainConfig = field(default_factory=_default_train_config)
1041
704
 
1042
705
  @classmethod
1043
706
  def from_preset(
1044
707
  cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1045
708
  ) -> "RefAlleleConfig":
1046
- """Presets mainly keep parity with logging/IO and split test_size."""
709
+ """Presets mainly keep parity with logging/IO and split test_size.
710
+
711
+ Args:
712
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
713
+
714
+ Returns:
715
+ RefAlleleConfig: Configuration instance corresponding to the preset.
716
+ """
1047
717
  if preset not in {"fast", "balanced", "thorough"}:
1048
718
  raise ValueError(f"Unknown preset: {preset}")
1049
719
 
1050
720
  cfg = cls()
1051
721
  cfg.io.verbose = False
722
+ cfg.io.ploidy = 2
1052
723
  cfg.split.test_size = 0.2
1053
724
  cfg.sim.simulate_missing = True
1054
725
  cfg.sim.sim_strategy = "random"
@@ -1261,16 +932,23 @@ class RFConfig:
1261
932
 
1262
933
  @classmethod
1263
934
  def from_preset(cls, preset: str = "balanced") -> "RFConfig":
1264
- """Build a config from a named preset."""
935
+ """Build a config from a named preset.
936
+
937
+ Args:
938
+ preset (str): Preset name.
939
+
940
+ Returns:
941
+ RFConfig: Configuration instance corresponding to the preset.
942
+ """
1265
943
  cfg = cls()
1266
944
  if preset == "fast":
1267
- cfg.model.n_estimators = 100 # Increased from 50
945
+ cfg.model.n_estimators = 50
1268
946
  cfg.model.max_depth = None
1269
947
  cfg.imputer.max_iter = 5
1270
948
  cfg.io.n_jobs = 1
1271
949
  cfg.tune.enabled = False
1272
950
  elif preset == "balanced":
1273
- cfg.model.n_estimators = 200 # Increased from 100
951
+ cfg.model.n_estimators = 200
1274
952
  cfg.model.max_depth = None
1275
953
  cfg.imputer.max_iter = 10
1276
954
  cfg.io.n_jobs = 1
@@ -1279,7 +957,7 @@ class RFConfig:
1279
957
  elif preset == "thorough":
1280
958
  cfg.model.n_estimators = 500
1281
959
  cfg.model.max_depth = 50 # Added safety cap
1282
- cfg.imputer.max_iter = 15
960
+ cfg.imputer.max_iter = 20
1283
961
  cfg.io.n_jobs = 1
1284
962
  cfg.tune.enabled = False
1285
963
  cfg.tune.n_trials = 250
@@ -1353,18 +1031,25 @@ class HGBConfig:
1353
1031
 
1354
1032
  @classmethod
1355
1033
  def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
1356
- """Build a config from a named preset."""
1034
+ """Build a config from a named preset.
1035
+
1036
+ Args:
1037
+ preset (str): Preset name.
1038
+
1039
+ Returns:
1040
+ HGBConfig: Configuration instance corresponding to the preset.
1041
+ """
1357
1042
  cfg = cls()
1358
1043
  if preset == "fast":
1359
1044
  cfg.model.n_estimators = 50
1360
- cfg.model.learning_rate = 0.15
1045
+ cfg.model.learning_rate = 0.2
1361
1046
  cfg.model.max_depth = None
1362
1047
  cfg.imputer.max_iter = 5
1363
1048
  cfg.io.n_jobs = 1
1364
1049
  cfg.tune.enabled = False
1365
1050
  cfg.tune.n_trials = 50
1366
1051
  elif preset == "balanced":
1367
- cfg.model.n_estimators = 100
1052
+ cfg.model.n_estimators = 150
1368
1053
  cfg.model.learning_rate = 0.1
1369
1054
  cfg.model.max_depth = None
1370
1055
  cfg.imputer.max_iter = 10
@@ -1373,10 +1058,10 @@ class HGBConfig:
1373
1058
  cfg.tune.n_trials = 100
1374
1059
  elif preset == "thorough":
1375
1060
  cfg.model.n_estimators = 500
1376
- cfg.model.learning_rate = 0.05 # Reduced from 0.08
1061
+ cfg.model.learning_rate = 0.05
1377
1062
  cfg.model.n_iter_no_change = 20 # Increased patience
1378
1063
  cfg.model.max_depth = None
1379
- cfg.imputer.max_iter = 15
1064
+ cfg.imputer.max_iter = 20
1380
1065
  cfg.io.n_jobs = 1
1381
1066
  cfg.tune.enabled = False
1382
1067
  cfg.tune.n_trials = 250