pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,1424 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Any, Dict, Literal, Optional, Sequence
5
+
6
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
7
+
8
+
9
+ @dataclass
10
+ class _SimParams:
11
+ """Container for simulation hyperparameters.
12
+
13
+ Attributes:
14
+ prop_missing (float): Proportion of missing values to simulate.
15
+ strategy (Literal["random", "random_inv_genotype"]): Strategy for simulating missing values.
16
+ missing_val (int | float): Value to represent missing data.
17
+ het_boost (float): Boost factor for heterozygous genotypes.
18
+ seed (int | None): Random seed for reproducibility.
19
+ """
20
+
21
+ prop_missing: float = 0.3
22
+ strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
23
+ missing_val: int | float = -1
24
+ het_boost: float = 2.0
25
+ seed: int | None = None
26
+
27
+ def to_dict(self) -> dict:
28
+ return asdict(self)
29
+
30
+
31
+ @dataclass
32
+ class _ImputerParams:
33
+ """Container for imputer hyperparameters.
34
+
35
+ Attributes:
36
+ n_nearest_features (int | None): Number of nearest features to consider for imputation.
37
+ max_iter (int): Maximum number of iterations for the imputation algorithm.
38
+ initial_strategy (Literal["mean", "median", "most_frequent", "constant"]): Strategy for initial imputation.
39
+ keep_empty_features (bool): Whether to keep features that are entirely missing.
40
+ random_state (int | None): Random seed for reproducibility.
41
+ verbose (bool): If True, enables verbose logging during imputation.
42
+ """
43
+
44
+ n_nearest_features: int | None = 10
45
+ max_iter: int = 10
46
+ initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = (
47
+ "most_frequent"
48
+ )
49
+ keep_empty_features: bool = True
50
+ random_state: int | None = None
51
+ verbose: bool = False
52
+
53
+ def to_dict(self) -> dict:
54
+ return asdict(self)
55
+
56
+
57
+ @dataclass
58
+ class _RFParams:
59
+ """Container for RandomForest hyperparameters.
60
+
61
+ Attributes:
62
+ n_estimators (int): Number of trees in the forest.
63
+ max_depth (int | None): Maximum depth of the trees.
64
+ min_samples_split (int): Minimum number of samples required to split an internal node.
65
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
66
+ max_features (Literal["sqrt", "log2"] | float | None): Number of features to consider for split.
67
+ criterion (Literal["gini", "entropy", "log_loss"]): Function to measure the quality of a split.
68
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
69
+ """
70
+
71
+ n_estimators: int = 300
72
+ max_depth: int | None = None
73
+ min_samples_split: int = 2
74
+ min_samples_leaf: int = 1
75
+ max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
76
+ criterion: Literal["gini", "entropy", "log_loss"] = "gini"
77
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
78
+
79
+ def to_dict(self) -> dict:
80
+ return asdict(self)
81
+
82
+
83
+ @dataclass
84
+ class _HGBParams:
85
+ """Container for HistGradientBoosting hyperparameters.
86
+
87
+ Attributes:
88
+ max_iter (int): Maximum number of iterations.
89
+ learning_rate (float): Learning rate shrinks the contribution of each tree.
90
+ max_depth (int | None): Maximum depth of the individual regression estimators.
91
+ min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
92
+ n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping.
93
+ tol (float): Tolerance for the early stopping.
94
+ max_features (float | None): The fraction of features to consider when looking for the best split.
95
+ class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
96
+ random_state (int | None): Random seed for reproducibility.
97
+ verbose (bool): If True, enables verbose logging during training.
98
+ """
99
+
100
+ max_iter: int = 100
101
+ learning_rate: float = 0.1
102
+ max_depth: int | None = None
103
+ min_samples_leaf: int = 1
104
+ n_iter_no_change: int = 10
105
+ tol: float = 1e-7
106
+ max_features: float | None = 1.0
107
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
108
+ random_state: int | None = None
109
+ verbose: bool = False
110
+
111
+ def to_dict(self) -> dict:
112
+ return asdict(self)
113
+
114
+
115
+ @dataclass
116
+ class ModelConfig:
117
+ """Model architecture configuration.
118
+
119
+ Attributes:
120
+ latent_init (Literal["random", "pca"]): Method for initializing the latent space.
121
+ latent_dim (int): Dimensionality of the latent space.
122
+ dropout_rate (float): Dropout rate for regularization.
123
+ num_hidden_layers (int): Number of hidden layers in the neural network.
124
+ hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
125
+ layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
126
+ layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
127
+ gamma (float): Parameter for the focal loss function.
128
+ """
129
+
130
+ latent_init: Literal["random", "pca"] = "random"
131
+ latent_dim: int = 2
132
+ dropout_rate: float = 0.2
133
+ num_hidden_layers: int = 2
134
+ hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
135
+ layer_scaling_factor: float = 5.0
136
+ layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
137
+ gamma: float = 2.0
138
+
139
+
140
+ @dataclass
141
+ class TrainConfig:
142
+ """Training procedure configuration.
143
+
144
+ Attributes:
145
+ batch_size (int): Number of samples per training batch.
146
+ learning_rate (float): Learning rate for the optimizer.
147
+ lr_input_factor (float): Factor to scale the learning rate for input layer.
148
+ l1_penalty (float): L1 regularization penalty.
149
+ early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
150
+ min_epochs (int): Minimum number of epochs to train.
151
+ max_epochs (int): Maximum number of epochs to train.
152
+ validation_split (float): Proportion of data to use for validation.
153
+ weights_beta (float): Smoothing factor for class weights.
154
+ weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
155
+ device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
156
+ """
157
+
158
+ batch_size: int = 32
159
+ learning_rate: float = 1e-3
160
+ lr_input_factor: float = 1.0
161
+ l1_penalty: float = 0.0
162
+ early_stop_gen: int = 20
163
+ min_epochs: int = 100
164
+ max_epochs: int = 5000
165
+ validation_split: float = 0.2
166
+ weights_beta: float = 0.9999
167
+ weights_max_ratio: float = 1.0
168
+ device: Literal["gpu", "cpu", "mps"] = "cpu"
169
+
170
+
171
+ @dataclass
172
+ class TuneConfig:
173
+ """Hyperparameter tuning configuration.
174
+
175
+ Attributes:
176
+ enabled (bool): If True, enables hyperparameter tuning.
177
+ metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
178
+ n_trials (int): Number of hyperparameter trials to run.
179
+ resume (bool): If True, resumes tuning from a previous state.
180
+ save_db (bool): If True, saves the tuning results to a database.
181
+ fast (bool): If True, uses a faster but less thorough tuning approach.
182
+ max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
183
+ max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
184
+ epochs (int): Number of epochs to train each trial.
185
+ batch_size (int): Batch size for training during tuning.
186
+ eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
187
+ infer_epochs (int): Number of epochs for inference during tuning.
188
+ patience (int): Number of evaluations with no improvement before stopping early.
189
+ proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
190
+ """
191
+
192
+ enabled: bool = False
193
+ metric: Literal[
194
+ "f1",
195
+ "accuracy",
196
+ "pr_macro",
197
+ "average_precision",
198
+ "roc_auc",
199
+ "precision",
200
+ "recall",
201
+ ] = "f1"
202
+ n_trials: int = 100
203
+ resume: bool = False
204
+ save_db: bool = False
205
+ fast: bool = True
206
+ max_samples: int = 512
207
+ max_loci: int = 0 # 0 = all
208
+ epochs: int = 500
209
+ batch_size: int = 64
210
+ eval_interval: int = 20
211
+ infer_epochs: int = 100
212
+ patience: int = 10
213
+ proxy_metric_batch: int = 0
214
+
215
+
216
+ @dataclass
217
+ class EvalConfig:
218
+ """Evaluation configuration.
219
+
220
+ Attributes:
221
+ eval_latent_steps (int): Number of optimization steps for latent space evaluation.
222
+ eval_latent_lr (float): Learning rate for latent space optimization.
223
+ eval_latent_weight_decay (float): Weight decay for latent space optimization.
224
+ """
225
+
226
+ eval_latent_steps: int = 50
227
+ eval_latent_lr: float = 1e-2
228
+ eval_latent_weight_decay: float = 0.0
229
+
230
+
231
+ @dataclass
232
+ class PlotConfig:
233
+ """Plotting configuration.
234
+
235
+ Attributes:
236
+ fmt (Literal["pdf", "png", "jpg", "jpeg", "svg"]): Output file format.
237
+ dpi (int): Dots per inch for the output figure.
238
+ fontsize (int): Font size for text in the plots.
239
+ despine (bool): If True, removes the top and right spines from plots.
240
+ show (bool): If True, displays the plot interactively.
241
+ """
242
+
243
+ fmt: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
244
+ dpi: int = 300
245
+ fontsize: int = 18
246
+ despine: bool = True
247
+ show: bool = False
248
+
249
+
250
+ @dataclass
251
+ class IOConfig:
252
+ """I/O configuration.
253
+
254
+ Attributes:
255
+ prefix (str): Prefix for output files. Default is "pgsui".
256
+ verbose (bool): If True, enables verbose logging. Default is False.
257
+ debug (bool): If True, enables debug mode. Default is False.
258
+ seed (int | None): Random seed for reproducibility. Default is None.
259
+ n_jobs (int): Number of parallel jobs to run. Default is 1.
260
+ scoring_averaging (Literal["macro", "micro", "weighted"]): Averaging method.
261
+ """
262
+
263
+ prefix: str = "pgsui"
264
+ verbose: bool = False
265
+ debug: bool = False
266
+ seed: int | None = None
267
+ n_jobs: int = 1
268
+ scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
269
+
270
+
271
+ @dataclass
272
+ class SimConfig:
273
+ """Top-level configuration for data simulation and imputation.
274
+
275
+ Attributes:
276
+ simulate_missing (bool): If True, simulates missing data.
277
+ sim_strategy (Literal["random", ...]): Strategy for simulating missing data.
278
+ sim_prop (float): Proportion of data to simulate as missing.
279
+ sim_kwargs (dict | None): Additional keyword arguments for simulation.
280
+ """
281
+
282
+ simulate_missing: bool = False
283
+ sim_strategy: Literal[
284
+ "random",
285
+ "random_weighted",
286
+ "random_weighted_inv",
287
+ "nonrandom",
288
+ "nonrandom_weighted",
289
+ ] = "random"
290
+ sim_prop: float = 0.10
291
+ sim_kwargs: dict | None = None
292
+
293
+
294
+ @dataclass
295
+ class NLPCAConfig:
296
+ """Top-level configuration for ImputeNLPCA.
297
+
298
+ Attributes:
299
+ io (IOConfig): I/O configuration.
300
+ model (ModelConfig): Model architecture configuration.
301
+ train (TrainConfig): Training procedure configuration.
302
+ tune (TuneConfig): Hyperparameter tuning configuration.
303
+ evaluate (EvalConfig): Evaluation configuration.
304
+ plot (PlotConfig): Plotting configuration.
305
+ sim (SimConfig): Simulation configuration.
306
+ """
307
+
308
+ io: IOConfig = field(default_factory=IOConfig)
309
+ model: ModelConfig = field(default_factory=ModelConfig)
310
+ train: TrainConfig = field(default_factory=TrainConfig)
311
+ tune: TuneConfig = field(default_factory=TuneConfig)
312
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
313
+ plot: PlotConfig = field(default_factory=PlotConfig)
314
+ sim: SimConfig = field(default_factory=SimConfig)
315
+
316
+ @classmethod
317
+ def from_preset(
318
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
319
+ ) -> "NLPCAConfig":
320
+ """Build a NLPCAConfig from a named preset."""
321
+ if preset not in {"fast", "balanced", "thorough"}:
322
+ raise ValueError(f"Unknown preset: {preset}")
323
+
324
+ cfg = cls()
325
+
326
+ # Common baselines
327
+ cfg.io.verbose = False
328
+ cfg.train.validation_split = 0.20
329
+ cfg.model.hidden_activation = "relu"
330
+ cfg.model.layer_schedule = "pyramid"
331
+ cfg.model.latent_init = "random"
332
+ cfg.evaluate.eval_latent_lr = 1e-2
333
+ cfg.evaluate.eval_latent_weight_decay = 0.0
334
+ cfg.sim.simulate_missing = True
335
+ cfg.sim.sim_strategy = "random"
336
+ cfg.sim.sim_prop = 0.2
337
+
338
+ if preset == "fast":
339
+ # Model
340
+ cfg.model.latent_dim = 4
341
+ cfg.model.num_hidden_layers = 1
342
+ cfg.model.layer_scaling_factor = 2.0
343
+ cfg.model.dropout_rate = 0.10
344
+ cfg.model.gamma = 1.5
345
+ # Train
346
+ cfg.train.batch_size = 128
347
+ cfg.train.learning_rate = 1e-3
348
+ cfg.train.early_stop_gen = 5
349
+ cfg.train.min_epochs = 10
350
+ cfg.train.max_epochs = 120
351
+ cfg.train.weights_beta = 0.9999
352
+ cfg.train.weights_max_ratio = 2.0
353
+ # Tuning
354
+ cfg.tune.enabled = True
355
+ cfg.tune.fast = True
356
+ cfg.tune.n_trials = 25
357
+ cfg.tune.epochs = 120
358
+ cfg.tune.batch_size = 128
359
+ cfg.tune.max_samples = 512
360
+ cfg.tune.max_loci = 0
361
+ cfg.tune.eval_interval = 20
362
+ cfg.tune.infer_epochs = 20
363
+ cfg.tune.patience = 5
364
+ cfg.tune.proxy_metric_batch = 0
365
+ # Eval
366
+ cfg.evaluate.eval_latent_steps = 20
367
+
368
+ elif preset == "balanced":
369
+ # Model
370
+ cfg.model.latent_dim = 8
371
+ cfg.model.num_hidden_layers = 2
372
+ cfg.model.layer_scaling_factor = 3.0
373
+ cfg.model.dropout_rate = 0.20
374
+ cfg.model.gamma = 2.0
375
+ # Train
376
+ cfg.train.batch_size = 128
377
+ cfg.train.learning_rate = 8e-4
378
+ cfg.train.early_stop_gen = 15
379
+ cfg.train.min_epochs = 50
380
+ cfg.train.max_epochs = 600
381
+ cfg.train.weights_beta = 0.9999
382
+ cfg.train.weights_max_ratio = 2.0
383
+ # Tuning
384
+ cfg.tune.enabled = True
385
+ cfg.tune.fast = True
386
+ cfg.tune.n_trials = 75
387
+ cfg.tune.epochs = 300
388
+ cfg.tune.batch_size = 128
389
+ cfg.tune.max_samples = 2048
390
+ cfg.tune.max_loci = 0
391
+ cfg.tune.eval_interval = 20
392
+ cfg.tune.infer_epochs = 40
393
+ cfg.tune.patience = 10
394
+ cfg.tune.proxy_metric_batch = 0
395
+ # Eval
396
+ cfg.evaluate.eval_latent_steps = 30
397
+
398
+ else: # thorough
399
+ # Model
400
+ cfg.model.latent_dim = 16
401
+ cfg.model.num_hidden_layers = 3
402
+ cfg.model.layer_scaling_factor = 5.0
403
+ cfg.model.dropout_rate = 0.30
404
+ cfg.model.gamma = 2.5
405
+ # Train
406
+ cfg.train.batch_size = 64
407
+ cfg.train.learning_rate = 6e-4
408
+ cfg.train.early_stop_gen = 20 # Reduced from 30
409
+ cfg.train.min_epochs = 100
410
+ cfg.train.max_epochs = 800 # Reduced from 1200
411
+ cfg.train.weights_beta = 0.9999
412
+ cfg.train.weights_max_ratio = 2.0
413
+ # Tuning
414
+ cfg.tune.enabled = True
415
+ cfg.tune.fast = False
416
+ cfg.tune.n_trials = 150
417
+ cfg.tune.epochs = 600
418
+ cfg.tune.batch_size = 64
419
+ cfg.tune.max_samples = 5000 # Capped from 0
420
+ cfg.tune.max_loci = 0
421
+ cfg.tune.eval_interval = 10
422
+ cfg.tune.infer_epochs = 80
423
+ cfg.tune.patience = 15 # Reduced from 20
424
+ cfg.tune.proxy_metric_batch = 0
425
+ # Eval
426
+ cfg.evaluate.eval_latent_steps = 50
427
+
428
+ return cfg
429
+
430
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
431
+ """Apply flat dot-key overrides."""
432
+ if not overrides:
433
+ return self
434
+ for k, v in overrides.items():
435
+ node = self
436
+ parts = k.split(".")
437
+ for p in parts[:-1]:
438
+ node = getattr(node, p)
439
+ last = parts[-1]
440
+ if hasattr(node, last):
441
+ setattr(node, last, v)
442
+ else:
443
+ raise KeyError(f"Unknown config key: {k}")
444
+ return self
445
+
446
+ def to_dict(self) -> Dict[str, Any]:
447
+ return asdict(self)
448
+
449
+
450
+ @dataclass
451
+ class UBPConfig:
452
+ """Top-level configuration for ImputeUBP.
453
+
454
+ Attributes:
455
+ io (IOConfig): I/O configuration.
456
+ model (ModelConfig): Model architecture configuration.
457
+ train (TrainConfig): Training procedure configuration.
458
+ tune (TuneConfig): Hyperparameter tuning configuration.
459
+ evaluate (EvalConfig): Evaluation configuration.
460
+ plot (PlotConfig): Plotting configuration.
461
+ sim (SimConfig): Simulated-missing configuration.
462
+ """
463
+
464
+ io: IOConfig = field(default_factory=IOConfig)
465
+ model: ModelConfig = field(default_factory=ModelConfig)
466
+ train: TrainConfig = field(default_factory=TrainConfig)
467
+ tune: TuneConfig = field(default_factory=TuneConfig)
468
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
469
+ plot: PlotConfig = field(default_factory=PlotConfig)
470
+ sim: SimConfig = field(default_factory=SimConfig)
471
+
472
+ @classmethod
473
+ def from_preset(
474
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
475
+ ) -> "UBPConfig":
476
+ """Build a UBPConfig from a named preset."""
477
+ if preset not in {"fast", "balanced", "thorough"}:
478
+ raise ValueError(f"Unknown preset: {preset}")
479
+
480
+ cfg = cls()
481
+
482
+ # Common baselines
483
+ cfg.io.verbose = False
484
+ cfg.model.hidden_activation = "relu"
485
+ cfg.model.layer_schedule = "pyramid"
486
+ cfg.model.latent_init = "random"
487
+ cfg.sim.simulate_missing = True
488
+ cfg.sim.sim_strategy = "random"
489
+ cfg.sim.sim_prop = 0.2
490
+
491
+ if preset == "fast":
492
+ # Model
493
+ cfg.model.latent_dim = 4
494
+ cfg.model.num_hidden_layers = 1
495
+ cfg.model.layer_scaling_factor = 2.0
496
+ cfg.model.dropout_rate = 0.10
497
+ cfg.model.gamma = 1.5
498
+ # Train
499
+ cfg.train.batch_size = 128
500
+ cfg.train.learning_rate = 1e-3
501
+ cfg.train.early_stop_gen = 5
502
+ cfg.train.min_epochs = 10
503
+ cfg.train.max_epochs = 120
504
+ cfg.train.weights_beta = 0.9999
505
+ cfg.train.weights_max_ratio = 2.0
506
+ # Tuning
507
+ cfg.tune.enabled = True
508
+ cfg.tune.fast = True
509
+ cfg.tune.n_trials = 25
510
+ cfg.tune.epochs = 120
511
+ cfg.tune.batch_size = 128
512
+ cfg.tune.max_samples = 512
513
+ cfg.tune.max_loci = 0
514
+ cfg.tune.eval_interval = 20
515
+ cfg.tune.infer_epochs = 20
516
+ cfg.tune.patience = 5
517
+ cfg.tune.proxy_metric_batch = 0
518
+ # Eval
519
+ cfg.evaluate.eval_latent_steps = 20
520
+ cfg.evaluate.eval_latent_lr = 1e-2
521
+ cfg.evaluate.eval_latent_weight_decay = 0.0
522
+
523
+ elif preset == "balanced":
524
+ # Model
525
+ cfg.model.latent_dim = 8
526
+ cfg.model.num_hidden_layers = 2
527
+ cfg.model.layer_scaling_factor = 3.0
528
+ cfg.model.dropout_rate = 0.20
529
+ cfg.model.gamma = 2.0
530
+ # Train
531
+ cfg.train.batch_size = 128
532
+ cfg.train.learning_rate = 8e-4
533
+ cfg.train.early_stop_gen = 15
534
+ cfg.train.min_epochs = 50
535
+ cfg.train.max_epochs = 600
536
+ cfg.train.weights_beta = 0.9999
537
+ cfg.train.weights_max_ratio = 2.0
538
+ # Tuning
539
+ cfg.tune.enabled = True
540
+ cfg.tune.fast = True
541
+ cfg.tune.n_trials = 75
542
+ cfg.tune.epochs = 300
543
+ cfg.tune.batch_size = 128
544
+ cfg.tune.max_samples = 2048
545
+ cfg.tune.max_loci = 0
546
+ cfg.tune.eval_interval = 20
547
+ cfg.tune.infer_epochs = 40
548
+ cfg.tune.patience = 10
549
+ cfg.tune.proxy_metric_batch = 0
550
+ # Eval
551
+ cfg.evaluate.eval_latent_steps = 30
552
+ cfg.evaluate.eval_latent_lr = 1e-2
553
+ cfg.evaluate.eval_latent_weight_decay = 0.0
554
+
555
+ else: # thorough
556
+ # Model
557
+ cfg.model.latent_dim = 16
558
+ cfg.model.num_hidden_layers = 3
559
+ cfg.model.layer_scaling_factor = 5.0
560
+ cfg.model.dropout_rate = 0.30
561
+ cfg.model.gamma = 2.5
562
+ # Train
563
+ cfg.train.batch_size = 64
564
+ cfg.train.learning_rate = 6e-4
565
+ cfg.train.early_stop_gen = 20 # Reduced from 30
566
+ cfg.train.min_epochs = 100
567
+ cfg.train.max_epochs = 800 # Reduced from 1200
568
+ cfg.train.weights_beta = 0.9999
569
+ cfg.train.weights_max_ratio = 2.0
570
+ # Tuning
571
+ cfg.tune.enabled = True
572
+ cfg.tune.fast = False
573
+ cfg.tune.n_trials = 150
574
+ cfg.tune.epochs = 600
575
+ cfg.tune.batch_size = 64
576
+ cfg.tune.max_samples = 5000 # Capped from 0
577
+ cfg.tune.max_loci = 0
578
+ cfg.tune.eval_interval = 10
579
+ cfg.tune.infer_epochs = 80
580
+ cfg.tune.patience = 15 # Reduced from 20
581
+ cfg.tune.proxy_metric_batch = 0
582
+ # Eval
583
+ cfg.evaluate.eval_latent_steps = 50
584
+ cfg.evaluate.eval_latent_lr = 1e-2
585
+ cfg.evaluate.eval_latent_weight_decay = 0.0
586
+
587
+ return cfg
588
+
589
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
590
+ """Apply flat dot-key overrides."""
591
+ if not overrides:
592
+ return self
593
+
594
+ for k, v in overrides.items():
595
+ node = self
596
+ parts = k.split(".")
597
+ for p in parts[:-1]:
598
+ node = getattr(node, p)
599
+ last = parts[-1]
600
+ if hasattr(node, last):
601
+ setattr(node, last, v)
602
+ else:
603
+ raise KeyError(f"Unknown config key: {k}")
604
+ return self
605
+
606
+ def to_dict(self) -> Dict[str, Any]:
607
+ return asdict(self)
608
+
609
+
610
+ @dataclass
611
+ class AutoencoderConfig:
612
+ """Top-level configuration for ImputeAutoencoder.
613
+
614
+ Attributes:
615
+ io (IOConfig): I/O configuration.
616
+ model (ModelConfig): Model architecture configuration.
617
+ train (TrainConfig): Training procedure configuration.
618
+ tune (TuneConfig): Hyperparameter tuning configuration.
619
+ evaluate (EvalConfig): Evaluation configuration.
620
+ plot (PlotConfig): Plotting configuration.
621
+ sim (SimConfig): Simulated-missing configuration.
622
+ """
623
+
624
+ io: IOConfig = field(default_factory=IOConfig)
625
+ model: ModelConfig = field(default_factory=ModelConfig)
626
+ train: TrainConfig = field(default_factory=TrainConfig)
627
+ tune: TuneConfig = field(default_factory=TuneConfig)
628
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
629
+ plot: PlotConfig = field(default_factory=PlotConfig)
630
+ sim: SimConfig = field(default_factory=SimConfig)
631
+
632
+ @classmethod
633
+ def from_preset(
634
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
635
+ ) -> "AutoencoderConfig":
636
+ """Build a AutoencoderConfig from a named preset."""
637
+ if preset not in {"fast", "balanced", "thorough"}:
638
+ raise ValueError(f"Unknown preset: {preset}")
639
+
640
+ cfg = cls()
641
+
642
+ # Common baselines (no latent refinement at eval)
643
+ cfg.io.verbose = False
644
+ cfg.train.validation_split = 0.20
645
+ cfg.model.hidden_activation = "relu"
646
+ cfg.model.layer_schedule = "pyramid"
647
+ cfg.evaluate.eval_latent_steps = 0
648
+ cfg.evaluate.eval_latent_lr = 0.0
649
+ cfg.evaluate.eval_latent_weight_decay = 0.0
650
+ cfg.sim.simulate_missing = True
651
+ cfg.sim.sim_strategy = "random"
652
+ cfg.sim.sim_prop = 0.2
653
+
654
+ if preset == "fast":
655
+ cfg.model.latent_dim = 4
656
+ cfg.model.num_hidden_layers = 1
657
+ cfg.model.layer_scaling_factor = 2.0
658
+ cfg.model.dropout_rate = 0.10
659
+ cfg.model.gamma = 1.5
660
+ cfg.train.batch_size = 128
661
+ cfg.train.learning_rate = 1e-3
662
+ cfg.train.early_stop_gen = 5
663
+ cfg.train.min_epochs = 10
664
+ cfg.train.max_epochs = 120
665
+ cfg.train.weights_beta = 0.9999
666
+ cfg.train.weights_max_ratio = 2.0
667
+ cfg.tune.enabled = True
668
+ cfg.tune.fast = True
669
+ cfg.tune.n_trials = 25
670
+ cfg.tune.epochs = 120
671
+ cfg.tune.batch_size = 128
672
+ cfg.tune.max_samples = 512
673
+ cfg.tune.max_loci = 0
674
+ cfg.tune.eval_interval = 20
675
+ cfg.tune.patience = 5
676
+ cfg.tune.proxy_metric_batch = 0
677
+ if hasattr(cfg.tune, "infer_epochs"):
678
+ cfg.tune.infer_epochs = 0
679
+
680
+ elif preset == "balanced":
681
+ cfg.model.latent_dim = 8
682
+ cfg.model.num_hidden_layers = 2
683
+ cfg.model.layer_scaling_factor = 3.0
684
+ cfg.model.dropout_rate = 0.20
685
+ cfg.model.gamma = 2.0
686
+ cfg.train.batch_size = 128
687
+ cfg.train.learning_rate = 8e-4
688
+ cfg.train.early_stop_gen = 15
689
+ cfg.train.min_epochs = 50
690
+ cfg.train.max_epochs = 600
691
+ cfg.train.weights_beta = 0.9999
692
+ cfg.train.weights_max_ratio = 2.0
693
+ cfg.tune.enabled = True
694
+ cfg.tune.fast = True
695
+ cfg.tune.n_trials = 75
696
+ cfg.tune.epochs = 300
697
+ cfg.tune.batch_size = 128
698
+ cfg.tune.max_samples = 2048
699
+ cfg.tune.max_loci = 0
700
+ cfg.tune.eval_interval = 20
701
+ cfg.tune.patience = 10
702
+ cfg.tune.proxy_metric_batch = 0
703
+ if hasattr(cfg.tune, "infer_epochs"):
704
+ cfg.tune.infer_epochs = 0
705
+
706
+ else: # thorough
707
+ cfg.model.latent_dim = 16
708
+ cfg.model.num_hidden_layers = 3
709
+ cfg.model.layer_scaling_factor = 5.0
710
+ cfg.model.dropout_rate = 0.30
711
+ cfg.model.gamma = 2.5
712
+ cfg.train.batch_size = 64
713
+ cfg.train.learning_rate = 6e-4
714
+ cfg.train.early_stop_gen = 20 # Reduced from 30
715
+ cfg.train.min_epochs = 100
716
+ cfg.train.max_epochs = 800 # Reduced from 1200
717
+ cfg.train.weights_beta = 0.9999
718
+ cfg.train.weights_max_ratio = 2.0
719
+ cfg.tune.enabled = True
720
+ cfg.tune.fast = False
721
+ cfg.tune.n_trials = 150
722
+ cfg.tune.epochs = 600
723
+ cfg.tune.batch_size = 64
724
+ cfg.tune.max_samples = 5000 # Capped from 0
725
+ cfg.tune.max_loci = 0
726
+ cfg.tune.eval_interval = 10
727
+ cfg.tune.patience = 15 # Reduced from 20
728
+ cfg.tune.proxy_metric_batch = 0
729
+ if hasattr(cfg.tune, "infer_epochs"):
730
+ cfg.tune.infer_epochs = 0
731
+
732
+ return cfg
733
+
734
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
735
+ """Apply flat dot-key overrides."""
736
+ if not overrides:
737
+ return self
738
+ for k, v in overrides.items():
739
+ node = self
740
+ parts = k.split(".")
741
+ for p in parts[:-1]:
742
+ node = getattr(node, p)
743
+ last = parts[-1]
744
+ if hasattr(node, last):
745
+ setattr(node, last, v)
746
+ else:
747
+ raise KeyError(f"Unknown config key: {k}")
748
+ return self
749
+
750
+ def to_dict(self) -> Dict[str, Any]:
751
+ return asdict(self)
752
+
753
+
754
+ @dataclass
755
+ class VAEExtraConfig:
756
+ """VAE-specific knobs.
757
+
758
+ Attributes:
759
+ kl_beta (float): Final β for KL divergence term.
760
+ kl_warmup (int): Number of epochs with β=0 (warm-up period).
761
+ kl_ramp (int): Number of epochs for linear ramp to final β.
762
+ """
763
+
764
+ kl_beta: float = 1.0
765
+ kl_warmup: int = 50
766
+ kl_ramp: int = 200
767
+
768
+
769
+ @dataclass
770
+ class VAEConfig:
771
+ """Top-level configuration for ImputeVAE (AE-parity + VAE extras).
772
+
773
+ Attributes:
774
+ io (IOConfig): I/O configuration.
775
+ model (ModelConfig): Model architecture configuration.
776
+ train (TrainConfig): Training procedure configuration.
777
+ tune (TuneConfig): Hyperparameter tuning configuration.
778
+ evaluate (EvalConfig): Evaluation configuration.
779
+ plot (PlotConfig): Plotting configuration.
780
+ vae (VAEExtraConfig): VAE-specific configuration.
781
+ sim (SimConfig): Simulated-missing configuration.
782
+ """
783
+
784
+ io: IOConfig = field(default_factory=IOConfig)
785
+ model: ModelConfig = field(default_factory=ModelConfig)
786
+ train: TrainConfig = field(default_factory=TrainConfig)
787
+ tune: TuneConfig = field(default_factory=TuneConfig)
788
+ evaluate: EvalConfig = field(default_factory=EvalConfig)
789
+ plot: PlotConfig = field(default_factory=PlotConfig)
790
+ vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
791
+ sim: SimConfig = field(default_factory=SimConfig)
792
+
793
+ @classmethod
794
+ def from_preset(
795
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
796
+ ) -> "VAEConfig":
797
+ """Build a VAEConfig from a named preset."""
798
+ if preset not in {"fast", "balanced", "thorough"}:
799
+ raise ValueError(f"Unknown preset: {preset}")
800
+
801
+ cfg = cls()
802
+
803
+ # Common baselines (match AE; no latent refinement at eval)
804
+ cfg.io.verbose = False
805
+ cfg.train.validation_split = 0.20
806
+ cfg.model.hidden_activation = "relu"
807
+ cfg.model.layer_schedule = "pyramid"
808
+ cfg.evaluate.eval_latent_steps = 0
809
+ cfg.evaluate.eval_latent_lr = 0.0
810
+ cfg.evaluate.eval_latent_weight_decay = 0.0
811
+ cfg.sim.simulate_missing = True
812
+ cfg.sim.sim_strategy = "random"
813
+ cfg.sim.sim_prop = 0.2
814
+
815
+ # VAE KL schedules, shortened for speed
816
+ cfg.vae.kl_beta = 1.0
817
+ cfg.vae.kl_warmup = 25
818
+ cfg.vae.kl_ramp = 100
819
+
820
+ if preset == "fast":
821
+ cfg.model.latent_dim = 4
822
+ cfg.model.num_hidden_layers = 1
823
+ cfg.model.layer_scaling_factor = 2.0
824
+ cfg.model.dropout_rate = 0.10
825
+ cfg.model.gamma = 1.5
826
+ cfg.vae.kl_beta = 0.5 # Lower beta for fast training
827
+ cfg.train.batch_size = 128
828
+ cfg.train.learning_rate = 1e-3
829
+ cfg.train.early_stop_gen = 5
830
+ cfg.train.min_epochs = 10
831
+ cfg.train.max_epochs = 120
832
+ cfg.train.weights_beta = 0.9999
833
+ cfg.train.weights_max_ratio = 2.0
834
+ cfg.tune.enabled = True
835
+ cfg.tune.fast = True
836
+ cfg.tune.n_trials = 25
837
+ cfg.tune.epochs = 120
838
+ cfg.tune.batch_size = 128
839
+ cfg.tune.max_samples = 512
840
+ cfg.tune.max_loci = 0
841
+ cfg.tune.eval_interval = 20
842
+ cfg.tune.patience = 5
843
+ cfg.tune.proxy_metric_batch = 0
844
+ if hasattr(cfg.tune, "infer_epochs"):
845
+ cfg.tune.infer_epochs = 0
846
+
847
+ elif preset == "balanced":
848
+ cfg.model.latent_dim = 8
849
+ cfg.model.num_hidden_layers = 2
850
+ cfg.model.layer_scaling_factor = 3.0
851
+ cfg.model.dropout_rate = 0.20
852
+ cfg.model.gamma = 2.0
853
+ cfg.train.batch_size = 128
854
+ cfg.train.learning_rate = 8e-4
855
+ cfg.train.early_stop_gen = 15
856
+ cfg.train.min_epochs = 50
857
+ cfg.train.max_epochs = 600
858
+ cfg.train.weights_beta = 0.9999
859
+ cfg.train.weights_max_ratio = 2.0
860
+ cfg.tune.enabled = True
861
+ cfg.tune.fast = True
862
+ cfg.tune.n_trials = 75
863
+ cfg.tune.epochs = 300
864
+ cfg.tune.batch_size = 128
865
+ cfg.tune.max_samples = 2048
866
+ cfg.tune.max_loci = 0
867
+ cfg.tune.eval_interval = 20
868
+ cfg.tune.patience = 10
869
+ cfg.tune.proxy_metric_batch = 0
870
+ if hasattr(cfg.tune, "infer_epochs"):
871
+ cfg.tune.infer_epochs = 0
872
+
873
+ else: # thorough
874
+ cfg.model.latent_dim = 16
875
+ cfg.model.num_hidden_layers = 3
876
+ cfg.model.layer_scaling_factor = 5.0
877
+ cfg.model.dropout_rate = 0.30
878
+ cfg.model.gamma = 2.5
879
+ cfg.train.batch_size = 64
880
+ cfg.train.learning_rate = 6e-4
881
+ cfg.train.early_stop_gen = 20 # Reduced from 30
882
+ cfg.train.min_epochs = 100
883
+ cfg.train.max_epochs = 800 # Reduced from 1200
884
+ cfg.train.weights_beta = 0.9999
885
+ cfg.train.weights_max_ratio = 2.0
886
+ cfg.tune.enabled = True
887
+ cfg.tune.fast = False
888
+ cfg.tune.n_trials = 150
889
+ cfg.tune.epochs = 600
890
+ cfg.tune.batch_size = 64
891
+ cfg.tune.max_samples = 5000 # Capped from 0
892
+ cfg.tune.max_loci = 0
893
+ cfg.tune.eval_interval = 10
894
+ cfg.tune.patience = 15 # Reduced from 20
895
+ cfg.tune.proxy_metric_batch = 0
896
+ if hasattr(cfg.tune, "infer_epochs"):
897
+ cfg.tune.infer_epochs = 0
898
+
899
+ return cfg
900
+
901
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "VAEConfig":
902
+ """Apply flat dot-key overrides."""
903
+ if not overrides:
904
+ return self
905
+ for k, v in overrides.items():
906
+ node = self
907
+ parts = k.split(".")
908
+ for p in parts[:-1]:
909
+ node = getattr(node, p)
910
+ last = parts[-1]
911
+ if hasattr(node, last):
912
+ setattr(node, last, v)
913
+ else:
914
+ raise KeyError(f"Unknown config key: {k}")
915
+ return self
916
+
917
+ def to_dict(self) -> Dict[str, Any]:
918
+ return asdict(self)
919
+
920
+
921
+ @dataclass
922
+ class MostFrequentAlgoConfig:
923
+ """Algorithmic knobs for ImputeMostFrequent.
924
+
925
+ Attributes:
926
+ by_populations (bool): Whether to compute per-population modes.
927
+ default (int): Fallback mode if no valid entries in a locus.
928
+ missing (int): Code for missing genotypes in 0/1/2.
929
+ """
930
+
931
+ by_populations: bool = False
932
+ default: int = 0
933
+ missing: int = -1
934
+
935
+
936
+ @dataclass
937
+ class DeterministicSplitConfig:
938
+ """Evaluation split configuration shared by deterministic imputers.
939
+
940
+ Attributes:
941
+ test_size (float): Proportion of data to use as the test set.
942
+ test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
943
+ """
944
+
945
+ test_size: float = 0.2
946
+ test_indices: Optional[Sequence[int]] = None
947
+
948
+
949
+ @dataclass
950
+ class MostFrequentConfig:
951
+ """Top-level configuration for ImputeMostFrequent.
952
+
953
+ Attributes:
954
+ io (IOConfig): I/O configuration.
955
+ plot (PlotConfig): Plotting configuration.
956
+ split (DeterministicSplitConfig): Data splitting configuration.
957
+ algo (MostFrequentAlgoConfig): Algorithmic configuration.
958
+ sim (SimConfig): Simulation configuration.
959
+ tune (TuneConfig): Hyperparameter tuning configuration.
960
+ train (TrainConfig): Training configuration.
961
+ """
962
+
963
+ io: IOConfig = field(default_factory=IOConfig)
964
+ plot: PlotConfig = field(default_factory=PlotConfig)
965
+ split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
966
+ algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
967
+ sim: SimConfig = field(default_factory=SimConfig)
968
+ tune: TuneConfig = field(default_factory=TuneConfig)
969
+ train: TrainConfig = field(default_factory=TrainConfig)
970
+
971
+ @classmethod
972
+ def from_preset(
973
+ cls,
974
+ preset: Literal["fast", "balanced", "thorough"] = "balanced",
975
+ ) -> "MostFrequentConfig":
976
+ """Construct a preset configuration."""
977
+ if preset not in {"fast", "balanced", "thorough"}:
978
+ raise ValueError(f"Unknown preset: {preset}")
979
+
980
+ cfg = cls()
981
+ cfg.io.verbose = False
982
+ cfg.split.test_size = 0.2
983
+ cfg.sim.simulate_missing = True
984
+ cfg.sim.sim_strategy = "random"
985
+ cfg.sim.sim_prop = 0.2
986
+
987
+ return cfg
988
+
989
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "MostFrequentConfig":
990
+ """Apply dot-key overrides."""
991
+ if not overrides:
992
+ return self
993
+ for k, v in overrides.items():
994
+ node = self
995
+ parts = k.split(".")
996
+ for p in parts[:-1]:
997
+ node = getattr(node, p)
998
+ last = parts[-1]
999
+ if hasattr(node, last):
1000
+ setattr(node, last, v)
1001
+ else:
1002
+ pass
1003
+ return self
1004
+
1005
+ def to_dict(self) -> Dict[str, Any]:
1006
+ return asdict(self)
1007
+
1008
+
1009
+ @dataclass
1010
+ class RefAlleleAlgoConfig:
1011
+ """Algorithmic knobs for ImputeRefAllele.
1012
+
1013
+ Attributes:
1014
+ missing (int): Code for missing genotypes in 0/1/2.
1015
+ """
1016
+
1017
+ missing: int = -1
1018
+
1019
+
1020
+ @dataclass
1021
+ class RefAlleleConfig:
1022
+ """Top-level configuration for ImputeRefAllele.
1023
+
1024
+ Attributes:
1025
+ io (IOConfig): I/O configuration.
1026
+ plot (PlotConfig): Plotting configuration.
1027
+ split (DeterministicSplitConfig): Data splitting configuration.
1028
+ algo (RefAlleleAlgoConfig): Algorithmic configuration.
1029
+ sim (SimConfig): Simulation configuration.
1030
+ tune (TuneConfig): Hyperparameter tuning configuration.
1031
+ train (TrainConfig): Training configuration.
1032
+ """
1033
+
1034
+ io: IOConfig = field(default_factory=IOConfig)
1035
+ plot: PlotConfig = field(default_factory=PlotConfig)
1036
+ split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
1037
+ algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
1038
+ sim: SimConfig = field(default_factory=SimConfig)
1039
+ tune: TuneConfig = field(default_factory=TuneConfig)
1040
+ train: TrainConfig = field(default_factory=TrainConfig)
1041
+
1042
+ @classmethod
1043
+ def from_preset(
1044
+ cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
1045
+ ) -> "RefAlleleConfig":
1046
+ """Presets mainly keep parity with logging/IO and split test_size."""
1047
+ if preset not in {"fast", "balanced", "thorough"}:
1048
+ raise ValueError(f"Unknown preset: {preset}")
1049
+
1050
+ cfg = cls()
1051
+ cfg.io.verbose = False
1052
+ cfg.split.test_size = 0.2
1053
+ cfg.sim.simulate_missing = True
1054
+ cfg.sim.sim_strategy = "random"
1055
+ cfg.sim.sim_prop = 0.2
1056
+ return cfg
1057
+
1058
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RefAlleleConfig":
1059
+ """Apply dot-key overrides."""
1060
+ if not overrides:
1061
+ return self
1062
+ for k, v in overrides.items():
1063
+ node = self
1064
+ parts = k.split(".")
1065
+ for p in parts[:-1]:
1066
+ node = getattr(node, p)
1067
+ last = parts[-1]
1068
+ if hasattr(node, last):
1069
+ setattr(node, last, v)
1070
+ else:
1071
+ pass
1072
+ return self
1073
+
1074
+ def to_dict(self) -> Dict[str, Any]:
1075
+ return asdict(self)
1076
+
1077
+
1078
+ def _flatten_dict(
1079
+ d: Dict[str, Any], prefix: str = "", out: Optional[Dict[str, Any]] = None
1080
+ ) -> Dict[str, Any]:
1081
+ """Flatten a nested dictionary into dot-key format."""
1082
+ out = out or {}
1083
+ for k, v in d.items():
1084
+ kk = f"{prefix}.{k}" if prefix else k
1085
+ if isinstance(v, dict):
1086
+ _flatten_dict(v, kk, out)
1087
+ else:
1088
+ out[kk] = v
1089
+ return out
1090
+
1091
+
1092
+ @dataclass
1093
+ class IOConfigSupervised:
1094
+ """I/O, logging, and run identity.
1095
+
1096
+ Attributes:
1097
+ prefix (str): Prefix for output files and logs.
1098
+ seed (Optional[int]): Random seed for reproducibility.
1099
+ n_jobs (int): Number of parallel jobs to use.
1100
+ verbose (bool): Whether to enable verbose logging.
1101
+ debug (bool): Whether to enable debug mode.
1102
+ """
1103
+
1104
+ prefix: str = "pgsui"
1105
+ seed: Optional[int] = None
1106
+ n_jobs: int = 1
1107
+ verbose: bool = False
1108
+ debug: bool = False
1109
+
1110
+
1111
+ @dataclass
1112
+ class PlotConfigSupervised:
1113
+ """Plot/figure styling.
1114
+
1115
+ Attributes:
1116
+ fmt (Literal["pdf", "png", "jpg", "jpeg"]): File format.
1117
+ dpi (int): Resolution in dots per inch.
1118
+ fontsize (int): Base font size for plot text.
1119
+ despine (bool): Whether to remove top/right spines.
1120
+ show (bool): Whether to display plots interactively.
1121
+ """
1122
+
1123
+ fmt: Literal["pdf", "png", "jpg", "jpeg"] = "pdf"
1124
+ dpi: int = 300
1125
+ fontsize: int = 18
1126
+ despine: bool = True
1127
+ show: bool = False
1128
+
1129
+
1130
+ @dataclass
1131
+ class TrainConfigSupervised:
1132
+ """Training/evaluation split (by samples).
1133
+
1134
+ Attributes:
1135
+ validation_split (float): Proportion of data to use for validation.
1136
+ """
1137
+
1138
+ validation_split: float = 0.20
1139
+
1140
+ def __post_init__(self):
1141
+ if not (0.0 < self.validation_split < 1.0):
1142
+ raise ValueError("validation_split must be between 0.0 and 1.0")
1143
+
1144
+
1145
+ @dataclass
1146
+ class ImputerConfigSupervised:
1147
+ """IterativeImputer-like scaffolding used by current supervised wrappers.
1148
+
1149
+ Attributes:
1150
+ n_nearest_features (Optional[int]): Number of nearest features to use.
1151
+ max_iter (int): Maximum number of imputation iterations to perform.
1152
+ """
1153
+
1154
+ n_nearest_features: Optional[int] = 10
1155
+ max_iter: int = 10
1156
+
1157
+
1158
+ @dataclass
1159
+ class SimConfigSupervised:
1160
+ """Simulation of missingness for evaluation.
1161
+
1162
+ Attributes:
1163
+ prop_missing (float): Proportion of features to set as missing.
1164
+ strategy (Literal["random", "random_inv_genotype"]): Strategy.
1165
+ het_boost (float): Boosting factor for heterogeneity.
1166
+ missing_val (int): Internal code for missing genotypes.
1167
+ """
1168
+
1169
+ prop_missing: float = 0.5
1170
+ strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
1171
+ het_boost: float = 2.0
1172
+ missing_val: int = -1
1173
+
1174
+
1175
+ @dataclass
1176
+ class TuningConfigSupervised:
1177
+ """Optuna tuning envelope."""
1178
+
1179
+ enabled: bool = True
1180
+ n_trials: int = 100
1181
+ metric: str = "pr_macro"
1182
+ n_jobs: int = 8
1183
+ fast: bool = True
1184
+
1185
+
1186
+ @dataclass
1187
+ class RFModelConfig:
1188
+ """Random Forest hyperparameters.
1189
+
1190
+ Attributes:
1191
+ n_estimators (int): Number of trees in the forest.
1192
+ max_depth (Optional[int]): Maximum depth of the trees.
1193
+ min_samples_split (int): Minimum number of samples required to split.
1194
+ min_samples_leaf (int): Minimum number of samples required at a leaf.
1195
+ max_features (Literal["sqrt", "log2"] | float | None): Features to consider.
1196
+ criterion (Literal["gini", "entropy", "log_loss"]): Split quality metric.
1197
+ class_weight (Literal["balanced", "balanced_subsample", None]): Class weights.
1198
+ """
1199
+
1200
+ n_estimators: int = 100
1201
+ max_depth: Optional[int] = None
1202
+ min_samples_split: int = 2
1203
+ min_samples_leaf: int = 1
1204
+ max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
1205
+ criterion: Literal["gini", "entropy", "log_loss"] = "gini"
1206
+ class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
1207
+
1208
+
1209
+ @dataclass
1210
+ class HGBModelConfig:
1211
+ """Histogram-based Gradient Boosting hyperparameters.
1212
+
1213
+ Attributes:
1214
+ n_estimators (int): Number of boosting iterations (max_iter).
1215
+ learning_rate (float): Step size for each boosting iteration.
1216
+ max_depth (Optional[int]): Maximum depth of each tree.
1217
+ min_samples_leaf (int): Minimum number of samples required at a leaf.
1218
+ max_features (float | None): Proportion of features to consider.
1219
+ n_iter_no_change (int): Iterations to wait for early stopping.
1220
+ tol (float): Minimum improvement in the loss.
1221
+ """
1222
+
1223
+ n_estimators: int = 100 # maps to max_iter
1224
+ learning_rate: float = 0.1
1225
+ max_depth: Optional[int] = None
1226
+ min_samples_leaf: int = 1
1227
+ max_features: float | None = 1.0
1228
+ n_iter_no_change: int = 10
1229
+ tol: float = 1e-7
1230
+
1231
+ def __post_init__(self) -> None:
1232
+ if isinstance(self.max_features, float):
1233
+ if not (0.0 < self.max_features <= 1.0):
1234
+ raise ValueError("max_features as float must be in (0.0, 1.0]")
1235
+
1236
+ if self.n_estimators <= 0:
1237
+ raise ValueError("n_estimators must be a positive integer")
1238
+
1239
+
1240
+ @dataclass
1241
+ class RFConfig:
1242
+ """Configuration for ImputeRandomForest.
1243
+
1244
+ Attributes:
1245
+ io (IOConfigSupervised): Run identity, logging, and seeds.
1246
+ model (RFModelConfig): RandomForest hyperparameters.
1247
+ train (TrainConfigSupervised): Sample split for validation.
1248
+ imputer (ImputerConfigSupervised): IterativeImputer scaffolding.
1249
+ sim (SimConfigSupervised): Simulated missingness.
1250
+ plot (PlotConfigSupervised): Plot styling.
1251
+ tune (TuningConfigSupervised): Optuna knobs.
1252
+ """
1253
+
1254
+ io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
1255
+ model: RFModelConfig = field(default_factory=RFModelConfig)
1256
+ train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
1257
+ imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
1258
+ sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
1259
+ plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
1260
+ tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
1261
+
1262
+ @classmethod
1263
+ def from_preset(cls, preset: str = "balanced") -> "RFConfig":
1264
+ """Build a config from a named preset."""
1265
+ cfg = cls()
1266
+ if preset == "fast":
1267
+ cfg.model.n_estimators = 100 # Increased from 50
1268
+ cfg.model.max_depth = None
1269
+ cfg.imputer.max_iter = 5
1270
+ cfg.io.n_jobs = 1
1271
+ cfg.tune.enabled = False
1272
+ elif preset == "balanced":
1273
+ cfg.model.n_estimators = 200 # Increased from 100
1274
+ cfg.model.max_depth = None
1275
+ cfg.imputer.max_iter = 10
1276
+ cfg.io.n_jobs = 1
1277
+ cfg.tune.enabled = False
1278
+ cfg.tune.n_trials = 100
1279
+ elif preset == "thorough":
1280
+ cfg.model.n_estimators = 500
1281
+ cfg.model.max_depth = 50 # Added safety cap
1282
+ cfg.imputer.max_iter = 15
1283
+ cfg.io.n_jobs = 1
1284
+ cfg.tune.enabled = False
1285
+ cfg.tune.n_trials = 250
1286
+ else:
1287
+ raise ValueError(f"Unknown preset: {preset}")
1288
+
1289
+ return cfg
1290
+
1291
+ @classmethod
1292
+ def from_yaml(cls, path: str) -> "RFConfig":
1293
+ """Load from YAML; honors optional top-level 'preset'."""
1294
+ return load_yaml_to_dataclass(path, cls)
1295
+
1296
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RFConfig":
1297
+ """Apply flat dot-key overrides."""
1298
+ if overrides:
1299
+ apply_dot_overrides(self, overrides)
1300
+ return self
1301
+
1302
+ def to_dict(self) -> Dict[str, Any]:
1303
+ return asdict(self)
1304
+
1305
+ def to_imputer_kwargs(self) -> Dict[str, Any]:
1306
+ return {
1307
+ "prefix": self.io.prefix,
1308
+ "seed": self.io.seed,
1309
+ "n_jobs": self.io.n_jobs,
1310
+ "verbose": self.io.verbose,
1311
+ "debug": self.io.debug,
1312
+ "model_n_estimators": self.model.n_estimators,
1313
+ "model_max_depth": self.model.max_depth,
1314
+ "model_min_samples_split": self.model.min_samples_split,
1315
+ "model_min_samples_leaf": self.model.min_samples_leaf,
1316
+ "model_max_features": self.model.max_features,
1317
+ "model_criterion": self.model.criterion,
1318
+ "model_validation_split": self.train.validation_split,
1319
+ "model_n_nearest_features": self.imputer.n_nearest_features,
1320
+ "model_max_iter": self.imputer.max_iter,
1321
+ "sim_prop_missing": self.sim.prop_missing,
1322
+ "sim_strategy": self.sim.strategy,
1323
+ "sim_het_boost": self.sim.het_boost,
1324
+ "plot_format": self.plot.fmt,
1325
+ "plot_fontsize": self.plot.fontsize,
1326
+ "plot_despine": self.plot.despine,
1327
+ "plot_dpi": self.plot.dpi,
1328
+ "plot_show_plots": self.plot.show,
1329
+ }
1330
+
1331
+
1332
+ @dataclass
1333
+ class HGBConfig:
1334
+ """Configuration for ImputeHistGradientBoosting.
1335
+
1336
+ Attributes:
1337
+ io (IOConfigSupervised): Run identity, logging, and seeds.
1338
+ model (HGBModelConfig): HistGradientBoosting hyperparameters.
1339
+ train (TrainConfigSupervised): Sample split for validation.
1340
+ imputer (ImputerConfigSupervised): IterativeImputer scaffolding.
1341
+ sim (SimConfigSupervised): Simulated missingness.
1342
+ plot (PlotConfigSupervised): Plot styling.
1343
+ tune (TuningConfigSupervised): Optuna knobs.
1344
+ """
1345
+
1346
+ io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
1347
+ model: HGBModelConfig = field(default_factory=HGBModelConfig)
1348
+ train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
1349
+ imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
1350
+ sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
1351
+ plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
1352
+ tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
1353
+
1354
+ @classmethod
1355
+ def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
1356
+ """Build a config from a named preset."""
1357
+ cfg = cls()
1358
+ if preset == "fast":
1359
+ cfg.model.n_estimators = 50
1360
+ cfg.model.learning_rate = 0.15
1361
+ cfg.model.max_depth = None
1362
+ cfg.imputer.max_iter = 5
1363
+ cfg.io.n_jobs = 1
1364
+ cfg.tune.enabled = False
1365
+ cfg.tune.n_trials = 50
1366
+ elif preset == "balanced":
1367
+ cfg.model.n_estimators = 100
1368
+ cfg.model.learning_rate = 0.1
1369
+ cfg.model.max_depth = None
1370
+ cfg.imputer.max_iter = 10
1371
+ cfg.io.n_jobs = 1
1372
+ cfg.tune.enabled = False
1373
+ cfg.tune.n_trials = 100
1374
+ elif preset == "thorough":
1375
+ cfg.model.n_estimators = 500
1376
+ cfg.model.learning_rate = 0.05 # Reduced from 0.08
1377
+ cfg.model.n_iter_no_change = 20 # Increased patience
1378
+ cfg.model.max_depth = None
1379
+ cfg.imputer.max_iter = 15
1380
+ cfg.io.n_jobs = 1
1381
+ cfg.tune.enabled = False
1382
+ cfg.tune.n_trials = 250
1383
+ else:
1384
+ raise ValueError(f"Unknown preset: {preset}")
1385
+ return cfg
1386
+
1387
+ @classmethod
1388
+ def from_yaml(cls, path: str) -> "HGBConfig":
1389
+ return load_yaml_to_dataclass(path, cls)
1390
+
1391
+ def apply_overrides(self, overrides: Dict[str, Any] | None) -> "HGBConfig":
1392
+ if overrides:
1393
+ apply_dot_overrides(self, overrides)
1394
+ return self
1395
+
1396
+ def to_dict(self) -> Dict[str, Any]:
1397
+ return asdict(self)
1398
+
1399
+ def to_imputer_kwargs(self) -> Dict[str, Any]:
1400
+ return {
1401
+ "prefix": self.io.prefix,
1402
+ "seed": self.io.seed,
1403
+ "n_jobs": self.io.n_jobs,
1404
+ "verbose": self.io.verbose,
1405
+ "debug": self.io.debug,
1406
+ "model_n_estimators": self.model.n_estimators,
1407
+ "model_learning_rate": self.model.learning_rate,
1408
+ "model_n_iter_no_change": self.model.n_iter_no_change,
1409
+ "model_tol": self.model.tol,
1410
+ "model_max_depth": self.model.max_depth,
1411
+ "model_min_samples_leaf": self.model.min_samples_leaf,
1412
+ "model_max_features": self.model.max_features,
1413
+ "model_validation_split": self.train.validation_split,
1414
+ "model_n_nearest_features": self.imputer.n_nearest_features,
1415
+ "model_max_iter": self.imputer.max_iter,
1416
+ "sim_prop_missing": self.sim.prop_missing,
1417
+ "sim_strategy": self.sim.strategy,
1418
+ "sim_het_boost": self.sim.het_boost,
1419
+ "plot_format": self.plot.fmt,
1420
+ "plot_fontsize": self.plot.fontsize,
1421
+ "plot_despine": self.plot.despine,
1422
+ "plot_dpi": self.plot.dpi,
1423
+ "plot_show_plots": self.plot.show,
1424
+ }