pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -121,20 +121,18 @@ class ModelConfig:
121
121
  latent_dim (int): Dimensionality of the latent space.
122
122
  dropout_rate (float): Dropout rate for regularization.
123
123
  num_hidden_layers (int): Number of hidden layers in the neural network.
124
- hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
124
+ activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
125
125
  layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
126
- layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
127
- gamma (float): Parameter for the focal loss function.
126
+ layer_schedule (Literal["pyramid", "linear"]): Schedule for scaling hidden layer sizes.
128
127
  """
129
128
 
130
129
  latent_init: Literal["random", "pca"] = "random"
131
130
  latent_dim: int = 2
132
131
  dropout_rate: float = 0.2
133
132
  num_hidden_layers: int = 2
134
- hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
133
+ activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
135
134
  layer_scaling_factor: float = 5.0
136
- layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
137
- gamma: float = 2.0
135
+ layer_schedule: Literal["pyramid", "linear"] = "pyramid"
138
136
 
139
137
 
140
138
  @dataclass
@@ -144,28 +142,39 @@ class TrainConfig:
144
142
  Attributes:
145
143
  batch_size (int): Number of samples per training batch.
146
144
  learning_rate (float): Learning rate for the optimizer.
147
- lr_input_factor (float): Factor to scale the learning rate for input layer.
148
145
  l1_penalty (float): L1 regularization penalty.
149
146
  early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
150
147
  min_epochs (int): Minimum number of epochs to train.
151
148
  max_epochs (int): Maximum number of epochs to train.
152
149
  validation_split (float): Proportion of data to use for validation.
153
- weights_beta (float): Smoothing factor for class weights.
154
- weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
150
+ weights_max_ratio (float | None): Maximum ratio for class weights to prevent extreme values.
151
+ gamma (float): Focusing parameter for focal loss.
155
152
  device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
156
153
  """
157
154
 
158
- batch_size: int = 32
155
+ batch_size: int = 64
159
156
  learning_rate: float = 1e-3
160
- lr_input_factor: float = 1.0
161
157
  l1_penalty: float = 0.0
162
- early_stop_gen: int = 20
158
+ early_stop_gen: int = 25
163
159
  min_epochs: int = 100
164
- max_epochs: int = 5000
160
+ max_epochs: int = 2000
165
161
  validation_split: float = 0.2
166
- weights_beta: float = 0.9999
167
- weights_max_ratio: float = 1.0
168
162
  device: Literal["gpu", "cpu", "mps"] = "cpu"
163
+ weights_max_ratio: Optional[float] = None
164
+ weights_power: float = 1.0
165
+ weights_normalize: bool = True
166
+ weights_inverse: bool = False
167
+ gamma: float = 0.0
168
+ gamma_schedule: bool = False
169
+
170
+
171
+ def _default_train_config() -> TrainConfig:
172
+ """Typed default factory for TrainConfig (helps some type checkers).
173
+
174
+ Using the class object directly (default_factory=TrainConfig) is valid at runtime but certain type checkers can fail to match dataclasses.field overloads.
175
+ """
176
+
177
+ return TrainConfig()
169
178
 
170
179
 
171
180
  @dataclass
@@ -174,19 +183,13 @@ class TuneConfig:
174
183
 
175
184
  Attributes:
176
185
  enabled (bool): If True, enables hyperparameter tuning.
177
- metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
186
+ metric (Literal["f1", "accuracy", "pr_macro", "average_precision", "roc_auc", "precision", "recall", "mcc", "jaccard"]): Metric to optimize during tuning.
178
187
  n_trials (int): Number of hyperparameter trials to run.
179
188
  resume (bool): If True, resumes tuning from a previous state.
180
189
  save_db (bool): If True, saves the tuning results to a database.
181
- fast (bool): If True, uses a faster but less thorough tuning approach.
182
- max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
183
- max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
184
190
  epochs (int): Number of epochs to train each trial.
185
191
  batch_size (int): Batch size for training during tuning.
186
- eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
187
- infer_epochs (int): Number of epochs for inference during tuning.
188
192
  patience (int): Number of evaluations with no improvement before stopping early.
189
- proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
190
193
  """
191
194
 
192
195
  enabled: bool = False
@@ -198,34 +201,15 @@ class TuneConfig:
198
201
  "roc_auc",
199
202
  "precision",
200
203
  "recall",
204
+ "mcc",
205
+ "jaccard",
201
206
  ] = "f1"
202
207
  n_trials: int = 100
203
208
  resume: bool = False
204
209
  save_db: bool = False
205
- fast: bool = True
206
- max_samples: int = 512
207
- max_loci: int = 0 # 0 = all
208
210
  epochs: int = 500
209
211
  batch_size: int = 64
210
- eval_interval: int = 20
211
- infer_epochs: int = 100
212
212
  patience: int = 10
213
- proxy_metric_batch: int = 0
214
-
215
-
216
- @dataclass
217
- class EvalConfig:
218
- """Evaluation configuration.
219
-
220
- Attributes:
221
- eval_latent_steps (int): Number of optimization steps for latent space evaluation.
222
- eval_latent_lr (float): Learning rate for latent space optimization.
223
- eval_latent_weight_decay (float): Weight decay for latent space optimization.
224
- """
225
-
226
- eval_latent_steps: int = 50
227
- eval_latent_lr: float = 1e-2
228
- eval_latent_weight_decay: float = 0.0
229
213
 
230
214
 
231
215
  @dataclass
@@ -244,15 +228,18 @@ class PlotConfig:
244
228
  dpi: int = 300
245
229
  fontsize: int = 18
246
230
  despine: bool = True
247
- show: bool = False
231
+ show: bool = True
248
232
 
249
233
 
250
234
  @dataclass
251
235
  class IOConfig:
252
236
  """I/O configuration.
253
237
 
238
+ Dataclass that includes configuration settings for file naming, logging verbosity, random seed, and parallelism.
239
+
254
240
  Attributes:
255
241
  prefix (str): Prefix for output files. Default is "pgsui".
242
+ ploidy (int): Ploidy level of the organism. Default is 2.
256
243
  verbose (bool): If True, enables verbose logging. Default is False.
257
244
  debug (bool): If True, enables debug mode. Default is False.
258
245
  seed (int | None): Random seed for reproducibility. Default is None.
@@ -261,6 +248,7 @@ class IOConfig:
261
248
  """
262
249
 
263
250
  prefix: str = "pgsui"
251
+ ploidy: int = 2
264
252
  verbose: bool = False
265
253
  debug: bool = False
266
254
  seed: int | None = None
@@ -287,37 +275,46 @@ class SimConfig:
287
275
  "nonrandom",
288
276
  "nonrandom_weighted",
289
277
  ] = "random"
290
- sim_prop: float = 0.10
278
+ sim_prop: float = 0.20
291
279
  sim_kwargs: dict | None = None
292
280
 
293
281
 
294
282
  @dataclass
295
- class NLPCAConfig:
296
- """Top-level configuration for ImputeNLPCA.
283
+ class AutoencoderConfig:
284
+ """Top-level configuration for ImputeAutoencoder.
285
+
286
+ This configuration class encapsulates all settings required for the
287
+ ImputeAutoencoder model, including I/O, model architecture, training,
288
+ hyperparameter tuning, plotting, and simulated-missing configuration.
297
289
 
298
290
  Attributes:
299
291
  io (IOConfig): I/O configuration.
300
292
  model (ModelConfig): Model architecture configuration.
301
293
  train (TrainConfig): Training procedure configuration.
302
294
  tune (TuneConfig): Hyperparameter tuning configuration.
303
- evaluate (EvalConfig): Evaluation configuration.
304
295
  plot (PlotConfig): Plotting configuration.
305
- sim (SimConfig): Simulation configuration.
296
+ sim (SimConfig): Simulated-missing configuration.
306
297
  """
307
298
 
308
299
  io: IOConfig = field(default_factory=IOConfig)
309
300
  model: ModelConfig = field(default_factory=ModelConfig)
310
- train: TrainConfig = field(default_factory=TrainConfig)
301
+ train: TrainConfig = field(default_factory=_default_train_config)
311
302
  tune: TuneConfig = field(default_factory=TuneConfig)
312
- evaluate: EvalConfig = field(default_factory=EvalConfig)
313
303
  plot: PlotConfig = field(default_factory=PlotConfig)
314
304
  sim: SimConfig = field(default_factory=SimConfig)
315
305
 
316
306
  @classmethod
317
307
  def from_preset(
318
308
  cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
319
- ) -> "NLPCAConfig":
320
- """Build a NLPCAConfig from a named preset."""
309
+ ) -> "AutoencoderConfig":
310
+ """Build a AutoencoderConfig from a named preset.
311
+
312
+ Args:
313
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
314
+
315
+ Returns:
316
+ AutoencoderConfig: Configuration instance corresponding to the preset.
317
+ """
321
318
  if preset not in {"fast", "balanced", "thorough"}:
322
319
  raise ValueError(f"Unknown preset: {preset}")
323
320
 
@@ -325,414 +322,87 @@ class NLPCAConfig:
325
322
 
326
323
  # Common baselines
327
324
  cfg.io.verbose = False
328
- cfg.train.validation_split = 0.20
329
- cfg.model.hidden_activation = "relu"
325
+ cfg.io.ploidy = 2
326
+ cfg.train.validation_split = 0.2
327
+ cfg.model.activation = "relu"
330
328
  cfg.model.layer_schedule = "pyramid"
331
- cfg.model.latent_init = "random"
332
- cfg.evaluate.eval_latent_lr = 1e-2
333
- cfg.evaluate.eval_latent_weight_decay = 0.0
334
- cfg.sim.simulate_missing = True
329
+ cfg.model.layer_scaling_factor = 2.0
335
330
  cfg.sim.sim_strategy = "random"
336
331
  cfg.sim.sim_prop = 0.2
332
+ cfg.plot.show = True
333
+
334
+ # Train settings
335
+ cfg.train.weights_max_ratio = None
336
+ cfg.train.weights_power = 1.0
337
+ cfg.train.weights_normalize = True
338
+ cfg.train.weights_inverse = False
339
+ cfg.train.gamma = 0.0
340
+ cfg.train.gamma_schedule = False
341
+ cfg.train.min_epochs = 100
342
+
343
+ # Tune
344
+ cfg.tune.enabled = False
345
+ cfg.tune.n_trials = 100
337
346
 
338
347
  if preset == "fast":
339
348
  # Model
340
349
  cfg.model.latent_dim = 4
341
350
  cfg.model.num_hidden_layers = 1
342
- cfg.model.layer_scaling_factor = 2.0
343
351
  cfg.model.dropout_rate = 0.10
344
- cfg.model.gamma = 1.5
345
- # Train
346
- cfg.train.batch_size = 256
347
- cfg.train.learning_rate = 2e-3
348
- cfg.train.early_stop_gen = 5
349
- cfg.train.min_epochs = 10
350
- cfg.train.max_epochs = 150
351
- cfg.train.weights_beta = 0.999
352
- cfg.train.weights_max_ratio = 5.0
353
- # Tuning
354
- cfg.tune.enabled = True
355
- cfg.tune.fast = True
356
- cfg.tune.n_trials = 20
357
- cfg.tune.epochs = 150
358
- cfg.tune.batch_size = 256
359
- cfg.tune.max_samples = 512
360
- cfg.tune.max_loci = 0
361
- cfg.tune.eval_interval = 20
362
- cfg.tune.infer_epochs = 20
363
- cfg.tune.patience = 5
364
- cfg.tune.proxy_metric_batch = 0
365
- # Eval
366
- cfg.evaluate.eval_latent_steps = 20
367
352
 
368
- elif preset == "balanced":
369
- # Model
370
- cfg.model.latent_dim = 8
371
- cfg.model.num_hidden_layers = 2
372
- cfg.model.layer_scaling_factor = 3.0
373
- cfg.model.dropout_rate = 0.20
374
- cfg.model.gamma = 2.0
375
353
  # Train
376
354
  cfg.train.batch_size = 128
377
- cfg.train.learning_rate = 1e-3
355
+ cfg.train.learning_rate = 2e-3
378
356
  cfg.train.early_stop_gen = 15
379
- cfg.train.min_epochs = 50
380
- cfg.train.max_epochs = 600
381
- cfg.train.weights_beta = 0.9999
382
- cfg.train.weights_max_ratio = 5.0
383
- # Tuning
384
- cfg.tune.enabled = True
385
- cfg.tune.fast = False
386
- cfg.tune.n_trials = 60
387
- cfg.tune.epochs = 200
388
- cfg.tune.batch_size = 128
389
- cfg.tune.max_samples = 2048
390
- cfg.tune.max_loci = 0
391
- cfg.tune.eval_interval = 10
392
- cfg.tune.infer_epochs = 50
393
- cfg.tune.patience = 10
394
- cfg.tune.proxy_metric_batch = 0
395
- # Eval
396
- cfg.evaluate.eval_latent_steps = 40
397
-
398
- else: # thorough
399
- # Model
400
- cfg.model.latent_dim = 16
401
- cfg.model.num_hidden_layers = 3
402
- cfg.model.layer_scaling_factor = 5.0
403
- cfg.model.dropout_rate = 0.30
404
- cfg.model.gamma = 2.5
405
- # Train
406
- cfg.train.batch_size = 64
407
- cfg.train.learning_rate = 5e-4
408
- cfg.train.early_stop_gen = 30
409
- cfg.train.min_epochs = 100
410
- cfg.train.max_epochs = 2000
411
- cfg.train.weights_beta = 0.9999
412
- cfg.train.weights_max_ratio = 5.0
413
- # Tuning
414
- cfg.tune.enabled = True
415
- cfg.tune.fast = False # Full search
416
- cfg.tune.n_trials = 100
417
- cfg.tune.epochs = 600
418
- cfg.tune.batch_size = 64
419
- cfg.tune.max_samples = 0 # No limit
420
- cfg.tune.max_loci = 0
421
- cfg.tune.eval_interval = 10
422
- cfg.tune.infer_epochs = 80
423
- cfg.tune.patience = 20
424
- cfg.tune.proxy_metric_batch = 0
425
- # Eval
426
- cfg.evaluate.eval_latent_steps = 100
427
-
428
- return cfg
429
-
430
- def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
431
- """Apply flat dot-key overrides."""
432
- if not overrides:
433
- return self
434
- for k, v in overrides.items():
435
- node = self
436
- parts = k.split(".")
437
- for p in parts[:-1]:
438
- node = getattr(node, p)
439
- last = parts[-1]
440
- if hasattr(node, last):
441
- setattr(node, last, v)
442
- else:
443
- raise KeyError(f"Unknown config key: {k}")
444
- return self
445
-
446
- def to_dict(self) -> Dict[str, Any]:
447
- return asdict(self)
448
-
449
-
450
- @dataclass
451
- class UBPConfig:
452
- """Top-level configuration for ImputeUBP.
453
-
454
- Attributes:
455
- io (IOConfig): I/O configuration.
456
- model (ModelConfig): Model architecture configuration.
457
- train (TrainConfig): Training procedure configuration.
458
- tune (TuneConfig): Hyperparameter tuning configuration.
459
- evaluate (EvalConfig): Evaluation configuration.
460
- plot (PlotConfig): Plotting configuration.
461
- sim (SimConfig): Simulated-missing configuration.
462
- """
357
+ cfg.train.max_epochs = 200
358
+ cfg.train.weights_max_ratio = None
463
359
 
464
- io: IOConfig = field(default_factory=IOConfig)
465
- model: ModelConfig = field(default_factory=ModelConfig)
466
- train: TrainConfig = field(default_factory=TrainConfig)
467
- tune: TuneConfig = field(default_factory=TuneConfig)
468
- evaluate: EvalConfig = field(default_factory=EvalConfig)
469
- plot: PlotConfig = field(default_factory=PlotConfig)
470
- sim: SimConfig = field(default_factory=SimConfig)
471
-
472
- @classmethod
473
- def from_preset(
474
- cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
475
- ) -> "UBPConfig":
476
- """Build a UBPConfig from a named preset."""
477
- if preset not in {"fast", "balanced", "thorough"}:
478
- raise ValueError(f"Unknown preset: {preset}")
479
-
480
- cfg = cls()
481
-
482
- # Common baselines
483
- cfg.io.verbose = False
484
- cfg.model.hidden_activation = "relu"
485
- cfg.model.layer_schedule = "pyramid"
486
- cfg.model.latent_init = "random"
487
- cfg.sim.simulate_missing = True
488
- cfg.sim.sim_strategy = "random"
489
- cfg.sim.sim_prop = 0.2
490
-
491
- if preset == "fast":
492
- # Model
493
- cfg.model.latent_dim = 4
494
- cfg.model.num_hidden_layers = 1
495
- cfg.model.layer_scaling_factor = 2.0
496
- cfg.model.dropout_rate = 0.10
497
- cfg.model.gamma = 1.5
498
- # Train
499
- cfg.train.batch_size = 256
500
- cfg.train.learning_rate = 2e-3
501
- cfg.train.early_stop_gen = 5
502
- cfg.train.min_epochs = 10
503
- cfg.train.max_epochs = 150
504
- cfg.train.weights_beta = 0.999
505
- cfg.train.weights_max_ratio = 5.0
506
- # Tuning
507
- cfg.tune.enabled = True
508
- cfg.tune.fast = True
509
- cfg.tune.n_trials = 20
510
- cfg.tune.epochs = 150
511
- cfg.tune.batch_size = 256
512
- cfg.tune.max_samples = 512
513
- cfg.tune.max_loci = 0
514
- cfg.tune.eval_interval = 20
515
- cfg.tune.infer_epochs = 20
516
- cfg.tune.patience = 5
517
- cfg.tune.proxy_metric_batch = 0
518
- # Eval
519
- cfg.evaluate.eval_latent_steps = 20
520
- cfg.evaluate.eval_latent_lr = 1e-2
521
- cfg.evaluate.eval_latent_weight_decay = 0.0
360
+ # Tune
361
+ cfg.tune.patience = 15
522
362
 
523
363
  elif preset == "balanced":
524
364
  # Model
525
365
  cfg.model.latent_dim = 8
526
366
  cfg.model.num_hidden_layers = 2
527
- cfg.model.layer_scaling_factor = 3.0
528
367
  cfg.model.dropout_rate = 0.20
529
- cfg.model.gamma = 2.0
368
+
530
369
  # Train
531
- cfg.train.batch_size = 128
370
+ cfg.train.batch_size = 64
532
371
  cfg.train.learning_rate = 1e-3
533
- cfg.train.early_stop_gen = 15
534
- cfg.train.min_epochs = 50
535
- cfg.train.max_epochs = 600
536
- cfg.train.weights_beta = 0.9999
537
- cfg.train.weights_max_ratio = 5.0
538
- # Tuning
539
- cfg.tune.enabled = True
540
- cfg.tune.fast = False
541
- cfg.tune.n_trials = 60
542
- cfg.tune.epochs = 200
543
- cfg.tune.batch_size = 128
544
- cfg.tune.max_samples = 2048
545
- cfg.tune.max_loci = 0
546
- cfg.tune.eval_interval = 10
547
- cfg.tune.infer_epochs = 50
548
- cfg.tune.patience = 10
549
- cfg.tune.proxy_metric_batch = 0
550
- # Eval
551
- cfg.evaluate.eval_latent_steps = 40
552
- cfg.evaluate.eval_latent_lr = 1e-2
553
- cfg.evaluate.eval_latent_weight_decay = 0.0
372
+ cfg.train.early_stop_gen = 25
373
+ cfg.train.max_epochs = 500
374
+ cfg.train.weights_max_ratio = None
375
+
376
+ # Tune
377
+ cfg.tune.patience = 25
554
378
 
555
379
  else: # thorough
556
380
  # Model
557
381
  cfg.model.latent_dim = 16
558
382
  cfg.model.num_hidden_layers = 3
559
- cfg.model.layer_scaling_factor = 5.0
560
383
  cfg.model.dropout_rate = 0.30
561
- cfg.model.gamma = 2.5
384
+
562
385
  # Train
563
386
  cfg.train.batch_size = 64
564
387
  cfg.train.learning_rate = 5e-4
565
- cfg.train.early_stop_gen = 30
566
- cfg.train.min_epochs = 100
567
- cfg.train.max_epochs = 2000
568
- cfg.train.weights_beta = 0.9999
569
- cfg.train.weights_max_ratio = 5.0
570
- # Tuning
571
- cfg.tune.enabled = True
572
- cfg.tune.fast = False
573
- cfg.tune.n_trials = 100
574
- cfg.tune.epochs = 600
575
- cfg.tune.batch_size = 64
576
- cfg.tune.max_samples = 0
577
- cfg.tune.max_loci = 0
578
- cfg.tune.eval_interval = 10
579
- cfg.tune.infer_epochs = 80
580
- cfg.tune.patience = 20
581
- cfg.tune.proxy_metric_batch = 0
582
- # Eval
583
- cfg.evaluate.eval_latent_steps = 100
584
- cfg.evaluate.eval_latent_lr = 1e-2
585
- cfg.evaluate.eval_latent_weight_decay = 0.0
586
-
587
- return cfg
588
-
589
- def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
590
- """Apply flat dot-key overrides."""
591
- if not overrides:
592
- return self
593
-
594
- for k, v in overrides.items():
595
- node = self
596
- parts = k.split(".")
597
- for p in parts[:-1]:
598
- node = getattr(node, p)
599
- last = parts[-1]
600
- if hasattr(node, last):
601
- setattr(node, last, v)
602
- else:
603
- raise KeyError(f"Unknown config key: {k}")
604
- return self
605
-
606
- def to_dict(self) -> Dict[str, Any]:
607
- return asdict(self)
608
-
609
-
610
- @dataclass
611
- class AutoencoderConfig:
612
- """Top-level configuration for ImputeAutoencoder.
613
-
614
- Attributes:
615
- io (IOConfig): I/O configuration.
616
- model (ModelConfig): Model architecture configuration.
617
- train (TrainConfig): Training procedure configuration.
618
- tune (TuneConfig): Hyperparameter tuning configuration.
619
- evaluate (EvalConfig): Evaluation configuration.
620
- plot (PlotConfig): Plotting configuration.
621
- sim (SimConfig): Simulated-missing configuration.
622
- """
623
-
624
- io: IOConfig = field(default_factory=IOConfig)
625
- model: ModelConfig = field(default_factory=ModelConfig)
626
- train: TrainConfig = field(default_factory=TrainConfig)
627
- tune: TuneConfig = field(default_factory=TuneConfig)
628
- evaluate: EvalConfig = field(default_factory=EvalConfig)
629
- plot: PlotConfig = field(default_factory=PlotConfig)
630
- sim: SimConfig = field(default_factory=SimConfig)
631
-
632
- @classmethod
633
- def from_preset(
634
- cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
635
- ) -> "AutoencoderConfig":
636
- """Build a AutoencoderConfig from a named preset."""
637
- if preset not in {"fast", "balanced", "thorough"}:
638
- raise ValueError(f"Unknown preset: {preset}")
639
-
640
- cfg = cls()
641
-
642
- # Common baselines (no latent refinement at eval)
643
- cfg.io.verbose = False
644
- cfg.train.validation_split = 0.20
645
- cfg.model.hidden_activation = "relu"
646
- cfg.model.layer_schedule = "pyramid"
647
- cfg.evaluate.eval_latent_steps = 0
648
- cfg.evaluate.eval_latent_lr = 0.0
649
- cfg.evaluate.eval_latent_weight_decay = 0.0
650
- cfg.sim.simulate_missing = True
651
- cfg.sim.sim_strategy = "random"
652
- cfg.sim.sim_prop = 0.2
653
-
654
- if preset == "fast":
655
- cfg.model.latent_dim = 4
656
- cfg.model.num_hidden_layers = 1
657
- cfg.model.layer_scaling_factor = 2.0
658
- cfg.model.dropout_rate = 0.10
659
- cfg.model.gamma = 1.5
660
- cfg.train.batch_size = 256
661
- cfg.train.learning_rate = 2e-3
662
- cfg.train.early_stop_gen = 5
663
- cfg.train.min_epochs = 10
664
- cfg.train.max_epochs = 150
665
- cfg.train.weights_beta = 0.999
666
- cfg.train.weights_max_ratio = 5.0
667
- cfg.tune.enabled = True
668
- cfg.tune.fast = True
669
- cfg.tune.n_trials = 20
670
- cfg.tune.epochs = 150
671
- cfg.tune.batch_size = 256
672
- cfg.tune.max_samples = 512
673
- cfg.tune.max_loci = 0
674
- cfg.tune.eval_interval = 20
675
- cfg.tune.patience = 5
676
- cfg.tune.proxy_metric_batch = 0
677
- if hasattr(cfg.tune, "infer_epochs"):
678
- cfg.tune.infer_epochs = 0
388
+ cfg.train.early_stop_gen = 50
389
+ cfg.train.max_epochs = 1000
390
+ cfg.train.weights_max_ratio = None
679
391
 
680
- elif preset == "balanced":
681
- cfg.model.latent_dim = 8
682
- cfg.model.num_hidden_layers = 2
683
- cfg.model.layer_scaling_factor = 3.0
684
- cfg.model.dropout_rate = 0.20
685
- cfg.model.gamma = 2.0
686
- cfg.train.batch_size = 128
687
- cfg.train.learning_rate = 1e-3
688
- cfg.train.early_stop_gen = 15
689
- cfg.train.min_epochs = 50
690
- cfg.train.max_epochs = 600
691
- cfg.train.weights_beta = 0.9999
692
- cfg.train.weights_max_ratio = 5.0
693
- cfg.tune.enabled = True
694
- cfg.tune.fast = False
695
- cfg.tune.n_trials = 60
696
- cfg.tune.epochs = 200
697
- cfg.tune.batch_size = 128
698
- cfg.tune.max_samples = 2048
699
- cfg.tune.max_loci = 0
700
- cfg.tune.eval_interval = 10
701
- cfg.tune.patience = 10
702
- cfg.tune.proxy_metric_batch = 0
703
- if hasattr(cfg.tune, "infer_epochs"):
704
- cfg.tune.infer_epochs = 0
705
-
706
- else: # thorough
707
- cfg.model.latent_dim = 16
708
- cfg.model.num_hidden_layers = 3
709
- cfg.model.layer_scaling_factor = 5.0
710
- cfg.model.dropout_rate = 0.30
711
- cfg.model.gamma = 2.5
712
- cfg.train.batch_size = 64
713
- cfg.train.learning_rate = 5e-4
714
- cfg.train.early_stop_gen = 30
715
- cfg.train.min_epochs = 100
716
- cfg.train.max_epochs = 2000
717
- cfg.train.weights_beta = 0.9999
718
- cfg.train.weights_max_ratio = 5.0
719
- cfg.tune.enabled = True
720
- cfg.tune.fast = False
721
- cfg.tune.n_trials = 100
722
- cfg.tune.epochs = 600
723
- cfg.tune.batch_size = 64
724
- cfg.tune.max_samples = 0
725
- cfg.tune.max_loci = 0
726
- cfg.tune.eval_interval = 10
727
- cfg.tune.patience = 20
728
- cfg.tune.proxy_metric_batch = 0
729
- if hasattr(cfg.tune, "infer_epochs"):
730
- cfg.tune.infer_epochs = 0
392
+ # Tune
393
+ cfg.tune.patience = 50
731
394
 
732
395
  return cfg
733
396
 
734
397
  def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
735
- """Apply flat dot-key overrides."""
398
+ """Apply flat dot-key overrides.
399
+
400
+ Args:
401
+ overrides (Dict[str, Any] | None): Dictionary of overrides with dot-separated keys.
402
+
403
+ Returns:
404
+ AutoencoderConfig: New configuration instance with overrides applied.
405
+ """
736
406
  if not overrides:
737
407
  return self
738
408
  for k, v in overrides.items():
@@ -753,29 +423,22 @@ class AutoencoderConfig:
753
423
 
754
424
  @dataclass
755
425
  class VAEExtraConfig:
756
- """VAE-specific knobs.
757
-
758
- Attributes:
759
- kl_beta (float): Final β for KL divergence term.
760
- kl_warmup (int): Number of epochs with β=0 (warm-up period).
761
- kl_ramp (int): Number of epochs for linear ramp to final β.
762
- """
763
-
764
426
  kl_beta: float = 1.0
765
- kl_warmup: int = 50
766
- kl_ramp: int = 200
427
+ kl_beta_schedule: bool = False
767
428
 
768
429
 
769
430
  @dataclass
770
431
  class VAEConfig:
771
432
  """Top-level configuration for ImputeVAE (AE-parity + VAE extras).
772
433
 
434
+ Mirrors AutoencoderConfig sections and adds a ``vae`` block with KL-beta
435
+ controls for the VAE loss.
436
+
773
437
  Attributes:
774
438
  io (IOConfig): I/O configuration.
775
439
  model (ModelConfig): Model architecture configuration.
776
440
  train (TrainConfig): Training procedure configuration.
777
441
  tune (TuneConfig): Hyperparameter tuning configuration.
778
- evaluate (EvalConfig): Evaluation configuration.
779
442
  plot (PlotConfig): Plotting configuration.
780
443
  vae (VAEExtraConfig): VAE-specific configuration.
781
444
  sim (SimConfig): Simulated-missing configuration.
@@ -783,9 +446,8 @@ class VAEConfig:
783
446
 
784
447
  io: IOConfig = field(default_factory=IOConfig)
785
448
  model: ModelConfig = field(default_factory=ModelConfig)
786
- train: TrainConfig = field(default_factory=TrainConfig)
449
+ train: TrainConfig = field(default_factory=_default_train_config)
787
450
  tune: TuneConfig = field(default_factory=TuneConfig)
788
- evaluate: EvalConfig = field(default_factory=EvalConfig)
789
451
  plot: PlotConfig = field(default_factory=PlotConfig)
790
452
  vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
791
453
  sim: SimConfig = field(default_factory=SimConfig)
@@ -794,119 +456,92 @@ class VAEConfig:
794
456
  def from_preset(
795
457
  cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
796
458
  ) -> "VAEConfig":
797
- """Build a VAEConfig from a named preset."""
459
+ """Build a VAEConfig from a named preset.
460
+
461
+ Args:
462
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
463
+
464
+ Returns:
465
+ VAEConfig: Configuration instance corresponding to the preset.
466
+ """
798
467
  if preset not in {"fast", "balanced", "thorough"}:
799
468
  raise ValueError(f"Unknown preset: {preset}")
800
469
 
801
470
  cfg = cls()
802
471
 
803
- # Common baselines (match AE; no latent refinement at eval)
472
+ # General settings
804
473
  cfg.io.verbose = False
805
- cfg.train.validation_split = 0.20
806
- cfg.model.hidden_activation = "relu"
474
+ cfg.io.ploidy = 2
475
+ cfg.train.validation_split = 0.2
476
+ cfg.model.activation = "relu"
807
477
  cfg.model.layer_schedule = "pyramid"
808
- cfg.evaluate.eval_latent_steps = 0
809
- cfg.evaluate.eval_latent_lr = 0.0
810
- cfg.evaluate.eval_latent_weight_decay = 0.0
478
+ cfg.model.layer_scaling_factor = 2.0
811
479
  cfg.sim.simulate_missing = True
812
480
  cfg.sim.sim_strategy = "random"
813
481
  cfg.sim.sim_prop = 0.2
482
+ cfg.plot.show = True
483
+
484
+ # Train settings
485
+ cfg.train.weights_max_ratio = None
486
+ cfg.train.weights_power = 1.0
487
+ cfg.train.weights_normalize = True
488
+ cfg.train.weights_inverse = False
489
+ cfg.train.gamma = 0.0
490
+ cfg.train.gamma_schedule = False
491
+ cfg.train.min_epochs = 100
492
+
493
+ # VAE-specific
494
+ cfg.vae.kl_beta = 1.0
495
+ cfg.vae.kl_beta_schedule = False
496
+
497
+ # Tune
498
+ cfg.tune.enabled = False
499
+ cfg.tune.n_trials = 100
814
500
 
815
501
  if preset == "fast":
502
+ # Model
816
503
  cfg.model.latent_dim = 4
817
- cfg.model.num_hidden_layers = 1
818
- cfg.model.layer_scaling_factor = 2.0
504
+ cfg.model.num_hidden_layers = 2
819
505
  cfg.model.dropout_rate = 0.10
820
- cfg.model.gamma = 1.5
821
- # VAE specifics
822
- cfg.vae.kl_beta = 0.5
823
- cfg.vae.kl_warmup = 10
824
- cfg.vae.kl_ramp = 40
506
+
825
507
  # Train
826
- cfg.train.batch_size = 256
508
+ cfg.train.batch_size = 128
827
509
  cfg.train.learning_rate = 2e-3
828
- cfg.train.early_stop_gen = 5
829
- cfg.train.min_epochs = 10
830
- cfg.train.max_epochs = 150
831
- cfg.train.weights_beta = 0.999
832
- cfg.train.weights_max_ratio = 5.0
510
+ cfg.train.early_stop_gen = 15
511
+ cfg.train.max_epochs = 200
512
+
833
513
  # Tune
834
- cfg.tune.enabled = True
835
- cfg.tune.fast = True
836
- cfg.tune.n_trials = 20
837
- cfg.tune.epochs = 150
838
- cfg.tune.batch_size = 256
839
- cfg.tune.max_samples = 512
840
- cfg.tune.max_loci = 0
841
- cfg.tune.eval_interval = 20
842
- cfg.tune.patience = 5
843
- cfg.tune.proxy_metric_batch = 0
844
- if hasattr(cfg.tune, "infer_epochs"):
845
- cfg.tune.infer_epochs = 0
514
+ cfg.tune.patience = 15
846
515
 
847
516
  elif preset == "balanced":
517
+ # Model
848
518
  cfg.model.latent_dim = 8
849
- cfg.model.num_hidden_layers = 2
850
- cfg.model.layer_scaling_factor = 3.0
519
+ cfg.model.num_hidden_layers = 4
851
520
  cfg.model.dropout_rate = 0.20
852
- cfg.model.gamma = 2.0
853
- # VAE specifics
854
- cfg.vae.kl_beta = 1.0
855
- cfg.vae.kl_warmup = 50
856
- cfg.vae.kl_ramp = 150
521
+
857
522
  # Train
858
- cfg.train.batch_size = 128
523
+ cfg.train.batch_size = 64
859
524
  cfg.train.learning_rate = 1e-3
860
- cfg.train.early_stop_gen = 15
861
- cfg.train.min_epochs = 50
862
- cfg.train.max_epochs = 600
863
- cfg.train.weights_beta = 0.9999
864
- cfg.train.weights_max_ratio = 5.0
525
+ cfg.train.early_stop_gen = 25
526
+ cfg.train.max_epochs = 500
527
+
865
528
  # Tune
866
- cfg.tune.enabled = True
867
- cfg.tune.fast = False
868
- cfg.tune.n_trials = 60
869
- cfg.tune.epochs = 200
870
- cfg.tune.batch_size = 128
871
- cfg.tune.max_samples = 2048
872
- cfg.tune.max_loci = 0
873
- cfg.tune.eval_interval = 10
874
- cfg.tune.patience = 10
875
- cfg.tune.proxy_metric_batch = 0
876
- if hasattr(cfg.tune, "infer_epochs"):
877
- cfg.tune.infer_epochs = 0
529
+ cfg.tune.patience = 25
878
530
 
879
531
  else: # thorough
532
+ # Model
880
533
  cfg.model.latent_dim = 16
881
- cfg.model.num_hidden_layers = 3
882
- cfg.model.layer_scaling_factor = 5.0
534
+ cfg.model.num_hidden_layers = 8
883
535
  cfg.model.dropout_rate = 0.30
884
- cfg.model.gamma = 2.5
885
- # VAE specifics
886
- cfg.vae.kl_beta = 1.0
887
- cfg.vae.kl_warmup = 100
888
- cfg.vae.kl_ramp = 400
536
+
889
537
  # Train
890
538
  cfg.train.batch_size = 64
891
539
  cfg.train.learning_rate = 5e-4
892
- cfg.train.early_stop_gen = 30
893
- cfg.train.min_epochs = 100
894
- cfg.train.max_epochs = 2000
895
- cfg.train.weights_beta = 0.9999
896
- cfg.train.weights_max_ratio = 5.0
540
+ cfg.train.early_stop_gen = 50
541
+ cfg.train.max_epochs = 1000
542
+
897
543
  # Tune
898
- cfg.tune.enabled = True
899
- cfg.tune.fast = False
900
- cfg.tune.n_trials = 100
901
- cfg.tune.epochs = 600
902
- cfg.tune.batch_size = 64
903
- cfg.tune.max_samples = 0
904
- cfg.tune.max_loci = 0
905
- cfg.tune.eval_interval = 10
906
- cfg.tune.patience = 20
907
- cfg.tune.proxy_metric_batch = 0
908
- if hasattr(cfg.tune, "infer_epochs"):
909
- cfg.tune.infer_epochs = 0
544
+ cfg.tune.patience = 50
910
545
 
911
546
  return cfg
912
547
 
@@ -935,9 +570,9 @@ class MostFrequentAlgoConfig:
935
570
  """Algorithmic knobs for ImputeMostFrequent.
936
571
 
937
572
  Attributes:
938
- by_populations (bool): Whether to compute per-population modes.
939
- default (int): Fallback mode if no valid entries in a locus.
940
- missing (int): Code for missing genotypes in 0/1/2.
573
+ by_populations (bool): Whether to compute per-population modes. Default is False.
574
+ default (int): Fallback mode if no valid entries in a locus. Default is 0.
575
+ missing (int): Code for missing genotypes in 0/1/2. Default is -1.
941
576
  """
942
577
 
943
578
  by_populations: bool = False
@@ -950,8 +585,8 @@ class DeterministicSplitConfig:
950
585
  """Evaluation split configuration shared by deterministic imputers.
951
586
 
952
587
  Attributes:
953
- test_size (float): Proportion of data to use as the test set.
954
- test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
588
+ test_size (float): Proportion of data to use as the test set. Default is 0.2.
589
+ test_indices (Optional[Sequence[int]]): Specific indices to use as the test set. Default is None.
955
590
  """
956
591
 
957
592
  test_size: float = 0.2
@@ -962,6 +597,10 @@ class DeterministicSplitConfig:
962
597
  class MostFrequentConfig:
963
598
  """Top-level configuration for ImputeMostFrequent.
964
599
 
600
+ Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
601
+ and ``sim``. The ``train`` and ``tune`` sections are retained for schema
602
+ parity with NN models but are not currently used by ImputeMostFrequent.
603
+
965
604
  Attributes:
966
605
  io (IOConfig): I/O configuration.
967
606
  plot (PlotConfig): Plotting configuration.
@@ -978,19 +617,27 @@ class MostFrequentConfig:
978
617
  algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
979
618
  sim: SimConfig = field(default_factory=SimConfig)
980
619
  tune: TuneConfig = field(default_factory=TuneConfig)
981
- train: TrainConfig = field(default_factory=TrainConfig)
620
+ train: TrainConfig = field(default_factory=_default_train_config)
982
621
 
983
622
  @classmethod
984
623
  def from_preset(
985
624
  cls,
986
625
  preset: Literal["fast", "balanced", "thorough"] = "balanced",
987
626
  ) -> "MostFrequentConfig":
988
- """Construct a preset configuration."""
627
+ """Construct a preset configuration.
628
+
629
+ Args:
630
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
631
+
632
+ Returns:
633
+ MostFrequentConfig: Configuration instance corresponding to the preset.
634
+ """
989
635
  if preset not in {"fast", "balanced", "thorough"}:
990
636
  raise ValueError(f"Unknown preset: {preset}")
991
637
 
992
638
  cfg = cls()
993
639
  cfg.io.verbose = False
640
+ cfg.io.ploidy = 2
994
641
  cfg.split.test_size = 0.2
995
642
  cfg.sim.simulate_missing = True
996
643
  cfg.sim.sim_strategy = "random"
@@ -1033,6 +680,10 @@ class RefAlleleAlgoConfig:
1033
680
  class RefAlleleConfig:
1034
681
  """Top-level configuration for ImputeRefAllele.
1035
682
 
683
+ Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
684
+ and ``sim``. The ``train`` and ``tune`` sections are retained for schema
685
+ parity with NN models but are not currently used by ImputeRefAllele.
686
+
1036
687
  Attributes:
1037
688
  io (IOConfig): I/O configuration.
1038
689
  plot (PlotConfig): Plotting configuration.
@@ -1049,18 +700,26 @@ class RefAlleleConfig:
1049
700
  algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
1050
701
  sim: SimConfig = field(default_factory=SimConfig)
1051
702
  tune: TuneConfig = field(default_factory=TuneConfig)
1052
- train: TrainConfig = field(default_factory=TrainConfig)
703
+ train: TrainConfig = field(default_factory=_default_train_config)
1053
704
 
1054
705
  @classmethod
1055
706
  def from_preset(
1056
707
  cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1057
708
  ) -> "RefAlleleConfig":
1058
- """Presets mainly keep parity with logging/IO and split test_size."""
709
+ """Presets mainly keep parity with logging/IO and split test_size.
710
+
711
+ Args:
712
+ preset (Literal["fast", "balanced", "thorough"]): Preset name.
713
+
714
+ Returns:
715
+ RefAlleleConfig: Configuration instance corresponding to the preset.
716
+ """
1059
717
  if preset not in {"fast", "balanced", "thorough"}:
1060
718
  raise ValueError(f"Unknown preset: {preset}")
1061
719
 
1062
720
  cfg = cls()
1063
721
  cfg.io.verbose = False
722
+ cfg.io.ploidy = 2
1064
723
  cfg.split.test_size = 0.2
1065
724
  cfg.sim.simulate_missing = True
1066
725
  cfg.sim.sim_strategy = "random"
@@ -1273,7 +932,14 @@ class RFConfig:
1273
932
 
1274
933
  @classmethod
1275
934
  def from_preset(cls, preset: str = "balanced") -> "RFConfig":
1276
- """Build a config from a named preset."""
935
+ """Build a config from a named preset.
936
+
937
+ Args:
938
+ preset (str): Preset name.
939
+
940
+ Returns:
941
+ RFConfig: Configuration instance corresponding to the preset.
942
+ """
1277
943
  cfg = cls()
1278
944
  if preset == "fast":
1279
945
  cfg.model.n_estimators = 50
@@ -1365,7 +1031,14 @@ class HGBConfig:
1365
1031
 
1366
1032
  @classmethod
1367
1033
  def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
1368
- """Build a config from a named preset."""
1034
+ """Build a config from a named preset.
1035
+
1036
+ Args:
1037
+ preset (str): Preset name.
1038
+
1039
+ Returns:
1040
+ HGBConfig: Configuration instance corresponding to the preset.
1041
+ """
1369
1042
  cfg = cls()
1370
1043
  if preset == "fast":
1371
1044
  cfg.model.n_estimators = 50