pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -0,0 +1,1436 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Any, Dict, Literal, Optional, Sequence
5
+
6
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
7
+
8
+
9
+ @dataclass
10
+ class _SimParams:
11
+ """Container for simulation hyperparameters.
12
+
13
+ Attributes:
14
+ prop_missing (float): Proportion of missing values to simulate.
15
+ strategy (Literal["random", "random_inv_genotype"]): Strategy for simulating missing values.
16
+ missing_val (int | float): Value to represent missing data.
17
+ het_boost (float): Boost factor for heterozygous genotypes.
18
+ seed (int | None): Random seed for reproducibility.
19
+ """
20
+
21
+ prop_missing: float = 0.3
22
+ strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
23
+ missing_val: int | float = -1
24
+ het_boost: float = 2.0
25
+ seed: int | None = None
26
+
27
+ def to_dict(self) -> dict:
28
+ return asdict(self)
29
+
30
+
31
+ @dataclass
32
+ class _ImputerParams:
33
+ """Container for imputer hyperparameters.
34
+
35
+ Attributes:
36
+ n_nearest_features (int | None): Number of nearest features to consider for imputation.
37
+ max_iter (int): Maximum number of iterations for the imputation algorithm.
38
+ initial_strategy (Literal["mean", "median", "most_frequent", "constant"]): Strategy for initial imputation.
39
+ keep_empty_features (bool): Whether to keep features that are entirely missing.
40
+ random_state (int | None): Random seed for reproducibility.
41
+ verbose (bool): If True, enables verbose logging during imputation.
42
+ """
43
+
44
+ n_nearest_features: int | None = 10
45
+ max_iter: int = 10
46
+ initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = (
47
+ "most_frequent"
48
+ )
49
+ keep_empty_features: bool = True
50
+ random_state: int | None = None
51
+ verbose: bool = False
52
+
53
+ def to_dict(self) -> dict:
54
+ return asdict(self)
55
+
56
+
57
+ @dataclass
58
+ class _RFParams:
59
+ """Container for RandomForest hyperparameters.
60
+
61
+ Attributes:
62
+ n_estimators (int): Number of trees in the forest.
63
+ max_depth (int | None): Maximum depth of the trees.
64
+ min_samples_split (int): Minimum number of samples required to split an internal node.
65
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
66
+ max_features (Literal["sqrt", "log2"] | float | None): Number of features to consider for split.
67
+ criterion (Literal["gini", "entropy", "log_loss"]): Function to measure the quality of a split.
68
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
69
+ """
70
+
71
+ n_estimators: int = 300
72
+ max_depth: int | None = None
73
+ min_samples_split: int = 2
74
+ min_samples_leaf: int = 1
75
+ max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
76
+ criterion: Literal["gini", "entropy", "log_loss"] = "gini"
77
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
78
+
79
+ def to_dict(self) -> dict:
80
+ return asdict(self)
81
+
82
+
83
+ @dataclass
84
+ class _HGBParams:
85
+ """Container for HistGradientBoosting hyperparameters.
86
+
87
+ Attributes:
88
+ max_iter (int): Maximum number of iterations.
89
+ learning_rate (float): Learning rate shrinks the contribution of each tree.
90
+ max_depth (int | None): Maximum depth of the individual regression estimators.
91
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
92
+ n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping.
93
+ tol (float): Tolerance for the early stopping.
94
+ max_features (float | None): The fraction of features to consider when looking for the best split.
95
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
96
+ random_state (int | None): Random seed for reproducibility.
97
+ verbose (bool): If True, enables verbose logging during training.
98
+ """
99
+
100
+ max_iter: int = 100
101
+ learning_rate: float = 0.1
102
+ max_depth: int | None = None
103
+ min_samples_leaf: int = 1
104
+ n_iter_no_change: int = 10
105
+ tol: float = 1e-7
106
+ max_features: float | None = 1.0
107
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
108
+ random_state: int | None = None
109
+ verbose: bool = False
110
+
111
+ def to_dict(self) -> dict:
112
+ return asdict(self)
113
+
114
+
115
+ @dataclass
116
+ class ModelConfig:
117
+ """Model architecture configuration.
118
+
119
+ Attributes:
120
+ latent_init (Literal["random", "pca"]): Method for initializing the latent space.
121
+ latent_dim (int): Dimensionality of the latent space.
122
+ dropout_rate (float): Dropout rate for regularization.
123
+ num_hidden_layers (int): Number of hidden layers in the neural network.
124
+ hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
125
+ layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
126
+ layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
127
+ gamma (float): Parameter for the focal loss function.
128
+ """
129
+
130
+ latent_init: Literal["random", "pca"] = "random"
131
+ latent_dim: int = 2
132
+ dropout_rate: float = 0.2
133
+ num_hidden_layers: int = 2
134
+ hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
135
+ layer_scaling_factor: float = 5.0
136
+ layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
137
+ gamma: float = 2.0
138
+
139
+
140
+ @dataclass
141
+ class TrainConfig:
142
+ """Training procedure configuration.
143
+
144
+ Attributes:
145
+ batch_size (int): Number of samples per training batch.
146
+ learning_rate (float): Learning rate for the optimizer.
147
+ lr_input_factor (float): Factor to scale the learning rate for input layer.
148
+ l1_penalty (float): L1 regularization penalty.
149
+ early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
150
+ min_epochs (int): Minimum number of epochs to train.
151
+ max_epochs (int): Maximum number of epochs to train.
152
+ validation_split (float): Proportion of data to use for validation.
153
+ weights_beta (float): Smoothing factor for class weights.
154
+ weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
155
+ device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
156
+ """
157
+
158
+ batch_size: int = 32
159
+ learning_rate: float = 1e-3
160
+ lr_input_factor: float = 1.0
161
+ l1_penalty: float = 0.0
162
+ early_stop_gen: int = 20
163
+ min_epochs: int = 100
164
+ max_epochs: int = 5000
165
+ validation_split: float = 0.2
166
+ weights_beta: float = 0.9999
167
+ weights_max_ratio: float = 1.0
168
+ device: Literal["gpu", "cpu", "mps"] = "cpu"
169
+
170
+
171
+ @dataclass
172
+ class TuneConfig:
173
+ """Hyperparameter tuning configuration.
174
+
175
+ Attributes:
176
+ enabled (bool): If True, enables hyperparameter tuning.
177
+ metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
178
+ n_trials (int): Number of hyperparameter trials to run.
179
+ resume (bool): If True, resumes tuning from a previous state.
180
+ save_db (bool): If True, saves the tuning results to a database.
181
+ fast (bool): If True, uses a faster but less thorough tuning approach.
182
+ max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
183
+ max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
184
+ epochs (int): Number of epochs to train each trial.
185
+ batch_size (int): Batch size for training during tuning.
186
+ eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
187
+ infer_epochs (int): Number of epochs for inference during tuning.
188
+ patience (int): Number of evaluations with no improvement before stopping early.
189
+ proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
190
+ """
191
+
192
+ enabled: bool = False
193
+ metric: Literal[
194
+ "f1",
195
+ "accuracy",
196
+ "pr_macro",
197
+ "average_precision",
198
+ "roc_auc",
199
+ "precision",
200
+ "recall",
201
+ ] = "f1"
202
+ n_trials: int = 100
203
+ resume: bool = False
204
+ save_db: bool = False
205
+ fast: bool = True
206
+ max_samples: int = 512
207
+ max_loci: int = 0 # 0 = all
208
+ epochs: int = 500
209
+ batch_size: int = 64
210
+ eval_interval: int = 20
211
+ infer_epochs: int = 100
212
+ patience: int = 10
213
+ proxy_metric_batch: int = 0
214
+
215
+
216
+ @dataclass
217
+ class EvalConfig:
218
+ """Evaluation configuration.
219
+
220
+ Attributes:
221
+ eval_latent_steps (int): Number of optimization steps for latent space evaluation.
222
+ eval_latent_lr (float): Learning rate for latent space optimization.
223
+ eval_latent_weight_decay (float): Weight decay for latent space optimization.
224
+ """
225
+
226
+ eval_latent_steps: int = 50
227
+ eval_latent_lr: float = 1e-2
228
+ eval_latent_weight_decay: float = 0.0
229
+
230
+
231
+ @dataclass
232
+ class PlotConfig:
233
+ """Plotting configuration.
234
+
235
+ Attributes:
236
+ fmt (Literal["pdf", "png", "jpg", "jpeg", "svg"]): Output file format.
237
+ dpi (int): Dots per inch for the output figure.
238
+ fontsize (int): Font size for text in the plots.
239
+ despine (bool): If True, removes the top and right spines from plots.
240
+ show (bool): If True, displays the plot interactively.
241
+ """
242
+
243
+ fmt: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
244
+ dpi: int = 300
245
+ fontsize: int = 18
246
+ despine: bool = True
247
+ show: bool = False
248
+
249
+
250
+ @dataclass
251
+ class IOConfig:
252
+ """I/O configuration.
253
+
254
+ Attributes:
255
+ prefix (str): Prefix for output files. Default is "pgsui".
256
+ verbose (bool): If True, enables verbose logging. Default is False.
257
+ debug (bool): If True, enables debug mode. Default is False.
258
+ seed (int | None): Random seed for reproducibility. Default is None.
259
+ n_jobs (int): Number of parallel jobs to run. Default is 1.
260
+ scoring_averaging (Literal["macro", "micro", "weighted"]): Averaging method.
261
+ """
262
+
263
+ prefix: str = "pgsui"
264
+ verbose: bool = False
265
+ debug: bool = False
266
+ seed: int | None = None
267
+ n_jobs: int = 1
268
+ scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
269
+
270
+
271
+ @dataclass
272
+ class SimConfig:
273
+ """Top-level configuration for data simulation and imputation.
274
+
275
+ Attributes:
276
+ simulate_missing (bool): If True, simulates missing data.
277
+ sim_strategy (Literal["random", ...]): Strategy for simulating missing data.
278
+ sim_prop (float): Proportion of data to simulate as missing.
279
+ sim_kwargs (dict | None): Additional keyword arguments for simulation.
280
+ """
281
+
282
+ simulate_missing: bool = False
283
+ sim_strategy: Literal[
284
+ "random",
285
+ "random_weighted",
286
+ "random_weighted_inv",
287
+ "nonrandom",
288
+ "nonrandom_weighted",
289
+ ] = "random"
290
+ sim_prop: float = 0.10
291
+ sim_kwargs: dict | None = None
292
+
293
+
294
+ @dataclass
295
+ class NLPCAConfig:
296
+ """Top-level configuration for ImputeNLPCA.
297
+
298
+ Attributes:
299
+ io (IOConfig): I/O configuration.
300
+ model (ModelConfig): Model architecture configuration.
301
+ train (TrainConfig): Training procedure configuration.
302
+ tune (TuneConfig): Hyperparameter tuning configuration.
303
+ evaluate (EvalConfig): Evaluation configuration.
304
+ plot (PlotConfig): Plotting configuration.
305
+ sim (SimConfig): Simulation configuration.
306
+ """
307
+
308
+ io: IOConfig = field(default_factory=IOConfig)
309
+ model: ModelConfig = field(default_factory=ModelConfig)
310
+ train: TrainConfig = field(default_factory=TrainConfig)
311
+ tune: TuneConfig = field(default_factory=TuneConfig)
312
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
313
+ plot: PlotConfig = field(default_factory=PlotConfig)
314
+ sim: SimConfig = field(default_factory=SimConfig)
315
+
316
+ @classmethod
317
+ def from_preset(
318
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
319
+ ) -> "NLPCAConfig":
320
+ """Build a NLPCAConfig from a named preset."""
321
+ if preset not in {"fast", "balanced", "thorough"}:
322
+ raise ValueError(f"Unknown preset: {preset}")
323
+
324
+ cfg = cls()
325
+
326
+ # Common baselines
327
+ cfg.io.verbose = False
328
+ cfg.train.validation_split = 0.20
329
+ cfg.model.hidden_activation = "relu"
330
+ cfg.model.layer_schedule = "pyramid"
331
+ cfg.model.latent_init = "random"
332
+ cfg.evaluate.eval_latent_lr = 1e-2
333
+ cfg.evaluate.eval_latent_weight_decay = 0.0
334
+ cfg.sim.simulate_missing = True
335
+ cfg.sim.sim_strategy = "random"
336
+ cfg.sim.sim_prop = 0.2
337
+
338
+ if preset == "fast":
339
+ # Model
340
+ cfg.model.latent_dim = 4
341
+ cfg.model.num_hidden_layers = 1
342
+ cfg.model.layer_scaling_factor = 2.0
343
+ cfg.model.dropout_rate = 0.10
344
+ cfg.model.gamma = 1.5
345
+ # Train
346
+ cfg.train.batch_size = 256
347
+ cfg.train.learning_rate = 2e-3
348
+ cfg.train.early_stop_gen = 5
349
+ cfg.train.min_epochs = 10
350
+ cfg.train.max_epochs = 150
351
+ cfg.train.weights_beta = 0.999
352
+ cfg.train.weights_max_ratio = 5.0
353
+ # Tuning
354
+ cfg.tune.enabled = True
355
+ cfg.tune.fast = True
356
+ cfg.tune.n_trials = 20
357
+ cfg.tune.epochs = 150
358
+ cfg.tune.batch_size = 256
359
+ cfg.tune.max_samples = 512
360
+ cfg.tune.max_loci = 0
361
+ cfg.tune.eval_interval = 20
362
+ cfg.tune.infer_epochs = 20
363
+ cfg.tune.patience = 5
364
+ cfg.tune.proxy_metric_batch = 0
365
+ # Eval
366
+ cfg.evaluate.eval_latent_steps = 20
367
+
368
+ elif preset == "balanced":
369
+ # Model
370
+ cfg.model.latent_dim = 8
371
+ cfg.model.num_hidden_layers = 2
372
+ cfg.model.layer_scaling_factor = 3.0
373
+ cfg.model.dropout_rate = 0.20
374
+ cfg.model.gamma = 2.0
375
+ # Train
376
+ cfg.train.batch_size = 128
377
+ cfg.train.learning_rate = 1e-3
378
+ cfg.train.early_stop_gen = 15
379
+ cfg.train.min_epochs = 50
380
+ cfg.train.max_epochs = 600
381
+ cfg.train.weights_beta = 0.9999
382
+ cfg.train.weights_max_ratio = 5.0
383
+ # Tuning
384
+ cfg.tune.enabled = True
385
+ cfg.tune.fast = False
386
+ cfg.tune.n_trials = 60
387
+ cfg.tune.epochs = 200
388
+ cfg.tune.batch_size = 128
389
+ cfg.tune.max_samples = 2048
390
+ cfg.tune.max_loci = 0
391
+ cfg.tune.eval_interval = 10
392
+ cfg.tune.infer_epochs = 50
393
+ cfg.tune.patience = 10
394
+ cfg.tune.proxy_metric_batch = 0
395
+ # Eval
396
+ cfg.evaluate.eval_latent_steps = 40
397
+
398
+ else: # thorough
399
+ # Model
400
+ cfg.model.latent_dim = 16
401
+ cfg.model.num_hidden_layers = 3
402
+ cfg.model.layer_scaling_factor = 5.0
403
+ cfg.model.dropout_rate = 0.30
404
+ cfg.model.gamma = 2.5
405
+ # Train
406
+ cfg.train.batch_size = 64
407
+ cfg.train.learning_rate = 5e-4
408
+ cfg.train.early_stop_gen = 30
409
+ cfg.train.min_epochs = 100
410
+ cfg.train.max_epochs = 2000
411
+ cfg.train.weights_beta = 0.9999
412
+ cfg.train.weights_max_ratio = 5.0
413
+ # Tuning
414
+ cfg.tune.enabled = True
415
+ cfg.tune.fast = False # Full search
416
+ cfg.tune.n_trials = 100
417
+ cfg.tune.epochs = 600
418
+ cfg.tune.batch_size = 64
419
+ cfg.tune.max_samples = 0 # No limit
420
+ cfg.tune.max_loci = 0
421
+ cfg.tune.eval_interval = 10
422
+ cfg.tune.infer_epochs = 80
423
+ cfg.tune.patience = 20
424
+ cfg.tune.proxy_metric_batch = 0
425
+ # Eval
426
+ cfg.evaluate.eval_latent_steps = 100
427
+
428
+ return cfg
429
+
430
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
431
+ """Apply flat dot-key overrides."""
432
+ if not overrides:
433
+ return self
434
+ for k, v in overrides.items():
435
+ node = self
436
+ parts = k.split(".")
437
+ for p in parts[:-1]:
438
+ node = getattr(node, p)
439
+ last = parts[-1]
440
+ if hasattr(node, last):
441
+ setattr(node, last, v)
442
+ else:
443
+ raise KeyError(f"Unknown config key: {k}")
444
+ return self
445
+
446
+ def to_dict(self) -> Dict[str, Any]:
447
+ return asdict(self)
448
+
449
+
450
+ @dataclass
451
+ class UBPConfig:
452
+ """Top-level configuration for ImputeUBP.
453
+
454
+ Attributes:
455
+ io (IOConfig): I/O configuration.
456
+ model (ModelConfig): Model architecture configuration.
457
+ train (TrainConfig): Training procedure configuration.
458
+ tune (TuneConfig): Hyperparameter tuning configuration.
459
+ evaluate (EvalConfig): Evaluation configuration.
460
+ plot (PlotConfig): Plotting configuration.
461
+ sim (SimConfig): Simulated-missing configuration.
462
+ """
463
+
464
+ io: IOConfig = field(default_factory=IOConfig)
465
+ model: ModelConfig = field(default_factory=ModelConfig)
466
+ train: TrainConfig = field(default_factory=TrainConfig)
467
+ tune: TuneConfig = field(default_factory=TuneConfig)
468
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
469
+ plot: PlotConfig = field(default_factory=PlotConfig)
470
+ sim: SimConfig = field(default_factory=SimConfig)
471
+
472
+ @classmethod
473
+ def from_preset(
474
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
475
+ ) -> "UBPConfig":
476
+ """Build a UBPConfig from a named preset."""
477
+ if preset not in {"fast", "balanced", "thorough"}:
478
+ raise ValueError(f"Unknown preset: {preset}")
479
+
480
+ cfg = cls()
481
+
482
+ # Common baselines
483
+ cfg.io.verbose = False
484
+ cfg.model.hidden_activation = "relu"
485
+ cfg.model.layer_schedule = "pyramid"
486
+ cfg.model.latent_init = "random"
487
+ cfg.sim.simulate_missing = True
488
+ cfg.sim.sim_strategy = "random"
489
+ cfg.sim.sim_prop = 0.2
490
+
491
+ if preset == "fast":
492
+ # Model
493
+ cfg.model.latent_dim = 4
494
+ cfg.model.num_hidden_layers = 1
495
+ cfg.model.layer_scaling_factor = 2.0
496
+ cfg.model.dropout_rate = 0.10
497
+ cfg.model.gamma = 1.5
498
+ # Train
499
+ cfg.train.batch_size = 256
500
+ cfg.train.learning_rate = 2e-3
501
+ cfg.train.early_stop_gen = 5
502
+ cfg.train.min_epochs = 10
503
+ cfg.train.max_epochs = 150
504
+ cfg.train.weights_beta = 0.999
505
+ cfg.train.weights_max_ratio = 5.0
506
+ # Tuning
507
+ cfg.tune.enabled = True
508
+ cfg.tune.fast = True
509
+ cfg.tune.n_trials = 20
510
+ cfg.tune.epochs = 150
511
+ cfg.tune.batch_size = 256
512
+ cfg.tune.max_samples = 512
513
+ cfg.tune.max_loci = 0
514
+ cfg.tune.eval_interval = 20
515
+ cfg.tune.infer_epochs = 20
516
+ cfg.tune.patience = 5
517
+ cfg.tune.proxy_metric_batch = 0
518
+ # Eval
519
+ cfg.evaluate.eval_latent_steps = 20
520
+ cfg.evaluate.eval_latent_lr = 1e-2
521
+ cfg.evaluate.eval_latent_weight_decay = 0.0
522
+
523
+ elif preset == "balanced":
524
+ # Model
525
+ cfg.model.latent_dim = 8
526
+ cfg.model.num_hidden_layers = 2
527
+ cfg.model.layer_scaling_factor = 3.0
528
+ cfg.model.dropout_rate = 0.20
529
+ cfg.model.gamma = 2.0
530
+ # Train
531
+ cfg.train.batch_size = 128
532
+ cfg.train.learning_rate = 1e-3
533
+ cfg.train.early_stop_gen = 15
534
+ cfg.train.min_epochs = 50
535
+ cfg.train.max_epochs = 600
536
+ cfg.train.weights_beta = 0.9999
537
+ cfg.train.weights_max_ratio = 5.0
538
+ # Tuning
539
+ cfg.tune.enabled = True
540
+ cfg.tune.fast = False
541
+ cfg.tune.n_trials = 60
542
+ cfg.tune.epochs = 200
543
+ cfg.tune.batch_size = 128
544
+ cfg.tune.max_samples = 2048
545
+ cfg.tune.max_loci = 0
546
+ cfg.tune.eval_interval = 10
547
+ cfg.tune.infer_epochs = 50
548
+ cfg.tune.patience = 10
549
+ cfg.tune.proxy_metric_batch = 0
550
+ # Eval
551
+ cfg.evaluate.eval_latent_steps = 40
552
+ cfg.evaluate.eval_latent_lr = 1e-2
553
+ cfg.evaluate.eval_latent_weight_decay = 0.0
554
+
555
+ else: # thorough
556
+ # Model
557
+ cfg.model.latent_dim = 16
558
+ cfg.model.num_hidden_layers = 3
559
+ cfg.model.layer_scaling_factor = 5.0
560
+ cfg.model.dropout_rate = 0.30
561
+ cfg.model.gamma = 2.5
562
+ # Train
563
+ cfg.train.batch_size = 64
564
+ cfg.train.learning_rate = 5e-4
565
+ cfg.train.early_stop_gen = 30
566
+ cfg.train.min_epochs = 100
567
+ cfg.train.max_epochs = 2000
568
+ cfg.train.weights_beta = 0.9999
569
+ cfg.train.weights_max_ratio = 5.0
570
+ # Tuning
571
+ cfg.tune.enabled = True
572
+ cfg.tune.fast = False
573
+ cfg.tune.n_trials = 100
574
+ cfg.tune.epochs = 600
575
+ cfg.tune.batch_size = 64
576
+ cfg.tune.max_samples = 0
577
+ cfg.tune.max_loci = 0
578
+ cfg.tune.eval_interval = 10
579
+ cfg.tune.infer_epochs = 80
580
+ cfg.tune.patience = 20
581
+ cfg.tune.proxy_metric_batch = 0
582
+ # Eval
583
+ cfg.evaluate.eval_latent_steps = 100
584
+ cfg.evaluate.eval_latent_lr = 1e-2
585
+ cfg.evaluate.eval_latent_weight_decay = 0.0
586
+
587
+ return cfg
588
+
589
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
590
+ """Apply flat dot-key overrides."""
591
+ if not overrides:
592
+ return self
593
+
594
+ for k, v in overrides.items():
595
+ node = self
596
+ parts = k.split(".")
597
+ for p in parts[:-1]:
598
+ node = getattr(node, p)
599
+ last = parts[-1]
600
+ if hasattr(node, last):
601
+ setattr(node, last, v)
602
+ else:
603
+ raise KeyError(f"Unknown config key: {k}")
604
+ return self
605
+
606
+ def to_dict(self) -> Dict[str, Any]:
607
+ return asdict(self)
608
+
609
+
610
+ @dataclass
611
+ class AutoencoderConfig:
612
+ """Top-level configuration for ImputeAutoencoder.
613
+
614
+ Attributes:
615
+ io (IOConfig): I/O configuration.
616
+ model (ModelConfig): Model architecture configuration.
617
+ train (TrainConfig): Training procedure configuration.
618
+ tune (TuneConfig): Hyperparameter tuning configuration.
619
+ evaluate (EvalConfig): Evaluation configuration.
620
+ plot (PlotConfig): Plotting configuration.
621
+ sim (SimConfig): Simulated-missing configuration.
622
+ """
623
+
624
+ io: IOConfig = field(default_factory=IOConfig)
625
+ model: ModelConfig = field(default_factory=ModelConfig)
626
+ train: TrainConfig = field(default_factory=TrainConfig)
627
+ tune: TuneConfig = field(default_factory=TuneConfig)
628
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
629
+ plot: PlotConfig = field(default_factory=PlotConfig)
630
+ sim: SimConfig = field(default_factory=SimConfig)
631
+
632
+ @classmethod
633
+ def from_preset(
634
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
635
+ ) -> "AutoencoderConfig":
636
+ """Build a AutoencoderConfig from a named preset."""
637
+ if preset not in {"fast", "balanced", "thorough"}:
638
+ raise ValueError(f"Unknown preset: {preset}")
639
+
640
+ cfg = cls()
641
+
642
+ # Common baselines (no latent refinement at eval)
643
+ cfg.io.verbose = False
644
+ cfg.train.validation_split = 0.20
645
+ cfg.model.hidden_activation = "relu"
646
+ cfg.model.layer_schedule = "pyramid"
647
+ cfg.evaluate.eval_latent_steps = 0
648
+ cfg.evaluate.eval_latent_lr = 0.0
649
+ cfg.evaluate.eval_latent_weight_decay = 0.0
650
+ cfg.sim.simulate_missing = True
651
+ cfg.sim.sim_strategy = "random"
652
+ cfg.sim.sim_prop = 0.2
653
+
654
+ if preset == "fast":
655
+ cfg.model.latent_dim = 4
656
+ cfg.model.num_hidden_layers = 1
657
+ cfg.model.layer_scaling_factor = 2.0
658
+ cfg.model.dropout_rate = 0.10
659
+ cfg.model.gamma = 1.5
660
+ cfg.train.batch_size = 256
661
+ cfg.train.learning_rate = 2e-3
662
+ cfg.train.early_stop_gen = 5
663
+ cfg.train.min_epochs = 10
664
+ cfg.train.max_epochs = 150
665
+ cfg.train.weights_beta = 0.999
666
+ cfg.train.weights_max_ratio = 5.0
667
+ cfg.tune.enabled = True
668
+ cfg.tune.fast = True
669
+ cfg.tune.n_trials = 20
670
+ cfg.tune.epochs = 150
671
+ cfg.tune.batch_size = 256
672
+ cfg.tune.max_samples = 512
673
+ cfg.tune.max_loci = 0
674
+ cfg.tune.eval_interval = 20
675
+ cfg.tune.patience = 5
676
+ cfg.tune.proxy_metric_batch = 0
677
+ if hasattr(cfg.tune, "infer_epochs"):
678
+ cfg.tune.infer_epochs = 0
679
+
680
+ elif preset == "balanced":
681
+ cfg.model.latent_dim = 8
682
+ cfg.model.num_hidden_layers = 2
683
+ cfg.model.layer_scaling_factor = 3.0
684
+ cfg.model.dropout_rate = 0.20
685
+ cfg.model.gamma = 2.0
686
+ cfg.train.batch_size = 128
687
+ cfg.train.learning_rate = 1e-3
688
+ cfg.train.early_stop_gen = 15
689
+ cfg.train.min_epochs = 50
690
+ cfg.train.max_epochs = 600
691
+ cfg.train.weights_beta = 0.9999
692
+ cfg.train.weights_max_ratio = 5.0
693
+ cfg.tune.enabled = True
694
+ cfg.tune.fast = False
695
+ cfg.tune.n_trials = 60
696
+ cfg.tune.epochs = 200
697
+ cfg.tune.batch_size = 128
698
+ cfg.tune.max_samples = 2048
699
+ cfg.tune.max_loci = 0
700
+ cfg.tune.eval_interval = 10
701
+ cfg.tune.patience = 10
702
+ cfg.tune.proxy_metric_batch = 0
703
+ if hasattr(cfg.tune, "infer_epochs"):
704
+ cfg.tune.infer_epochs = 0
705
+
706
+ else: # thorough
707
+ cfg.model.latent_dim = 16
708
+ cfg.model.num_hidden_layers = 3
709
+ cfg.model.layer_scaling_factor = 5.0
710
+ cfg.model.dropout_rate = 0.30
711
+ cfg.model.gamma = 2.5
712
+ cfg.train.batch_size = 64
713
+ cfg.train.learning_rate = 5e-4
714
+ cfg.train.early_stop_gen = 30
715
+ cfg.train.min_epochs = 100
716
+ cfg.train.max_epochs = 2000
717
+ cfg.train.weights_beta = 0.9999
718
+ cfg.train.weights_max_ratio = 5.0
719
+ cfg.tune.enabled = True
720
+ cfg.tune.fast = False
721
+ cfg.tune.n_trials = 100
722
+ cfg.tune.epochs = 600
723
+ cfg.tune.batch_size = 64
724
+ cfg.tune.max_samples = 0
725
+ cfg.tune.max_loci = 0
726
+ cfg.tune.eval_interval = 10
727
+ cfg.tune.patience = 20
728
+ cfg.tune.proxy_metric_batch = 0
729
+ if hasattr(cfg.tune, "infer_epochs"):
730
+ cfg.tune.infer_epochs = 0
731
+
732
+ return cfg
733
+
734
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
735
+ """Apply flat dot-key overrides."""
736
+ if not overrides:
737
+ return self
738
+ for k, v in overrides.items():
739
+ node = self
740
+ parts = k.split(".")
741
+ for p in parts[:-1]:
742
+ node = getattr(node, p)
743
+ last = parts[-1]
744
+ if hasattr(node, last):
745
+ setattr(node, last, v)
746
+ else:
747
+ raise KeyError(f"Unknown config key: {k}")
748
+ return self
749
+
750
+ def to_dict(self) -> Dict[str, Any]:
751
+ return asdict(self)
752
+
753
+
754
+ @dataclass
755
+ class VAEExtraConfig:
756
+ """VAE-specific knobs.
757
+
758
+ Attributes:
759
+ kl_beta (float): Final β for KL divergence term.
760
+ kl_warmup (int): Number of epochs with β=0 (warm-up period).
761
+ kl_ramp (int): Number of epochs for linear ramp to final β.
762
+ """
763
+
764
+ kl_beta: float = 1.0
765
+ kl_warmup: int = 50
766
+ kl_ramp: int = 200
767
+
768
+
769
+ @dataclass
770
+ class VAEConfig:
771
+ """Top-level configuration for ImputeVAE (AE-parity + VAE extras).
772
+
773
+ Attributes:
774
+ io (IOConfig): I/O configuration.
775
+ model (ModelConfig): Model architecture configuration.
776
+ train (TrainConfig): Training procedure configuration.
777
+ tune (TuneConfig): Hyperparameter tuning configuration.
778
+ evaluate (EvalConfig): Evaluation configuration.
779
+ plot (PlotConfig): Plotting configuration.
780
+ vae (VAEExtraConfig): VAE-specific configuration.
781
+ sim (SimConfig): Simulated-missing configuration.
782
+ """
783
+
784
+ io: IOConfig = field(default_factory=IOConfig)
785
+ model: ModelConfig = field(default_factory=ModelConfig)
786
+ train: TrainConfig = field(default_factory=TrainConfig)
787
+ tune: TuneConfig = field(default_factory=TuneConfig)
788
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
789
+ plot: PlotConfig = field(default_factory=PlotConfig)
790
+ vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
791
+ sim: SimConfig = field(default_factory=SimConfig)
792
+
793
+ @classmethod
794
+ def from_preset(
795
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
796
+ ) -> "VAEConfig":
797
+ """Build a VAEConfig from a named preset."""
798
+ if preset not in {"fast", "balanced", "thorough"}:
799
+ raise ValueError(f"Unknown preset: {preset}")
800
+
801
+ cfg = cls()
802
+
803
+ # Common baselines (match AE; no latent refinement at eval)
804
+ cfg.io.verbose = False
805
+ cfg.train.validation_split = 0.20
806
+ cfg.model.hidden_activation = "relu"
807
+ cfg.model.layer_schedule = "pyramid"
808
+ cfg.evaluate.eval_latent_steps = 0
809
+ cfg.evaluate.eval_latent_lr = 0.0
810
+ cfg.evaluate.eval_latent_weight_decay = 0.0
811
+ cfg.sim.simulate_missing = True
812
+ cfg.sim.sim_strategy = "random"
813
+ cfg.sim.sim_prop = 0.2
814
+
815
+ if preset == "fast":
816
+ cfg.model.latent_dim = 4
817
+ cfg.model.num_hidden_layers = 1
818
+ cfg.model.layer_scaling_factor = 2.0
819
+ cfg.model.dropout_rate = 0.10
820
+ cfg.model.gamma = 1.5
821
+ # VAE specifics
822
+ cfg.vae.kl_beta = 0.5
823
+ cfg.vae.kl_warmup = 10
824
+ cfg.vae.kl_ramp = 40
825
+ # Train
826
+ cfg.train.batch_size = 256
827
+ cfg.train.learning_rate = 2e-3
828
+ cfg.train.early_stop_gen = 5
829
+ cfg.train.min_epochs = 10
830
+ cfg.train.max_epochs = 150
831
+ cfg.train.weights_beta = 0.999
832
+ cfg.train.weights_max_ratio = 5.0
833
+ # Tune
834
+ cfg.tune.enabled = True
835
+ cfg.tune.fast = True
836
+ cfg.tune.n_trials = 20
837
+ cfg.tune.epochs = 150
838
+ cfg.tune.batch_size = 256
839
+ cfg.tune.max_samples = 512
840
+ cfg.tune.max_loci = 0
841
+ cfg.tune.eval_interval = 20
842
+ cfg.tune.patience = 5
843
+ cfg.tune.proxy_metric_batch = 0
844
+ if hasattr(cfg.tune, "infer_epochs"):
845
+ cfg.tune.infer_epochs = 0
846
+
847
+ elif preset == "balanced":
848
+ cfg.model.latent_dim = 8
849
+ cfg.model.num_hidden_layers = 2
850
+ cfg.model.layer_scaling_factor = 3.0
851
+ cfg.model.dropout_rate = 0.20
852
+ cfg.model.gamma = 2.0
853
+ # VAE specifics
854
+ cfg.vae.kl_beta = 1.0
855
+ cfg.vae.kl_warmup = 50
856
+ cfg.vae.kl_ramp = 150
857
+ # Train
858
+ cfg.train.batch_size = 128
859
+ cfg.train.learning_rate = 1e-3
860
+ cfg.train.early_stop_gen = 15
861
+ cfg.train.min_epochs = 50
862
+ cfg.train.max_epochs = 600
863
+ cfg.train.weights_beta = 0.9999
864
+ cfg.train.weights_max_ratio = 5.0
865
+ # Tune
866
+ cfg.tune.enabled = True
867
+ cfg.tune.fast = False
868
+ cfg.tune.n_trials = 60
869
+ cfg.tune.epochs = 200
870
+ cfg.tune.batch_size = 128
871
+ cfg.tune.max_samples = 2048
872
+ cfg.tune.max_loci = 0
873
+ cfg.tune.eval_interval = 10
874
+ cfg.tune.patience = 10
875
+ cfg.tune.proxy_metric_batch = 0
876
+ if hasattr(cfg.tune, "infer_epochs"):
877
+ cfg.tune.infer_epochs = 0
878
+
879
+ else: # thorough
880
+ cfg.model.latent_dim = 16
881
+ cfg.model.num_hidden_layers = 3
882
+ cfg.model.layer_scaling_factor = 5.0
883
+ cfg.model.dropout_rate = 0.30
884
+ cfg.model.gamma = 2.5
885
+ # VAE specifics
886
+ cfg.vae.kl_beta = 1.0
887
+ cfg.vae.kl_warmup = 100
888
+ cfg.vae.kl_ramp = 400
889
+ # Train
890
+ cfg.train.batch_size = 64
891
+ cfg.train.learning_rate = 5e-4
892
+ cfg.train.early_stop_gen = 30
893
+ cfg.train.min_epochs = 100
894
+ cfg.train.max_epochs = 2000
895
+ cfg.train.weights_beta = 0.9999
896
+ cfg.train.weights_max_ratio = 5.0
897
+ # Tune
898
+ cfg.tune.enabled = True
899
+ cfg.tune.fast = False
900
+ cfg.tune.n_trials = 100
901
+ cfg.tune.epochs = 600
902
+ cfg.tune.batch_size = 64
903
+ cfg.tune.max_samples = 0
904
+ cfg.tune.max_loci = 0
905
+ cfg.tune.eval_interval = 10
906
+ cfg.tune.patience = 20
907
+ cfg.tune.proxy_metric_batch = 0
908
+ if hasattr(cfg.tune, "infer_epochs"):
909
+ cfg.tune.infer_epochs = 0
910
+
911
+ return cfg
912
+
913
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "VAEConfig":
914
+ """Apply flat dot-key overrides."""
915
+ if not overrides:
916
+ return self
917
+ for k, v in overrides.items():
918
+ node = self
919
+ parts = k.split(".")
920
+ for p in parts[:-1]:
921
+ node = getattr(node, p)
922
+ last = parts[-1]
923
+ if hasattr(node, last):
924
+ setattr(node, last, v)
925
+ else:
926
+ raise KeyError(f"Unknown config key: {k}")
927
+ return self
928
+
929
+ def to_dict(self) -> Dict[str, Any]:
930
+ return asdict(self)
931
+
932
+
933
+ @dataclass
934
+ class MostFrequentAlgoConfig:
935
+ """Algorithmic knobs for ImputeMostFrequent.
936
+
937
+ Attributes:
938
+ by_populations (bool): Whether to compute per-population modes.
939
+ default (int): Fallback mode if no valid entries in a locus.
940
+ missing (int): Code for missing genotypes in 0/1/2.
941
+ """
942
+
943
+ by_populations: bool = False
944
+ default: int = 0
945
+ missing: int = -1
946
+
947
+
948
+ @dataclass
949
+ class DeterministicSplitConfig:
950
+ """Evaluation split configuration shared by deterministic imputers.
951
+
952
+ Attributes:
953
+ test_size (float): Proportion of data to use as the test set.
954
+ test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
955
+ """
956
+
957
+ test_size: float = 0.2
958
+ test_indices: Optional[Sequence[int]] = None
959
+
960
+
961
+ @dataclass
962
+ class MostFrequentConfig:
963
+ """Top-level configuration for ImputeMostFrequent.
964
+
965
+ Attributes:
966
+ io (IOConfig): I/O configuration.
967
+ plot (PlotConfig): Plotting configuration.
968
+ split (DeterministicSplitConfig): Data splitting configuration.
969
+ algo (MostFrequentAlgoConfig): Algorithmic configuration.
970
+ sim (SimConfig): Simulation configuration.
971
+ tune (TuneConfig): Hyperparameter tuning configuration.
972
+ train (TrainConfig): Training configuration.
973
+ """
974
+
975
+ io: IOConfig = field(default_factory=IOConfig)
976
+ plot: PlotConfig = field(default_factory=PlotConfig)
977
+ split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
978
+ algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
979
+ sim: SimConfig = field(default_factory=SimConfig)
980
+ tune: TuneConfig = field(default_factory=TuneConfig)
981
+ train: TrainConfig = field(default_factory=TrainConfig)
982
+
983
+ @classmethod
984
+ def from_preset(
985
+ cls,
986
+ preset: Literal["fast", "balanced", "thorough"] = "balanced",
987
+ ) -> "MostFrequentConfig":
988
+ """Construct a preset configuration."""
989
+ if preset not in {"fast", "balanced", "thorough"}:
990
+ raise ValueError(f"Unknown preset: {preset}")
991
+
992
+ cfg = cls()
993
+ cfg.io.verbose = False
994
+ cfg.split.test_size = 0.2
995
+ cfg.sim.simulate_missing = True
996
+ cfg.sim.sim_strategy = "random"
997
+ cfg.sim.sim_prop = 0.2
998
+
999
+ return cfg
1000
+
1001
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "MostFrequentConfig":
1002
+ """Apply dot-key overrides."""
1003
+ if not overrides:
1004
+ return self
1005
+ for k, v in overrides.items():
1006
+ node = self
1007
+ parts = k.split(".")
1008
+ for p in parts[:-1]:
1009
+ node = getattr(node, p)
1010
+ last = parts[-1]
1011
+ if hasattr(node, last):
1012
+ setattr(node, last, v)
1013
+ else:
1014
+ pass
1015
+ return self
1016
+
1017
+ def to_dict(self) -> Dict[str, Any]:
1018
+ return asdict(self)
1019
+
1020
+
1021
+ @dataclass
1022
+ class RefAlleleAlgoConfig:
1023
+ """Algorithmic knobs for ImputeRefAllele.
1024
+
1025
+ Attributes:
1026
+ missing (int): Code for missing genotypes in 0/1/2.
1027
+ """
1028
+
1029
+ missing: int = -1
1030
+
1031
+
1032
+ @dataclass
1033
+ class RefAlleleConfig:
1034
+ """Top-level configuration for ImputeRefAllele.
1035
+
1036
+ Attributes:
1037
+ io (IOConfig): I/O configuration.
1038
+ plot (PlotConfig): Plotting configuration.
1039
+ split (DeterministicSplitConfig): Data splitting configuration.
1040
+ algo (RefAlleleAlgoConfig): Algorithmic configuration.
1041
+ sim (SimConfig): Simulation configuration.
1042
+ tune (TuneConfig): Hyperparameter tuning configuration.
1043
+ train (TrainConfig): Training configuration.
1044
+ """
1045
+
1046
+ io: IOConfig = field(default_factory=IOConfig)
1047
+ plot: PlotConfig = field(default_factory=PlotConfig)
1048
+ split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
1049
+ algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
1050
+ sim: SimConfig = field(default_factory=SimConfig)
1051
+ tune: TuneConfig = field(default_factory=TuneConfig)
1052
+ train: TrainConfig = field(default_factory=TrainConfig)
1053
+
1054
+ @classmethod
1055
+ def from_preset(
1056
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1057
+ ) -> "RefAlleleConfig":
1058
+ """Presets mainly keep parity with logging/IO and split test_size."""
1059
+ if preset not in {"fast", "balanced", "thorough"}:
1060
+ raise ValueError(f"Unknown preset: {preset}")
1061
+
1062
+ cfg = cls()
1063
+ cfg.io.verbose = False
1064
+ cfg.split.test_size = 0.2
1065
+ cfg.sim.simulate_missing = True
1066
+ cfg.sim.sim_strategy = "random"
1067
+ cfg.sim.sim_prop = 0.2
1068
+ return cfg
1069
+
1070
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RefAlleleConfig":
1071
+ """Apply dot-key overrides."""
1072
+ if not overrides:
1073
+ return self
1074
+ for k, v in overrides.items():
1075
+ node = self
1076
+ parts = k.split(".")
1077
+ for p in parts[:-1]:
1078
+ node = getattr(node, p)
1079
+ last = parts[-1]
1080
+ if hasattr(node, last):
1081
+ setattr(node, last, v)
1082
+ else:
1083
+ pass
1084
+ return self
1085
+
1086
+ def to_dict(self) -> Dict[str, Any]:
1087
+ return asdict(self)
1088
+
1089
+
1090
+ def _flatten_dict(
1091
+ d: Dict[str, Any], prefix: str = "", out: Optional[Dict[str, Any]] = None
1092
+ ) -> Dict[str, Any]:
1093
+ """Flatten a nested dictionary into dot-key format."""
1094
+ out = out or {}
1095
+ for k, v in d.items():
1096
+ kk = f"{prefix}.{k}" if prefix else k
1097
+ if isinstance(v, dict):
1098
+ _flatten_dict(v, kk, out)
1099
+ else:
1100
+ out[kk] = v
1101
+ return out
1102
+
1103
+
1104
+ @dataclass
1105
+ class IOConfigSupervised:
1106
+ """I/O, logging, and run identity.
1107
+
1108
+ Attributes:
1109
+ prefix (str): Prefix for output files and logs.
1110
+ seed (Optional[int]): Random seed for reproducibility.
1111
+ n_jobs (int): Number of parallel jobs to use.
1112
+ verbose (bool): Whether to enable verbose logging.
1113
+ debug (bool): Whether to enable debug mode.
1114
+ """
1115
+
1116
+ prefix: str = "pgsui"
1117
+ seed: Optional[int] = None
1118
+ n_jobs: int = 1
1119
+ verbose: bool = False
1120
+ debug: bool = False
1121
+
1122
+
1123
+ @dataclass
1124
+ class PlotConfigSupervised:
1125
+ """Plot/figure styling.
1126
+
1127
+ Attributes:
1128
+ fmt (Literal["pdf", "png", "jpg", "jpeg"]): File format.
1129
+ dpi (int): Resolution in dots per inch.
1130
+ fontsize (int): Base font size for plot text.
1131
+ despine (bool): Whether to remove top/right spines.
1132
+ show (bool): Whether to display plots interactively.
1133
+ """
1134
+
1135
+ fmt: Literal["pdf", "png", "jpg", "jpeg"] = "pdf"
1136
+ dpi: int = 300
1137
+ fontsize: int = 18
1138
+ despine: bool = True
1139
+ show: bool = False
1140
+
1141
+
1142
+ @dataclass
1143
+ class TrainConfigSupervised:
1144
+ """Training/evaluation split (by samples).
1145
+
1146
+ Attributes:
1147
+ validation_split (float): Proportion of data to use for validation.
1148
+ """
1149
+
1150
+ validation_split: float = 0.20
1151
+
1152
+ def __post_init__(self):
1153
+ if not (0.0 < self.validation_split < 1.0):
1154
+ raise ValueError("validation_split must be between 0.0 and 1.0")
1155
+
1156
+
1157
+ @dataclass
1158
+ class ImputerConfigSupervised:
1159
+ """IterativeImputer-like scaffolding used by current supervised wrappers.
1160
+
1161
+ Attributes:
1162
+ n_nearest_features (Optional[int]): Number of nearest features to use.
1163
+ max_iter (int): Maximum number of imputation iterations to perform.
1164
+ """
1165
+
1166
+ n_nearest_features: Optional[int] = 10
1167
+ max_iter: int = 10
1168
+
1169
+
1170
+ @dataclass
1171
+ class SimConfigSupervised:
1172
+ """Simulation of missingness for evaluation.
1173
+
1174
+ Attributes:
1175
+ prop_missing (float): Proportion of features to set as missing.
1176
+ strategy (Literal["random", "random_inv_genotype"]): Strategy.
1177
+ het_boost (float): Boosting factor for heterogeneity.
1178
+ missing_val (int): Internal code for missing genotypes.
1179
+ """
1180
+
1181
+ prop_missing: float = 0.5
1182
+ strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
1183
+ het_boost: float = 2.0
1184
+ missing_val: int = -1
1185
+
1186
+
1187
+ @dataclass
1188
+ class TuningConfigSupervised:
1189
+ """Optuna tuning envelope."""
1190
+
1191
+ enabled: bool = True
1192
+ n_trials: int = 100
1193
+ metric: str = "pr_macro"
1194
+ n_jobs: int = 8
1195
+ fast: bool = True
1196
+
1197
+
1198
+ @dataclass
1199
+ class RFModelConfig:
1200
+ """Random Forest hyperparameters.
1201
+
1202
+ Attributes:
1203
+ n_estimators (int): Number of trees in the forest.
1204
+ max_depth (Optional[int]): Maximum depth of the trees.
1205
+ min_samples_split (int): Minimum number of samples required to split.
1206
+ min_samples_leaf (int): Minimum number of samples required at a leaf.
1207
+ max_features (Literal["sqrt", "log2"] | float | None): Features to consider.
1208
+ criterion (Literal["gini", "entropy", "log_loss"]): Split quality metric.
1209
+ class_weight (Literal["balanced", "balanced_subsample", None]): Class weights.
1210
+ """
1211
+
1212
+ n_estimators: int = 100
1213
+ max_depth: Optional[int] = None
1214
+ min_samples_split: int = 2
1215
+ min_samples_leaf: int = 1
1216
+ max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
1217
+ criterion: Literal["gini", "entropy", "log_loss"] = "gini"
1218
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
1219
+
1220
+
1221
+ @dataclass
1222
+ class HGBModelConfig:
1223
+ """Histogram-based Gradient Boosting hyperparameters.
1224
+
1225
+ Attributes:
1226
+ n_estimators (int): Number of boosting iterations (max_iter).
1227
+ learning_rate (float): Step size for each boosting iteration.
1228
+ max_depth (Optional[int]): Maximum depth of each tree.
1229
+ min_samples_leaf (int): Minimum number of samples required at a leaf.
1230
+ max_features (float | None): Proportion of features to consider.
1231
+ n_iter_no_change (int): Iterations to wait for early stopping.
1232
+ tol (float): Minimum improvement in the loss.
1233
+ """
1234
+
1235
+ n_estimators: int = 100 # maps to max_iter
1236
+ learning_rate: float = 0.1
1237
+ max_depth: Optional[int] = None
1238
+ min_samples_leaf: int = 1
1239
+ max_features: float | None = 1.0
1240
+ n_iter_no_change: int = 10
1241
+ tol: float = 1e-7
1242
+
1243
+ def __post_init__(self) -> None:
1244
+ if isinstance(self.max_features, float):
1245
+ if not (0.0 < self.max_features <= 1.0):
1246
+ raise ValueError("max_features as float must be in (0.0, 1.0]")
1247
+
1248
+ if self.n_estimators <= 0:
1249
+ raise ValueError("n_estimators must be a positive integer")
1250
+
1251
+
1252
+ @dataclass
1253
+ class RFConfig:
1254
+ """Configuration for ImputeRandomForest.
1255
+
1256
+ Attributes:
1257
+ io (IOConfigSupervised): Run identity, logging, and seeds.
1258
+ model (RFModelConfig): RandomForest hyperparameters.
1259
+ train (TrainConfigSupervised): Sample split for validation.
1260
+ imputer (ImputerConfigSupervised): IterativeImputer scaffolding.
1261
+ sim (SimConfigSupervised): Simulated missingness.
1262
+ plot (PlotConfigSupervised): Plot styling.
1263
+ tune (TuningConfigSupervised): Optuna knobs.
1264
+ """
1265
+
1266
+ io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
1267
+ model: RFModelConfig = field(default_factory=RFModelConfig)
1268
+ train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
1269
+ imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
1270
+ sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
1271
+ plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
1272
+ tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
1273
+
1274
+ @classmethod
1275
+ def from_preset(cls, preset: str = "balanced") -> "RFConfig":
1276
+ """Build a config from a named preset."""
1277
+ cfg = cls()
1278
+ if preset == "fast":
1279
+ cfg.model.n_estimators = 50
1280
+ cfg.model.max_depth = None
1281
+ cfg.imputer.max_iter = 5
1282
+ cfg.io.n_jobs = 1
1283
+ cfg.tune.enabled = False
1284
+ elif preset == "balanced":
1285
+ cfg.model.n_estimators = 200
1286
+ cfg.model.max_depth = None
1287
+ cfg.imputer.max_iter = 10
1288
+ cfg.io.n_jobs = 1
1289
+ cfg.tune.enabled = False
1290
+ cfg.tune.n_trials = 100
1291
+ elif preset == "thorough":
1292
+ cfg.model.n_estimators = 500
1293
+ cfg.model.max_depth = 50 # Added safety cap
1294
+ cfg.imputer.max_iter = 20
1295
+ cfg.io.n_jobs = 1
1296
+ cfg.tune.enabled = False
1297
+ cfg.tune.n_trials = 250
1298
+ else:
1299
+ raise ValueError(f"Unknown preset: {preset}")
1300
+
1301
+ return cfg
1302
+
1303
+ @classmethod
1304
+ def from_yaml(cls, path: str) -> "RFConfig":
1305
+ """Load from YAML; honors optional top-level 'preset'."""
1306
+ return load_yaml_to_dataclass(path, cls)
1307
+
1308
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RFConfig":
1309
+ """Apply flat dot-key overrides."""
1310
+ if overrides:
1311
+ apply_dot_overrides(self, overrides)
1312
+ return self
1313
+
1314
+ def to_dict(self) -> Dict[str, Any]:
1315
+ return asdict(self)
1316
+
1317
+ def to_imputer_kwargs(self) -> Dict[str, Any]:
1318
+ return {
1319
+ "prefix": self.io.prefix,
1320
+ "seed": self.io.seed,
1321
+ "n_jobs": self.io.n_jobs,
1322
+ "verbose": self.io.verbose,
1323
+ "debug": self.io.debug,
1324
+ "model_n_estimators": self.model.n_estimators,
1325
+ "model_max_depth": self.model.max_depth,
1326
+ "model_min_samples_split": self.model.min_samples_split,
1327
+ "model_min_samples_leaf": self.model.min_samples_leaf,
1328
+ "model_max_features": self.model.max_features,
1329
+ "model_criterion": self.model.criterion,
1330
+ "model_validation_split": self.train.validation_split,
1331
+ "model_n_nearest_features": self.imputer.n_nearest_features,
1332
+ "model_max_iter": self.imputer.max_iter,
1333
+ "sim_prop_missing": self.sim.prop_missing,
1334
+ "sim_strategy": self.sim.strategy,
1335
+ "sim_het_boost": self.sim.het_boost,
1336
+ "plot_format": self.plot.fmt,
1337
+ "plot_fontsize": self.plot.fontsize,
1338
+ "plot_despine": self.plot.despine,
1339
+ "plot_dpi": self.plot.dpi,
1340
+ "plot_show_plots": self.plot.show,
1341
+ }
1342
+
1343
+
1344
+ @dataclass
1345
+ class HGBConfig:
1346
+ """Configuration for ImputeHistGradientBoosting.
1347
+
1348
+ Attributes:
1349
+ io (IOConfigSupervised): Run identity, logging, and seeds.
1350
+ model (HGBModelConfig): HistGradientBoosting hyperparameters.
1351
+ train (TrainConfigSupervised): Sample split for validation.
1352
+ imputer (ImputerConfigSupervised): IterativeImputer scaffolding.
1353
+ sim (SimConfigSupervised): Simulated missingness.
1354
+ plot (PlotConfigSupervised): Plot styling.
1355
+ tune (TuningConfigSupervised): Optuna knobs.
1356
+ """
1357
+
1358
+ io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
1359
+ model: HGBModelConfig = field(default_factory=HGBModelConfig)
1360
+ train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
1361
+ imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
1362
+ sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
1363
+ plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
1364
+ tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
1365
+
1366
+ @classmethod
1367
+ def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
1368
+ """Build a config from a named preset."""
1369
+ cfg = cls()
1370
+ if preset == "fast":
1371
+ cfg.model.n_estimators = 50
1372
+ cfg.model.learning_rate = 0.2
1373
+ cfg.model.max_depth = None
1374
+ cfg.imputer.max_iter = 5
1375
+ cfg.io.n_jobs = 1
1376
+ cfg.tune.enabled = False
1377
+ cfg.tune.n_trials = 50
1378
+ elif preset == "balanced":
1379
+ cfg.model.n_estimators = 150
1380
+ cfg.model.learning_rate = 0.1
1381
+ cfg.model.max_depth = None
1382
+ cfg.imputer.max_iter = 10
1383
+ cfg.io.n_jobs = 1
1384
+ cfg.tune.enabled = False
1385
+ cfg.tune.n_trials = 100
1386
+ elif preset == "thorough":
1387
+ cfg.model.n_estimators = 500
1388
+ cfg.model.learning_rate = 0.05
1389
+ cfg.model.n_iter_no_change = 20 # Increased patience
1390
+ cfg.model.max_depth = None
1391
+ cfg.imputer.max_iter = 20
1392
+ cfg.io.n_jobs = 1
1393
+ cfg.tune.enabled = False
1394
+ cfg.tune.n_trials = 250
1395
+ else:
1396
+ raise ValueError(f"Unknown preset: {preset}")
1397
+ return cfg
1398
+
1399
+ @classmethod
1400
+ def from_yaml(cls, path: str) -> "HGBConfig":
1401
+ return load_yaml_to_dataclass(path, cls)
1402
+
1403
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "HGBConfig":
1404
+ if overrides:
1405
+ apply_dot_overrides(self, overrides)
1406
+ return self
1407
+
1408
+ def to_dict(self) -> Dict[str, Any]:
1409
+ return asdict(self)
1410
+
1411
+ def to_imputer_kwargs(self) -> Dict[str, Any]:
1412
+ return {
1413
+ "prefix": self.io.prefix,
1414
+ "seed": self.io.seed,
1415
+ "n_jobs": self.io.n_jobs,
1416
+ "verbose": self.io.verbose,
1417
+ "debug": self.io.debug,
1418
+ "model_n_estimators": self.model.n_estimators,
1419
+ "model_learning_rate": self.model.learning_rate,
1420
+ "model_n_iter_no_change": self.model.n_iter_no_change,
1421
+ "model_tol": self.model.tol,
1422
+ "model_max_depth": self.model.max_depth,
1423
+ "model_min_samples_leaf": self.model.min_samples_leaf,
1424
+ "model_max_features": self.model.max_features,
1425
+ "model_validation_split": self.train.validation_split,
1426
+ "model_n_nearest_features": self.imputer.n_nearest_features,
1427
+ "model_max_iter": self.imputer.max_iter,
1428
+ "sim_prop_missing": self.sim.prop_missing,
1429
+ "sim_strategy": self.sim.strategy,
1430
+ "sim_het_boost": self.sim.het_boost,
1431
+ "plot_format": self.plot.fmt,
1432
+ "plot_fontsize": self.plot.fontsize,
1433
+ "plot_despine": self.plot.despine,
1434
+ "plot_dpi": self.plot.dpi,
1435
+ "plot_show_plots": self.plot.show,
1436
+ }