pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,1424 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from typing import Any, Dict, Literal, Optional, Sequence
|
|
5
|
+
|
|
6
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class _SimParams:
|
|
11
|
+
"""Container for simulation hyperparameters.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
prop_missing (float): Proportion of missing values to simulate.
|
|
15
|
+
strategy (Literal["random", "random_inv_genotype"]): Strategy for simulating missing values.
|
|
16
|
+
missing_val (int | float): Value to represent missing data.
|
|
17
|
+
het_boost (float): Boost factor for heterozygous genotypes.
|
|
18
|
+
seed (int | None): Random seed for reproducibility.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
prop_missing: float = 0.3
|
|
22
|
+
strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
|
|
23
|
+
missing_val: int | float = -1
|
|
24
|
+
het_boost: float = 2.0
|
|
25
|
+
seed: int | None = None
|
|
26
|
+
|
|
27
|
+
def to_dict(self) -> dict:
|
|
28
|
+
return asdict(self)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class _ImputerParams:
|
|
33
|
+
"""Container for imputer hyperparameters.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
n_nearest_features (int | None): Number of nearest features to consider for imputation.
|
|
37
|
+
max_iter (int): Maximum number of iterations for the imputation algorithm.
|
|
38
|
+
initial_strategy (Literal["mean", "median", "most_frequent", "constant"]): Strategy for initial imputation.
|
|
39
|
+
keep_empty_features (bool): Whether to keep features that are entirely missing.
|
|
40
|
+
random_state (int | None): Random seed for reproducibility.
|
|
41
|
+
verbose (bool): If True, enables verbose logging during imputation.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
n_nearest_features: int | None = 10
|
|
45
|
+
max_iter: int = 10
|
|
46
|
+
initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = (
|
|
47
|
+
"most_frequent"
|
|
48
|
+
)
|
|
49
|
+
keep_empty_features: bool = True
|
|
50
|
+
random_state: int | None = None
|
|
51
|
+
verbose: bool = False
|
|
52
|
+
|
|
53
|
+
def to_dict(self) -> dict:
|
|
54
|
+
return asdict(self)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class _RFParams:
|
|
59
|
+
"""Container for RandomForest hyperparameters.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
n_estimators (int): Number of trees in the forest.
|
|
63
|
+
max_depth (int | None): Maximum depth of the trees.
|
|
64
|
+
min_samples_split (int): Minimum number of samples required to split an internal node.
|
|
65
|
+
min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
|
|
66
|
+
max_features (Literal["sqrt", "log2"] | float | None): Number of features to consider for split.
|
|
67
|
+
criterion (Literal["gini", "entropy", "log_loss"]): Function to measure the quality of a split.
|
|
68
|
+
class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
n_estimators: int = 300
|
|
72
|
+
max_depth: int | None = None
|
|
73
|
+
min_samples_split: int = 2
|
|
74
|
+
min_samples_leaf: int = 1
|
|
75
|
+
max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
|
|
76
|
+
criterion: Literal["gini", "entropy", "log_loss"] = "gini"
|
|
77
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> dict:
|
|
80
|
+
return asdict(self)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class _HGBParams:
|
|
85
|
+
"""Container for HistGradientBoosting hyperparameters.
|
|
86
|
+
|
|
87
|
+
Attributes:
|
|
88
|
+
max_iter (int): Maximum number of iterations.
|
|
89
|
+
learning_rate (float): Learning rate shrinks the contribution of each tree.
|
|
90
|
+
max_depth (int | None): Maximum depth of the individual regression estimators.
|
|
91
|
+
min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
|
|
92
|
+
n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping.
|
|
93
|
+
tol (float): Tolerance for the early stopping.
|
|
94
|
+
max_features (float | None): The fraction of features to consider when looking for the best split.
|
|
95
|
+
class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
|
|
96
|
+
random_state (int | None): Random seed for reproducibility.
|
|
97
|
+
verbose (bool): If True, enables verbose logging during training.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
max_iter: int = 100
|
|
101
|
+
learning_rate: float = 0.1
|
|
102
|
+
max_depth: int | None = None
|
|
103
|
+
min_samples_leaf: int = 1
|
|
104
|
+
n_iter_no_change: int = 10
|
|
105
|
+
tol: float = 1e-7
|
|
106
|
+
max_features: float | None = 1.0
|
|
107
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
|
|
108
|
+
random_state: int | None = None
|
|
109
|
+
verbose: bool = False
|
|
110
|
+
|
|
111
|
+
def to_dict(self) -> dict:
|
|
112
|
+
return asdict(self)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class ModelConfig:
|
|
117
|
+
"""Model architecture configuration.
|
|
118
|
+
|
|
119
|
+
Attributes:
|
|
120
|
+
latent_init (Literal["random", "pca"]): Method for initializing the latent space.
|
|
121
|
+
latent_dim (int): Dimensionality of the latent space.
|
|
122
|
+
dropout_rate (float): Dropout rate for regularization.
|
|
123
|
+
num_hidden_layers (int): Number of hidden layers in the neural network.
|
|
124
|
+
hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
|
|
125
|
+
layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
|
|
126
|
+
layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
|
|
127
|
+
gamma (float): Parameter for the focal loss function.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
latent_init: Literal["random", "pca"] = "random"
|
|
131
|
+
latent_dim: int = 2
|
|
132
|
+
dropout_rate: float = 0.2
|
|
133
|
+
num_hidden_layers: int = 2
|
|
134
|
+
hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
|
|
135
|
+
layer_scaling_factor: float = 5.0
|
|
136
|
+
layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
|
|
137
|
+
gamma: float = 2.0
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class TrainConfig:
|
|
142
|
+
"""Training procedure configuration.
|
|
143
|
+
|
|
144
|
+
Attributes:
|
|
145
|
+
batch_size (int): Number of samples per training batch.
|
|
146
|
+
learning_rate (float): Learning rate for the optimizer.
|
|
147
|
+
lr_input_factor (float): Factor to scale the learning rate for input layer.
|
|
148
|
+
l1_penalty (float): L1 regularization penalty.
|
|
149
|
+
early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
|
|
150
|
+
min_epochs (int): Minimum number of epochs to train.
|
|
151
|
+
max_epochs (int): Maximum number of epochs to train.
|
|
152
|
+
validation_split (float): Proportion of data to use for validation.
|
|
153
|
+
weights_beta (float): Smoothing factor for class weights.
|
|
154
|
+
weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
|
|
155
|
+
device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
batch_size: int = 32
|
|
159
|
+
learning_rate: float = 1e-3
|
|
160
|
+
lr_input_factor: float = 1.0
|
|
161
|
+
l1_penalty: float = 0.0
|
|
162
|
+
early_stop_gen: int = 20
|
|
163
|
+
min_epochs: int = 100
|
|
164
|
+
max_epochs: int = 5000
|
|
165
|
+
validation_split: float = 0.2
|
|
166
|
+
weights_beta: float = 0.9999
|
|
167
|
+
weights_max_ratio: float = 1.0
|
|
168
|
+
device: Literal["gpu", "cpu", "mps"] = "cpu"
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@dataclass
|
|
172
|
+
class TuneConfig:
|
|
173
|
+
"""Hyperparameter tuning configuration.
|
|
174
|
+
|
|
175
|
+
Attributes:
|
|
176
|
+
enabled (bool): If True, enables hyperparameter tuning.
|
|
177
|
+
metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
|
|
178
|
+
n_trials (int): Number of hyperparameter trials to run.
|
|
179
|
+
resume (bool): If True, resumes tuning from a previous state.
|
|
180
|
+
save_db (bool): If True, saves the tuning results to a database.
|
|
181
|
+
fast (bool): If True, uses a faster but less thorough tuning approach.
|
|
182
|
+
max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
|
|
183
|
+
max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
|
|
184
|
+
epochs (int): Number of epochs to train each trial.
|
|
185
|
+
batch_size (int): Batch size for training during tuning.
|
|
186
|
+
eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
|
|
187
|
+
infer_epochs (int): Number of epochs for inference during tuning.
|
|
188
|
+
patience (int): Number of evaluations with no improvement before stopping early.
|
|
189
|
+
proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
enabled: bool = False
|
|
193
|
+
metric: Literal[
|
|
194
|
+
"f1",
|
|
195
|
+
"accuracy",
|
|
196
|
+
"pr_macro",
|
|
197
|
+
"average_precision",
|
|
198
|
+
"roc_auc",
|
|
199
|
+
"precision",
|
|
200
|
+
"recall",
|
|
201
|
+
] = "f1"
|
|
202
|
+
n_trials: int = 100
|
|
203
|
+
resume: bool = False
|
|
204
|
+
save_db: bool = False
|
|
205
|
+
fast: bool = True
|
|
206
|
+
max_samples: int = 512
|
|
207
|
+
max_loci: int = 0 # 0 = all
|
|
208
|
+
epochs: int = 500
|
|
209
|
+
batch_size: int = 64
|
|
210
|
+
eval_interval: int = 20
|
|
211
|
+
infer_epochs: int = 100
|
|
212
|
+
patience: int = 10
|
|
213
|
+
proxy_metric_batch: int = 0
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@dataclass
|
|
217
|
+
class EvalConfig:
|
|
218
|
+
"""Evaluation configuration.
|
|
219
|
+
|
|
220
|
+
Attributes:
|
|
221
|
+
eval_latent_steps (int): Number of optimization steps for latent space evaluation.
|
|
222
|
+
eval_latent_lr (float): Learning rate for latent space optimization.
|
|
223
|
+
eval_latent_weight_decay (float): Weight decay for latent space optimization.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
eval_latent_steps: int = 50
|
|
227
|
+
eval_latent_lr: float = 1e-2
|
|
228
|
+
eval_latent_weight_decay: float = 0.0
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@dataclass
|
|
232
|
+
class PlotConfig:
|
|
233
|
+
"""Plotting configuration.
|
|
234
|
+
|
|
235
|
+
Attributes:
|
|
236
|
+
fmt (Literal["pdf", "png", "jpg", "jpeg", "svg"]): Output file format.
|
|
237
|
+
dpi (int): Dots per inch for the output figure.
|
|
238
|
+
fontsize (int): Font size for text in the plots.
|
|
239
|
+
despine (bool): If True, removes the top and right spines from plots.
|
|
240
|
+
show (bool): If True, displays the plot interactively.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
fmt: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
|
|
244
|
+
dpi: int = 300
|
|
245
|
+
fontsize: int = 18
|
|
246
|
+
despine: bool = True
|
|
247
|
+
show: bool = False
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@dataclass
|
|
251
|
+
class IOConfig:
|
|
252
|
+
"""I/O configuration.
|
|
253
|
+
|
|
254
|
+
Attributes:
|
|
255
|
+
prefix (str): Prefix for output files. Default is "pgsui".
|
|
256
|
+
verbose (bool): If True, enables verbose logging. Default is False.
|
|
257
|
+
debug (bool): If True, enables debug mode. Default is False.
|
|
258
|
+
seed (int | None): Random seed for reproducibility. Default is None.
|
|
259
|
+
n_jobs (int): Number of parallel jobs to run. Default is 1.
|
|
260
|
+
scoring_averaging (Literal["macro", "micro", "weighted"]): Averaging method.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
prefix: str = "pgsui"
|
|
264
|
+
verbose: bool = False
|
|
265
|
+
debug: bool = False
|
|
266
|
+
seed: int | None = None
|
|
267
|
+
n_jobs: int = 1
|
|
268
|
+
scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@dataclass
|
|
272
|
+
class SimConfig:
|
|
273
|
+
"""Top-level configuration for data simulation and imputation.
|
|
274
|
+
|
|
275
|
+
Attributes:
|
|
276
|
+
simulate_missing (bool): If True, simulates missing data.
|
|
277
|
+
sim_strategy (Literal["random", ...]): Strategy for simulating missing data.
|
|
278
|
+
sim_prop (float): Proportion of data to simulate as missing.
|
|
279
|
+
sim_kwargs (dict | None): Additional keyword arguments for simulation.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
simulate_missing: bool = False
|
|
283
|
+
sim_strategy: Literal[
|
|
284
|
+
"random",
|
|
285
|
+
"random_weighted",
|
|
286
|
+
"random_weighted_inv",
|
|
287
|
+
"nonrandom",
|
|
288
|
+
"nonrandom_weighted",
|
|
289
|
+
] = "random"
|
|
290
|
+
sim_prop: float = 0.10
|
|
291
|
+
sim_kwargs: dict | None = None
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@dataclass
|
|
295
|
+
class NLPCAConfig:
|
|
296
|
+
"""Top-level configuration for ImputeNLPCA.
|
|
297
|
+
|
|
298
|
+
Attributes:
|
|
299
|
+
io (IOConfig): I/O configuration.
|
|
300
|
+
model (ModelConfig): Model architecture configuration.
|
|
301
|
+
train (TrainConfig): Training procedure configuration.
|
|
302
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
303
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
304
|
+
plot (PlotConfig): Plotting configuration.
|
|
305
|
+
sim (SimConfig): Simulation configuration.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
309
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
310
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
311
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
312
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
313
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
314
|
+
sim: SimConfig = field(default_factory=SimConfig)
|
|
315
|
+
|
|
316
|
+
@classmethod
|
|
317
|
+
def from_preset(
|
|
318
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
319
|
+
) -> "NLPCAConfig":
|
|
320
|
+
"""Build a NLPCAConfig from a named preset."""
|
|
321
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
322
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
323
|
+
|
|
324
|
+
cfg = cls()
|
|
325
|
+
|
|
326
|
+
# Common baselines
|
|
327
|
+
cfg.io.verbose = False
|
|
328
|
+
cfg.train.validation_split = 0.20
|
|
329
|
+
cfg.model.hidden_activation = "relu"
|
|
330
|
+
cfg.model.layer_schedule = "pyramid"
|
|
331
|
+
cfg.model.latent_init = "random"
|
|
332
|
+
cfg.evaluate.eval_latent_lr = 1e-2
|
|
333
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
334
|
+
cfg.sim.simulate_missing = True
|
|
335
|
+
cfg.sim.sim_strategy = "random"
|
|
336
|
+
cfg.sim.sim_prop = 0.2
|
|
337
|
+
|
|
338
|
+
if preset == "fast":
|
|
339
|
+
# Model
|
|
340
|
+
cfg.model.latent_dim = 4
|
|
341
|
+
cfg.model.num_hidden_layers = 1
|
|
342
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
343
|
+
cfg.model.dropout_rate = 0.10
|
|
344
|
+
cfg.model.gamma = 1.5
|
|
345
|
+
# Train
|
|
346
|
+
cfg.train.batch_size = 128
|
|
347
|
+
cfg.train.learning_rate = 1e-3
|
|
348
|
+
cfg.train.early_stop_gen = 5
|
|
349
|
+
cfg.train.min_epochs = 10
|
|
350
|
+
cfg.train.max_epochs = 120
|
|
351
|
+
cfg.train.weights_beta = 0.9999
|
|
352
|
+
cfg.train.weights_max_ratio = 2.0
|
|
353
|
+
# Tuning
|
|
354
|
+
cfg.tune.enabled = True
|
|
355
|
+
cfg.tune.fast = True
|
|
356
|
+
cfg.tune.n_trials = 25
|
|
357
|
+
cfg.tune.epochs = 120
|
|
358
|
+
cfg.tune.batch_size = 128
|
|
359
|
+
cfg.tune.max_samples = 512
|
|
360
|
+
cfg.tune.max_loci = 0
|
|
361
|
+
cfg.tune.eval_interval = 20
|
|
362
|
+
cfg.tune.infer_epochs = 20
|
|
363
|
+
cfg.tune.patience = 5
|
|
364
|
+
cfg.tune.proxy_metric_batch = 0
|
|
365
|
+
# Eval
|
|
366
|
+
cfg.evaluate.eval_latent_steps = 20
|
|
367
|
+
|
|
368
|
+
elif preset == "balanced":
|
|
369
|
+
# Model
|
|
370
|
+
cfg.model.latent_dim = 8
|
|
371
|
+
cfg.model.num_hidden_layers = 2
|
|
372
|
+
cfg.model.layer_scaling_factor = 3.0
|
|
373
|
+
cfg.model.dropout_rate = 0.20
|
|
374
|
+
cfg.model.gamma = 2.0
|
|
375
|
+
# Train
|
|
376
|
+
cfg.train.batch_size = 128
|
|
377
|
+
cfg.train.learning_rate = 8e-4
|
|
378
|
+
cfg.train.early_stop_gen = 15
|
|
379
|
+
cfg.train.min_epochs = 50
|
|
380
|
+
cfg.train.max_epochs = 600
|
|
381
|
+
cfg.train.weights_beta = 0.9999
|
|
382
|
+
cfg.train.weights_max_ratio = 2.0
|
|
383
|
+
# Tuning
|
|
384
|
+
cfg.tune.enabled = True
|
|
385
|
+
cfg.tune.fast = True
|
|
386
|
+
cfg.tune.n_trials = 75
|
|
387
|
+
cfg.tune.epochs = 300
|
|
388
|
+
cfg.tune.batch_size = 128
|
|
389
|
+
cfg.tune.max_samples = 2048
|
|
390
|
+
cfg.tune.max_loci = 0
|
|
391
|
+
cfg.tune.eval_interval = 20
|
|
392
|
+
cfg.tune.infer_epochs = 40
|
|
393
|
+
cfg.tune.patience = 10
|
|
394
|
+
cfg.tune.proxy_metric_batch = 0
|
|
395
|
+
# Eval
|
|
396
|
+
cfg.evaluate.eval_latent_steps = 30
|
|
397
|
+
|
|
398
|
+
else: # thorough
|
|
399
|
+
# Model
|
|
400
|
+
cfg.model.latent_dim = 16
|
|
401
|
+
cfg.model.num_hidden_layers = 3
|
|
402
|
+
cfg.model.layer_scaling_factor = 5.0
|
|
403
|
+
cfg.model.dropout_rate = 0.30
|
|
404
|
+
cfg.model.gamma = 2.5
|
|
405
|
+
# Train
|
|
406
|
+
cfg.train.batch_size = 64
|
|
407
|
+
cfg.train.learning_rate = 6e-4
|
|
408
|
+
cfg.train.early_stop_gen = 20 # Reduced from 30
|
|
409
|
+
cfg.train.min_epochs = 100
|
|
410
|
+
cfg.train.max_epochs = 800 # Reduced from 1200
|
|
411
|
+
cfg.train.weights_beta = 0.9999
|
|
412
|
+
cfg.train.weights_max_ratio = 2.0
|
|
413
|
+
# Tuning
|
|
414
|
+
cfg.tune.enabled = True
|
|
415
|
+
cfg.tune.fast = False
|
|
416
|
+
cfg.tune.n_trials = 150
|
|
417
|
+
cfg.tune.epochs = 600
|
|
418
|
+
cfg.tune.batch_size = 64
|
|
419
|
+
cfg.tune.max_samples = 5000 # Capped from 0
|
|
420
|
+
cfg.tune.max_loci = 0
|
|
421
|
+
cfg.tune.eval_interval = 10
|
|
422
|
+
cfg.tune.infer_epochs = 80
|
|
423
|
+
cfg.tune.patience = 15 # Reduced from 20
|
|
424
|
+
cfg.tune.proxy_metric_batch = 0
|
|
425
|
+
# Eval
|
|
426
|
+
cfg.evaluate.eval_latent_steps = 50
|
|
427
|
+
|
|
428
|
+
return cfg
|
|
429
|
+
|
|
430
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
|
|
431
|
+
"""Apply flat dot-key overrides."""
|
|
432
|
+
if not overrides:
|
|
433
|
+
return self
|
|
434
|
+
for k, v in overrides.items():
|
|
435
|
+
node = self
|
|
436
|
+
parts = k.split(".")
|
|
437
|
+
for p in parts[:-1]:
|
|
438
|
+
node = getattr(node, p)
|
|
439
|
+
last = parts[-1]
|
|
440
|
+
if hasattr(node, last):
|
|
441
|
+
setattr(node, last, v)
|
|
442
|
+
else:
|
|
443
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
444
|
+
return self
|
|
445
|
+
|
|
446
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
447
|
+
return asdict(self)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
@dataclass
|
|
451
|
+
class UBPConfig:
|
|
452
|
+
"""Top-level configuration for ImputeUBP.
|
|
453
|
+
|
|
454
|
+
Attributes:
|
|
455
|
+
io (IOConfig): I/O configuration.
|
|
456
|
+
model (ModelConfig): Model architecture configuration.
|
|
457
|
+
train (TrainConfig): Training procedure configuration.
|
|
458
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
459
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
460
|
+
plot (PlotConfig): Plotting configuration.
|
|
461
|
+
sim (SimConfig): Simulated-missing configuration.
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
465
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
466
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
467
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
468
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
469
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
470
|
+
sim: SimConfig = field(default_factory=SimConfig)
|
|
471
|
+
|
|
472
|
+
@classmethod
|
|
473
|
+
def from_preset(
|
|
474
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
475
|
+
) -> "UBPConfig":
|
|
476
|
+
"""Build a UBPConfig from a named preset."""
|
|
477
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
478
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
479
|
+
|
|
480
|
+
cfg = cls()
|
|
481
|
+
|
|
482
|
+
# Common baselines
|
|
483
|
+
cfg.io.verbose = False
|
|
484
|
+
cfg.model.hidden_activation = "relu"
|
|
485
|
+
cfg.model.layer_schedule = "pyramid"
|
|
486
|
+
cfg.model.latent_init = "random"
|
|
487
|
+
cfg.sim.simulate_missing = True
|
|
488
|
+
cfg.sim.sim_strategy = "random"
|
|
489
|
+
cfg.sim.sim_prop = 0.2
|
|
490
|
+
|
|
491
|
+
if preset == "fast":
|
|
492
|
+
# Model
|
|
493
|
+
cfg.model.latent_dim = 4
|
|
494
|
+
cfg.model.num_hidden_layers = 1
|
|
495
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
496
|
+
cfg.model.dropout_rate = 0.10
|
|
497
|
+
cfg.model.gamma = 1.5
|
|
498
|
+
# Train
|
|
499
|
+
cfg.train.batch_size = 128
|
|
500
|
+
cfg.train.learning_rate = 1e-3
|
|
501
|
+
cfg.train.early_stop_gen = 5
|
|
502
|
+
cfg.train.min_epochs = 10
|
|
503
|
+
cfg.train.max_epochs = 120
|
|
504
|
+
cfg.train.weights_beta = 0.9999
|
|
505
|
+
cfg.train.weights_max_ratio = 2.0
|
|
506
|
+
# Tuning
|
|
507
|
+
cfg.tune.enabled = True
|
|
508
|
+
cfg.tune.fast = True
|
|
509
|
+
cfg.tune.n_trials = 25
|
|
510
|
+
cfg.tune.epochs = 120
|
|
511
|
+
cfg.tune.batch_size = 128
|
|
512
|
+
cfg.tune.max_samples = 512
|
|
513
|
+
cfg.tune.max_loci = 0
|
|
514
|
+
cfg.tune.eval_interval = 20
|
|
515
|
+
cfg.tune.infer_epochs = 20
|
|
516
|
+
cfg.tune.patience = 5
|
|
517
|
+
cfg.tune.proxy_metric_batch = 0
|
|
518
|
+
# Eval
|
|
519
|
+
cfg.evaluate.eval_latent_steps = 20
|
|
520
|
+
cfg.evaluate.eval_latent_lr = 1e-2
|
|
521
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
522
|
+
|
|
523
|
+
elif preset == "balanced":
|
|
524
|
+
# Model
|
|
525
|
+
cfg.model.latent_dim = 8
|
|
526
|
+
cfg.model.num_hidden_layers = 2
|
|
527
|
+
cfg.model.layer_scaling_factor = 3.0
|
|
528
|
+
cfg.model.dropout_rate = 0.20
|
|
529
|
+
cfg.model.gamma = 2.0
|
|
530
|
+
# Train
|
|
531
|
+
cfg.train.batch_size = 128
|
|
532
|
+
cfg.train.learning_rate = 8e-4
|
|
533
|
+
cfg.train.early_stop_gen = 15
|
|
534
|
+
cfg.train.min_epochs = 50
|
|
535
|
+
cfg.train.max_epochs = 600
|
|
536
|
+
cfg.train.weights_beta = 0.9999
|
|
537
|
+
cfg.train.weights_max_ratio = 2.0
|
|
538
|
+
# Tuning
|
|
539
|
+
cfg.tune.enabled = True
|
|
540
|
+
cfg.tune.fast = True
|
|
541
|
+
cfg.tune.n_trials = 75
|
|
542
|
+
cfg.tune.epochs = 300
|
|
543
|
+
cfg.tune.batch_size = 128
|
|
544
|
+
cfg.tune.max_samples = 2048
|
|
545
|
+
cfg.tune.max_loci = 0
|
|
546
|
+
cfg.tune.eval_interval = 20
|
|
547
|
+
cfg.tune.infer_epochs = 40
|
|
548
|
+
cfg.tune.patience = 10
|
|
549
|
+
cfg.tune.proxy_metric_batch = 0
|
|
550
|
+
# Eval
|
|
551
|
+
cfg.evaluate.eval_latent_steps = 30
|
|
552
|
+
cfg.evaluate.eval_latent_lr = 1e-2
|
|
553
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
554
|
+
|
|
555
|
+
else: # thorough
|
|
556
|
+
# Model
|
|
557
|
+
cfg.model.latent_dim = 16
|
|
558
|
+
cfg.model.num_hidden_layers = 3
|
|
559
|
+
cfg.model.layer_scaling_factor = 5.0
|
|
560
|
+
cfg.model.dropout_rate = 0.30
|
|
561
|
+
cfg.model.gamma = 2.5
|
|
562
|
+
# Train
|
|
563
|
+
cfg.train.batch_size = 64
|
|
564
|
+
cfg.train.learning_rate = 6e-4
|
|
565
|
+
cfg.train.early_stop_gen = 20 # Reduced from 30
|
|
566
|
+
cfg.train.min_epochs = 100
|
|
567
|
+
cfg.train.max_epochs = 800 # Reduced from 1200
|
|
568
|
+
cfg.train.weights_beta = 0.9999
|
|
569
|
+
cfg.train.weights_max_ratio = 2.0
|
|
570
|
+
# Tuning
|
|
571
|
+
cfg.tune.enabled = True
|
|
572
|
+
cfg.tune.fast = False
|
|
573
|
+
cfg.tune.n_trials = 150
|
|
574
|
+
cfg.tune.epochs = 600
|
|
575
|
+
cfg.tune.batch_size = 64
|
|
576
|
+
cfg.tune.max_samples = 5000 # Capped from 0
|
|
577
|
+
cfg.tune.max_loci = 0
|
|
578
|
+
cfg.tune.eval_interval = 10
|
|
579
|
+
cfg.tune.infer_epochs = 80
|
|
580
|
+
cfg.tune.patience = 15 # Reduced from 20
|
|
581
|
+
cfg.tune.proxy_metric_batch = 0
|
|
582
|
+
# Eval
|
|
583
|
+
cfg.evaluate.eval_latent_steps = 50
|
|
584
|
+
cfg.evaluate.eval_latent_lr = 1e-2
|
|
585
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
586
|
+
|
|
587
|
+
return cfg
|
|
588
|
+
|
|
589
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
|
|
590
|
+
"""Apply flat dot-key overrides."""
|
|
591
|
+
if not overrides:
|
|
592
|
+
return self
|
|
593
|
+
|
|
594
|
+
for k, v in overrides.items():
|
|
595
|
+
node = self
|
|
596
|
+
parts = k.split(".")
|
|
597
|
+
for p in parts[:-1]:
|
|
598
|
+
node = getattr(node, p)
|
|
599
|
+
last = parts[-1]
|
|
600
|
+
if hasattr(node, last):
|
|
601
|
+
setattr(node, last, v)
|
|
602
|
+
else:
|
|
603
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
604
|
+
return self
|
|
605
|
+
|
|
606
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
607
|
+
return asdict(self)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
@dataclass
|
|
611
|
+
class AutoencoderConfig:
|
|
612
|
+
"""Top-level configuration for ImputeAutoencoder.
|
|
613
|
+
|
|
614
|
+
Attributes:
|
|
615
|
+
io (IOConfig): I/O configuration.
|
|
616
|
+
model (ModelConfig): Model architecture configuration.
|
|
617
|
+
train (TrainConfig): Training procedure configuration.
|
|
618
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
619
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
620
|
+
plot (PlotConfig): Plotting configuration.
|
|
621
|
+
sim (SimConfig): Simulated-missing configuration.
|
|
622
|
+
"""
|
|
623
|
+
|
|
624
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
625
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
626
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
627
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
628
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
629
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
630
|
+
sim: SimConfig = field(default_factory=SimConfig)
|
|
631
|
+
|
|
632
|
+
@classmethod
|
|
633
|
+
def from_preset(
|
|
634
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
635
|
+
) -> "AutoencoderConfig":
|
|
636
|
+
"""Build a AutoencoderConfig from a named preset."""
|
|
637
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
638
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
639
|
+
|
|
640
|
+
cfg = cls()
|
|
641
|
+
|
|
642
|
+
# Common baselines (no latent refinement at eval)
|
|
643
|
+
cfg.io.verbose = False
|
|
644
|
+
cfg.train.validation_split = 0.20
|
|
645
|
+
cfg.model.hidden_activation = "relu"
|
|
646
|
+
cfg.model.layer_schedule = "pyramid"
|
|
647
|
+
cfg.evaluate.eval_latent_steps = 0
|
|
648
|
+
cfg.evaluate.eval_latent_lr = 0.0
|
|
649
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
650
|
+
cfg.sim.simulate_missing = True
|
|
651
|
+
cfg.sim.sim_strategy = "random"
|
|
652
|
+
cfg.sim.sim_prop = 0.2
|
|
653
|
+
|
|
654
|
+
if preset == "fast":
|
|
655
|
+
cfg.model.latent_dim = 4
|
|
656
|
+
cfg.model.num_hidden_layers = 1
|
|
657
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
658
|
+
cfg.model.dropout_rate = 0.10
|
|
659
|
+
cfg.model.gamma = 1.5
|
|
660
|
+
cfg.train.batch_size = 128
|
|
661
|
+
cfg.train.learning_rate = 1e-3
|
|
662
|
+
cfg.train.early_stop_gen = 5
|
|
663
|
+
cfg.train.min_epochs = 10
|
|
664
|
+
cfg.train.max_epochs = 120
|
|
665
|
+
cfg.train.weights_beta = 0.9999
|
|
666
|
+
cfg.train.weights_max_ratio = 2.0
|
|
667
|
+
cfg.tune.enabled = True
|
|
668
|
+
cfg.tune.fast = True
|
|
669
|
+
cfg.tune.n_trials = 25
|
|
670
|
+
cfg.tune.epochs = 120
|
|
671
|
+
cfg.tune.batch_size = 128
|
|
672
|
+
cfg.tune.max_samples = 512
|
|
673
|
+
cfg.tune.max_loci = 0
|
|
674
|
+
cfg.tune.eval_interval = 20
|
|
675
|
+
cfg.tune.patience = 5
|
|
676
|
+
cfg.tune.proxy_metric_batch = 0
|
|
677
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
678
|
+
cfg.tune.infer_epochs = 0
|
|
679
|
+
|
|
680
|
+
elif preset == "balanced":
|
|
681
|
+
cfg.model.latent_dim = 8
|
|
682
|
+
cfg.model.num_hidden_layers = 2
|
|
683
|
+
cfg.model.layer_scaling_factor = 3.0
|
|
684
|
+
cfg.model.dropout_rate = 0.20
|
|
685
|
+
cfg.model.gamma = 2.0
|
|
686
|
+
cfg.train.batch_size = 128
|
|
687
|
+
cfg.train.learning_rate = 8e-4
|
|
688
|
+
cfg.train.early_stop_gen = 15
|
|
689
|
+
cfg.train.min_epochs = 50
|
|
690
|
+
cfg.train.max_epochs = 600
|
|
691
|
+
cfg.train.weights_beta = 0.9999
|
|
692
|
+
cfg.train.weights_max_ratio = 2.0
|
|
693
|
+
cfg.tune.enabled = True
|
|
694
|
+
cfg.tune.fast = True
|
|
695
|
+
cfg.tune.n_trials = 75
|
|
696
|
+
cfg.tune.epochs = 300
|
|
697
|
+
cfg.tune.batch_size = 128
|
|
698
|
+
cfg.tune.max_samples = 2048
|
|
699
|
+
cfg.tune.max_loci = 0
|
|
700
|
+
cfg.tune.eval_interval = 20
|
|
701
|
+
cfg.tune.patience = 10
|
|
702
|
+
cfg.tune.proxy_metric_batch = 0
|
|
703
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
704
|
+
cfg.tune.infer_epochs = 0
|
|
705
|
+
|
|
706
|
+
else: # thorough
|
|
707
|
+
cfg.model.latent_dim = 16
|
|
708
|
+
cfg.model.num_hidden_layers = 3
|
|
709
|
+
cfg.model.layer_scaling_factor = 5.0
|
|
710
|
+
cfg.model.dropout_rate = 0.30
|
|
711
|
+
cfg.model.gamma = 2.5
|
|
712
|
+
cfg.train.batch_size = 64
|
|
713
|
+
cfg.train.learning_rate = 6e-4
|
|
714
|
+
cfg.train.early_stop_gen = 20 # Reduced from 30
|
|
715
|
+
cfg.train.min_epochs = 100
|
|
716
|
+
cfg.train.max_epochs = 800 # Reduced from 1200
|
|
717
|
+
cfg.train.weights_beta = 0.9999
|
|
718
|
+
cfg.train.weights_max_ratio = 2.0
|
|
719
|
+
cfg.tune.enabled = True
|
|
720
|
+
cfg.tune.fast = False
|
|
721
|
+
cfg.tune.n_trials = 150
|
|
722
|
+
cfg.tune.epochs = 600
|
|
723
|
+
cfg.tune.batch_size = 64
|
|
724
|
+
cfg.tune.max_samples = 5000 # Capped from 0
|
|
725
|
+
cfg.tune.max_loci = 0
|
|
726
|
+
cfg.tune.eval_interval = 10
|
|
727
|
+
cfg.tune.patience = 15 # Reduced from 20
|
|
728
|
+
cfg.tune.proxy_metric_batch = 0
|
|
729
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
730
|
+
cfg.tune.infer_epochs = 0
|
|
731
|
+
|
|
732
|
+
return cfg
|
|
733
|
+
|
|
734
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
|
|
735
|
+
"""Apply flat dot-key overrides."""
|
|
736
|
+
if not overrides:
|
|
737
|
+
return self
|
|
738
|
+
for k, v in overrides.items():
|
|
739
|
+
node = self
|
|
740
|
+
parts = k.split(".")
|
|
741
|
+
for p in parts[:-1]:
|
|
742
|
+
node = getattr(node, p)
|
|
743
|
+
last = parts[-1]
|
|
744
|
+
if hasattr(node, last):
|
|
745
|
+
setattr(node, last, v)
|
|
746
|
+
else:
|
|
747
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
748
|
+
return self
|
|
749
|
+
|
|
750
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
751
|
+
return asdict(self)
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
@dataclass
|
|
755
|
+
class VAEExtraConfig:
|
|
756
|
+
"""VAE-specific knobs.
|
|
757
|
+
|
|
758
|
+
Attributes:
|
|
759
|
+
kl_beta (float): Final β for KL divergence term.
|
|
760
|
+
kl_warmup (int): Number of epochs with β=0 (warm-up period).
|
|
761
|
+
kl_ramp (int): Number of epochs for linear ramp to final β.
|
|
762
|
+
"""
|
|
763
|
+
|
|
764
|
+
kl_beta: float = 1.0
|
|
765
|
+
kl_warmup: int = 50
|
|
766
|
+
kl_ramp: int = 200
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
@dataclass
|
|
770
|
+
class VAEConfig:
|
|
771
|
+
"""Top-level configuration for ImputeVAE (AE-parity + VAE extras).
|
|
772
|
+
|
|
773
|
+
Attributes:
|
|
774
|
+
io (IOConfig): I/O configuration.
|
|
775
|
+
model (ModelConfig): Model architecture configuration.
|
|
776
|
+
train (TrainConfig): Training procedure configuration.
|
|
777
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
778
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
779
|
+
plot (PlotConfig): Plotting configuration.
|
|
780
|
+
vae (VAEExtraConfig): VAE-specific configuration.
|
|
781
|
+
sim (SimConfig): Simulated-missing configuration.
|
|
782
|
+
"""
|
|
783
|
+
|
|
784
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
785
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
786
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
787
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
788
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
789
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
790
|
+
vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
|
|
791
|
+
sim: SimConfig = field(default_factory=SimConfig)
|
|
792
|
+
|
|
793
|
+
@classmethod
|
|
794
|
+
def from_preset(
|
|
795
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
796
|
+
) -> "VAEConfig":
|
|
797
|
+
"""Build a VAEConfig from a named preset."""
|
|
798
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
799
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
800
|
+
|
|
801
|
+
cfg = cls()
|
|
802
|
+
|
|
803
|
+
# Common baselines (match AE; no latent refinement at eval)
|
|
804
|
+
cfg.io.verbose = False
|
|
805
|
+
cfg.train.validation_split = 0.20
|
|
806
|
+
cfg.model.hidden_activation = "relu"
|
|
807
|
+
cfg.model.layer_schedule = "pyramid"
|
|
808
|
+
cfg.evaluate.eval_latent_steps = 0
|
|
809
|
+
cfg.evaluate.eval_latent_lr = 0.0
|
|
810
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
811
|
+
cfg.sim.simulate_missing = True
|
|
812
|
+
cfg.sim.sim_strategy = "random"
|
|
813
|
+
cfg.sim.sim_prop = 0.2
|
|
814
|
+
|
|
815
|
+
# VAE KL schedules, shortened for speed
|
|
816
|
+
cfg.vae.kl_beta = 1.0
|
|
817
|
+
cfg.vae.kl_warmup = 25
|
|
818
|
+
cfg.vae.kl_ramp = 100
|
|
819
|
+
|
|
820
|
+
if preset == "fast":
|
|
821
|
+
cfg.model.latent_dim = 4
|
|
822
|
+
cfg.model.num_hidden_layers = 1
|
|
823
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
824
|
+
cfg.model.dropout_rate = 0.10
|
|
825
|
+
cfg.model.gamma = 1.5
|
|
826
|
+
cfg.vae.kl_beta = 0.5 # Lower beta for fast training
|
|
827
|
+
cfg.train.batch_size = 128
|
|
828
|
+
cfg.train.learning_rate = 1e-3
|
|
829
|
+
cfg.train.early_stop_gen = 5
|
|
830
|
+
cfg.train.min_epochs = 10
|
|
831
|
+
cfg.train.max_epochs = 120
|
|
832
|
+
cfg.train.weights_beta = 0.9999
|
|
833
|
+
cfg.train.weights_max_ratio = 2.0
|
|
834
|
+
cfg.tune.enabled = True
|
|
835
|
+
cfg.tune.fast = True
|
|
836
|
+
cfg.tune.n_trials = 25
|
|
837
|
+
cfg.tune.epochs = 120
|
|
838
|
+
cfg.tune.batch_size = 128
|
|
839
|
+
cfg.tune.max_samples = 512
|
|
840
|
+
cfg.tune.max_loci = 0
|
|
841
|
+
cfg.tune.eval_interval = 20
|
|
842
|
+
cfg.tune.patience = 5
|
|
843
|
+
cfg.tune.proxy_metric_batch = 0
|
|
844
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
845
|
+
cfg.tune.infer_epochs = 0
|
|
846
|
+
|
|
847
|
+
elif preset == "balanced":
|
|
848
|
+
cfg.model.latent_dim = 8
|
|
849
|
+
cfg.model.num_hidden_layers = 2
|
|
850
|
+
cfg.model.layer_scaling_factor = 3.0
|
|
851
|
+
cfg.model.dropout_rate = 0.20
|
|
852
|
+
cfg.model.gamma = 2.0
|
|
853
|
+
cfg.train.batch_size = 128
|
|
854
|
+
cfg.train.learning_rate = 8e-4
|
|
855
|
+
cfg.train.early_stop_gen = 15
|
|
856
|
+
cfg.train.min_epochs = 50
|
|
857
|
+
cfg.train.max_epochs = 600
|
|
858
|
+
cfg.train.weights_beta = 0.9999
|
|
859
|
+
cfg.train.weights_max_ratio = 2.0
|
|
860
|
+
cfg.tune.enabled = True
|
|
861
|
+
cfg.tune.fast = True
|
|
862
|
+
cfg.tune.n_trials = 75
|
|
863
|
+
cfg.tune.epochs = 300
|
|
864
|
+
cfg.tune.batch_size = 128
|
|
865
|
+
cfg.tune.max_samples = 2048
|
|
866
|
+
cfg.tune.max_loci = 0
|
|
867
|
+
cfg.tune.eval_interval = 20
|
|
868
|
+
cfg.tune.patience = 10
|
|
869
|
+
cfg.tune.proxy_metric_batch = 0
|
|
870
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
871
|
+
cfg.tune.infer_epochs = 0
|
|
872
|
+
|
|
873
|
+
else: # thorough
|
|
874
|
+
cfg.model.latent_dim = 16
|
|
875
|
+
cfg.model.num_hidden_layers = 3
|
|
876
|
+
cfg.model.layer_scaling_factor = 5.0
|
|
877
|
+
cfg.model.dropout_rate = 0.30
|
|
878
|
+
cfg.model.gamma = 2.5
|
|
879
|
+
cfg.train.batch_size = 64
|
|
880
|
+
cfg.train.learning_rate = 6e-4
|
|
881
|
+
cfg.train.early_stop_gen = 20 # Reduced from 30
|
|
882
|
+
cfg.train.min_epochs = 100
|
|
883
|
+
cfg.train.max_epochs = 800 # Reduced from 1200
|
|
884
|
+
cfg.train.weights_beta = 0.9999
|
|
885
|
+
cfg.train.weights_max_ratio = 2.0
|
|
886
|
+
cfg.tune.enabled = True
|
|
887
|
+
cfg.tune.fast = False
|
|
888
|
+
cfg.tune.n_trials = 150
|
|
889
|
+
cfg.tune.epochs = 600
|
|
890
|
+
cfg.tune.batch_size = 64
|
|
891
|
+
cfg.tune.max_samples = 5000 # Capped from 0
|
|
892
|
+
cfg.tune.max_loci = 0
|
|
893
|
+
cfg.tune.eval_interval = 10
|
|
894
|
+
cfg.tune.patience = 15 # Reduced from 20
|
|
895
|
+
cfg.tune.proxy_metric_batch = 0
|
|
896
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
897
|
+
cfg.tune.infer_epochs = 0
|
|
898
|
+
|
|
899
|
+
return cfg
|
|
900
|
+
|
|
901
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "VAEConfig":
|
|
902
|
+
"""Apply flat dot-key overrides."""
|
|
903
|
+
if not overrides:
|
|
904
|
+
return self
|
|
905
|
+
for k, v in overrides.items():
|
|
906
|
+
node = self
|
|
907
|
+
parts = k.split(".")
|
|
908
|
+
for p in parts[:-1]:
|
|
909
|
+
node = getattr(node, p)
|
|
910
|
+
last = parts[-1]
|
|
911
|
+
if hasattr(node, last):
|
|
912
|
+
setattr(node, last, v)
|
|
913
|
+
else:
|
|
914
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
915
|
+
return self
|
|
916
|
+
|
|
917
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
918
|
+
return asdict(self)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
@dataclass
|
|
922
|
+
class MostFrequentAlgoConfig:
|
|
923
|
+
"""Algorithmic knobs for ImputeMostFrequent.
|
|
924
|
+
|
|
925
|
+
Attributes:
|
|
926
|
+
by_populations (bool): Whether to compute per-population modes.
|
|
927
|
+
default (int): Fallback mode if no valid entries in a locus.
|
|
928
|
+
missing (int): Code for missing genotypes in 0/1/2.
|
|
929
|
+
"""
|
|
930
|
+
|
|
931
|
+
by_populations: bool = False
|
|
932
|
+
default: int = 0
|
|
933
|
+
missing: int = -1
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
@dataclass
|
|
937
|
+
class DeterministicSplitConfig:
|
|
938
|
+
"""Evaluation split configuration shared by deterministic imputers.
|
|
939
|
+
|
|
940
|
+
Attributes:
|
|
941
|
+
test_size (float): Proportion of data to use as the test set.
|
|
942
|
+
test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
test_size: float = 0.2
|
|
946
|
+
test_indices: Optional[Sequence[int]] = None
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
@dataclass
|
|
950
|
+
class MostFrequentConfig:
|
|
951
|
+
"""Top-level configuration for ImputeMostFrequent.
|
|
952
|
+
|
|
953
|
+
Attributes:
|
|
954
|
+
io (IOConfig): I/O configuration.
|
|
955
|
+
plot (PlotConfig): Plotting configuration.
|
|
956
|
+
split (DeterministicSplitConfig): Data splitting configuration.
|
|
957
|
+
algo (MostFrequentAlgoConfig): Algorithmic configuration.
|
|
958
|
+
sim (SimConfig): Simulation configuration.
|
|
959
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
960
|
+
train (TrainConfig): Training configuration.
|
|
961
|
+
"""
|
|
962
|
+
|
|
963
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
964
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
965
|
+
split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
|
|
966
|
+
algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
|
|
967
|
+
sim: SimConfig = field(default_factory=SimConfig)
|
|
968
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
969
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
970
|
+
|
|
971
|
+
@classmethod
|
|
972
|
+
def from_preset(
|
|
973
|
+
cls,
|
|
974
|
+
preset: Literal["fast", "balanced", "thorough"] = "balanced",
|
|
975
|
+
) -> "MostFrequentConfig":
|
|
976
|
+
"""Construct a preset configuration."""
|
|
977
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
978
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
979
|
+
|
|
980
|
+
cfg = cls()
|
|
981
|
+
cfg.io.verbose = False
|
|
982
|
+
cfg.split.test_size = 0.2
|
|
983
|
+
cfg.sim.simulate_missing = True
|
|
984
|
+
cfg.sim.sim_strategy = "random"
|
|
985
|
+
cfg.sim.sim_prop = 0.2
|
|
986
|
+
|
|
987
|
+
return cfg
|
|
988
|
+
|
|
989
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "MostFrequentConfig":
|
|
990
|
+
"""Apply dot-key overrides."""
|
|
991
|
+
if not overrides:
|
|
992
|
+
return self
|
|
993
|
+
for k, v in overrides.items():
|
|
994
|
+
node = self
|
|
995
|
+
parts = k.split(".")
|
|
996
|
+
for p in parts[:-1]:
|
|
997
|
+
node = getattr(node, p)
|
|
998
|
+
last = parts[-1]
|
|
999
|
+
if hasattr(node, last):
|
|
1000
|
+
setattr(node, last, v)
|
|
1001
|
+
else:
|
|
1002
|
+
pass
|
|
1003
|
+
return self
|
|
1004
|
+
|
|
1005
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1006
|
+
return asdict(self)
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
@dataclass
|
|
1010
|
+
class RefAlleleAlgoConfig:
|
|
1011
|
+
"""Algorithmic knobs for ImputeRefAllele.
|
|
1012
|
+
|
|
1013
|
+
Attributes:
|
|
1014
|
+
missing (int): Code for missing genotypes in 0/1/2.
|
|
1015
|
+
"""
|
|
1016
|
+
|
|
1017
|
+
missing: int = -1
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
@dataclass
|
|
1021
|
+
class RefAlleleConfig:
|
|
1022
|
+
"""Top-level configuration for ImputeRefAllele.
|
|
1023
|
+
|
|
1024
|
+
Attributes:
|
|
1025
|
+
io (IOConfig): I/O configuration.
|
|
1026
|
+
plot (PlotConfig): Plotting configuration.
|
|
1027
|
+
split (DeterministicSplitConfig): Data splitting configuration.
|
|
1028
|
+
algo (RefAlleleAlgoConfig): Algorithmic configuration.
|
|
1029
|
+
sim (SimConfig): Simulation configuration.
|
|
1030
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
1031
|
+
train (TrainConfig): Training configuration.
|
|
1032
|
+
"""
|
|
1033
|
+
|
|
1034
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
1035
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
1036
|
+
split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
|
|
1037
|
+
algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
|
|
1038
|
+
sim: SimConfig = field(default_factory=SimConfig)
|
|
1039
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
1040
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
1041
|
+
|
|
1042
|
+
@classmethod
|
|
1043
|
+
def from_preset(
|
|
1044
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1045
|
+
) -> "RefAlleleConfig":
|
|
1046
|
+
"""Presets mainly keep parity with logging/IO and split test_size."""
|
|
1047
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
1048
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1049
|
+
|
|
1050
|
+
cfg = cls()
|
|
1051
|
+
cfg.io.verbose = False
|
|
1052
|
+
cfg.split.test_size = 0.2
|
|
1053
|
+
cfg.sim.simulate_missing = True
|
|
1054
|
+
cfg.sim.sim_strategy = "random"
|
|
1055
|
+
cfg.sim.sim_prop = 0.2
|
|
1056
|
+
return cfg
|
|
1057
|
+
|
|
1058
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RefAlleleConfig":
|
|
1059
|
+
"""Apply dot-key overrides."""
|
|
1060
|
+
if not overrides:
|
|
1061
|
+
return self
|
|
1062
|
+
for k, v in overrides.items():
|
|
1063
|
+
node = self
|
|
1064
|
+
parts = k.split(".")
|
|
1065
|
+
for p in parts[:-1]:
|
|
1066
|
+
node = getattr(node, p)
|
|
1067
|
+
last = parts[-1]
|
|
1068
|
+
if hasattr(node, last):
|
|
1069
|
+
setattr(node, last, v)
|
|
1070
|
+
else:
|
|
1071
|
+
pass
|
|
1072
|
+
return self
|
|
1073
|
+
|
|
1074
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1075
|
+
return asdict(self)
|
|
1076
|
+
|
|
1077
|
+
|
|
1078
|
+
def _flatten_dict(
|
|
1079
|
+
d: Dict[str, Any], prefix: str = "", out: Optional[Dict[str, Any]] = None
|
|
1080
|
+
) -> Dict[str, Any]:
|
|
1081
|
+
"""Flatten a nested dictionary into dot-key format."""
|
|
1082
|
+
out = out or {}
|
|
1083
|
+
for k, v in d.items():
|
|
1084
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
1085
|
+
if isinstance(v, dict):
|
|
1086
|
+
_flatten_dict(v, kk, out)
|
|
1087
|
+
else:
|
|
1088
|
+
out[kk] = v
|
|
1089
|
+
return out
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
@dataclass
|
|
1093
|
+
class IOConfigSupervised:
|
|
1094
|
+
"""I/O, logging, and run identity.
|
|
1095
|
+
|
|
1096
|
+
Attributes:
|
|
1097
|
+
prefix (str): Prefix for output files and logs.
|
|
1098
|
+
seed (Optional[int]): Random seed for reproducibility.
|
|
1099
|
+
n_jobs (int): Number of parallel jobs to use.
|
|
1100
|
+
verbose (bool): Whether to enable verbose logging.
|
|
1101
|
+
debug (bool): Whether to enable debug mode.
|
|
1102
|
+
"""
|
|
1103
|
+
|
|
1104
|
+
prefix: str = "pgsui"
|
|
1105
|
+
seed: Optional[int] = None
|
|
1106
|
+
n_jobs: int = 1
|
|
1107
|
+
verbose: bool = False
|
|
1108
|
+
debug: bool = False
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
@dataclass
|
|
1112
|
+
class PlotConfigSupervised:
|
|
1113
|
+
"""Plot/figure styling.
|
|
1114
|
+
|
|
1115
|
+
Attributes:
|
|
1116
|
+
fmt (Literal["pdf", "png", "jpg", "jpeg"]): File format.
|
|
1117
|
+
dpi (int): Resolution in dots per inch.
|
|
1118
|
+
fontsize (int): Base font size for plot text.
|
|
1119
|
+
despine (bool): Whether to remove top/right spines.
|
|
1120
|
+
show (bool): Whether to display plots interactively.
|
|
1121
|
+
"""
|
|
1122
|
+
|
|
1123
|
+
fmt: Literal["pdf", "png", "jpg", "jpeg"] = "pdf"
|
|
1124
|
+
dpi: int = 300
|
|
1125
|
+
fontsize: int = 18
|
|
1126
|
+
despine: bool = True
|
|
1127
|
+
show: bool = False
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
@dataclass
|
|
1131
|
+
class TrainConfigSupervised:
|
|
1132
|
+
"""Training/evaluation split (by samples).
|
|
1133
|
+
|
|
1134
|
+
Attributes:
|
|
1135
|
+
validation_split (float): Proportion of data to use for validation.
|
|
1136
|
+
"""
|
|
1137
|
+
|
|
1138
|
+
validation_split: float = 0.20
|
|
1139
|
+
|
|
1140
|
+
def __post_init__(self):
|
|
1141
|
+
if not (0.0 < self.validation_split < 1.0):
|
|
1142
|
+
raise ValueError("validation_split must be between 0.0 and 1.0")
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
@dataclass
|
|
1146
|
+
class ImputerConfigSupervised:
|
|
1147
|
+
"""IterativeImputer-like scaffolding used by current supervised wrappers.
|
|
1148
|
+
|
|
1149
|
+
Attributes:
|
|
1150
|
+
n_nearest_features (Optional[int]): Number of nearest features to use.
|
|
1151
|
+
max_iter (int): Maximum number of imputation iterations to perform.
|
|
1152
|
+
"""
|
|
1153
|
+
|
|
1154
|
+
n_nearest_features: Optional[int] = 10
|
|
1155
|
+
max_iter: int = 10
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
@dataclass
|
|
1159
|
+
class SimConfigSupervised:
|
|
1160
|
+
"""Simulation of missingness for evaluation.
|
|
1161
|
+
|
|
1162
|
+
Attributes:
|
|
1163
|
+
prop_missing (float): Proportion of features to set as missing.
|
|
1164
|
+
strategy (Literal["random", "random_inv_genotype"]): Strategy.
|
|
1165
|
+
het_boost (float): Boosting factor for heterogeneity.
|
|
1166
|
+
missing_val (int): Internal code for missing genotypes.
|
|
1167
|
+
"""
|
|
1168
|
+
|
|
1169
|
+
prop_missing: float = 0.5
|
|
1170
|
+
strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
|
|
1171
|
+
het_boost: float = 2.0
|
|
1172
|
+
missing_val: int = -1
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
@dataclass
|
|
1176
|
+
class TuningConfigSupervised:
|
|
1177
|
+
"""Optuna tuning envelope."""
|
|
1178
|
+
|
|
1179
|
+
enabled: bool = True
|
|
1180
|
+
n_trials: int = 100
|
|
1181
|
+
metric: str = "pr_macro"
|
|
1182
|
+
n_jobs: int = 8
|
|
1183
|
+
fast: bool = True
|
|
1184
|
+
|
|
1185
|
+
|
|
1186
|
+
@dataclass
|
|
1187
|
+
class RFModelConfig:
|
|
1188
|
+
"""Random Forest hyperparameters.
|
|
1189
|
+
|
|
1190
|
+
Attributes:
|
|
1191
|
+
n_estimators (int): Number of trees in the forest.
|
|
1192
|
+
max_depth (Optional[int]): Maximum depth of the trees.
|
|
1193
|
+
min_samples_split (int): Minimum number of samples required to split.
|
|
1194
|
+
min_samples_leaf (int): Minimum number of samples required at a leaf.
|
|
1195
|
+
max_features (Literal["sqrt", "log2"] | float | None): Features to consider.
|
|
1196
|
+
criterion (Literal["gini", "entropy", "log_loss"]): Split quality metric.
|
|
1197
|
+
class_weight (Literal["balanced", "balanced_subsample", None]): Class weights.
|
|
1198
|
+
"""
|
|
1199
|
+
|
|
1200
|
+
n_estimators: int = 100
|
|
1201
|
+
max_depth: Optional[int] = None
|
|
1202
|
+
min_samples_split: int = 2
|
|
1203
|
+
min_samples_leaf: int = 1
|
|
1204
|
+
max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
|
|
1205
|
+
criterion: Literal["gini", "entropy", "log_loss"] = "gini"
|
|
1206
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
@dataclass
|
|
1210
|
+
class HGBModelConfig:
|
|
1211
|
+
"""Histogram-based Gradient Boosting hyperparameters.
|
|
1212
|
+
|
|
1213
|
+
Attributes:
|
|
1214
|
+
n_estimators (int): Number of boosting iterations (max_iter).
|
|
1215
|
+
learning_rate (float): Step size for each boosting iteration.
|
|
1216
|
+
max_depth (Optional[int]): Maximum depth of each tree.
|
|
1217
|
+
min_samples_leaf (int): Minimum number of samples required at a leaf.
|
|
1218
|
+
max_features (float | None): Proportion of features to consider.
|
|
1219
|
+
n_iter_no_change (int): Iterations to wait for early stopping.
|
|
1220
|
+
tol (float): Minimum improvement in the loss.
|
|
1221
|
+
"""
|
|
1222
|
+
|
|
1223
|
+
n_estimators: int = 100 # maps to max_iter
|
|
1224
|
+
learning_rate: float = 0.1
|
|
1225
|
+
max_depth: Optional[int] = None
|
|
1226
|
+
min_samples_leaf: int = 1
|
|
1227
|
+
max_features: float | None = 1.0
|
|
1228
|
+
n_iter_no_change: int = 10
|
|
1229
|
+
tol: float = 1e-7
|
|
1230
|
+
|
|
1231
|
+
def __post_init__(self) -> None:
|
|
1232
|
+
if isinstance(self.max_features, float):
|
|
1233
|
+
if not (0.0 < self.max_features <= 1.0):
|
|
1234
|
+
raise ValueError("max_features as float must be in (0.0, 1.0]")
|
|
1235
|
+
|
|
1236
|
+
if self.n_estimators <= 0:
|
|
1237
|
+
raise ValueError("n_estimators must be a positive integer")
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
@dataclass
|
|
1241
|
+
class RFConfig:
|
|
1242
|
+
"""Configuration for ImputeRandomForest.
|
|
1243
|
+
|
|
1244
|
+
Attributes:
|
|
1245
|
+
io (IOConfigSupervised): Run identity, logging, and seeds.
|
|
1246
|
+
model (RFModelConfig): RandomForest hyperparameters.
|
|
1247
|
+
train (TrainConfigSupervised): Sample split for validation.
|
|
1248
|
+
imputer (ImputerConfigSupervised): IterativeImputer scaffolding.
|
|
1249
|
+
sim (SimConfigSupervised): Simulated missingness.
|
|
1250
|
+
plot (PlotConfigSupervised): Plot styling.
|
|
1251
|
+
tune (TuningConfigSupervised): Optuna knobs.
|
|
1252
|
+
"""
|
|
1253
|
+
|
|
1254
|
+
io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
|
|
1255
|
+
model: RFModelConfig = field(default_factory=RFModelConfig)
|
|
1256
|
+
train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
|
|
1257
|
+
imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
|
|
1258
|
+
sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
|
|
1259
|
+
plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
|
|
1260
|
+
tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
|
|
1261
|
+
|
|
1262
|
+
@classmethod
|
|
1263
|
+
def from_preset(cls, preset: str = "balanced") -> "RFConfig":
|
|
1264
|
+
"""Build a config from a named preset."""
|
|
1265
|
+
cfg = cls()
|
|
1266
|
+
if preset == "fast":
|
|
1267
|
+
cfg.model.n_estimators = 100 # Increased from 50
|
|
1268
|
+
cfg.model.max_depth = None
|
|
1269
|
+
cfg.imputer.max_iter = 5
|
|
1270
|
+
cfg.io.n_jobs = 1
|
|
1271
|
+
cfg.tune.enabled = False
|
|
1272
|
+
elif preset == "balanced":
|
|
1273
|
+
cfg.model.n_estimators = 200 # Increased from 100
|
|
1274
|
+
cfg.model.max_depth = None
|
|
1275
|
+
cfg.imputer.max_iter = 10
|
|
1276
|
+
cfg.io.n_jobs = 1
|
|
1277
|
+
cfg.tune.enabled = False
|
|
1278
|
+
cfg.tune.n_trials = 100
|
|
1279
|
+
elif preset == "thorough":
|
|
1280
|
+
cfg.model.n_estimators = 500
|
|
1281
|
+
cfg.model.max_depth = 50 # Added safety cap
|
|
1282
|
+
cfg.imputer.max_iter = 15
|
|
1283
|
+
cfg.io.n_jobs = 1
|
|
1284
|
+
cfg.tune.enabled = False
|
|
1285
|
+
cfg.tune.n_trials = 250
|
|
1286
|
+
else:
|
|
1287
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1288
|
+
|
|
1289
|
+
return cfg
|
|
1290
|
+
|
|
1291
|
+
@classmethod
|
|
1292
|
+
def from_yaml(cls, path: str) -> "RFConfig":
|
|
1293
|
+
"""Load from YAML; honors optional top-level 'preset'."""
|
|
1294
|
+
return load_yaml_to_dataclass(path, cls)
|
|
1295
|
+
|
|
1296
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RFConfig":
|
|
1297
|
+
"""Apply flat dot-key overrides."""
|
|
1298
|
+
if overrides:
|
|
1299
|
+
apply_dot_overrides(self, overrides)
|
|
1300
|
+
return self
|
|
1301
|
+
|
|
1302
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1303
|
+
return asdict(self)
|
|
1304
|
+
|
|
1305
|
+
def to_imputer_kwargs(self) -> Dict[str, Any]:
|
|
1306
|
+
return {
|
|
1307
|
+
"prefix": self.io.prefix,
|
|
1308
|
+
"seed": self.io.seed,
|
|
1309
|
+
"n_jobs": self.io.n_jobs,
|
|
1310
|
+
"verbose": self.io.verbose,
|
|
1311
|
+
"debug": self.io.debug,
|
|
1312
|
+
"model_n_estimators": self.model.n_estimators,
|
|
1313
|
+
"model_max_depth": self.model.max_depth,
|
|
1314
|
+
"model_min_samples_split": self.model.min_samples_split,
|
|
1315
|
+
"model_min_samples_leaf": self.model.min_samples_leaf,
|
|
1316
|
+
"model_max_features": self.model.max_features,
|
|
1317
|
+
"model_criterion": self.model.criterion,
|
|
1318
|
+
"model_validation_split": self.train.validation_split,
|
|
1319
|
+
"model_n_nearest_features": self.imputer.n_nearest_features,
|
|
1320
|
+
"model_max_iter": self.imputer.max_iter,
|
|
1321
|
+
"sim_prop_missing": self.sim.prop_missing,
|
|
1322
|
+
"sim_strategy": self.sim.strategy,
|
|
1323
|
+
"sim_het_boost": self.sim.het_boost,
|
|
1324
|
+
"plot_format": self.plot.fmt,
|
|
1325
|
+
"plot_fontsize": self.plot.fontsize,
|
|
1326
|
+
"plot_despine": self.plot.despine,
|
|
1327
|
+
"plot_dpi": self.plot.dpi,
|
|
1328
|
+
"plot_show_plots": self.plot.show,
|
|
1329
|
+
}
|
|
1330
|
+
|
|
1331
|
+
|
|
1332
|
+
@dataclass
|
|
1333
|
+
class HGBConfig:
|
|
1334
|
+
"""Configuration for ImputeHistGradientBoosting.
|
|
1335
|
+
|
|
1336
|
+
Attributes:
|
|
1337
|
+
io (IOConfigSupervised): Run identity, logging, and seeds.
|
|
1338
|
+
model (HGBModelConfig): HistGradientBoosting hyperparameters.
|
|
1339
|
+
train (TrainConfigSupervised): Sample split for validation.
|
|
1340
|
+
imputer (ImputerConfigSupervised): IterativeImputer scaffolding.
|
|
1341
|
+
sim (SimConfigSupervised): Simulated missingness.
|
|
1342
|
+
plot (PlotConfigSupervised): Plot styling.
|
|
1343
|
+
tune (TuningConfigSupervised): Optuna knobs.
|
|
1344
|
+
"""
|
|
1345
|
+
|
|
1346
|
+
io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
|
|
1347
|
+
model: HGBModelConfig = field(default_factory=HGBModelConfig)
|
|
1348
|
+
train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
|
|
1349
|
+
imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
|
|
1350
|
+
sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
|
|
1351
|
+
plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
|
|
1352
|
+
tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
|
|
1353
|
+
|
|
1354
|
+
@classmethod
|
|
1355
|
+
def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
|
|
1356
|
+
"""Build a config from a named preset."""
|
|
1357
|
+
cfg = cls()
|
|
1358
|
+
if preset == "fast":
|
|
1359
|
+
cfg.model.n_estimators = 50
|
|
1360
|
+
cfg.model.learning_rate = 0.15
|
|
1361
|
+
cfg.model.max_depth = None
|
|
1362
|
+
cfg.imputer.max_iter = 5
|
|
1363
|
+
cfg.io.n_jobs = 1
|
|
1364
|
+
cfg.tune.enabled = False
|
|
1365
|
+
cfg.tune.n_trials = 50
|
|
1366
|
+
elif preset == "balanced":
|
|
1367
|
+
cfg.model.n_estimators = 100
|
|
1368
|
+
cfg.model.learning_rate = 0.1
|
|
1369
|
+
cfg.model.max_depth = None
|
|
1370
|
+
cfg.imputer.max_iter = 10
|
|
1371
|
+
cfg.io.n_jobs = 1
|
|
1372
|
+
cfg.tune.enabled = False
|
|
1373
|
+
cfg.tune.n_trials = 100
|
|
1374
|
+
elif preset == "thorough":
|
|
1375
|
+
cfg.model.n_estimators = 500
|
|
1376
|
+
cfg.model.learning_rate = 0.05 # Reduced from 0.08
|
|
1377
|
+
cfg.model.n_iter_no_change = 20 # Increased patience
|
|
1378
|
+
cfg.model.max_depth = None
|
|
1379
|
+
cfg.imputer.max_iter = 15
|
|
1380
|
+
cfg.io.n_jobs = 1
|
|
1381
|
+
cfg.tune.enabled = False
|
|
1382
|
+
cfg.tune.n_trials = 250
|
|
1383
|
+
else:
|
|
1384
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1385
|
+
return cfg
|
|
1386
|
+
|
|
1387
|
+
@classmethod
|
|
1388
|
+
def from_yaml(cls, path: str) -> "HGBConfig":
|
|
1389
|
+
return load_yaml_to_dataclass(path, cls)
|
|
1390
|
+
|
|
1391
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "HGBConfig":
|
|
1392
|
+
if overrides:
|
|
1393
|
+
apply_dot_overrides(self, overrides)
|
|
1394
|
+
return self
|
|
1395
|
+
|
|
1396
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1397
|
+
return asdict(self)
|
|
1398
|
+
|
|
1399
|
+
def to_imputer_kwargs(self) -> Dict[str, Any]:
|
|
1400
|
+
return {
|
|
1401
|
+
"prefix": self.io.prefix,
|
|
1402
|
+
"seed": self.io.seed,
|
|
1403
|
+
"n_jobs": self.io.n_jobs,
|
|
1404
|
+
"verbose": self.io.verbose,
|
|
1405
|
+
"debug": self.io.debug,
|
|
1406
|
+
"model_n_estimators": self.model.n_estimators,
|
|
1407
|
+
"model_learning_rate": self.model.learning_rate,
|
|
1408
|
+
"model_n_iter_no_change": self.model.n_iter_no_change,
|
|
1409
|
+
"model_tol": self.model.tol,
|
|
1410
|
+
"model_max_depth": self.model.max_depth,
|
|
1411
|
+
"model_min_samples_leaf": self.model.min_samples_leaf,
|
|
1412
|
+
"model_max_features": self.model.max_features,
|
|
1413
|
+
"model_validation_split": self.train.validation_split,
|
|
1414
|
+
"model_n_nearest_features": self.imputer.n_nearest_features,
|
|
1415
|
+
"model_max_iter": self.imputer.max_iter,
|
|
1416
|
+
"sim_prop_missing": self.sim.prop_missing,
|
|
1417
|
+
"sim_strategy": self.sim.strategy,
|
|
1418
|
+
"sim_het_boost": self.sim.het_boost,
|
|
1419
|
+
"plot_format": self.plot.fmt,
|
|
1420
|
+
"plot_fontsize": self.plot.fontsize,
|
|
1421
|
+
"plot_despine": self.plot.despine,
|
|
1422
|
+
"plot_dpi": self.plot.dpi,
|
|
1423
|
+
"plot_show_plots": self.plot.show,
|
|
1424
|
+
}
|