pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,1782 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Any, Dict, Literal, Optional, Sequence
5
+
6
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
7
+
8
+
9
+ @dataclass
10
+ class _SimParams:
11
+ """Container for simulation hyperparameters.
12
+
13
+ This class holds the hyperparameters for the simulation process, including the proportion of missing values, the imputation strategy, and other relevant settings.
14
+
15
+ Attributes:
16
+ prop_missing (float): Proportion of missing values to simulate.
17
+ strategy (Literal["random", "random_inv_genotype"]): Strategy for simulating missing values.
18
+ missing_val (int | float): Value to represent missing data.
19
+ het_boost (float): Boost factor for heterozygous genotypes.
20
+ seed (int | None): Random seed for reproducibility.
21
+
22
+ Notes:
23
+ - The `strategy` attribute determines how missing values are simulated.
24
+ "random" selects missing values uniformly at random, while "random_inv_genotype" selects missing values based on the inverse of the genotype distribution.
25
+ """
26
+
27
+ prop_missing: float = 0.3
28
+ strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
29
+ missing_val: int | float = -1
30
+ het_boost: float = 2.0
31
+ seed: int | None = None
32
+
33
+ def to_dict(self) -> dict:
34
+ """Convert the simulation parameters to a dictionary.
35
+
36
+ Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
37
+
38
+ Returns:
39
+ dict: A dictionary representation of the simulation parameters.
40
+ """
41
+ return asdict(self)
42
+
43
+
44
+ @dataclass
45
+ class _ImputerParams:
46
+ """Container for imputer hyperparameters.
47
+
48
+ This class holds the hyperparameters for the imputation process, including the number of nearest features to consider, the maximum number of iterations, and other relevant settings.
49
+
50
+ Attributes:
51
+ n_nearest_features (int | None): Number of nearest features to consider for imputation
52
+ max_iter (int): Maximum number of iterations for the imputation algorithm.
53
+ initial_strategy (Literal["mean", "median", "most_frequent", "constant"]): Strategy for initial imputation of missing values.
54
+ keep_empty_features (bool): Whether to keep features that are entirely missing.
55
+ random_state (int | None): Random seed for reproducibility.
56
+ verbose (bool): If True, enables verbose logging during imputation.
57
+
58
+ Notes:
59
+ - The `initial_strategy` attribute determines how initial missing values are imputed before the iterative process begins.
60
+ """
61
+
62
+ n_nearest_features: int | None = 10
63
+ max_iter: int = 10
64
+ initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = (
65
+ "most_frequent"
66
+ )
67
+ keep_empty_features: bool = True
68
+ random_state: int | None = None
69
+ verbose: bool = False
70
+
71
+ def to_dict(self) -> dict:
72
+ """Convert the imputer parameters to a dictionary.
73
+
74
+ Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
75
+
76
+ Returns:
77
+ dict: A dictionary representation of the imputer parameters.
78
+ """
79
+
80
+ return asdict(self)
81
+
82
+
83
+ @dataclass
84
+ class _RFParams:
85
+ """Container for RandomForest hyperparameters.
86
+
87
+ This class holds the hyperparameters for the RandomForest classifier, including the number of estimators, maximum depth, and other relevant settings.
88
+
89
+ Attributes:
90
+ n_estimators (int): Number of trees in the forest.
91
+ max_depth (int | None): Maximum depth of the trees.
92
+ min_samples_split (int): Minimum number of samples required to split an internal node.
93
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
94
+ max_features (Literal["sqrt", "log2"] | float | None): Number of
95
+ features to consider when looking for the best split.
96
+ criterion (Literal["gini", "entropy", "log_loss"]): Function to measure
97
+ the quality of a split.
98
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights
99
+ associated with classes.
100
+ """
101
+
102
+ n_estimators: int = 300
103
+ max_depth: int | None = None
104
+ min_samples_split: int = 2
105
+ min_samples_leaf: int = 1
106
+ max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
107
+ criterion: Literal["gini", "entropy", "log_loss"] = "gini"
108
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
109
+
110
+ def to_dict(self) -> dict:
111
+ """Convert the RandomForest parameters to a dictionary.
112
+
113
+ Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
114
+
115
+ Returns:
116
+ dict: A dictionary representation of the RandomForest parameters.
117
+ """
118
+ return asdict(self)
119
+
120
+
121
+ @dataclass
122
+ class _HGBParams:
123
+ """Container for HistGradientBoosting hyperparameters.
124
+
125
+ This class holds the hyperparameters for the HistGradientBoosting classifier, including the number of iterations, learning rate, and other relevant settings.
126
+
127
+ Attributes:
128
+ max_iter (int): Maximum number of iterations.
129
+ learning_rate (float): Learning rate shrinks the contribution of each tree.
130
+ max_depth (int | None): Maximum depth of the individual regression estimators.
131
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
132
+ n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping
133
+ tol (float): Tolerance for the early stopping.
134
+ max_features (float): The fraction of features to consider when looking for the best split.
135
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
136
+ random_state (int | None): Random seed for reproducibility.
137
+ verbose (bool): If True, enables verbose logging during training.
138
+
139
+ Notes:
140
+ - The `class_weight` attribute helps to handle class imbalance by adjusting the weights associated with classes.
141
+ """
142
+
143
+ max_iter: int = 100
144
+ learning_rate: float = 0.1
145
+ max_depth: int | None = None
146
+ min_samples_leaf: int = 1
147
+ n_iter_no_change: int = 10
148
+ tol: float = 1e-7
149
+ max_features: float = 1.0
150
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
151
+ random_state: int | None = None
152
+ verbose: bool = False
153
+
154
+ def to_dict(self) -> dict:
155
+ """Convert the HistGradientBoosting parameters to a dictionary.
156
+
157
+ Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
158
+
159
+ Returns:
160
+ dict: A dictionary representation of the HistGradientBoosting parameters.
161
+ """
162
+ return asdict(self)
163
+
164
+
165
+ @dataclass
166
+ class ModelConfig:
167
+ """Model architecture configuration.
168
+
169
+ This class contains configuration options for the model architecture, including latent space initialization, dimensionality, dropout rate, and other relevant settings.
170
+
171
+ Attributes:
172
+ latent_init (Literal["random", "pca"]): Method for initializing the latent space.
173
+ latent_dim (int): Dimensionality of the latent space.
174
+ dropout_rate (float): Dropout rate for regularization.
175
+ num_hidden_layers (int): Number of hidden layers in the neural network.
176
+ hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function for hidden layers.
177
+ layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
178
+ layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
179
+ gamma (float): Parameter for the loss function.
180
+
181
+ Notes:
182
+ - The `layer_schedule` attribute determines how the size of hidden layers changes across the network (e.g., "pyramid" means decreasing size).
183
+ - The `latent_init` attribute specifies how the latent space is initialized, either randomly or using PCA.
184
+ """
185
+
186
+ latent_init: Literal["random", "pca"] = "random"
187
+ latent_dim: int = 2
188
+ dropout_rate: float = 0.2
189
+ num_hidden_layers: int = 2
190
+ hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
191
+ layer_scaling_factor: float = 5.0
192
+ layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
193
+ gamma: float = 2.0
194
+
195
+
196
+ @dataclass
197
+ class TrainConfig:
198
+ """Training procedure configuration.
199
+
200
+ This class contains configuration options for the training procedure, including batch size, learning rate, early stopping criteria, and other relevant settings.
201
+
202
+ Attributes:
203
+ batch_size (int): Number of samples per training batch.
204
+ learning_rate (float): Learning rate for the optimizer.
205
+ lr_input_factor (float): Factor to scale the learning rate for input layer.
206
+ l1_penalty (float): L1 regularization penalty.
207
+ early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
208
+ min_epochs (int): Minimum number of epochs to train.
209
+ max_epochs (int): Maximum number of epochs to train.
210
+ validation_split (float): Proportion of data to use for validation.
211
+ weights_beta (float): Smoothing factor for class weights.
212
+ weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
213
+ device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
214
+
215
+ Notes:
216
+ - The `device` attribute specifies the computation device to use, such as "gpu", "cpu", or "mps" (for Apple Silicon).
217
+ """
218
+
219
+ batch_size: int = 32
220
+ learning_rate: float = 1e-3
221
+ lr_input_factor: float = 1.0
222
+ l1_penalty: float = 0.0
223
+ early_stop_gen: int = 20
224
+ min_epochs: int = 100
225
+ max_epochs: int = 5000
226
+ validation_split: float = 0.2
227
+ weights_beta: float = 0.9999
228
+ weights_max_ratio: float = 1.0
229
+ device: Literal["gpu", "cpu", "mps"] = "cpu"
230
+
231
+
232
+ @dataclass
233
+ class TuneConfig:
234
+ """Hyperparameter tuning configuration.
235
+
236
+ This class contains configuration options for hyperparameter tuning, including the number of trials, evaluation metrics, and other relevant settings.
237
+
238
+ Attributes:
239
+ enabled (bool): If True, enables hyperparameter tuning.
240
+ metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
241
+ n_trials (int): Number of hyperparameter trials to run.
242
+ resume (bool): If True, resumes tuning from a previous state.
243
+ save_db (bool): If True, saves the tuning results to a database.
244
+ fast (bool): If True, uses a faster but less thorough tuning approach.
245
+ max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
246
+ max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
247
+ epochs (int): Number of epochs to train each trial.
248
+ batch_size (int): Batch size for training during tuning.
249
+ eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
250
+ infer_epochs (int): Number of epochs for inference during tuning.
251
+ patience (int): Number of evaluations with no improvement before stopping early.
252
+ proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
253
+ """
254
+
255
+ enabled: bool = False
256
+ metric: Literal["f1", "accuracy", "pr_macro"] = "f1"
257
+ n_trials: int = 100
258
+ resume: bool = False
259
+ save_db: bool = False
260
+ fast: bool = True
261
+ max_samples: int = 512
262
+ max_loci: int = 0 # 0 = all
263
+ epochs: int = 500
264
+ batch_size: int = 64
265
+ eval_interval: int = 1
266
+ infer_epochs: int = 100
267
+ patience: int = 10
268
+ proxy_metric_batch: int = 0
269
+
270
+
271
+ @dataclass
272
+ class EvalConfig:
273
+ """Evaluation configuration.
274
+
275
+ This class contains configuration options for the evaluation process, including batch size, evaluation intervals, and other relevant settings.
276
+
277
+ Attributes:
278
+ eval_latent_steps (int): Number of optimization steps for latent space evaluation.
279
+ eval_latent_lr (float): Learning rate for latent space optimization.
280
+ eval_latent_weight_decay (float): Weight decay for latent space optimization.
281
+ """
282
+
283
+ eval_latent_steps: int = 50
284
+ eval_latent_lr: float = 1e-2
285
+ eval_latent_weight_decay: float = 0.0
286
+
287
+
288
+ @dataclass
289
+ class PlotConfig:
290
+ """Plotting configuration.
291
+
292
+ This class contains configuration options for plotting, including file format, resolution, and other relevant settings.
293
+
294
+ Attributes:
295
+ fmt (Literal["pdf", "png", "jpg", "jpeg", "svg"]): Output file format.
296
+ dpi (int): Dots per inch for the output figure.
297
+ fontsize (int): Font size for text in the plots.
298
+ despine (bool): If True, removes the top and right spines from plots.
299
+ show (bool): If True, displays the plot interactively.
300
+ """
301
+
302
+ fmt: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
303
+ dpi: int = 300
304
+ fontsize: int = 18
305
+ despine: bool = True
306
+ show: bool = False
307
+
308
+
309
+ @dataclass
310
+ class IOConfig:
311
+ """I/O configuration.
312
+
313
+ This class contains configuration options for input/output operations, including file prefixes, verbosity, random seed, and other relevant settings.
314
+
315
+ Attributes:
316
+ prefix (str): Prefix for output files. Default is "pgsui".
317
+ verbose (bool): If True, enables verbose logging. Default is False.
318
+ debug (bool): If True, enables debug mode. Default is False.
319
+ seed (int | None): Random seed for reproducibility. Default is None.
320
+ n_jobs (int): Number of parallel jobs to run. Default is 1.
321
+ scoring_averaging (Literal["macro", "micro", "weighted"]): Averaging
322
+ method for scoring metrics. Default is "macro".
323
+ """
324
+
325
+ prefix: str = "pgsui"
326
+ verbose: bool = False
327
+ debug: bool = False
328
+ seed: int | None = None
329
+ n_jobs: int = 1
330
+ scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
331
+
332
+
333
+ @dataclass
334
+ class NLPCAConfig:
335
+ """Top-level configuration for ImputeNLPCA.
336
+
337
+ This class contains all the configuration options for the ImputeNLPCA model. The configuration is organized into several sections, each represented by a dataclass.
338
+
339
+ Attributes:
340
+ io (IOConfig): I/O configuration.
341
+ model (ModelConfig): Model architecture configuration.
342
+ train (TrainConfig): Training procedure configuration.
343
+ tune (TuneConfig): Hyperparameter tuning configuration.
344
+ evaluate (EvalConfig): Evaluation configuration.
345
+ plot (PlotConfig): Plotting configuration.
346
+
347
+ Notes:
348
+ - fast: Quick baseline; tiny net; NO tuning by default.
349
+ - balanced: Practical default balancing speed and model performance; moderate tuning.
350
+ - thorough: Prioritizes model performance; deeper nets; extensive tuning.
351
+ - Overrides: Overrides are applied after presets and can be used to fine-tune specific parameters. Specifically uses flat dot-keys like {"model.latent_dim": 8}.
352
+ """
353
+
354
+ io: IOConfig = field(default_factory=IOConfig)
355
+ model: ModelConfig = field(default_factory=ModelConfig)
356
+ train: TrainConfig = field(default_factory=TrainConfig)
357
+ tune: TuneConfig = field(default_factory=TuneConfig)
358
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
359
+ plot: PlotConfig = field(default_factory=PlotConfig)
360
+
361
+ @classmethod
362
+ def from_preset(
363
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
364
+ ) -> "NLPCAConfig":
365
+ """Build a config from a named preset.
366
+
367
+ This method allows for easy construction of a NLPCAConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs:
368
+
369
+ Args:
370
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
371
+
372
+ Returns:
373
+ NLPCAConfig: Configuration instance with preset values applied.
374
+ """
375
+ if preset not in {"fast", "balanced", "thorough"}:
376
+ raise ValueError(f"Unknown preset: {preset}")
377
+
378
+ cfg = cls() # start from dataclass defaults
379
+
380
+ # Common sensible baselines
381
+ cfg.io.verbose = True
382
+ cfg.train.validation_split = 0.2
383
+ cfg.evaluate.eval_latent_steps = 50
384
+ cfg.evaluate.eval_latent_lr = 1e-2
385
+ cfg.evaluate.eval_latent_weight_decay = 0.0
386
+ cfg.model.hidden_activation = "relu"
387
+ cfg.model.layer_schedule = "pyramid"
388
+ cfg.model.latent_init = "random"
389
+
390
+ if preset == "fast":
391
+ # Model
392
+ cfg.model.latent_dim = 4
393
+ cfg.model.num_hidden_layers = 1
394
+ cfg.model.layer_scaling_factor = 2.0
395
+ cfg.model.dropout_rate = 0.10
396
+ cfg.model.gamma = 1.5
397
+
398
+ # Train
399
+ cfg.train.batch_size = 128
400
+ cfg.train.learning_rate = 1e-3
401
+ cfg.train.early_stop_gen = 5
402
+ cfg.train.min_epochs = 10
403
+ cfg.train.max_epochs = 100
404
+ cfg.train.weights_beta = 0.9999
405
+ cfg.train.weights_max_ratio = 1.0 # no rebalancing pressure
406
+
407
+ # Tuning (off for true "fast")
408
+ cfg.tune.enabled = False
409
+ cfg.tune.fast = True
410
+ cfg.tune.n_trials = 50
411
+ cfg.tune.epochs = 100
412
+ cfg.tune.batch_size = 128
413
+ cfg.tune.max_samples = 512 # cap data for speed
414
+ cfg.tune.max_loci = 0
415
+ cfg.tune.eval_interval = 1
416
+ cfg.tune.infer_epochs = 50
417
+ cfg.tune.patience = 5
418
+ cfg.tune.proxy_metric_batch = 0
419
+
420
+ elif preset == "balanced":
421
+ # Model
422
+ cfg.model.latent_dim = 8
423
+ cfg.model.num_hidden_layers = 2
424
+ cfg.model.layer_scaling_factor = 4.0
425
+ cfg.model.dropout_rate = 0.20
426
+ cfg.model.gamma = 2.0
427
+
428
+ # Train
429
+ cfg.train.batch_size = 128
430
+ cfg.train.learning_rate = 8e-4
431
+ cfg.train.early_stop_gen = 15
432
+ cfg.train.min_epochs = 50
433
+ cfg.train.max_epochs = 1000
434
+ cfg.train.weights_beta = 0.9999
435
+ cfg.train.weights_max_ratio = 1.0
436
+
437
+ # Tuning
438
+ cfg.tune.enabled = True
439
+ cfg.tune.fast = True # favor speed with good coverage
440
+ cfg.tune.n_trials = 100 # more trials
441
+ cfg.tune.epochs = 250
442
+ cfg.tune.batch_size = 128
443
+ cfg.tune.max_samples = 1024
444
+ cfg.tune.max_loci = 0
445
+ cfg.tune.eval_interval = 1
446
+ cfg.tune.infer_epochs = 80
447
+ cfg.tune.patience = 10
448
+ cfg.tune.proxy_metric_batch = 0
449
+
450
+ else: # thorough
451
+ # Model
452
+ cfg.model.latent_dim = 16
453
+ cfg.model.num_hidden_layers = 3
454
+ cfg.model.layer_scaling_factor = 6.0
455
+ cfg.model.dropout_rate = 0.30
456
+ cfg.model.gamma = 2.5
457
+
458
+ # Train
459
+ cfg.train.batch_size = 64
460
+ cfg.train.learning_rate = 6e-4
461
+ cfg.train.early_stop_gen = 30
462
+ cfg.train.min_epochs = 100
463
+ cfg.train.max_epochs = 3000
464
+ cfg.train.weights_beta = 0.9999
465
+ cfg.train.weights_max_ratio = 1.0
466
+
467
+ # Tuning
468
+ cfg.tune.enabled = True
469
+ cfg.tune.fast = False
470
+ cfg.tune.n_trials = 250
471
+ cfg.tune.epochs = 1000
472
+ cfg.tune.batch_size = 64
473
+ cfg.tune.max_samples = 0 # use all samples
474
+ cfg.tune.max_loci = 0 # use all loci
475
+ cfg.tune.eval_interval = 1
476
+ cfg.tune.infer_epochs = 120
477
+ cfg.tune.patience = 20
478
+ cfg.tune.proxy_metric_batch = 0
479
+
480
+ return cfg
481
+
482
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
483
+ """Apply flat dot-key overrides (e.g. {'model.latent_dim': 4}).
484
+
485
+ This method allows for easy modification of the configuration by specifying the keys to change in a flat dictionary format.
486
+
487
+ Args:
488
+ overrides (Dict[str, Any] | None): A mapping of dot-key paths to values to override.
489
+
490
+ Returns:
491
+ NLPCAConfig: The updated config instance (same as `self`).
492
+ """
493
+ if not overrides:
494
+ return self
495
+ for k, v in overrides.items():
496
+ node = self
497
+ parts = k.split(".")
498
+ for p in parts[:-1]:
499
+ node = getattr(node, p)
500
+ last = parts[-1]
501
+ if hasattr(node, last):
502
+ setattr(node, last, v)
503
+ else:
504
+ raise KeyError(f"Unknown config key: {k}")
505
+ return self
506
+
507
+ def to_dict(self) -> Dict[str, Any]:
508
+ """Return the config as a nested dictionary.
509
+
510
+ This method uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
511
+
512
+ Returns:
513
+ Dict[str, Any]: The config as a nested dictionary.
514
+ """
515
+ return asdict(self)
516
+
517
+
518
+ @dataclass
519
+ class UBPConfig:
520
+ """Top-level configuration for ImputeUBP.
521
+
522
+ This class contains all the configuration options for the ImputeUBP model. The configuration is organized into several sections, each represented by a dataclass.
523
+
524
+ Attributes:
525
+ io (IOConfig): I/O configuration.
526
+ model (ModelConfig): Model architecture configuration.
527
+ train (TrainConfig): Training procedure configuration.
528
+ tune (TuneConfig): Hyperparameter tuning configuration.
529
+ evaluate (EvalConfig): Evaluation configuration.
530
+ plot (PlotConfig): Plotting configuration.
531
+
532
+ Notes:
533
+ - fast: Quick baseline; tiny net; NO tuning by default.
534
+ - balanced: Practical default balancing speed and model performance; moderate tuning.
535
+ - thorough: Prioritizes model performance; deeper nets; extensive tuning.
536
+ - Overrides: Overrides are applied after presets and can be used to fine-tune specific parameters. Specifically uses flat dot-keys like {"model.latent_dim": 8}.
537
+ """
538
+
539
+ io: IOConfig = field(default_factory=IOConfig)
540
+ model: ModelConfig = field(default_factory=ModelConfig)
541
+ train: TrainConfig = field(default_factory=TrainConfig)
542
+ tune: TuneConfig = field(default_factory=TuneConfig)
543
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
544
+ plot: PlotConfig = field(default_factory=PlotConfig)
545
+
546
+ @classmethod
547
+ def from_preset(
548
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
549
+ ) -> "UBPConfig":
550
+ """Build a UBPConfig from a named preset.
551
+
552
+ This method allows for easy construction of a UBPConfig instance with sensible defaults based on the chosen preset. UBP is often used when classes (genotype states) are imbalanced. Presets adjust both capacity and weighting behavior across speed/quality tradeoffs.
553
+
554
+ Args:
555
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast","balanced","thorough"}.
556
+
557
+ Returns:
558
+ UBPConfig: Populated config instance.
559
+
560
+ Notes:
561
+ - fast: Quick baseline; tiny net; NO tuning by default.
562
+ - balanced: Practical default balancing speed and model performance; moderate tuning.
563
+ - thorough: Prioritizes model performance; deeper nets; extensive tuning.
564
+ - Overrides: Overrides are applied after presets and can be used to fine-tune specific parameters. Specifically uses flat dot-keys like {"model.latent_dim": 8}.
565
+
566
+ Raises:
567
+ ValueError: If an unknown preset is provided.
568
+ """
569
+ if preset not in {"fast", "balanced", "thorough"}:
570
+ raise ValueError(f"Unknown preset: {preset}")
571
+
572
+ cfg = cls()
573
+
574
+ # Shared baselines
575
+ cfg.io.verbose = True
576
+ cfg.model.hidden_activation = "relu"
577
+ cfg.model.layer_schedule = "pyramid"
578
+ cfg.model.latent_init = "random"
579
+
580
+ if preset == "fast":
581
+ # Model (slightly smaller than NLPCA fast)
582
+ cfg.model.latent_dim = 3
583
+ cfg.model.num_hidden_layers = 1
584
+ cfg.model.layer_scaling_factor = 2.0
585
+ cfg.model.dropout_rate = 0.10
586
+ cfg.model.gamma = 1.5 # lighter focusing
587
+
588
+ # Train
589
+ cfg.train.batch_size = 128
590
+ cfg.train.learning_rate = 1e-3
591
+ cfg.train.early_stop_gen = 5
592
+ cfg.train.min_epochs = 10
593
+ cfg.train.max_epochs = 100
594
+ cfg.train.weights_beta = 0.9999
595
+ cfg.train.weights_max_ratio = 2.0 # allow mild rebalancing
596
+
597
+ # Tuning (off for true "fast")
598
+ cfg.tune.enabled = False
599
+ cfg.tune.fast = True
600
+ cfg.tune.n_trials = 50
601
+ cfg.tune.epochs = 100
602
+ cfg.tune.batch_size = 128
603
+ cfg.tune.max_samples = 512
604
+ cfg.tune.max_loci = 0
605
+ cfg.tune.eval_interval = 1
606
+ cfg.tune.infer_epochs = 50
607
+ cfg.tune.patience = 5
608
+ cfg.tune.proxy_metric_batch = 0
609
+
610
+ elif preset == "balanced":
611
+ # Model
612
+ cfg.model.latent_dim = 6
613
+ cfg.model.num_hidden_layers = 2
614
+ cfg.model.layer_scaling_factor = 3.0
615
+ cfg.model.dropout_rate = 0.20
616
+ cfg.model.gamma = 2.0
617
+
618
+ # Train
619
+ cfg.train.batch_size = 128
620
+ cfg.train.learning_rate = 8e-4
621
+ cfg.train.early_stop_gen = 15
622
+ cfg.train.min_epochs = 50
623
+ cfg.train.max_epochs = 1000
624
+ cfg.train.weights_beta = 0.9999
625
+ cfg.train.weights_max_ratio = 3.0 # moderate cap for imbalance
626
+
627
+ # Tuning
628
+ cfg.tune.enabled = True
629
+ cfg.tune.fast = True
630
+ cfg.tune.n_trials = 100
631
+ cfg.tune.epochs = 250
632
+ cfg.tune.batch_size = 128
633
+ cfg.tune.max_samples = 1024
634
+ cfg.tune.max_loci = 0
635
+ cfg.tune.eval_interval = 1
636
+ cfg.tune.infer_epochs = 80
637
+ cfg.tune.patience = 10
638
+ cfg.tune.proxy_metric_batch = 0
639
+
640
+ else: # thorough
641
+ # Model
642
+ cfg.model.latent_dim = 12
643
+ cfg.model.num_hidden_layers = 3
644
+ cfg.model.layer_scaling_factor = 5.0
645
+ cfg.model.dropout_rate = 0.30
646
+ cfg.model.gamma = 2.5 # stronger focusing for harder imbalance
647
+
648
+ # Train
649
+ cfg.train.batch_size = 64
650
+ cfg.train.learning_rate = 6e-4
651
+ cfg.train.early_stop_gen = 30
652
+ cfg.train.min_epochs = 100
653
+ cfg.train.max_epochs = 3000
654
+ cfg.train.weights_beta = 0.9999
655
+ cfg.train.weights_max_ratio = 5.0 # allow stronger class weighting
656
+
657
+ # Tuning
658
+ cfg.tune.enabled = True
659
+ cfg.tune.fast = False
660
+ cfg.tune.n_trials = 250
661
+ cfg.tune.epochs = 1000
662
+ cfg.tune.batch_size = 64
663
+ cfg.tune.max_samples = 0 # all samples
664
+ cfg.tune.max_loci = 0 # all loci
665
+ cfg.tune.eval_interval = 1
666
+ cfg.tune.infer_epochs = 120
667
+ cfg.tune.patience = 20
668
+ cfg.tune.proxy_metric_batch = 0
669
+
670
+ return cfg
671
+
672
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
673
+ """Apply flat dot-key overrides (e.g. {'model.latent_dim': 4}).
674
+
675
+ Args:
676
+ overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
677
+
678
+ Returns:
679
+ UBPConfig: This instance after applying overrides.
680
+ """
681
+ if overrides is None or not overrides:
682
+ return self
683
+
684
+ for k, v in overrides.items():
685
+ node = self
686
+ parts = k.split(".")
687
+ for p in parts[:-1]:
688
+ node = getattr(node, p)
689
+ last = parts[-1]
690
+ if hasattr(node, last):
691
+ setattr(node, last, v)
692
+ else:
693
+ raise KeyError(f"Unknown config key: {k}")
694
+ return self
695
+
696
+ def to_dict(self) -> Dict[str, Any]:
697
+ """Return the config as a nested dictionary.
698
+
699
+ This method uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
700
+
701
+ Returns:
702
+ Dict[str, Any]: Nested dictionary.
703
+ """
704
+ return asdict(self)
705
+
706
+
707
+ @dataclass
708
+ class AutoencoderConfig:
709
+ """Top-level configuration for ImputeAutoencoder.
710
+
711
+ This class contains all the configuration options for the ImputeAutoencoder model. The configuration is organized into several sections, each represented by a dataclass.
712
+
713
+ Attributes:
714
+ io (IOConfig): I/O configuration.
715
+ model (ModelConfig): Model architecture configuration.
716
+ train (TrainConfig): Training procedure configuration.
717
+ tune (TuneConfig): Hyperparameter tuning configuration.
718
+ evaluate (EvalConfig): Evaluation configuration.
719
+ plot (PlotConfig): Plotting configuration.
720
+
721
+ Notes:
722
+ - fast: Quick baseline; tiny net; NO tuning by default.
723
+ - balanced: Practical default; moderate tuning.
724
+ - thorough: Prioritizes model performance; deeper nets; extensive tuning.
725
+ - Overrides: flat dot-keys like {"model.latent_dim": 8}.
726
+ """
727
+
728
+ io: IOConfig = field(default_factory=IOConfig)
729
+ model: ModelConfig = field(default_factory=ModelConfig)
730
+ train: TrainConfig = field(default_factory=TrainConfig)
731
+ tune: TuneConfig = field(default_factory=TuneConfig)
732
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
733
+ plot: PlotConfig = field(default_factory=PlotConfig)
734
+
735
+ @classmethod
736
+ def from_preset(
737
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
738
+ ) -> "AutoencoderConfig":
739
+ """Build an AutoencoderConfig from a named preset.
740
+
741
+ This method allows for easy construction of an AutoencoderConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
742
+
743
+ Args:
744
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast","balanced", "thorough"}.
745
+
746
+ Returns:
747
+ AutoencoderConfig: Populated config instance.
748
+ """
749
+ if preset not in {"fast", "balanced", "thorough"}:
750
+ raise ValueError(f"Unknown preset: {preset}")
751
+
752
+ cfg = cls()
753
+
754
+ # Common sensible baselines (aligned with NLPCA)
755
+ cfg.io.verbose = True
756
+ cfg.train.validation_split = 0.2
757
+ cfg.model.hidden_activation = "relu"
758
+ cfg.model.layer_schedule = "pyramid"
759
+
760
+ # AE difference: no latent refinement during eval
761
+ cfg.evaluate.eval_latent_steps = 0
762
+ cfg.evaluate.eval_latent_lr = 0.0
763
+ cfg.evaluate.eval_latent_weight_decay = 0.0
764
+
765
+ if preset == "fast":
766
+ # Model
767
+ cfg.model.latent_dim = 4
768
+ cfg.model.num_hidden_layers = 1
769
+ cfg.model.layer_scaling_factor = 2.0
770
+ cfg.model.dropout_rate = 0.10
771
+ cfg.model.gamma = 1.5
772
+ # Train
773
+ cfg.train.batch_size = 128
774
+ cfg.train.learning_rate = 1e-3
775
+ cfg.train.early_stop_gen = 5
776
+ cfg.train.min_epochs = 10
777
+ cfg.train.max_epochs = 100
778
+ cfg.train.weights_beta = 0.9999
779
+ cfg.train.weights_max_ratio = 1.0
780
+ # Tuning (off for true fast)
781
+ cfg.tune.enabled = False
782
+ cfg.tune.fast = True
783
+ cfg.tune.n_trials = 50
784
+ cfg.tune.epochs = 100
785
+ cfg.tune.batch_size = 128
786
+ cfg.tune.max_samples = 512
787
+ cfg.tune.max_loci = 0
788
+ cfg.tune.eval_interval = 1
789
+ cfg.tune.patience = 5
790
+ cfg.tune.proxy_metric_batch = 0
791
+ if hasattr(cfg.tune, "infer_epochs"):
792
+ cfg.tune.infer_epochs = 0
793
+
794
+ elif preset == "balanced":
795
+ # Model
796
+ cfg.model.latent_dim = 8
797
+ cfg.model.num_hidden_layers = 2
798
+ cfg.model.layer_scaling_factor = 4.0
799
+ cfg.model.dropout_rate = 0.20
800
+ cfg.model.gamma = 2.0
801
+ # Train
802
+ cfg.train.batch_size = 128
803
+ cfg.train.learning_rate = 8e-4
804
+ cfg.train.early_stop_gen = 15
805
+ cfg.train.min_epochs = 50
806
+ cfg.train.max_epochs = 1000
807
+ cfg.train.weights_beta = 0.9999
808
+ cfg.train.weights_max_ratio = 1.0
809
+ # Tuning
810
+ cfg.tune.enabled = True
811
+ cfg.tune.fast = True
812
+ cfg.tune.n_trials = 100
813
+ cfg.tune.epochs = 250
814
+ cfg.tune.batch_size = 128
815
+ cfg.tune.max_samples = 1024
816
+ cfg.tune.max_loci = 0
817
+ cfg.tune.eval_interval = 1
818
+ cfg.tune.patience = 10
819
+ cfg.tune.proxy_metric_batch = 0
820
+ if hasattr(cfg.tune, "infer_epochs"):
821
+ cfg.tune.infer_epochs = 0
822
+
823
+ else: # thorough
824
+ # Model
825
+ cfg.model.latent_dim = 16
826
+ cfg.model.num_hidden_layers = 3
827
+ cfg.model.layer_scaling_factor = 6.0
828
+ cfg.model.dropout_rate = 0.30
829
+ cfg.model.gamma = 2.5
830
+ # Train
831
+ cfg.train.batch_size = 64
832
+ cfg.train.learning_rate = 6e-4
833
+ cfg.train.early_stop_gen = 30
834
+ cfg.train.min_epochs = 100
835
+ cfg.train.max_epochs = 3000
836
+ cfg.train.weights_beta = 0.9999
837
+ cfg.train.weights_max_ratio = 1.0
838
+ # Tuning
839
+ cfg.tune.enabled = True
840
+ cfg.tune.fast = False
841
+ cfg.tune.n_trials = 250
842
+ cfg.tune.epochs = 1000
843
+ cfg.tune.batch_size = 64
844
+ cfg.tune.max_samples = 0 # use all samples
845
+ cfg.tune.max_loci = 0 # use all loci
846
+ cfg.tune.eval_interval = 1
847
+ cfg.tune.patience = 20
848
+ cfg.tune.proxy_metric_batch = 0
849
+ if hasattr(cfg.tune, "infer_epochs"):
850
+ cfg.tune.infer_epochs = 0
851
+
852
+ return cfg
853
+
854
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
855
+ """Apply flat dot-key overrides (e.g. {'model.latent_dim': 4}).
856
+
857
+ Args:
858
+ overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
859
+
860
+ Returns:
861
+ AutoencoderConfig: This instance after applying overrides.
862
+ """
863
+ if not overrides:
864
+ return self
865
+ for k, v in overrides.items():
866
+ node = self
867
+ parts = k.split(".")
868
+ for p in parts[:-1]:
869
+ node = getattr(node, p)
870
+ last = parts[-1]
871
+ if hasattr(node, last):
872
+ setattr(node, last, v)
873
+ else:
874
+ raise KeyError(f"Unknown config key: {k}")
875
+ return self
876
+
877
+ def to_dict(self) -> Dict[str, Any]:
878
+ return asdict(self)
879
+
880
+
881
+ @dataclass
882
+ class VAEExtraConfig:
883
+ """VAE-specific knobs.
884
+
885
+ This class contains additional configuration options specific to Variational Autoencoders (VAEs), particularly for controlling the KL divergence term in the loss function.
886
+
887
+ Attributes:
888
+ kl_beta (float): Final β for KL divergence term.
889
+ kl_warmup (int): Number of epochs with β=0 (warm-up period
890
+ to stabilize training).
891
+ kl_ramp (int): Number of epochs for linear ramp to final β.
892
+
893
+ Notes:
894
+ - These parameters control the behavior of the KL divergence term in the VAE loss function.
895
+ - The warm-up period helps to stabilize training by gradually introducing the KL term.
896
+ - The ramp period defines how quickly the KL term reaches its final value.
897
+ """
898
+
899
+ kl_beta: float = 1.0 # final β for KL
900
+ kl_warmup: int = 50 # epochs with β=0
901
+ kl_ramp: int = 200 # linear ramp to β
902
+
903
+
904
+ @dataclass
905
+ class VAEConfig:
906
+ """Top-level configuration for ImputeVAE (AE-parity + VAE extras).
907
+
908
+ This class contains all the configuration options for the ImputeVAE model. The configuration is organized into several sections, each represented by a dataclass.
909
+
910
+ Attributes:
911
+ io (IOConfig): I/O configuration.
912
+ model (ModelConfig): Model architecture configuration.
913
+ train (TrainConfig): Training procedure configuration.
914
+ tune (TuneConfig): Hyperparameter tuning configuration.
915
+ evaluate (EvalConfig): Evaluation configuration.
916
+ plot (PlotConfig): Plotting configuration.
917
+ vae (VAEExtraConfig): VAE-specific configuration.
918
+
919
+ Notes:
920
+ - fast: Quick baseline; tiny net; NO tuning by default.
921
+ - balanced: Practical default; moderate tuning.
922
+ - thorough: Prioritizes model performance; deeper nets; extensive tuning.
923
+ - Overrides: flat dot-keys like {"model.latent_dim": 8}.
924
+ """
925
+
926
+ io: IOConfig = field(default_factory=IOConfig)
927
+ model: ModelConfig = field(default_factory=ModelConfig)
928
+ train: TrainConfig = field(default_factory=TrainConfig)
929
+ tune: TuneConfig = field(default_factory=TuneConfig)
930
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
931
+ plot: PlotConfig = field(default_factory=PlotConfig)
932
+ vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
933
+
934
+ @classmethod
935
+ def from_preset(
936
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
937
+ ) -> "VAEConfig":
938
+ """Mirror AutoencoderConfig presets and add VAE defaults.
939
+
940
+ This method allows for easy construction of a VAEConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
941
+
942
+ Args:
943
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
944
+
945
+ Returns:
946
+ VAEConfig: Configuration instance with preset values applied.
947
+ """
948
+ if preset not in {"fast", "balanced", "thorough"}:
949
+ raise ValueError(f"Unknown preset: {preset}")
950
+
951
+ cfg = cls()
952
+
953
+ # Common sensible baselines (match AE/NLPCA style)
954
+ cfg.io.verbose = True
955
+ cfg.train.validation_split = 0.2
956
+ cfg.model.hidden_activation = "relu"
957
+ cfg.model.layer_schedule = "pyramid"
958
+
959
+ # Like AE, no latent refinement during eval
960
+ cfg.evaluate.eval_latent_steps = 0
961
+ cfg.evaluate.eval_latent_lr = 0.0
962
+ cfg.evaluate.eval_latent_weight_decay = 0.0
963
+
964
+ # VAE-specific schedule defaults (can be overridden)
965
+ cfg.vae.kl_beta = 1.0
966
+ cfg.vae.kl_warmup = 50
967
+ cfg.vae.kl_ramp = 200
968
+
969
+ if preset == "fast":
970
+ cfg.model.latent_dim = 4
971
+ cfg.model.num_hidden_layers = 1
972
+ cfg.model.layer_scaling_factor = 2.0
973
+ cfg.model.dropout_rate = 0.10
974
+ cfg.model.gamma = 1.5
975
+
976
+ cfg.train.batch_size = 128
977
+ cfg.train.learning_rate = 1e-3
978
+ cfg.train.early_stop_gen = 5
979
+ cfg.train.min_epochs = 10
980
+ cfg.train.max_epochs = 100
981
+ cfg.train.weights_beta = 0.9999
982
+ cfg.train.weights_max_ratio = 1.0
983
+
984
+ cfg.tune.enabled = False
985
+ cfg.tune.fast = True
986
+ cfg.tune.n_trials = 50
987
+ cfg.tune.epochs = 100
988
+ cfg.tune.batch_size = 128
989
+ cfg.tune.max_samples = 512
990
+ cfg.tune.max_loci = 0
991
+ cfg.tune.eval_interval = 1
992
+ cfg.tune.patience = 5
993
+
994
+ if hasattr(cfg.tune, "infer_epochs"):
995
+ cfg.tune.infer_epochs = 0
996
+
997
+ elif preset == "balanced":
998
+ cfg.model.latent_dim = 8
999
+ cfg.model.num_hidden_layers = 2
1000
+ cfg.model.layer_scaling_factor = 4.0
1001
+ cfg.model.dropout_rate = 0.20
1002
+ cfg.model.gamma = 2.0
1003
+
1004
+ cfg.train.batch_size = 128
1005
+ cfg.train.learning_rate = 8e-4
1006
+ cfg.train.early_stop_gen = 15
1007
+ cfg.train.min_epochs = 50
1008
+ cfg.train.max_epochs = 1000
1009
+ cfg.train.weights_beta = 0.9999
1010
+ cfg.train.weights_max_ratio = 1.0
1011
+
1012
+ cfg.tune.enabled = True
1013
+ cfg.tune.fast = True
1014
+ cfg.tune.n_trials = 100
1015
+ cfg.tune.epochs = 250
1016
+ cfg.tune.batch_size = 128
1017
+ cfg.tune.max_samples = 1024
1018
+ cfg.tune.max_loci = 0
1019
+ cfg.tune.eval_interval = 1
1020
+ cfg.tune.patience = 10
1021
+
1022
+ if hasattr(cfg.tune, "infer_epochs"):
1023
+ cfg.tune.infer_epochs = 0
1024
+
1025
+ else: # thorough
1026
+ cfg.model.latent_dim = 16
1027
+ cfg.model.num_hidden_layers = 3
1028
+ cfg.model.layer_scaling_factor = 6.0
1029
+ cfg.model.dropout_rate = 0.30
1030
+ cfg.model.gamma = 2.5
1031
+
1032
+ cfg.train.batch_size = 64
1033
+ cfg.train.learning_rate = 6e-4
1034
+ cfg.train.early_stop_gen = 30
1035
+ cfg.train.min_epochs = 100
1036
+ cfg.train.max_epochs = 3000
1037
+ cfg.train.weights_beta = 0.9999
1038
+ cfg.train.weights_max_ratio = 1.0
1039
+
1040
+ cfg.tune.enabled = True
1041
+ cfg.tune.fast = False
1042
+ cfg.tune.n_trials = 250
1043
+ cfg.tune.epochs = 1000
1044
+ cfg.tune.batch_size = 64
1045
+ cfg.tune.max_samples = 0
1046
+ cfg.tune.max_loci = 0
1047
+ cfg.tune.eval_interval = 1
1048
+ cfg.tune.patience = 20
1049
+
1050
+ if hasattr(cfg.tune, "infer_epochs"):
1051
+ cfg.tune.infer_epochs = 0
1052
+
1053
+ return cfg
1054
+
1055
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "VAEConfig":
1056
+ """Apply flat dot-key overrides (e.g., {'vae.kl_beta': 2.0})."""
1057
+ if not overrides:
1058
+ return self
1059
+ for k, v in overrides.items():
1060
+ node = self
1061
+ parts = k.split(".")
1062
+ for p in parts[:-1]:
1063
+ node = getattr(node, p)
1064
+ last = parts[-1]
1065
+ if hasattr(node, last):
1066
+ setattr(node, last, v)
1067
+ else:
1068
+ raise KeyError(f"Unknown config key: {k}")
1069
+ return self
1070
+
1071
+ def to_dict(self) -> Dict[str, Any]:
1072
+ return asdict(self)
1073
+
1074
+
1075
+ @dataclass
1076
+ class MostFrequentAlgoConfig:
1077
+ """Algorithmic knobs for ImputeMostFrequent.
1078
+
1079
+ This class contains configuration options specific to the most frequent genotype imputation algorithm.
1080
+
1081
+ Attributes:
1082
+ by_populations (bool): Whether to compute per-population modes when populations are available.
1083
+ default (int): Fallback mode if no valid entries in a locus.
1084
+ missing (int): Code for missing genotypes in 0/1/2.
1085
+ """
1086
+
1087
+ by_populations: bool = False # per-pop modes if pops available
1088
+ default: int = 0 # fallback mode if no valid entries in a locus
1089
+ missing: int = -1 # code for missing genotypes in 0/1/2
1090
+
1091
+
1092
+ @dataclass
1093
+ class DeterministicSplitConfig:
1094
+ """Evaluation split configuration shared by deterministic imputers.
1095
+
1096
+ This class contains configuration options for splitting data into training and testing sets for deterministic imputation algorithms. The split can be defined by a proportion of the data or by specific indices.
1097
+
1098
+ Attributes:
1099
+ test_size (float): Proportion of data to use as the test set.
1100
+ test_indices (Optional[Sequence[int]]): Specific indices to use as the test set. If provided, this overrides the `test_size` parameter.
1101
+ """
1102
+
1103
+ test_size: float = 0.2
1104
+ # If provided, overrides test_size.
1105
+ test_indices: Optional[Sequence[int]] = None
1106
+
1107
+
1108
+ @dataclass
1109
+ class MostFrequentConfig:
1110
+ """Top-level configuration for ImputeMostFrequent.
1111
+
1112
+ This class contains all the configuration options for the ImputeMostFrequent model. The configuration is organized into several sections, each represented by a dataclass.
1113
+
1114
+ Attributes:
1115
+ io (IOConfig): I/O configuration.
1116
+ plot (PlotConfig): Plotting configuration.
1117
+ split (DeterministicSplitConfig): Data splitting configuration.
1118
+ algo (MostFrequentAlgoConfig): Algorithmic configuration.
1119
+ """
1120
+
1121
+ io: IOConfig = field(default_factory=IOConfig)
1122
+ plot: PlotConfig = field(default_factory=PlotConfig)
1123
+ split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
1124
+ algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
1125
+
1126
+ @classmethod
1127
+ def from_preset(
1128
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1129
+ ) -> "MostFrequentConfig":
1130
+ """Presets mainly keep parity with logging/IO and split test_size.
1131
+
1132
+ Deterministic imputers don't have model/train knobs; presets exist for interface symmetry and minor UX defaults.
1133
+
1134
+ Args:
1135
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
1136
+
1137
+ Returns:
1138
+ MostFrequentConfig: Populated config instance.
1139
+ """
1140
+ if preset not in {"fast", "balanced", "thorough"}:
1141
+ raise ValueError(f"Unknown preset: {preset}")
1142
+
1143
+ cfg = cls()
1144
+ cfg.io.verbose = True
1145
+ cfg.split.test_size = 0.2 # keep stable across presets
1146
+ return cfg
1147
+
1148
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "MostFrequentConfig":
1149
+ """Apply dot-key overrides (e.g., {'algo.by_populations': True}).
1150
+
1151
+ Args:
1152
+ overrides (Dict[str, Any]): Mapping of dot-key paths to values to override.
1153
+
1154
+ Returns:
1155
+ MostFrequentConfig: This instance after applying overrides.
1156
+ """
1157
+ if not overrides:
1158
+ return self
1159
+ for k, v in overrides.items():
1160
+ node = self
1161
+ parts = k.split(".")
1162
+ for p in parts[:-1]:
1163
+ node = getattr(node, p)
1164
+ last = parts[-1]
1165
+ if hasattr(node, last):
1166
+ setattr(node, last, v)
1167
+ else:
1168
+ pass
1169
+ return self
1170
+
1171
+ def to_dict(self) -> Dict[str, Any]:
1172
+ """Return the config as a dictionary.
1173
+
1174
+ Returns:
1175
+ Dict[str, Any]: The config as a nested dictionary.
1176
+ """
1177
+ return asdict(self)
1178
+
1179
+
1180
+ @dataclass
1181
+ class RefAlleleAlgoConfig:
1182
+ """Algorithmic knobs for ImputeRefAllele.
1183
+
1184
+ This class contains configuration options specific to the reference allele imputation algorithm.
1185
+
1186
+ Attributes:
1187
+ missing (int): Code for missing genotypes in 0/1/2.
1188
+ """
1189
+
1190
+ missing: int = -1
1191
+
1192
+
1193
+ @dataclass
1194
+ class RefAlleleConfig:
1195
+ """Top-level configuration for ImputeRefAllele.
1196
+
1197
+ This class contains all the configuration options for the ImputeRefAllele model. The configuration is organized into several sections, each represented by a dataclass.
1198
+
1199
+ Attributes:
1200
+ io (IOConfig): I/O configuration.
1201
+ plot (PlotConfig): Plotting configuration.
1202
+ split (DeterministicSplitConfig): Data splitting configuration.
1203
+ algo (RefAlleleAlgoConfig): Algorithmic configuration.
1204
+ """
1205
+
1206
+ io: IOConfig = field(default_factory=IOConfig)
1207
+ plot: PlotConfig = field(default_factory=PlotConfig)
1208
+ split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
1209
+ algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
1210
+
1211
+ @classmethod
1212
+ def from_preset(
1213
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1214
+ ) -> "RefAlleleConfig":
1215
+ """Presets mainly keep parity with logging/IO and split test_size.
1216
+
1217
+ Deterministic imputers don't have model/train knobs; presets exist for interface symmetry and minor UX defaults.
1218
+
1219
+ Args:
1220
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
1221
+
1222
+ Returns:
1223
+ RefAlleleConfig: Populated config instance.
1224
+ """
1225
+ if preset not in {"fast", "balanced", "thorough"}:
1226
+ raise ValueError(f"Unknown preset: {preset}")
1227
+
1228
+ cfg = cls()
1229
+ cfg.io.verbose = True
1230
+ cfg.split.test_size = 0.2
1231
+ return cfg
1232
+
1233
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RefAlleleConfig":
1234
+ """Apply dot-key overrides (e.g., {'split.test_size': 0.3}).
1235
+
1236
+ This method allows for easy modification of the configuration by specifying the keys to change in a flat dictionary format.
1237
+
1238
+ Args:
1239
+ overrides (Dict[str, Any] | None): A mapping of dot-key paths to values to override.
1240
+
1241
+ Returns:
1242
+ RefAlleleConfig: The updated config instance (same as `self`).
1243
+ """
1244
+ if not overrides:
1245
+ return self
1246
+ for k, v in overrides.items():
1247
+ node = self
1248
+ parts = k.split(".")
1249
+ for p in parts[:-1]:
1250
+ node = getattr(node, p)
1251
+ last = parts[-1]
1252
+ if hasattr(node, last):
1253
+ setattr(node, last, v)
1254
+ else:
1255
+ pass
1256
+ return self
1257
+
1258
+ def to_dict(self) -> Dict[str, Any]:
1259
+ """Convert the config to a dictionary.
1260
+
1261
+ Returns:
1262
+ Dict[str, Any]: The config as a nested dictionary.
1263
+ """
1264
+ return asdict(self)
1265
+
1266
+
1267
+ def _flatten_dict(
1268
+ d: Dict[str, Any], prefix: str = "", out: Optional[Dict[str, Any]] = None
1269
+ ) -> Dict[str, Any]:
1270
+ """Flatten a nested dictionary into dot-key format.
1271
+
1272
+ Args:
1273
+ d (Dict[str, Any]): The nested dictionary to flatten.
1274
+ prefix (str): The prefix to use for keys (used in recursion).
1275
+ out (Optional[Dict[str, Any]]): The output dictionary to populate.
1276
+
1277
+ Returns:
1278
+ Dict[str, Any]: The flattened dictionary with dot-key format.
1279
+ """
1280
+ out = out or {}
1281
+ for k, v in d.items():
1282
+ kk = f"{prefix}.{k}" if prefix else k
1283
+ if isinstance(v, dict):
1284
+ _flatten_dict(v, kk, out)
1285
+ else:
1286
+ out[kk] = v
1287
+ return out
1288
+
1289
+
1290
+ @dataclass
1291
+ class IOConfigSupervised:
1292
+ """I/O, logging, and run identity.
1293
+
1294
+ This class contains configuration options for input/output operations, logging, and run identification.
1295
+
1296
+ Attributes:
1297
+ prefix (str): Prefix for output files and logs.
1298
+ seed (Optional[int]): Random seed for reproducibility.
1299
+ n_jobs (int): Number of parallel jobs to use. -1 uses all available cores.
1300
+ verbose (bool): Whether to enable verbose logging.
1301
+ debug (bool): Whether to enable debug mode with more detailed logs.
1302
+
1303
+ Notes:
1304
+ - The prefix is used to name output files and logs, helping to organize results from different runs.
1305
+ - Setting a random seed ensures that results are reproducible across different runs.
1306
+ - The number of jobs can be adjusted based on the available computational resources.
1307
+ - Verbose and debug modes provide additional logging information, which can be useful for troubleshooting.
1308
+ """
1309
+
1310
+ prefix: str = "pgsui"
1311
+ seed: Optional[int] = None
1312
+ n_jobs: int = 1
1313
+ verbose: bool = False
1314
+ debug: bool = False
1315
+
1316
+
1317
+ @dataclass
1318
+ class PlotConfigSupervised:
1319
+ """Plot/figure styling.
1320
+
1321
+ This class contains parameters for controlling the appearance of plots generated during the imputation process.
1322
+
1323
+ Attributes:
1324
+ fmt (Literal["pdf", "png", "jpg", "jpeg"]): File format
1325
+ for saving plots.
1326
+ dpi (int): Resolution in dots per inch for raster formats.
1327
+ fontsize (int): Base font size for plot text.
1328
+ despine (bool): Whether to remove top/right spines from plots.
1329
+ show (bool): Whether to display plots interactively.
1330
+
1331
+ Notes:
1332
+ - Supported formats: "pdf", "png", "jpg", "jpeg".
1333
+ - Higher DPI values yield better quality in raster images.
1334
+ - Despining is a common aesthetic choice for cleaner plots.
1335
+ """
1336
+
1337
+ fmt: Literal["pdf", "png", "jpg", "jpeg"] = "pdf"
1338
+ dpi: int = 300
1339
+ fontsize: int = 18
1340
+ despine: bool = True
1341
+ show: bool = False
1342
+
1343
+
1344
+ @dataclass
1345
+ class TrainConfigSupervised:
1346
+ """Training/evaluation split (by samples).
1347
+
1348
+ This class contains configuration options for splitting the dataset into training and validation sets during the training process.
1349
+
1350
+ Attributes:
1351
+ validation_split (float): Proportion of data to use for validation.
1352
+
1353
+ Notes:
1354
+ - Value should be between 0.0 and 1.0.
1355
+ """
1356
+
1357
+ validation_split: float = 0.20
1358
+
1359
+ def __post_init__(self):
1360
+ """Validate that validation_split is between 0.0 and 1.0."""
1361
+ if not (0.0 < self.validation_split < 1.0):
1362
+ raise ValueError("validation_split must be between 0.0 and 1.0")
1363
+
1364
+
1365
+ @dataclass
1366
+ class ImputerConfigSupervised:
1367
+ """IterativeImputer-like scaffolding used by current supervised wrappers.
1368
+
1369
+ This class contains configuration options for the imputation process, specifically for iterative imputation methods.
1370
+
1371
+ Attributes:
1372
+ n_nearest_features (Optional[int]): Number of nearest features to use
1373
+ for imputation. If None, all features are used.
1374
+ max_iter (int): Maximum number of imputation iterations to perform.
1375
+
1376
+ Notes:
1377
+ - n_nearest_features can help speed up imputation by limiting the number of features considered.
1378
+ - max_iter controls how many times the imputation process is repeated to refine estimates.
1379
+ - If n_nearest_features is None, the imputer will consider all features for each missing value.
1380
+ - Default max_iter is set to 10, which is typically sufficient for convergence.
1381
+ - Iterative imputation can be computationally intensive; consider adjusting n_nearest_features for large datasets.
1382
+ """
1383
+
1384
+ n_nearest_features: Optional[int] = 10
1385
+ max_iter: int = 10
1386
+
1387
+
1388
+ @dataclass
1389
+ class SimConfigSupervised:
1390
+ """Simulation of missingness for evaluation.
1391
+
1392
+ This class contains configuration options for simulating missing data during the evaluation process.
1393
+
1394
+ Attributes:
1395
+ prop_missing (float): Proportion of features to randomly set as missing.
1396
+ strategy (Literal["random", "random_inv_genotype"]): Strategy for generating missingness.
1397
+ het_boost (float): Boosting factor for heterogeneity in missingness.
1398
+ missing_val (int): Internal code for missing genotypes (e.g., -1).
1399
+
1400
+ Notes:
1401
+ - The choice of strategy can affect the realism of the missing data simulation.
1402
+ - Heterogeneous missingness can be useful for testing model robustness.
1403
+ """
1404
+
1405
+ prop_missing: float = 0.5
1406
+ strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
1407
+ het_boost: float = 2.0
1408
+ missing_val: int = -1 # internal use; your wrappers expect -1
1409
+
1410
+
1411
+ @dataclass
1412
+ class TuningConfigSupervised:
1413
+ """Optuna tuning envelope (kept for parity with unsupervised)."""
1414
+
1415
+ enabled: bool = True
1416
+ n_trials: int = 100
1417
+ metric: str = "pr_macro"
1418
+ n_jobs: int = 8 # for parallel eval (model-dependent)
1419
+ fast: bool = True # placeholder—trees don't need it but kept for consistency
1420
+
1421
+
1422
+ @dataclass
1423
+ class RFModelConfig:
1424
+ """Random Forest hyperparameters.
1425
+
1426
+ This class contains configuration options for the Random Forest model used in imputation.
1427
+
1428
+ Attributes:
1429
+ n_estimators (int): Number of trees in the forest.
1430
+ max_depth (Optional[int]): Maximum depth of the trees. If None, nodes are expanded until all leaves are pure or contain less than min_samples_leaf samples.
1431
+ min_samples_split (int): Minimum number of samples required to split an internal node.
1432
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
1433
+ max_features (Literal["sqrt", "log2"] | float | None): Number of features to consider when looking for the best split.
1434
+ criterion (Literal["gini", "entropy", "log_loss"]): Function to measure the quality of a split.
1435
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes. If "balanced", the class weights will be adjusted inversely proportional to class frequencies in the input data. If "balanced_subsample", the weights will be adjusted based on the bootstrap sample for each tree. If None, all classes will have weight of 1.0.
1436
+ """
1437
+
1438
+ n_estimators: int = 100
1439
+ max_depth: Optional[int] = None
1440
+ min_samples_split: int = 2
1441
+ min_samples_leaf: int = 1
1442
+ max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
1443
+ criterion: Literal["gini", "entropy", "log_loss"] = "gini"
1444
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
1445
+
1446
+
1447
+ @dataclass
1448
+ class HGBModelConfig:
1449
+ """Histogram-based Gradient Boosting hyperparameters.
1450
+
1451
+ This class contains configuration options for the Histogram-based Gradient Boosting (HGB) model used in imputation.
1452
+
1453
+ Attributes:
1454
+ n_estimators (int): Number of boosting iterations.
1455
+ learning_rate (float): Step size for each boosting iteration.
1456
+ max_depth (Optional[int]): Maximum depth of each tree. If None, nodes are expanded until all leaves are pure or contain less than min_samples_leaf samples.
1457
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
1458
+ max_features (float | None): Proportion of features to consider when looking for the best split. If None, all features are considered.
1459
+ n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping.
1460
+ tol (float): Minimum improvement in the loss to qualify as an improvement.
1461
+
1462
+ Notes:
1463
+ - These parameters control the complexity and learning behavior of the HGB model.
1464
+ - Early stopping is implemented to prevent overfitting.
1465
+ - The choice of criterion affects how the quality of a split is measured.
1466
+ - The model is sensitive to the learning_rate; smaller values require more estimators.
1467
+ - max_features can be set to a float between 0.0 and 1.0 to use a proportion of features.
1468
+ - Early stopping is driven by ``n_iter_no_change / tol``; sklearn controls randomness via random_state.
1469
+ """
1470
+
1471
+ # sklearn.HistGradientBoostingClassifier uses 'max_iter'
1472
+ # as number of boosting iterations
1473
+ # instead of 'n_estimators'.
1474
+ n_estimators: int = 100 # maps to max_iter
1475
+ learning_rate: float = 0.1
1476
+ max_depth: Optional[int] = None
1477
+ min_samples_leaf: int = 1
1478
+ max_features: float | None = 1.0
1479
+ n_iter_no_change: int = 10
1480
+ tol: float = 1e-7
1481
+
1482
+ def __post_init__(self) -> None:
1483
+ """Validate max_features if it's a float.
1484
+
1485
+ This method checks if the `max_features` attribute is a float and ensures that it falls within the valid range (0.0, 1.0]. It also validates that `n_estimators` is a positive integer.
1486
+ """
1487
+ if isinstance(self.max_features, float):
1488
+ if not (0.0 < self.max_features <= 1.0):
1489
+ raise ValueError("max_features as float must be in (0.0, 1.0]")
1490
+
1491
+ if self.n_estimators <= 0:
1492
+ raise ValueError("n_estimators must be a positive integer")
1493
+
1494
+
1495
+ @dataclass
1496
+ class RFConfig:
1497
+ """Configuration for ImputeRandomForest.
1498
+
1499
+ This dataclass mirrors the legacy ``__init__`` signature while supporting presets, YAML loading, and dot-key overrides. Use ``to_imputer_kwargs()`` to call the current constructor, or refactor the imputer to accept ``config: RFConfig``.
1500
+
1501
+ Attributes:
1502
+ io (IOConfigSupervised): Run identity, logging, and seeds.
1503
+ model (RFModelConfig): RandomForest hyperparameters.
1504
+ train (TrainConfigSupervised): Sample split for validation.
1505
+ imputer (ImputerConfigSupervised): IterativeImputer scaffolding (neighbors/iters).
1506
+ sim (SimConfigSupervised): Simulated missingness used during evaluation.
1507
+ plot (PlotConfigSupervised): Plot styling and export options.
1508
+ tune (TuningConfigSupervised): Optuna knobs (not required by RF itself).
1509
+ """
1510
+
1511
+ io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
1512
+ model: RFModelConfig = field(default_factory=RFModelConfig)
1513
+ train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
1514
+ imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
1515
+ sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
1516
+ plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
1517
+ tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
1518
+
1519
+ @classmethod
1520
+ def from_preset(
1521
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1522
+ ) -> "RFConfig":
1523
+ """Build a config from a named preset.
1524
+
1525
+ This method allows for easy construction of an RFConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
1526
+
1527
+ Args:
1528
+ preset: One of {"fast", "balanced", "thorough"}.
1529
+ - fast: Quick baseline; fewer trees; fewer imputer iters.
1530
+ - balanced: Balances speed and model performance; moderate trees and imputer iters.
1531
+ - thorough: Prioritizes model performance; more trees; more imputer iters.
1532
+
1533
+ Returns:
1534
+ RFConfig: Config with preset values applied.
1535
+ """
1536
+ cfg = cls()
1537
+ if preset == "fast":
1538
+ cfg.model.n_estimators = 50
1539
+ cfg.model.max_depth = None
1540
+ cfg.imputer.max_iter = 5
1541
+ cfg.io.n_jobs = 1
1542
+ cfg.tune.enabled = False
1543
+ elif preset == "balanced":
1544
+ cfg.model.n_estimators = 100
1545
+ cfg.model.max_depth = None
1546
+ cfg.imputer.max_iter = 10
1547
+ cfg.io.n_jobs = 1
1548
+ cfg.tune.enabled = False
1549
+ cfg.tune.n_trials = 100
1550
+ elif preset == "thorough":
1551
+ cfg.model.n_estimators = 500
1552
+ cfg.model.max_depth = None
1553
+ cfg.imputer.max_iter = 15
1554
+ cfg.io.n_jobs = 1
1555
+ cfg.tune.enabled = False
1556
+ cfg.tune.n_trials = 250
1557
+ else:
1558
+ raise ValueError(f"Unknown preset: {preset}")
1559
+
1560
+ return cfg
1561
+
1562
+ @classmethod
1563
+ def from_yaml(cls, path: str) -> "RFConfig":
1564
+ """Load from YAML; honors optional top-level 'preset' then merges keys.
1565
+
1566
+ This method allows for easy construction of an RFConfig instance from a YAML file, with support for presets. If the YAML file specifies a top-level 'preset', the corresponding preset values are applied first, and then any additional keys in the YAML file override those preset values.
1567
+
1568
+ Args:
1569
+ path (str): Path to the YAML configuration file.
1570
+
1571
+ Returns:
1572
+ RFConfig: Config instance populated from the YAML file.
1573
+ """
1574
+ return load_yaml_to_dataclass(path, cls, preset_builder=cls.from_preset)
1575
+
1576
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RFConfig":
1577
+ """Apply flat dot-key overrides (e.g., {'model.n_estimators': 500}).
1578
+
1579
+ This method allows for easy application of overrides to the config instance using a flat dictionary structure.
1580
+
1581
+ Args:
1582
+ overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
1583
+
1584
+ Returns:
1585
+ RFConfig: This instance after applying overrides.
1586
+ """
1587
+ if overrides:
1588
+ apply_dot_overrides(self, overrides)
1589
+ return self
1590
+
1591
+ def to_dict(self) -> Dict[str, Any]:
1592
+ """Return as nested dictionary.
1593
+
1594
+ This method converts the config instance into a nested dictionary format, which can be useful for serialization or inspection.
1595
+
1596
+ Returns:
1597
+ Dict[str, Any]: The config as a nested dictionary.
1598
+ """
1599
+ return asdict(self)
1600
+
1601
+ def to_imputer_kwargs(self) -> Dict[str, Any]:
1602
+ """Map config fields to current ImputeRandomForest ``__init__`` kwargs.
1603
+
1604
+ This method extracts relevant configuration fields and maps them to keyword arguments suitable for initializing the ImputeRandomForest class.
1605
+
1606
+ Returns:
1607
+ Dict[str, Any]: kwargs compatible with ImputeRandomForest(..., \*\*kwargs).
1608
+ """
1609
+ return {
1610
+ # General
1611
+ "prefix": self.io.prefix,
1612
+ "seed": self.io.seed,
1613
+ "n_jobs": self.io.n_jobs,
1614
+ "verbose": self.io.verbose,
1615
+ "debug": self.io.debug,
1616
+ # Model hyperparameters
1617
+ "model_n_estimators": self.model.n_estimators,
1618
+ "model_max_depth": self.model.max_depth,
1619
+ "model_min_samples_split": self.model.min_samples_split,
1620
+ "model_min_samples_leaf": self.model.min_samples_leaf,
1621
+ "model_max_features": self.model.max_features,
1622
+ "model_criterion": self.model.criterion,
1623
+ "model_validation_split": self.train.validation_split,
1624
+ "model_n_nearest_features": self.imputer.n_nearest_features,
1625
+ "model_max_iter": self.imputer.max_iter,
1626
+ # Simulation
1627
+ "sim_prop_missing": self.sim.prop_missing,
1628
+ "sim_strategy": self.sim.strategy,
1629
+ "sim_het_boost": self.sim.het_boost,
1630
+ # Plotting
1631
+ "plot_format": self.plot.fmt,
1632
+ "plot_fontsize": self.plot.fontsize,
1633
+ "plot_despine": self.plot.despine,
1634
+ "plot_dpi": self.plot.dpi,
1635
+ "plot_show_plots": self.plot.show,
1636
+ }
1637
+
1638
+
1639
+ @dataclass
1640
+ class HGBConfig:
1641
+ """Configuration for ImputeHistGradientBoosting.
1642
+
1643
+ Mirrors the legacy __init__ signature and provides presets/YAML/overrides.
1644
+ Use `to_imputer_kwargs()` now, or switch the imputer to accept `config: HGBConfig`.
1645
+
1646
+ Attributes:
1647
+ io (IOConfigSupervised): Run identity, logging, and seeds.
1648
+ model (HGBModelConfig): HistGradientBoosting hyperparameters.
1649
+ train (TrainConfigSupervised): Sample split for validation.
1650
+ imputer (ImputerConfigSupervised): IterativeImputer scaffolding (neighbors/iters).
1651
+ sim (SimConfigSupervised): Simulated missingness used during evaluation.
1652
+ plot (PlotConfigSupervised): Plot styling and export options.
1653
+ tune (TuningConfigSupervised): Optuna knobs (not required by HGB itself).
1654
+ """
1655
+
1656
+ io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
1657
+ model: HGBModelConfig = field(default_factory=HGBModelConfig)
1658
+ train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
1659
+ imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
1660
+ sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
1661
+ plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
1662
+ tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
1663
+
1664
+ @classmethod
1665
+ def from_preset(
1666
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1667
+ ) -> "HGBConfig":
1668
+ """Build a config from a named preset.
1669
+
1670
+ This class method allows for easy construction of a HGBConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
1671
+
1672
+ Args:
1673
+ preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}. fast: Quick baseline; fewer trees; fewer imputer iters. balanced: Balances speed and model performance; moderate trees and imputer iters. thorough: Prioritizes model performance; more trees; more imputer iterations.
1674
+
1675
+ Returns:
1676
+ HGBConfig: Config with preset values applied.
1677
+ """
1678
+ cfg = cls()
1679
+ if preset == "fast":
1680
+ cfg.model.n_estimators = 50
1681
+ cfg.model.learning_rate = 0.15
1682
+ cfg.model.max_depth = None
1683
+ cfg.imputer.max_iter = 5
1684
+ cfg.io.n_jobs = 1
1685
+ cfg.tune.enabled = False
1686
+ cfg.tune.n_trials = 50
1687
+ elif preset == "balanced":
1688
+ cfg.model.n_estimators = 100
1689
+ cfg.model.learning_rate = 0.1
1690
+ cfg.model.max_depth = None
1691
+ cfg.imputer.max_iter = 10
1692
+ cfg.io.n_jobs = 1
1693
+ cfg.tune.enabled = False
1694
+ cfg.tune.n_trials = 100
1695
+ elif preset == "thorough":
1696
+ cfg.model.n_estimators = 500
1697
+ cfg.model.learning_rate = 0.08
1698
+ cfg.model.max_depth = None
1699
+ cfg.imputer.max_iter = 15
1700
+ cfg.io.n_jobs = 1
1701
+ cfg.tune.enabled = False
1702
+ cfg.tune.n_trials = 250
1703
+ else:
1704
+ raise ValueError(f"Unknown preset: {preset}")
1705
+ return cfg
1706
+
1707
+ @classmethod
1708
+ def from_yaml(cls, path: str) -> "HGBConfig":
1709
+ """Load from YAML; honors optional top-level 'preset' then merges keys.
1710
+
1711
+ This method allows for easy construction of a HGBConfig instance from a YAML file, with support for presets. If the YAML file specifies a top-level 'preset', the corresponding preset values are applied first, and then any additional keys in the YAML file override those preset values.
1712
+
1713
+ Args:
1714
+ path (str): Path to the YAML configuration file.
1715
+
1716
+ Returns:
1717
+ HGBConfig: Config instance populated from the YAML file.
1718
+ """
1719
+ return load_yaml_to_dataclass(path, cls, preset_builder=cls.from_preset)
1720
+
1721
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "HGBConfig":
1722
+ """Apply flat dot-key overrides (e.g., {'model.learning_rate': 0.05}).
1723
+
1724
+ This method allows for easy application of overrides to the configuration fields using a flat dot-key notation.
1725
+
1726
+ Args:
1727
+ overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
1728
+
1729
+ Returns:
1730
+ HGBConfig: This instance after applying overrides.
1731
+ """
1732
+ if overrides:
1733
+ apply_dot_overrides(self, overrides)
1734
+ return self
1735
+
1736
+ def to_dict(self) -> Dict[str, Any]:
1737
+ """Return as nested dict.
1738
+
1739
+ This method converts the configuration instance into a nested dictionary format, which can be useful for serialization or inspection.
1740
+
1741
+ Returns:
1742
+ Dict[str, Any]: The config as a nested dictionary.
1743
+ """
1744
+ return asdict(self)
1745
+
1746
+ def to_imputer_kwargs(self) -> Dict[str, Any]:
1747
+ """Map config fields to current ImputeHistGradientBoosting ``__init__`` kwargs.
1748
+
1749
+ This method maps the configuration fields to the keyword arguments expected by the ImputeHistGradientBoosting class.
1750
+
1751
+ Returns:
1752
+ Dict[str, Any]: kwargs compatible with ImputeHistGradientBoosting(..., \*\*kwargs).
1753
+ """
1754
+ return {
1755
+ # General
1756
+ "prefix": self.io.prefix,
1757
+ "seed": self.io.seed,
1758
+ "n_jobs": self.io.n_jobs,
1759
+ "verbose": self.io.verbose,
1760
+ "debug": self.io.debug,
1761
+ # Model hyperparameters (note the mapping to sklearn's HGB)
1762
+ "model_n_estimators": self.model.n_estimators, # -> max_iter
1763
+ "model_learning_rate": self.model.learning_rate,
1764
+ "model_n_iter_no_change": self.model.n_iter_no_change,
1765
+ "model_tol": self.model.tol,
1766
+ "model_max_depth": self.model.max_depth,
1767
+ "model_min_samples_leaf": self.model.min_samples_leaf,
1768
+ "model_max_features": self.model.max_features,
1769
+ "model_validation_split": self.train.validation_split,
1770
+ "model_n_nearest_features": self.imputer.n_nearest_features,
1771
+ "model_max_iter": self.imputer.max_iter,
1772
+ # Simulation
1773
+ "sim_prop_missing": self.sim.prop_missing,
1774
+ "sim_strategy": self.sim.strategy,
1775
+ "sim_het_boost": self.sim.het_boost,
1776
+ # Plotting
1777
+ "plot_format": self.plot.fmt,
1778
+ "plot_fontsize": self.plot.fontsize,
1779
+ "plot_despine": self.plot.despine,
1780
+ "plot_dpi": self.plot.dpi,
1781
+ "plot_show_plots": self.plot.show,
1782
+ }