pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,1782 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from typing import Any, Dict, Literal, Optional, Sequence
|
|
5
|
+
|
|
6
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class _SimParams:
|
|
11
|
+
"""Container for simulation hyperparameters.
|
|
12
|
+
|
|
13
|
+
This class holds the hyperparameters for the simulation process, including the proportion of missing values, the imputation strategy, and other relevant settings.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
prop_missing (float): Proportion of missing values to simulate.
|
|
17
|
+
strategy (Literal["random", "random_inv_genotype"]): Strategy for simulating missing values.
|
|
18
|
+
missing_val (int | float): Value to represent missing data.
|
|
19
|
+
het_boost (float): Boost factor for heterozygous genotypes.
|
|
20
|
+
seed (int | None): Random seed for reproducibility.
|
|
21
|
+
|
|
22
|
+
Notes:
|
|
23
|
+
- The `strategy` attribute determines how missing values are simulated.
|
|
24
|
+
"random" selects missing values uniformly at random, while "random_inv_genotype" selects missing values based on the inverse of the genotype distribution.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
prop_missing: float = 0.3
|
|
28
|
+
strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
|
|
29
|
+
missing_val: int | float = -1
|
|
30
|
+
het_boost: float = 2.0
|
|
31
|
+
seed: int | None = None
|
|
32
|
+
|
|
33
|
+
def to_dict(self) -> dict:
|
|
34
|
+
"""Convert the simulation parameters to a dictionary.
|
|
35
|
+
|
|
36
|
+
Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
dict: A dictionary representation of the simulation parameters.
|
|
40
|
+
"""
|
|
41
|
+
return asdict(self)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class _ImputerParams:
|
|
46
|
+
"""Container for imputer hyperparameters.
|
|
47
|
+
|
|
48
|
+
This class holds the hyperparameters for the imputation process, including the number of nearest features to consider, the maximum number of iterations, and other relevant settings.
|
|
49
|
+
|
|
50
|
+
Attributes:
|
|
51
|
+
n_nearest_features (int | None): Number of nearest features to consider for imputation
|
|
52
|
+
max_iter (int): Maximum number of iterations for the imputation algorithm.
|
|
53
|
+
initial_strategy (Literal["mean", "median", "most_frequent", "constant"]): Strategy for initial imputation of missing values.
|
|
54
|
+
keep_empty_features (bool): Whether to keep features that are entirely missing.
|
|
55
|
+
random_state (int | None): Random seed for reproducibility.
|
|
56
|
+
verbose (bool): If True, enables verbose logging during imputation.
|
|
57
|
+
|
|
58
|
+
Notes:
|
|
59
|
+
- The `initial_strategy` attribute determines how initial missing values are imputed before the iterative process begins.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
n_nearest_features: int | None = 10
|
|
63
|
+
max_iter: int = 10
|
|
64
|
+
initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = (
|
|
65
|
+
"most_frequent"
|
|
66
|
+
)
|
|
67
|
+
keep_empty_features: bool = True
|
|
68
|
+
random_state: int | None = None
|
|
69
|
+
verbose: bool = False
|
|
70
|
+
|
|
71
|
+
def to_dict(self) -> dict:
|
|
72
|
+
"""Convert the imputer parameters to a dictionary.
|
|
73
|
+
|
|
74
|
+
Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: A dictionary representation of the imputer parameters.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
return asdict(self)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class _RFParams:
|
|
85
|
+
"""Container for RandomForest hyperparameters.
|
|
86
|
+
|
|
87
|
+
This class holds the hyperparameters for the RandomForest classifier, including the number of estimators, maximum depth, and other relevant settings.
|
|
88
|
+
|
|
89
|
+
Attributes:
|
|
90
|
+
n_estimators (int): Number of trees in the forest.
|
|
91
|
+
max_depth (int | None): Maximum depth of the trees.
|
|
92
|
+
min_samples_split (int): Minimum number of samples required to split an internal node.
|
|
93
|
+
min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
|
|
94
|
+
max_features (Literal["sqrt", "log2"] | float | None): Number of
|
|
95
|
+
features to consider when looking for the best split.
|
|
96
|
+
criterion (Literal["gini", "entropy", "log_loss"]): Function to measure
|
|
97
|
+
the quality of a split.
|
|
98
|
+
class_weight (Literal["balanced", "balanced_subsample", None]): Weights
|
|
99
|
+
associated with classes.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
n_estimators: int = 300
|
|
103
|
+
max_depth: int | None = None
|
|
104
|
+
min_samples_split: int = 2
|
|
105
|
+
min_samples_leaf: int = 1
|
|
106
|
+
max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
|
|
107
|
+
criterion: Literal["gini", "entropy", "log_loss"] = "gini"
|
|
108
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
|
|
109
|
+
|
|
110
|
+
def to_dict(self) -> dict:
|
|
111
|
+
"""Convert the RandomForest parameters to a dictionary.
|
|
112
|
+
|
|
113
|
+
Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
dict: A dictionary representation of the RandomForest parameters.
|
|
117
|
+
"""
|
|
118
|
+
return asdict(self)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass
|
|
122
|
+
class _HGBParams:
|
|
123
|
+
"""Container for HistGradientBoosting hyperparameters.
|
|
124
|
+
|
|
125
|
+
This class holds the hyperparameters for the HistGradientBoosting classifier, including the number of iterations, learning rate, and other relevant settings.
|
|
126
|
+
|
|
127
|
+
Attributes:
|
|
128
|
+
max_iter (int): Maximum number of iterations.
|
|
129
|
+
learning_rate (float): Learning rate shrinks the contribution of each tree.
|
|
130
|
+
max_depth (int | None): Maximum depth of the individual regression estimators.
|
|
131
|
+
min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
|
|
132
|
+
n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping
|
|
133
|
+
tol (float): Tolerance for the early stopping.
|
|
134
|
+
max_features (float): The fraction of features to consider when looking for the best split.
|
|
135
|
+
class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes.
|
|
136
|
+
random_state (int | None): Random seed for reproducibility.
|
|
137
|
+
verbose (bool): If True, enables verbose logging during training.
|
|
138
|
+
|
|
139
|
+
Notes:
|
|
140
|
+
- The `class_weight` attribute helps to handle class imbalance by adjusting the weights associated with classes.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
max_iter: int = 100
|
|
144
|
+
learning_rate: float = 0.1
|
|
145
|
+
max_depth: int | None = None
|
|
146
|
+
min_samples_leaf: int = 1
|
|
147
|
+
n_iter_no_change: int = 10
|
|
148
|
+
tol: float = 1e-7
|
|
149
|
+
max_features: float = 1.0
|
|
150
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
|
|
151
|
+
random_state: int | None = None
|
|
152
|
+
verbose: bool = False
|
|
153
|
+
|
|
154
|
+
def to_dict(self) -> dict:
|
|
155
|
+
"""Convert the HistGradientBoosting parameters to a dictionary.
|
|
156
|
+
|
|
157
|
+
Uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
dict: A dictionary representation of the HistGradientBoosting parameters.
|
|
161
|
+
"""
|
|
162
|
+
return asdict(self)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass
|
|
166
|
+
class ModelConfig:
|
|
167
|
+
"""Model architecture configuration.
|
|
168
|
+
|
|
169
|
+
This class contains configuration options for the model architecture, including latent space initialization, dimensionality, dropout rate, and other relevant settings.
|
|
170
|
+
|
|
171
|
+
Attributes:
|
|
172
|
+
latent_init (Literal["random", "pca"]): Method for initializing the latent space.
|
|
173
|
+
latent_dim (int): Dimensionality of the latent space.
|
|
174
|
+
dropout_rate (float): Dropout rate for regularization.
|
|
175
|
+
num_hidden_layers (int): Number of hidden layers in the neural network.
|
|
176
|
+
hidden_activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function for hidden layers.
|
|
177
|
+
layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
|
|
178
|
+
layer_schedule (Literal["pyramid", "constant", "linear"]): Schedule for scaling hidden layer sizes.
|
|
179
|
+
gamma (float): Parameter for the loss function.
|
|
180
|
+
|
|
181
|
+
Notes:
|
|
182
|
+
- The `layer_schedule` attribute determines how the size of hidden layers changes across the network (e.g., "pyramid" means decreasing size).
|
|
183
|
+
- The `latent_init` attribute specifies how the latent space is initialized, either randomly or using PCA.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
latent_init: Literal["random", "pca"] = "random"
|
|
187
|
+
latent_dim: int = 2
|
|
188
|
+
dropout_rate: float = 0.2
|
|
189
|
+
num_hidden_layers: int = 2
|
|
190
|
+
hidden_activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
|
|
191
|
+
layer_scaling_factor: float = 5.0
|
|
192
|
+
layer_schedule: Literal["pyramid", "constant", "linear"] = "pyramid"
|
|
193
|
+
gamma: float = 2.0
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@dataclass
|
|
197
|
+
class TrainConfig:
|
|
198
|
+
"""Training procedure configuration.
|
|
199
|
+
|
|
200
|
+
This class contains configuration options for the training procedure, including batch size, learning rate, early stopping criteria, and other relevant settings.
|
|
201
|
+
|
|
202
|
+
Attributes:
|
|
203
|
+
batch_size (int): Number of samples per training batch.
|
|
204
|
+
learning_rate (float): Learning rate for the optimizer.
|
|
205
|
+
lr_input_factor (float): Factor to scale the learning rate for input layer.
|
|
206
|
+
l1_penalty (float): L1 regularization penalty.
|
|
207
|
+
early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
|
|
208
|
+
min_epochs (int): Minimum number of epochs to train.
|
|
209
|
+
max_epochs (int): Maximum number of epochs to train.
|
|
210
|
+
validation_split (float): Proportion of data to use for validation.
|
|
211
|
+
weights_beta (float): Smoothing factor for class weights.
|
|
212
|
+
weights_max_ratio (float): Maximum ratio for class weights to prevent extreme values.
|
|
213
|
+
device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
|
|
214
|
+
|
|
215
|
+
Notes:
|
|
216
|
+
- The `device` attribute specifies the computation device to use, such as "gpu", "cpu", or "mps" (for Apple Silicon).
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
batch_size: int = 32
|
|
220
|
+
learning_rate: float = 1e-3
|
|
221
|
+
lr_input_factor: float = 1.0
|
|
222
|
+
l1_penalty: float = 0.0
|
|
223
|
+
early_stop_gen: int = 20
|
|
224
|
+
min_epochs: int = 100
|
|
225
|
+
max_epochs: int = 5000
|
|
226
|
+
validation_split: float = 0.2
|
|
227
|
+
weights_beta: float = 0.9999
|
|
228
|
+
weights_max_ratio: float = 1.0
|
|
229
|
+
device: Literal["gpu", "cpu", "mps"] = "cpu"
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@dataclass
|
|
233
|
+
class TuneConfig:
|
|
234
|
+
"""Hyperparameter tuning configuration.
|
|
235
|
+
|
|
236
|
+
This class contains configuration options for hyperparameter tuning, including the number of trials, evaluation metrics, and other relevant settings.
|
|
237
|
+
|
|
238
|
+
Attributes:
|
|
239
|
+
enabled (bool): If True, enables hyperparameter tuning.
|
|
240
|
+
metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
|
|
241
|
+
n_trials (int): Number of hyperparameter trials to run.
|
|
242
|
+
resume (bool): If True, resumes tuning from a previous state.
|
|
243
|
+
save_db (bool): If True, saves the tuning results to a database.
|
|
244
|
+
fast (bool): If True, uses a faster but less thorough tuning approach.
|
|
245
|
+
max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
|
|
246
|
+
max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
|
|
247
|
+
epochs (int): Number of epochs to train each trial.
|
|
248
|
+
batch_size (int): Batch size for training during tuning.
|
|
249
|
+
eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
|
|
250
|
+
infer_epochs (int): Number of epochs for inference during tuning.
|
|
251
|
+
patience (int): Number of evaluations with no improvement before stopping early.
|
|
252
|
+
proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
enabled: bool = False
|
|
256
|
+
metric: Literal["f1", "accuracy", "pr_macro"] = "f1"
|
|
257
|
+
n_trials: int = 100
|
|
258
|
+
resume: bool = False
|
|
259
|
+
save_db: bool = False
|
|
260
|
+
fast: bool = True
|
|
261
|
+
max_samples: int = 512
|
|
262
|
+
max_loci: int = 0 # 0 = all
|
|
263
|
+
epochs: int = 500
|
|
264
|
+
batch_size: int = 64
|
|
265
|
+
eval_interval: int = 1
|
|
266
|
+
infer_epochs: int = 100
|
|
267
|
+
patience: int = 10
|
|
268
|
+
proxy_metric_batch: int = 0
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@dataclass
|
|
272
|
+
class EvalConfig:
|
|
273
|
+
"""Evaluation configuration.
|
|
274
|
+
|
|
275
|
+
This class contains configuration options for the evaluation process, including batch size, evaluation intervals, and other relevant settings.
|
|
276
|
+
|
|
277
|
+
Attributes:
|
|
278
|
+
eval_latent_steps (int): Number of optimization steps for latent space evaluation.
|
|
279
|
+
eval_latent_lr (float): Learning rate for latent space optimization.
|
|
280
|
+
eval_latent_weight_decay (float): Weight decay for latent space optimization.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
eval_latent_steps: int = 50
|
|
284
|
+
eval_latent_lr: float = 1e-2
|
|
285
|
+
eval_latent_weight_decay: float = 0.0
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@dataclass
|
|
289
|
+
class PlotConfig:
|
|
290
|
+
"""Plotting configuration.
|
|
291
|
+
|
|
292
|
+
This class contains configuration options for plotting, including file format, resolution, and other relevant settings.
|
|
293
|
+
|
|
294
|
+
Attributes:
|
|
295
|
+
fmt (Literal["pdf", "png", "jpg", "jpeg", "svg"]): Output file format.
|
|
296
|
+
dpi (int): Dots per inch for the output figure.
|
|
297
|
+
fontsize (int): Font size for text in the plots.
|
|
298
|
+
despine (bool): If True, removes the top and right spines from plots.
|
|
299
|
+
show (bool): If True, displays the plot interactively.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
fmt: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
|
|
303
|
+
dpi: int = 300
|
|
304
|
+
fontsize: int = 18
|
|
305
|
+
despine: bool = True
|
|
306
|
+
show: bool = False
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class IOConfig:
|
|
311
|
+
"""I/O configuration.
|
|
312
|
+
|
|
313
|
+
This class contains configuration options for input/output operations, including file prefixes, verbosity, random seed, and other relevant settings.
|
|
314
|
+
|
|
315
|
+
Attributes:
|
|
316
|
+
prefix (str): Prefix for output files. Default is "pgsui".
|
|
317
|
+
verbose (bool): If True, enables verbose logging. Default is False.
|
|
318
|
+
debug (bool): If True, enables debug mode. Default is False.
|
|
319
|
+
seed (int | None): Random seed for reproducibility. Default is None.
|
|
320
|
+
n_jobs (int): Number of parallel jobs to run. Default is 1.
|
|
321
|
+
scoring_averaging (Literal["macro", "micro", "weighted"]): Averaging
|
|
322
|
+
method for scoring metrics. Default is "macro".
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
prefix: str = "pgsui"
|
|
326
|
+
verbose: bool = False
|
|
327
|
+
debug: bool = False
|
|
328
|
+
seed: int | None = None
|
|
329
|
+
n_jobs: int = 1
|
|
330
|
+
scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@dataclass
|
|
334
|
+
class NLPCAConfig:
|
|
335
|
+
"""Top-level configuration for ImputeNLPCA.
|
|
336
|
+
|
|
337
|
+
This class contains all the configuration options for the ImputeNLPCA model. The configuration is organized into several sections, each represented by a dataclass.
|
|
338
|
+
|
|
339
|
+
Attributes:
|
|
340
|
+
io (IOConfig): I/O configuration.
|
|
341
|
+
model (ModelConfig): Model architecture configuration.
|
|
342
|
+
train (TrainConfig): Training procedure configuration.
|
|
343
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
344
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
345
|
+
plot (PlotConfig): Plotting configuration.
|
|
346
|
+
|
|
347
|
+
Notes:
|
|
348
|
+
- fast: Quick baseline; tiny net; NO tuning by default.
|
|
349
|
+
- balanced: Practical default balancing speed and model performance; moderate tuning.
|
|
350
|
+
- thorough: Prioritizes model performance; deeper nets; extensive tuning.
|
|
351
|
+
- Overrides: Overrides are applied after presets and can be used to fine-tune specific parameters. Specifically uses flat dot-keys like {"model.latent_dim": 8}.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
355
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
356
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
357
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
358
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
359
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
360
|
+
|
|
361
|
+
@classmethod
|
|
362
|
+
def from_preset(
|
|
363
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
364
|
+
) -> "NLPCAConfig":
|
|
365
|
+
"""Build a config from a named preset.
|
|
366
|
+
|
|
367
|
+
This method allows for easy construction of a NLPCAConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs:
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
NLPCAConfig: Configuration instance with preset values applied.
|
|
374
|
+
"""
|
|
375
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
376
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
377
|
+
|
|
378
|
+
cfg = cls() # start from dataclass defaults
|
|
379
|
+
|
|
380
|
+
# Common sensible baselines
|
|
381
|
+
cfg.io.verbose = True
|
|
382
|
+
cfg.train.validation_split = 0.2
|
|
383
|
+
cfg.evaluate.eval_latent_steps = 50
|
|
384
|
+
cfg.evaluate.eval_latent_lr = 1e-2
|
|
385
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
386
|
+
cfg.model.hidden_activation = "relu"
|
|
387
|
+
cfg.model.layer_schedule = "pyramid"
|
|
388
|
+
cfg.model.latent_init = "random"
|
|
389
|
+
|
|
390
|
+
if preset == "fast":
|
|
391
|
+
# Model
|
|
392
|
+
cfg.model.latent_dim = 4
|
|
393
|
+
cfg.model.num_hidden_layers = 1
|
|
394
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
395
|
+
cfg.model.dropout_rate = 0.10
|
|
396
|
+
cfg.model.gamma = 1.5
|
|
397
|
+
|
|
398
|
+
# Train
|
|
399
|
+
cfg.train.batch_size = 128
|
|
400
|
+
cfg.train.learning_rate = 1e-3
|
|
401
|
+
cfg.train.early_stop_gen = 5
|
|
402
|
+
cfg.train.min_epochs = 10
|
|
403
|
+
cfg.train.max_epochs = 100
|
|
404
|
+
cfg.train.weights_beta = 0.9999
|
|
405
|
+
cfg.train.weights_max_ratio = 1.0 # no rebalancing pressure
|
|
406
|
+
|
|
407
|
+
# Tuning (off for true "fast")
|
|
408
|
+
cfg.tune.enabled = False
|
|
409
|
+
cfg.tune.fast = True
|
|
410
|
+
cfg.tune.n_trials = 50
|
|
411
|
+
cfg.tune.epochs = 100
|
|
412
|
+
cfg.tune.batch_size = 128
|
|
413
|
+
cfg.tune.max_samples = 512 # cap data for speed
|
|
414
|
+
cfg.tune.max_loci = 0
|
|
415
|
+
cfg.tune.eval_interval = 1
|
|
416
|
+
cfg.tune.infer_epochs = 50
|
|
417
|
+
cfg.tune.patience = 5
|
|
418
|
+
cfg.tune.proxy_metric_batch = 0
|
|
419
|
+
|
|
420
|
+
elif preset == "balanced":
|
|
421
|
+
# Model
|
|
422
|
+
cfg.model.latent_dim = 8
|
|
423
|
+
cfg.model.num_hidden_layers = 2
|
|
424
|
+
cfg.model.layer_scaling_factor = 4.0
|
|
425
|
+
cfg.model.dropout_rate = 0.20
|
|
426
|
+
cfg.model.gamma = 2.0
|
|
427
|
+
|
|
428
|
+
# Train
|
|
429
|
+
cfg.train.batch_size = 128
|
|
430
|
+
cfg.train.learning_rate = 8e-4
|
|
431
|
+
cfg.train.early_stop_gen = 15
|
|
432
|
+
cfg.train.min_epochs = 50
|
|
433
|
+
cfg.train.max_epochs = 1000
|
|
434
|
+
cfg.train.weights_beta = 0.9999
|
|
435
|
+
cfg.train.weights_max_ratio = 1.0
|
|
436
|
+
|
|
437
|
+
# Tuning
|
|
438
|
+
cfg.tune.enabled = True
|
|
439
|
+
cfg.tune.fast = True # favor speed with good coverage
|
|
440
|
+
cfg.tune.n_trials = 100 # more trials
|
|
441
|
+
cfg.tune.epochs = 250
|
|
442
|
+
cfg.tune.batch_size = 128
|
|
443
|
+
cfg.tune.max_samples = 1024
|
|
444
|
+
cfg.tune.max_loci = 0
|
|
445
|
+
cfg.tune.eval_interval = 1
|
|
446
|
+
cfg.tune.infer_epochs = 80
|
|
447
|
+
cfg.tune.patience = 10
|
|
448
|
+
cfg.tune.proxy_metric_batch = 0
|
|
449
|
+
|
|
450
|
+
else: # thorough
|
|
451
|
+
# Model
|
|
452
|
+
cfg.model.latent_dim = 16
|
|
453
|
+
cfg.model.num_hidden_layers = 3
|
|
454
|
+
cfg.model.layer_scaling_factor = 6.0
|
|
455
|
+
cfg.model.dropout_rate = 0.30
|
|
456
|
+
cfg.model.gamma = 2.5
|
|
457
|
+
|
|
458
|
+
# Train
|
|
459
|
+
cfg.train.batch_size = 64
|
|
460
|
+
cfg.train.learning_rate = 6e-4
|
|
461
|
+
cfg.train.early_stop_gen = 30
|
|
462
|
+
cfg.train.min_epochs = 100
|
|
463
|
+
cfg.train.max_epochs = 3000
|
|
464
|
+
cfg.train.weights_beta = 0.9999
|
|
465
|
+
cfg.train.weights_max_ratio = 1.0
|
|
466
|
+
|
|
467
|
+
# Tuning
|
|
468
|
+
cfg.tune.enabled = True
|
|
469
|
+
cfg.tune.fast = False
|
|
470
|
+
cfg.tune.n_trials = 250
|
|
471
|
+
cfg.tune.epochs = 1000
|
|
472
|
+
cfg.tune.batch_size = 64
|
|
473
|
+
cfg.tune.max_samples = 0 # use all samples
|
|
474
|
+
cfg.tune.max_loci = 0 # use all loci
|
|
475
|
+
cfg.tune.eval_interval = 1
|
|
476
|
+
cfg.tune.infer_epochs = 120
|
|
477
|
+
cfg.tune.patience = 20
|
|
478
|
+
cfg.tune.proxy_metric_batch = 0
|
|
479
|
+
|
|
480
|
+
return cfg
|
|
481
|
+
|
|
482
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
|
|
483
|
+
"""Apply flat dot-key overrides (e.g. {'model.latent_dim': 4}).
|
|
484
|
+
|
|
485
|
+
This method allows for easy modification of the configuration by specifying the keys to change in a flat dictionary format.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
overrides (Dict[str, Any] | None): A mapping of dot-key paths to values to override.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
NLPCAConfig: The updated config instance (same as `self`).
|
|
492
|
+
"""
|
|
493
|
+
if not overrides:
|
|
494
|
+
return self
|
|
495
|
+
for k, v in overrides.items():
|
|
496
|
+
node = self
|
|
497
|
+
parts = k.split(".")
|
|
498
|
+
for p in parts[:-1]:
|
|
499
|
+
node = getattr(node, p)
|
|
500
|
+
last = parts[-1]
|
|
501
|
+
if hasattr(node, last):
|
|
502
|
+
setattr(node, last, v)
|
|
503
|
+
else:
|
|
504
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
505
|
+
return self
|
|
506
|
+
|
|
507
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
508
|
+
"""Return the config as a nested dictionary.
|
|
509
|
+
|
|
510
|
+
This method uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
Dict[str, Any]: The config as a nested dictionary.
|
|
514
|
+
"""
|
|
515
|
+
return asdict(self)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@dataclass
|
|
519
|
+
class UBPConfig:
|
|
520
|
+
"""Top-level configuration for ImputeUBP.
|
|
521
|
+
|
|
522
|
+
This class contains all the configuration options for the ImputeUBP model. The configuration is organized into several sections, each represented by a dataclass.
|
|
523
|
+
|
|
524
|
+
Attributes:
|
|
525
|
+
io (IOConfig): I/O configuration.
|
|
526
|
+
model (ModelConfig): Model architecture configuration.
|
|
527
|
+
train (TrainConfig): Training procedure configuration.
|
|
528
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
529
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
530
|
+
plot (PlotConfig): Plotting configuration.
|
|
531
|
+
|
|
532
|
+
Notes:
|
|
533
|
+
- fast: Quick baseline; tiny net; NO tuning by default.
|
|
534
|
+
- balanced: Practical default balancing speed and model performance; moderate tuning.
|
|
535
|
+
- thorough: Prioritizes model performance; deeper nets; extensive tuning.
|
|
536
|
+
- Overrides: Overrides are applied after presets and can be used to fine-tune specific parameters. Specifically uses flat dot-keys like {"model.latent_dim": 8}.
|
|
537
|
+
"""
|
|
538
|
+
|
|
539
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
540
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
541
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
542
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
543
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
544
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
545
|
+
|
|
546
|
+
@classmethod
|
|
547
|
+
def from_preset(
|
|
548
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
549
|
+
) -> "UBPConfig":
|
|
550
|
+
"""Build a UBPConfig from a named preset.
|
|
551
|
+
|
|
552
|
+
This method allows for easy construction of a UBPConfig instance with sensible defaults based on the chosen preset. UBP is often used when classes (genotype states) are imbalanced. Presets adjust both capacity and weighting behavior across speed/quality tradeoffs.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast","balanced","thorough"}.
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
UBPConfig: Populated config instance.
|
|
559
|
+
|
|
560
|
+
Notes:
|
|
561
|
+
- fast: Quick baseline; tiny net; NO tuning by default.
|
|
562
|
+
- balanced: Practical default balancing speed and model performance; moderate tuning.
|
|
563
|
+
- thorough: Prioritizes model performance; deeper nets; extensive tuning.
|
|
564
|
+
- Overrides: Overrides are applied after presets and can be used to fine-tune specific parameters. Specifically uses flat dot-keys like {"model.latent_dim": 8}.
|
|
565
|
+
|
|
566
|
+
Raises:
|
|
567
|
+
ValueError: If an unknown preset is provided.
|
|
568
|
+
"""
|
|
569
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
570
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
571
|
+
|
|
572
|
+
cfg = cls()
|
|
573
|
+
|
|
574
|
+
# Shared baselines
|
|
575
|
+
cfg.io.verbose = True
|
|
576
|
+
cfg.model.hidden_activation = "relu"
|
|
577
|
+
cfg.model.layer_schedule = "pyramid"
|
|
578
|
+
cfg.model.latent_init = "random"
|
|
579
|
+
|
|
580
|
+
if preset == "fast":
|
|
581
|
+
# Model (slightly smaller than NLPCA fast)
|
|
582
|
+
cfg.model.latent_dim = 3
|
|
583
|
+
cfg.model.num_hidden_layers = 1
|
|
584
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
585
|
+
cfg.model.dropout_rate = 0.10
|
|
586
|
+
cfg.model.gamma = 1.5 # lighter focusing
|
|
587
|
+
|
|
588
|
+
# Train
|
|
589
|
+
cfg.train.batch_size = 128
|
|
590
|
+
cfg.train.learning_rate = 1e-3
|
|
591
|
+
cfg.train.early_stop_gen = 5
|
|
592
|
+
cfg.train.min_epochs = 10
|
|
593
|
+
cfg.train.max_epochs = 100
|
|
594
|
+
cfg.train.weights_beta = 0.9999
|
|
595
|
+
cfg.train.weights_max_ratio = 2.0 # allow mild rebalancing
|
|
596
|
+
|
|
597
|
+
# Tuning (off for true "fast")
|
|
598
|
+
cfg.tune.enabled = False
|
|
599
|
+
cfg.tune.fast = True
|
|
600
|
+
cfg.tune.n_trials = 50
|
|
601
|
+
cfg.tune.epochs = 100
|
|
602
|
+
cfg.tune.batch_size = 128
|
|
603
|
+
cfg.tune.max_samples = 512
|
|
604
|
+
cfg.tune.max_loci = 0
|
|
605
|
+
cfg.tune.eval_interval = 1
|
|
606
|
+
cfg.tune.infer_epochs = 50
|
|
607
|
+
cfg.tune.patience = 5
|
|
608
|
+
cfg.tune.proxy_metric_batch = 0
|
|
609
|
+
|
|
610
|
+
elif preset == "balanced":
|
|
611
|
+
# Model
|
|
612
|
+
cfg.model.latent_dim = 6
|
|
613
|
+
cfg.model.num_hidden_layers = 2
|
|
614
|
+
cfg.model.layer_scaling_factor = 3.0
|
|
615
|
+
cfg.model.dropout_rate = 0.20
|
|
616
|
+
cfg.model.gamma = 2.0
|
|
617
|
+
|
|
618
|
+
# Train
|
|
619
|
+
cfg.train.batch_size = 128
|
|
620
|
+
cfg.train.learning_rate = 8e-4
|
|
621
|
+
cfg.train.early_stop_gen = 15
|
|
622
|
+
cfg.train.min_epochs = 50
|
|
623
|
+
cfg.train.max_epochs = 1000
|
|
624
|
+
cfg.train.weights_beta = 0.9999
|
|
625
|
+
cfg.train.weights_max_ratio = 3.0 # moderate cap for imbalance
|
|
626
|
+
|
|
627
|
+
# Tuning
|
|
628
|
+
cfg.tune.enabled = True
|
|
629
|
+
cfg.tune.fast = True
|
|
630
|
+
cfg.tune.n_trials = 100
|
|
631
|
+
cfg.tune.epochs = 250
|
|
632
|
+
cfg.tune.batch_size = 128
|
|
633
|
+
cfg.tune.max_samples = 1024
|
|
634
|
+
cfg.tune.max_loci = 0
|
|
635
|
+
cfg.tune.eval_interval = 1
|
|
636
|
+
cfg.tune.infer_epochs = 80
|
|
637
|
+
cfg.tune.patience = 10
|
|
638
|
+
cfg.tune.proxy_metric_batch = 0
|
|
639
|
+
|
|
640
|
+
else: # thorough
|
|
641
|
+
# Model
|
|
642
|
+
cfg.model.latent_dim = 12
|
|
643
|
+
cfg.model.num_hidden_layers = 3
|
|
644
|
+
cfg.model.layer_scaling_factor = 5.0
|
|
645
|
+
cfg.model.dropout_rate = 0.30
|
|
646
|
+
cfg.model.gamma = 2.5 # stronger focusing for harder imbalance
|
|
647
|
+
|
|
648
|
+
# Train
|
|
649
|
+
cfg.train.batch_size = 64
|
|
650
|
+
cfg.train.learning_rate = 6e-4
|
|
651
|
+
cfg.train.early_stop_gen = 30
|
|
652
|
+
cfg.train.min_epochs = 100
|
|
653
|
+
cfg.train.max_epochs = 3000
|
|
654
|
+
cfg.train.weights_beta = 0.9999
|
|
655
|
+
cfg.train.weights_max_ratio = 5.0 # allow stronger class weighting
|
|
656
|
+
|
|
657
|
+
# Tuning
|
|
658
|
+
cfg.tune.enabled = True
|
|
659
|
+
cfg.tune.fast = False
|
|
660
|
+
cfg.tune.n_trials = 250
|
|
661
|
+
cfg.tune.epochs = 1000
|
|
662
|
+
cfg.tune.batch_size = 64
|
|
663
|
+
cfg.tune.max_samples = 0 # all samples
|
|
664
|
+
cfg.tune.max_loci = 0 # all loci
|
|
665
|
+
cfg.tune.eval_interval = 1
|
|
666
|
+
cfg.tune.infer_epochs = 120
|
|
667
|
+
cfg.tune.patience = 20
|
|
668
|
+
cfg.tune.proxy_metric_batch = 0
|
|
669
|
+
|
|
670
|
+
return cfg
|
|
671
|
+
|
|
672
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
|
|
673
|
+
"""Apply flat dot-key overrides (e.g. {'model.latent_dim': 4}).
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
UBPConfig: This instance after applying overrides.
|
|
680
|
+
"""
|
|
681
|
+
if overrides is None or not overrides:
|
|
682
|
+
return self
|
|
683
|
+
|
|
684
|
+
for k, v in overrides.items():
|
|
685
|
+
node = self
|
|
686
|
+
parts = k.split(".")
|
|
687
|
+
for p in parts[:-1]:
|
|
688
|
+
node = getattr(node, p)
|
|
689
|
+
last = parts[-1]
|
|
690
|
+
if hasattr(node, last):
|
|
691
|
+
setattr(node, last, v)
|
|
692
|
+
else:
|
|
693
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
694
|
+
return self
|
|
695
|
+
|
|
696
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
697
|
+
"""Return the config as a nested dictionary.
|
|
698
|
+
|
|
699
|
+
This method uses `asdict` from the `dataclasses` module to convert the dataclass instance into a dictionary.
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
Dict[str, Any]: Nested dictionary.
|
|
703
|
+
"""
|
|
704
|
+
return asdict(self)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
@dataclass
|
|
708
|
+
class AutoencoderConfig:
|
|
709
|
+
"""Top-level configuration for ImputeAutoencoder.
|
|
710
|
+
|
|
711
|
+
This class contains all the configuration options for the ImputeAutoencoder model. The configuration is organized into several sections, each represented by a dataclass.
|
|
712
|
+
|
|
713
|
+
Attributes:
|
|
714
|
+
io (IOConfig): I/O configuration.
|
|
715
|
+
model (ModelConfig): Model architecture configuration.
|
|
716
|
+
train (TrainConfig): Training procedure configuration.
|
|
717
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
718
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
719
|
+
plot (PlotConfig): Plotting configuration.
|
|
720
|
+
|
|
721
|
+
Notes:
|
|
722
|
+
- fast: Quick baseline; tiny net; NO tuning by default.
|
|
723
|
+
- balanced: Practical default; moderate tuning.
|
|
724
|
+
- thorough: Prioritizes model performance; deeper nets; extensive tuning.
|
|
725
|
+
- Overrides: flat dot-keys like {"model.latent_dim": 8}.
|
|
726
|
+
"""
|
|
727
|
+
|
|
728
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
729
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
730
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
731
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
732
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
733
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
734
|
+
|
|
735
|
+
@classmethod
|
|
736
|
+
def from_preset(
|
|
737
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
738
|
+
) -> "AutoencoderConfig":
|
|
739
|
+
"""Build an AutoencoderConfig from a named preset.
|
|
740
|
+
|
|
741
|
+
This method allows for easy construction of an AutoencoderConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast","balanced", "thorough"}.
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
AutoencoderConfig: Populated config instance.
|
|
748
|
+
"""
|
|
749
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
750
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
751
|
+
|
|
752
|
+
cfg = cls()
|
|
753
|
+
|
|
754
|
+
# Common sensible baselines (aligned with NLPCA)
|
|
755
|
+
cfg.io.verbose = True
|
|
756
|
+
cfg.train.validation_split = 0.2
|
|
757
|
+
cfg.model.hidden_activation = "relu"
|
|
758
|
+
cfg.model.layer_schedule = "pyramid"
|
|
759
|
+
|
|
760
|
+
# AE difference: no latent refinement during eval
|
|
761
|
+
cfg.evaluate.eval_latent_steps = 0
|
|
762
|
+
cfg.evaluate.eval_latent_lr = 0.0
|
|
763
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
764
|
+
|
|
765
|
+
if preset == "fast":
|
|
766
|
+
# Model
|
|
767
|
+
cfg.model.latent_dim = 4
|
|
768
|
+
cfg.model.num_hidden_layers = 1
|
|
769
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
770
|
+
cfg.model.dropout_rate = 0.10
|
|
771
|
+
cfg.model.gamma = 1.5
|
|
772
|
+
# Train
|
|
773
|
+
cfg.train.batch_size = 128
|
|
774
|
+
cfg.train.learning_rate = 1e-3
|
|
775
|
+
cfg.train.early_stop_gen = 5
|
|
776
|
+
cfg.train.min_epochs = 10
|
|
777
|
+
cfg.train.max_epochs = 100
|
|
778
|
+
cfg.train.weights_beta = 0.9999
|
|
779
|
+
cfg.train.weights_max_ratio = 1.0
|
|
780
|
+
# Tuning (off for true fast)
|
|
781
|
+
cfg.tune.enabled = False
|
|
782
|
+
cfg.tune.fast = True
|
|
783
|
+
cfg.tune.n_trials = 50
|
|
784
|
+
cfg.tune.epochs = 100
|
|
785
|
+
cfg.tune.batch_size = 128
|
|
786
|
+
cfg.tune.max_samples = 512
|
|
787
|
+
cfg.tune.max_loci = 0
|
|
788
|
+
cfg.tune.eval_interval = 1
|
|
789
|
+
cfg.tune.patience = 5
|
|
790
|
+
cfg.tune.proxy_metric_batch = 0
|
|
791
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
792
|
+
cfg.tune.infer_epochs = 0
|
|
793
|
+
|
|
794
|
+
elif preset == "balanced":
|
|
795
|
+
# Model
|
|
796
|
+
cfg.model.latent_dim = 8
|
|
797
|
+
cfg.model.num_hidden_layers = 2
|
|
798
|
+
cfg.model.layer_scaling_factor = 4.0
|
|
799
|
+
cfg.model.dropout_rate = 0.20
|
|
800
|
+
cfg.model.gamma = 2.0
|
|
801
|
+
# Train
|
|
802
|
+
cfg.train.batch_size = 128
|
|
803
|
+
cfg.train.learning_rate = 8e-4
|
|
804
|
+
cfg.train.early_stop_gen = 15
|
|
805
|
+
cfg.train.min_epochs = 50
|
|
806
|
+
cfg.train.max_epochs = 1000
|
|
807
|
+
cfg.train.weights_beta = 0.9999
|
|
808
|
+
cfg.train.weights_max_ratio = 1.0
|
|
809
|
+
# Tuning
|
|
810
|
+
cfg.tune.enabled = True
|
|
811
|
+
cfg.tune.fast = True
|
|
812
|
+
cfg.tune.n_trials = 100
|
|
813
|
+
cfg.tune.epochs = 250
|
|
814
|
+
cfg.tune.batch_size = 128
|
|
815
|
+
cfg.tune.max_samples = 1024
|
|
816
|
+
cfg.tune.max_loci = 0
|
|
817
|
+
cfg.tune.eval_interval = 1
|
|
818
|
+
cfg.tune.patience = 10
|
|
819
|
+
cfg.tune.proxy_metric_batch = 0
|
|
820
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
821
|
+
cfg.tune.infer_epochs = 0
|
|
822
|
+
|
|
823
|
+
else: # thorough
|
|
824
|
+
# Model
|
|
825
|
+
cfg.model.latent_dim = 16
|
|
826
|
+
cfg.model.num_hidden_layers = 3
|
|
827
|
+
cfg.model.layer_scaling_factor = 6.0
|
|
828
|
+
cfg.model.dropout_rate = 0.30
|
|
829
|
+
cfg.model.gamma = 2.5
|
|
830
|
+
# Train
|
|
831
|
+
cfg.train.batch_size = 64
|
|
832
|
+
cfg.train.learning_rate = 6e-4
|
|
833
|
+
cfg.train.early_stop_gen = 30
|
|
834
|
+
cfg.train.min_epochs = 100
|
|
835
|
+
cfg.train.max_epochs = 3000
|
|
836
|
+
cfg.train.weights_beta = 0.9999
|
|
837
|
+
cfg.train.weights_max_ratio = 1.0
|
|
838
|
+
# Tuning
|
|
839
|
+
cfg.tune.enabled = True
|
|
840
|
+
cfg.tune.fast = False
|
|
841
|
+
cfg.tune.n_trials = 250
|
|
842
|
+
cfg.tune.epochs = 1000
|
|
843
|
+
cfg.tune.batch_size = 64
|
|
844
|
+
cfg.tune.max_samples = 0 # use all samples
|
|
845
|
+
cfg.tune.max_loci = 0 # use all loci
|
|
846
|
+
cfg.tune.eval_interval = 1
|
|
847
|
+
cfg.tune.patience = 20
|
|
848
|
+
cfg.tune.proxy_metric_batch = 0
|
|
849
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
850
|
+
cfg.tune.infer_epochs = 0
|
|
851
|
+
|
|
852
|
+
return cfg
|
|
853
|
+
|
|
854
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
|
|
855
|
+
"""Apply flat dot-key overrides (e.g. {'model.latent_dim': 4}).
|
|
856
|
+
|
|
857
|
+
Args:
|
|
858
|
+
overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
|
|
859
|
+
|
|
860
|
+
Returns:
|
|
861
|
+
AutoencoderConfig: This instance after applying overrides.
|
|
862
|
+
"""
|
|
863
|
+
if not overrides:
|
|
864
|
+
return self
|
|
865
|
+
for k, v in overrides.items():
|
|
866
|
+
node = self
|
|
867
|
+
parts = k.split(".")
|
|
868
|
+
for p in parts[:-1]:
|
|
869
|
+
node = getattr(node, p)
|
|
870
|
+
last = parts[-1]
|
|
871
|
+
if hasattr(node, last):
|
|
872
|
+
setattr(node, last, v)
|
|
873
|
+
else:
|
|
874
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
875
|
+
return self
|
|
876
|
+
|
|
877
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
878
|
+
return asdict(self)
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
@dataclass
|
|
882
|
+
class VAEExtraConfig:
|
|
883
|
+
"""VAE-specific knobs.
|
|
884
|
+
|
|
885
|
+
This class contains additional configuration options specific to Variational Autoencoders (VAEs), particularly for controlling the KL divergence term in the loss function.
|
|
886
|
+
|
|
887
|
+
Attributes:
|
|
888
|
+
kl_beta (float): Final β for KL divergence term.
|
|
889
|
+
kl_warmup (int): Number of epochs with β=0 (warm-up period
|
|
890
|
+
to stabilize training).
|
|
891
|
+
kl_ramp (int): Number of epochs for linear ramp to final β.
|
|
892
|
+
|
|
893
|
+
Notes:
|
|
894
|
+
- These parameters control the behavior of the KL divergence term in the VAE loss function.
|
|
895
|
+
- The warm-up period helps to stabilize training by gradually introducing the KL term.
|
|
896
|
+
- The ramp period defines how quickly the KL term reaches its final value.
|
|
897
|
+
"""
|
|
898
|
+
|
|
899
|
+
kl_beta: float = 1.0 # final β for KL
|
|
900
|
+
kl_warmup: int = 50 # epochs with β=0
|
|
901
|
+
kl_ramp: int = 200 # linear ramp to β
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
@dataclass
|
|
905
|
+
class VAEConfig:
|
|
906
|
+
"""Top-level configuration for ImputeVAE (AE-parity + VAE extras).
|
|
907
|
+
|
|
908
|
+
This class contains all the configuration options for the ImputeVAE model. The configuration is organized into several sections, each represented by a dataclass.
|
|
909
|
+
|
|
910
|
+
Attributes:
|
|
911
|
+
io (IOConfig): I/O configuration.
|
|
912
|
+
model (ModelConfig): Model architecture configuration.
|
|
913
|
+
train (TrainConfig): Training procedure configuration.
|
|
914
|
+
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
915
|
+
evaluate (EvalConfig): Evaluation configuration.
|
|
916
|
+
plot (PlotConfig): Plotting configuration.
|
|
917
|
+
vae (VAEExtraConfig): VAE-specific configuration.
|
|
918
|
+
|
|
919
|
+
Notes:
|
|
920
|
+
- fast: Quick baseline; tiny net; NO tuning by default.
|
|
921
|
+
- balanced: Practical default; moderate tuning.
|
|
922
|
+
- thorough: Prioritizes model performance; deeper nets; extensive tuning.
|
|
923
|
+
- Overrides: flat dot-keys like {"model.latent_dim": 8}.
|
|
924
|
+
"""
|
|
925
|
+
|
|
926
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
927
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
928
|
+
train: TrainConfig = field(default_factory=TrainConfig)
|
|
929
|
+
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
930
|
+
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
931
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
932
|
+
vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
|
|
933
|
+
|
|
934
|
+
@classmethod
|
|
935
|
+
def from_preset(
|
|
936
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
937
|
+
) -> "VAEConfig":
|
|
938
|
+
"""Mirror AutoencoderConfig presets and add VAE defaults.
|
|
939
|
+
|
|
940
|
+
This method allows for easy construction of a VAEConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
|
|
944
|
+
|
|
945
|
+
Returns:
|
|
946
|
+
VAEConfig: Configuration instance with preset values applied.
|
|
947
|
+
"""
|
|
948
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
949
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
950
|
+
|
|
951
|
+
cfg = cls()
|
|
952
|
+
|
|
953
|
+
# Common sensible baselines (match AE/NLPCA style)
|
|
954
|
+
cfg.io.verbose = True
|
|
955
|
+
cfg.train.validation_split = 0.2
|
|
956
|
+
cfg.model.hidden_activation = "relu"
|
|
957
|
+
cfg.model.layer_schedule = "pyramid"
|
|
958
|
+
|
|
959
|
+
# Like AE, no latent refinement during eval
|
|
960
|
+
cfg.evaluate.eval_latent_steps = 0
|
|
961
|
+
cfg.evaluate.eval_latent_lr = 0.0
|
|
962
|
+
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
963
|
+
|
|
964
|
+
# VAE-specific schedule defaults (can be overridden)
|
|
965
|
+
cfg.vae.kl_beta = 1.0
|
|
966
|
+
cfg.vae.kl_warmup = 50
|
|
967
|
+
cfg.vae.kl_ramp = 200
|
|
968
|
+
|
|
969
|
+
if preset == "fast":
|
|
970
|
+
cfg.model.latent_dim = 4
|
|
971
|
+
cfg.model.num_hidden_layers = 1
|
|
972
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
973
|
+
cfg.model.dropout_rate = 0.10
|
|
974
|
+
cfg.model.gamma = 1.5
|
|
975
|
+
|
|
976
|
+
cfg.train.batch_size = 128
|
|
977
|
+
cfg.train.learning_rate = 1e-3
|
|
978
|
+
cfg.train.early_stop_gen = 5
|
|
979
|
+
cfg.train.min_epochs = 10
|
|
980
|
+
cfg.train.max_epochs = 100
|
|
981
|
+
cfg.train.weights_beta = 0.9999
|
|
982
|
+
cfg.train.weights_max_ratio = 1.0
|
|
983
|
+
|
|
984
|
+
cfg.tune.enabled = False
|
|
985
|
+
cfg.tune.fast = True
|
|
986
|
+
cfg.tune.n_trials = 50
|
|
987
|
+
cfg.tune.epochs = 100
|
|
988
|
+
cfg.tune.batch_size = 128
|
|
989
|
+
cfg.tune.max_samples = 512
|
|
990
|
+
cfg.tune.max_loci = 0
|
|
991
|
+
cfg.tune.eval_interval = 1
|
|
992
|
+
cfg.tune.patience = 5
|
|
993
|
+
|
|
994
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
995
|
+
cfg.tune.infer_epochs = 0
|
|
996
|
+
|
|
997
|
+
elif preset == "balanced":
|
|
998
|
+
cfg.model.latent_dim = 8
|
|
999
|
+
cfg.model.num_hidden_layers = 2
|
|
1000
|
+
cfg.model.layer_scaling_factor = 4.0
|
|
1001
|
+
cfg.model.dropout_rate = 0.20
|
|
1002
|
+
cfg.model.gamma = 2.0
|
|
1003
|
+
|
|
1004
|
+
cfg.train.batch_size = 128
|
|
1005
|
+
cfg.train.learning_rate = 8e-4
|
|
1006
|
+
cfg.train.early_stop_gen = 15
|
|
1007
|
+
cfg.train.min_epochs = 50
|
|
1008
|
+
cfg.train.max_epochs = 1000
|
|
1009
|
+
cfg.train.weights_beta = 0.9999
|
|
1010
|
+
cfg.train.weights_max_ratio = 1.0
|
|
1011
|
+
|
|
1012
|
+
cfg.tune.enabled = True
|
|
1013
|
+
cfg.tune.fast = True
|
|
1014
|
+
cfg.tune.n_trials = 100
|
|
1015
|
+
cfg.tune.epochs = 250
|
|
1016
|
+
cfg.tune.batch_size = 128
|
|
1017
|
+
cfg.tune.max_samples = 1024
|
|
1018
|
+
cfg.tune.max_loci = 0
|
|
1019
|
+
cfg.tune.eval_interval = 1
|
|
1020
|
+
cfg.tune.patience = 10
|
|
1021
|
+
|
|
1022
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
1023
|
+
cfg.tune.infer_epochs = 0
|
|
1024
|
+
|
|
1025
|
+
else: # thorough
|
|
1026
|
+
cfg.model.latent_dim = 16
|
|
1027
|
+
cfg.model.num_hidden_layers = 3
|
|
1028
|
+
cfg.model.layer_scaling_factor = 6.0
|
|
1029
|
+
cfg.model.dropout_rate = 0.30
|
|
1030
|
+
cfg.model.gamma = 2.5
|
|
1031
|
+
|
|
1032
|
+
cfg.train.batch_size = 64
|
|
1033
|
+
cfg.train.learning_rate = 6e-4
|
|
1034
|
+
cfg.train.early_stop_gen = 30
|
|
1035
|
+
cfg.train.min_epochs = 100
|
|
1036
|
+
cfg.train.max_epochs = 3000
|
|
1037
|
+
cfg.train.weights_beta = 0.9999
|
|
1038
|
+
cfg.train.weights_max_ratio = 1.0
|
|
1039
|
+
|
|
1040
|
+
cfg.tune.enabled = True
|
|
1041
|
+
cfg.tune.fast = False
|
|
1042
|
+
cfg.tune.n_trials = 250
|
|
1043
|
+
cfg.tune.epochs = 1000
|
|
1044
|
+
cfg.tune.batch_size = 64
|
|
1045
|
+
cfg.tune.max_samples = 0
|
|
1046
|
+
cfg.tune.max_loci = 0
|
|
1047
|
+
cfg.tune.eval_interval = 1
|
|
1048
|
+
cfg.tune.patience = 20
|
|
1049
|
+
|
|
1050
|
+
if hasattr(cfg.tune, "infer_epochs"):
|
|
1051
|
+
cfg.tune.infer_epochs = 0
|
|
1052
|
+
|
|
1053
|
+
return cfg
|
|
1054
|
+
|
|
1055
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "VAEConfig":
|
|
1056
|
+
"""Apply flat dot-key overrides (e.g., {'vae.kl_beta': 2.0})."""
|
|
1057
|
+
if not overrides:
|
|
1058
|
+
return self
|
|
1059
|
+
for k, v in overrides.items():
|
|
1060
|
+
node = self
|
|
1061
|
+
parts = k.split(".")
|
|
1062
|
+
for p in parts[:-1]:
|
|
1063
|
+
node = getattr(node, p)
|
|
1064
|
+
last = parts[-1]
|
|
1065
|
+
if hasattr(node, last):
|
|
1066
|
+
setattr(node, last, v)
|
|
1067
|
+
else:
|
|
1068
|
+
raise KeyError(f"Unknown config key: {k}")
|
|
1069
|
+
return self
|
|
1070
|
+
|
|
1071
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1072
|
+
return asdict(self)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
@dataclass
|
|
1076
|
+
class MostFrequentAlgoConfig:
|
|
1077
|
+
"""Algorithmic knobs for ImputeMostFrequent.
|
|
1078
|
+
|
|
1079
|
+
This class contains configuration options specific to the most frequent genotype imputation algorithm.
|
|
1080
|
+
|
|
1081
|
+
Attributes:
|
|
1082
|
+
by_populations (bool): Whether to compute per-population modes when populations are available.
|
|
1083
|
+
default (int): Fallback mode if no valid entries in a locus.
|
|
1084
|
+
missing (int): Code for missing genotypes in 0/1/2.
|
|
1085
|
+
"""
|
|
1086
|
+
|
|
1087
|
+
by_populations: bool = False # per-pop modes if pops available
|
|
1088
|
+
default: int = 0 # fallback mode if no valid entries in a locus
|
|
1089
|
+
missing: int = -1 # code for missing genotypes in 0/1/2
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
@dataclass
|
|
1093
|
+
class DeterministicSplitConfig:
|
|
1094
|
+
"""Evaluation split configuration shared by deterministic imputers.
|
|
1095
|
+
|
|
1096
|
+
This class contains configuration options for splitting data into training and testing sets for deterministic imputation algorithms. The split can be defined by a proportion of the data or by specific indices.
|
|
1097
|
+
|
|
1098
|
+
Attributes:
|
|
1099
|
+
test_size (float): Proportion of data to use as the test set.
|
|
1100
|
+
test_indices (Optional[Sequence[int]]): Specific indices to use as the test set. If provided, this overrides the `test_size` parameter.
|
|
1101
|
+
"""
|
|
1102
|
+
|
|
1103
|
+
test_size: float = 0.2
|
|
1104
|
+
# If provided, overrides test_size.
|
|
1105
|
+
test_indices: Optional[Sequence[int]] = None
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
@dataclass
|
|
1109
|
+
class MostFrequentConfig:
|
|
1110
|
+
"""Top-level configuration for ImputeMostFrequent.
|
|
1111
|
+
|
|
1112
|
+
This class contains all the configuration options for the ImputeMostFrequent model. The configuration is organized into several sections, each represented by a dataclass.
|
|
1113
|
+
|
|
1114
|
+
Attributes:
|
|
1115
|
+
io (IOConfig): I/O configuration.
|
|
1116
|
+
plot (PlotConfig): Plotting configuration.
|
|
1117
|
+
split (DeterministicSplitConfig): Data splitting configuration.
|
|
1118
|
+
algo (MostFrequentAlgoConfig): Algorithmic configuration.
|
|
1119
|
+
"""
|
|
1120
|
+
|
|
1121
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
1122
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
1123
|
+
split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
|
|
1124
|
+
algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
|
|
1125
|
+
|
|
1126
|
+
@classmethod
|
|
1127
|
+
def from_preset(
|
|
1128
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1129
|
+
) -> "MostFrequentConfig":
|
|
1130
|
+
"""Presets mainly keep parity with logging/IO and split test_size.
|
|
1131
|
+
|
|
1132
|
+
Deterministic imputers don't have model/train knobs; presets exist for interface symmetry and minor UX defaults.
|
|
1133
|
+
|
|
1134
|
+
Args:
|
|
1135
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
|
|
1136
|
+
|
|
1137
|
+
Returns:
|
|
1138
|
+
MostFrequentConfig: Populated config instance.
|
|
1139
|
+
"""
|
|
1140
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
1141
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1142
|
+
|
|
1143
|
+
cfg = cls()
|
|
1144
|
+
cfg.io.verbose = True
|
|
1145
|
+
cfg.split.test_size = 0.2 # keep stable across presets
|
|
1146
|
+
return cfg
|
|
1147
|
+
|
|
1148
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "MostFrequentConfig":
|
|
1149
|
+
"""Apply dot-key overrides (e.g., {'algo.by_populations': True}).
|
|
1150
|
+
|
|
1151
|
+
Args:
|
|
1152
|
+
overrides (Dict[str, Any]): Mapping of dot-key paths to values to override.
|
|
1153
|
+
|
|
1154
|
+
Returns:
|
|
1155
|
+
MostFrequentConfig: This instance after applying overrides.
|
|
1156
|
+
"""
|
|
1157
|
+
if not overrides:
|
|
1158
|
+
return self
|
|
1159
|
+
for k, v in overrides.items():
|
|
1160
|
+
node = self
|
|
1161
|
+
parts = k.split(".")
|
|
1162
|
+
for p in parts[:-1]:
|
|
1163
|
+
node = getattr(node, p)
|
|
1164
|
+
last = parts[-1]
|
|
1165
|
+
if hasattr(node, last):
|
|
1166
|
+
setattr(node, last, v)
|
|
1167
|
+
else:
|
|
1168
|
+
pass
|
|
1169
|
+
return self
|
|
1170
|
+
|
|
1171
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1172
|
+
"""Return the config as a dictionary.
|
|
1173
|
+
|
|
1174
|
+
Returns:
|
|
1175
|
+
Dict[str, Any]: The config as a nested dictionary.
|
|
1176
|
+
"""
|
|
1177
|
+
return asdict(self)
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
@dataclass
|
|
1181
|
+
class RefAlleleAlgoConfig:
|
|
1182
|
+
"""Algorithmic knobs for ImputeRefAllele.
|
|
1183
|
+
|
|
1184
|
+
This class contains configuration options specific to the reference allele imputation algorithm.
|
|
1185
|
+
|
|
1186
|
+
Attributes:
|
|
1187
|
+
missing (int): Code for missing genotypes in 0/1/2.
|
|
1188
|
+
"""
|
|
1189
|
+
|
|
1190
|
+
missing: int = -1
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
@dataclass
|
|
1194
|
+
class RefAlleleConfig:
|
|
1195
|
+
"""Top-level configuration for ImputeRefAllele.
|
|
1196
|
+
|
|
1197
|
+
This class contains all the configuration options for the ImputeRefAllele model. The configuration is organized into several sections, each represented by a dataclass.
|
|
1198
|
+
|
|
1199
|
+
Attributes:
|
|
1200
|
+
io (IOConfig): I/O configuration.
|
|
1201
|
+
plot (PlotConfig): Plotting configuration.
|
|
1202
|
+
split (DeterministicSplitConfig): Data splitting configuration.
|
|
1203
|
+
algo (RefAlleleAlgoConfig): Algorithmic configuration.
|
|
1204
|
+
"""
|
|
1205
|
+
|
|
1206
|
+
io: IOConfig = field(default_factory=IOConfig)
|
|
1207
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
1208
|
+
split: DeterministicSplitConfig = field(default_factory=DeterministicSplitConfig)
|
|
1209
|
+
algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
|
|
1210
|
+
|
|
1211
|
+
@classmethod
|
|
1212
|
+
def from_preset(
|
|
1213
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1214
|
+
) -> "RefAlleleConfig":
|
|
1215
|
+
"""Presets mainly keep parity with logging/IO and split test_size.
|
|
1216
|
+
|
|
1217
|
+
Deterministic imputers don't have model/train knobs; presets exist for interface symmetry and minor UX defaults.
|
|
1218
|
+
|
|
1219
|
+
Args:
|
|
1220
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}.
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
RefAlleleConfig: Populated config instance.
|
|
1224
|
+
"""
|
|
1225
|
+
if preset not in {"fast", "balanced", "thorough"}:
|
|
1226
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1227
|
+
|
|
1228
|
+
cfg = cls()
|
|
1229
|
+
cfg.io.verbose = True
|
|
1230
|
+
cfg.split.test_size = 0.2
|
|
1231
|
+
return cfg
|
|
1232
|
+
|
|
1233
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RefAlleleConfig":
|
|
1234
|
+
"""Apply dot-key overrides (e.g., {'split.test_size': 0.3}).
|
|
1235
|
+
|
|
1236
|
+
This method allows for easy modification of the configuration by specifying the keys to change in a flat dictionary format.
|
|
1237
|
+
|
|
1238
|
+
Args:
|
|
1239
|
+
overrides (Dict[str, Any] | None): A mapping of dot-key paths to values to override.
|
|
1240
|
+
|
|
1241
|
+
Returns:
|
|
1242
|
+
RefAlleleConfig: The updated config instance (same as `self`).
|
|
1243
|
+
"""
|
|
1244
|
+
if not overrides:
|
|
1245
|
+
return self
|
|
1246
|
+
for k, v in overrides.items():
|
|
1247
|
+
node = self
|
|
1248
|
+
parts = k.split(".")
|
|
1249
|
+
for p in parts[:-1]:
|
|
1250
|
+
node = getattr(node, p)
|
|
1251
|
+
last = parts[-1]
|
|
1252
|
+
if hasattr(node, last):
|
|
1253
|
+
setattr(node, last, v)
|
|
1254
|
+
else:
|
|
1255
|
+
pass
|
|
1256
|
+
return self
|
|
1257
|
+
|
|
1258
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1259
|
+
"""Convert the config to a dictionary.
|
|
1260
|
+
|
|
1261
|
+
Returns:
|
|
1262
|
+
Dict[str, Any]: The config as a nested dictionary.
|
|
1263
|
+
"""
|
|
1264
|
+
return asdict(self)
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def _flatten_dict(
|
|
1268
|
+
d: Dict[str, Any], prefix: str = "", out: Optional[Dict[str, Any]] = None
|
|
1269
|
+
) -> Dict[str, Any]:
|
|
1270
|
+
"""Flatten a nested dictionary into dot-key format.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
d (Dict[str, Any]): The nested dictionary to flatten.
|
|
1274
|
+
prefix (str): The prefix to use for keys (used in recursion).
|
|
1275
|
+
out (Optional[Dict[str, Any]]): The output dictionary to populate.
|
|
1276
|
+
|
|
1277
|
+
Returns:
|
|
1278
|
+
Dict[str, Any]: The flattened dictionary with dot-key format.
|
|
1279
|
+
"""
|
|
1280
|
+
out = out or {}
|
|
1281
|
+
for k, v in d.items():
|
|
1282
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
1283
|
+
if isinstance(v, dict):
|
|
1284
|
+
_flatten_dict(v, kk, out)
|
|
1285
|
+
else:
|
|
1286
|
+
out[kk] = v
|
|
1287
|
+
return out
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
@dataclass
|
|
1291
|
+
class IOConfigSupervised:
|
|
1292
|
+
"""I/O, logging, and run identity.
|
|
1293
|
+
|
|
1294
|
+
This class contains configuration options for input/output operations, logging, and run identification.
|
|
1295
|
+
|
|
1296
|
+
Attributes:
|
|
1297
|
+
prefix (str): Prefix for output files and logs.
|
|
1298
|
+
seed (Optional[int]): Random seed for reproducibility.
|
|
1299
|
+
n_jobs (int): Number of parallel jobs to use. -1 uses all available cores.
|
|
1300
|
+
verbose (bool): Whether to enable verbose logging.
|
|
1301
|
+
debug (bool): Whether to enable debug mode with more detailed logs.
|
|
1302
|
+
|
|
1303
|
+
Notes:
|
|
1304
|
+
- The prefix is used to name output files and logs, helping to organize results from different runs.
|
|
1305
|
+
- Setting a random seed ensures that results are reproducible across different runs.
|
|
1306
|
+
- The number of jobs can be adjusted based on the available computational resources.
|
|
1307
|
+
- Verbose and debug modes provide additional logging information, which can be useful for troubleshooting.
|
|
1308
|
+
"""
|
|
1309
|
+
|
|
1310
|
+
prefix: str = "pgsui"
|
|
1311
|
+
seed: Optional[int] = None
|
|
1312
|
+
n_jobs: int = 1
|
|
1313
|
+
verbose: bool = False
|
|
1314
|
+
debug: bool = False
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
@dataclass
|
|
1318
|
+
class PlotConfigSupervised:
|
|
1319
|
+
"""Plot/figure styling.
|
|
1320
|
+
|
|
1321
|
+
This class contains parameters for controlling the appearance of plots generated during the imputation process.
|
|
1322
|
+
|
|
1323
|
+
Attributes:
|
|
1324
|
+
fmt (Literal["pdf", "png", "jpg", "jpeg"]): File format
|
|
1325
|
+
for saving plots.
|
|
1326
|
+
dpi (int): Resolution in dots per inch for raster formats.
|
|
1327
|
+
fontsize (int): Base font size for plot text.
|
|
1328
|
+
despine (bool): Whether to remove top/right spines from plots.
|
|
1329
|
+
show (bool): Whether to display plots interactively.
|
|
1330
|
+
|
|
1331
|
+
Notes:
|
|
1332
|
+
- Supported formats: "pdf", "png", "jpg", "jpeg".
|
|
1333
|
+
- Higher DPI values yield better quality in raster images.
|
|
1334
|
+
- Despining is a common aesthetic choice for cleaner plots.
|
|
1335
|
+
"""
|
|
1336
|
+
|
|
1337
|
+
fmt: Literal["pdf", "png", "jpg", "jpeg"] = "pdf"
|
|
1338
|
+
dpi: int = 300
|
|
1339
|
+
fontsize: int = 18
|
|
1340
|
+
despine: bool = True
|
|
1341
|
+
show: bool = False
|
|
1342
|
+
|
|
1343
|
+
|
|
1344
|
+
@dataclass
|
|
1345
|
+
class TrainConfigSupervised:
|
|
1346
|
+
"""Training/evaluation split (by samples).
|
|
1347
|
+
|
|
1348
|
+
This class contains configuration options for splitting the dataset into training and validation sets during the training process.
|
|
1349
|
+
|
|
1350
|
+
Attributes:
|
|
1351
|
+
validation_split (float): Proportion of data to use for validation.
|
|
1352
|
+
|
|
1353
|
+
Notes:
|
|
1354
|
+
- Value should be between 0.0 and 1.0.
|
|
1355
|
+
"""
|
|
1356
|
+
|
|
1357
|
+
validation_split: float = 0.20
|
|
1358
|
+
|
|
1359
|
+
def __post_init__(self):
|
|
1360
|
+
"""Validate that validation_split is between 0.0 and 1.0."""
|
|
1361
|
+
if not (0.0 < self.validation_split < 1.0):
|
|
1362
|
+
raise ValueError("validation_split must be between 0.0 and 1.0")
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
@dataclass
|
|
1366
|
+
class ImputerConfigSupervised:
|
|
1367
|
+
"""IterativeImputer-like scaffolding used by current supervised wrappers.
|
|
1368
|
+
|
|
1369
|
+
This class contains configuration options for the imputation process, specifically for iterative imputation methods.
|
|
1370
|
+
|
|
1371
|
+
Attributes:
|
|
1372
|
+
n_nearest_features (Optional[int]): Number of nearest features to use
|
|
1373
|
+
for imputation. If None, all features are used.
|
|
1374
|
+
max_iter (int): Maximum number of imputation iterations to perform.
|
|
1375
|
+
|
|
1376
|
+
Notes:
|
|
1377
|
+
- n_nearest_features can help speed up imputation by limiting the number of features considered.
|
|
1378
|
+
- max_iter controls how many times the imputation process is repeated to refine estimates.
|
|
1379
|
+
- If n_nearest_features is None, the imputer will consider all features for each missing value.
|
|
1380
|
+
- Default max_iter is set to 10, which is typically sufficient for convergence.
|
|
1381
|
+
- Iterative imputation can be computationally intensive; consider adjusting n_nearest_features for large datasets.
|
|
1382
|
+
"""
|
|
1383
|
+
|
|
1384
|
+
n_nearest_features: Optional[int] = 10
|
|
1385
|
+
max_iter: int = 10
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
@dataclass
|
|
1389
|
+
class SimConfigSupervised:
|
|
1390
|
+
"""Simulation of missingness for evaluation.
|
|
1391
|
+
|
|
1392
|
+
This class contains configuration options for simulating missing data during the evaluation process.
|
|
1393
|
+
|
|
1394
|
+
Attributes:
|
|
1395
|
+
prop_missing (float): Proportion of features to randomly set as missing.
|
|
1396
|
+
strategy (Literal["random", "random_inv_genotype"]): Strategy for generating missingness.
|
|
1397
|
+
het_boost (float): Boosting factor for heterogeneity in missingness.
|
|
1398
|
+
missing_val (int): Internal code for missing genotypes (e.g., -1).
|
|
1399
|
+
|
|
1400
|
+
Notes:
|
|
1401
|
+
- The choice of strategy can affect the realism of the missing data simulation.
|
|
1402
|
+
- Heterogeneous missingness can be useful for testing model robustness.
|
|
1403
|
+
"""
|
|
1404
|
+
|
|
1405
|
+
prop_missing: float = 0.5
|
|
1406
|
+
strategy: Literal["random", "random_inv_genotype"] = "random_inv_genotype"
|
|
1407
|
+
het_boost: float = 2.0
|
|
1408
|
+
missing_val: int = -1 # internal use; your wrappers expect -1
|
|
1409
|
+
|
|
1410
|
+
|
|
1411
|
+
@dataclass
|
|
1412
|
+
class TuningConfigSupervised:
|
|
1413
|
+
"""Optuna tuning envelope (kept for parity with unsupervised)."""
|
|
1414
|
+
|
|
1415
|
+
enabled: bool = True
|
|
1416
|
+
n_trials: int = 100
|
|
1417
|
+
metric: str = "pr_macro"
|
|
1418
|
+
n_jobs: int = 8 # for parallel eval (model-dependent)
|
|
1419
|
+
fast: bool = True # placeholder—trees don't need it but kept for consistency
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
@dataclass
|
|
1423
|
+
class RFModelConfig:
|
|
1424
|
+
"""Random Forest hyperparameters.
|
|
1425
|
+
|
|
1426
|
+
This class contains configuration options for the Random Forest model used in imputation.
|
|
1427
|
+
|
|
1428
|
+
Attributes:
|
|
1429
|
+
n_estimators (int): Number of trees in the forest.
|
|
1430
|
+
max_depth (Optional[int]): Maximum depth of the trees. If None, nodes are expanded until all leaves are pure or contain less than min_samples_leaf samples.
|
|
1431
|
+
min_samples_split (int): Minimum number of samples required to split an internal node.
|
|
1432
|
+
min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
|
|
1433
|
+
max_features (Literal["sqrt", "log2"] | float | None): Number of features to consider when looking for the best split.
|
|
1434
|
+
criterion (Literal["gini", "entropy", "log_loss"]): Function to measure the quality of a split.
|
|
1435
|
+
class_weight (Literal["balanced", "balanced_subsample", None]): Weights associated with classes. If "balanced", the class weights will be adjusted inversely proportional to class frequencies in the input data. If "balanced_subsample", the weights will be adjusted based on the bootstrap sample for each tree. If None, all classes will have weight of 1.0.
|
|
1436
|
+
"""
|
|
1437
|
+
|
|
1438
|
+
n_estimators: int = 100
|
|
1439
|
+
max_depth: Optional[int] = None
|
|
1440
|
+
min_samples_split: int = 2
|
|
1441
|
+
min_samples_leaf: int = 1
|
|
1442
|
+
max_features: Literal["sqrt", "log2"] | float | None = "sqrt"
|
|
1443
|
+
criterion: Literal["gini", "entropy", "log_loss"] = "gini"
|
|
1444
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = "balanced"
|
|
1445
|
+
|
|
1446
|
+
|
|
1447
|
+
@dataclass
|
|
1448
|
+
class HGBModelConfig:
|
|
1449
|
+
"""Histogram-based Gradient Boosting hyperparameters.
|
|
1450
|
+
|
|
1451
|
+
This class contains configuration options for the Histogram-based Gradient Boosting (HGB) model used in imputation.
|
|
1452
|
+
|
|
1453
|
+
Attributes:
|
|
1454
|
+
n_estimators (int): Number of boosting iterations.
|
|
1455
|
+
learning_rate (float): Step size for each boosting iteration.
|
|
1456
|
+
max_depth (Optional[int]): Maximum depth of each tree. If None, nodes are expanded until all leaves are pure or contain less than min_samples_leaf samples.
|
|
1457
|
+
min_samples_leaf (int): Minimum number of samples required to be at a leaf node.
|
|
1458
|
+
max_features (float | None): Proportion of features to consider when looking for the best split. If None, all features are considered.
|
|
1459
|
+
n_iter_no_change (int): Number of iterations with no improvement to wait before early stopping.
|
|
1460
|
+
tol (float): Minimum improvement in the loss to qualify as an improvement.
|
|
1461
|
+
|
|
1462
|
+
Notes:
|
|
1463
|
+
- These parameters control the complexity and learning behavior of the HGB model.
|
|
1464
|
+
- Early stopping is implemented to prevent overfitting.
|
|
1465
|
+
- The choice of criterion affects how the quality of a split is measured.
|
|
1466
|
+
- The model is sensitive to the learning_rate; smaller values require more estimators.
|
|
1467
|
+
- max_features can be set to a float between 0.0 and 1.0 to use a proportion of features.
|
|
1468
|
+
- Early stopping is driven by ``n_iter_no_change / tol``; sklearn controls randomness via random_state.
|
|
1469
|
+
"""
|
|
1470
|
+
|
|
1471
|
+
# sklearn.HistGradientBoostingClassifier uses 'max_iter'
|
|
1472
|
+
# as number of boosting iterations
|
|
1473
|
+
# instead of 'n_estimators'.
|
|
1474
|
+
n_estimators: int = 100 # maps to max_iter
|
|
1475
|
+
learning_rate: float = 0.1
|
|
1476
|
+
max_depth: Optional[int] = None
|
|
1477
|
+
min_samples_leaf: int = 1
|
|
1478
|
+
max_features: float | None = 1.0
|
|
1479
|
+
n_iter_no_change: int = 10
|
|
1480
|
+
tol: float = 1e-7
|
|
1481
|
+
|
|
1482
|
+
def __post_init__(self) -> None:
|
|
1483
|
+
"""Validate max_features if it's a float.
|
|
1484
|
+
|
|
1485
|
+
This method checks if the `max_features` attribute is a float and ensures that it falls within the valid range (0.0, 1.0]. It also validates that `n_estimators` is a positive integer.
|
|
1486
|
+
"""
|
|
1487
|
+
if isinstance(self.max_features, float):
|
|
1488
|
+
if not (0.0 < self.max_features <= 1.0):
|
|
1489
|
+
raise ValueError("max_features as float must be in (0.0, 1.0]")
|
|
1490
|
+
|
|
1491
|
+
if self.n_estimators <= 0:
|
|
1492
|
+
raise ValueError("n_estimators must be a positive integer")
|
|
1493
|
+
|
|
1494
|
+
|
|
1495
|
+
@dataclass
|
|
1496
|
+
class RFConfig:
|
|
1497
|
+
"""Configuration for ImputeRandomForest.
|
|
1498
|
+
|
|
1499
|
+
This dataclass mirrors the legacy ``__init__`` signature while supporting presets, YAML loading, and dot-key overrides. Use ``to_imputer_kwargs()`` to call the current constructor, or refactor the imputer to accept ``config: RFConfig``.
|
|
1500
|
+
|
|
1501
|
+
Attributes:
|
|
1502
|
+
io (IOConfigSupervised): Run identity, logging, and seeds.
|
|
1503
|
+
model (RFModelConfig): RandomForest hyperparameters.
|
|
1504
|
+
train (TrainConfigSupervised): Sample split for validation.
|
|
1505
|
+
imputer (ImputerConfigSupervised): IterativeImputer scaffolding (neighbors/iters).
|
|
1506
|
+
sim (SimConfigSupervised): Simulated missingness used during evaluation.
|
|
1507
|
+
plot (PlotConfigSupervised): Plot styling and export options.
|
|
1508
|
+
tune (TuningConfigSupervised): Optuna knobs (not required by RF itself).
|
|
1509
|
+
"""
|
|
1510
|
+
|
|
1511
|
+
io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
|
|
1512
|
+
model: RFModelConfig = field(default_factory=RFModelConfig)
|
|
1513
|
+
train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
|
|
1514
|
+
imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
|
|
1515
|
+
sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
|
|
1516
|
+
plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
|
|
1517
|
+
tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
|
|
1518
|
+
|
|
1519
|
+
@classmethod
|
|
1520
|
+
def from_preset(
|
|
1521
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1522
|
+
) -> "RFConfig":
|
|
1523
|
+
"""Build a config from a named preset.
|
|
1524
|
+
|
|
1525
|
+
This method allows for easy construction of an RFConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
|
|
1526
|
+
|
|
1527
|
+
Args:
|
|
1528
|
+
preset: One of {"fast", "balanced", "thorough"}.
|
|
1529
|
+
- fast: Quick baseline; fewer trees; fewer imputer iters.
|
|
1530
|
+
- balanced: Balances speed and model performance; moderate trees and imputer iters.
|
|
1531
|
+
- thorough: Prioritizes model performance; more trees; more imputer iters.
|
|
1532
|
+
|
|
1533
|
+
Returns:
|
|
1534
|
+
RFConfig: Config with preset values applied.
|
|
1535
|
+
"""
|
|
1536
|
+
cfg = cls()
|
|
1537
|
+
if preset == "fast":
|
|
1538
|
+
cfg.model.n_estimators = 50
|
|
1539
|
+
cfg.model.max_depth = None
|
|
1540
|
+
cfg.imputer.max_iter = 5
|
|
1541
|
+
cfg.io.n_jobs = 1
|
|
1542
|
+
cfg.tune.enabled = False
|
|
1543
|
+
elif preset == "balanced":
|
|
1544
|
+
cfg.model.n_estimators = 100
|
|
1545
|
+
cfg.model.max_depth = None
|
|
1546
|
+
cfg.imputer.max_iter = 10
|
|
1547
|
+
cfg.io.n_jobs = 1
|
|
1548
|
+
cfg.tune.enabled = False
|
|
1549
|
+
cfg.tune.n_trials = 100
|
|
1550
|
+
elif preset == "thorough":
|
|
1551
|
+
cfg.model.n_estimators = 500
|
|
1552
|
+
cfg.model.max_depth = None
|
|
1553
|
+
cfg.imputer.max_iter = 15
|
|
1554
|
+
cfg.io.n_jobs = 1
|
|
1555
|
+
cfg.tune.enabled = False
|
|
1556
|
+
cfg.tune.n_trials = 250
|
|
1557
|
+
else:
|
|
1558
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1559
|
+
|
|
1560
|
+
return cfg
|
|
1561
|
+
|
|
1562
|
+
@classmethod
|
|
1563
|
+
def from_yaml(cls, path: str) -> "RFConfig":
|
|
1564
|
+
"""Load from YAML; honors optional top-level 'preset' then merges keys.
|
|
1565
|
+
|
|
1566
|
+
This method allows for easy construction of an RFConfig instance from a YAML file, with support for presets. If the YAML file specifies a top-level 'preset', the corresponding preset values are applied first, and then any additional keys in the YAML file override those preset values.
|
|
1567
|
+
|
|
1568
|
+
Args:
|
|
1569
|
+
path (str): Path to the YAML configuration file.
|
|
1570
|
+
|
|
1571
|
+
Returns:
|
|
1572
|
+
RFConfig: Config instance populated from the YAML file.
|
|
1573
|
+
"""
|
|
1574
|
+
return load_yaml_to_dataclass(path, cls, preset_builder=cls.from_preset)
|
|
1575
|
+
|
|
1576
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "RFConfig":
|
|
1577
|
+
"""Apply flat dot-key overrides (e.g., {'model.n_estimators': 500}).
|
|
1578
|
+
|
|
1579
|
+
This method allows for easy application of overrides to the config instance using a flat dictionary structure.
|
|
1580
|
+
|
|
1581
|
+
Args:
|
|
1582
|
+
overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
|
|
1583
|
+
|
|
1584
|
+
Returns:
|
|
1585
|
+
RFConfig: This instance after applying overrides.
|
|
1586
|
+
"""
|
|
1587
|
+
if overrides:
|
|
1588
|
+
apply_dot_overrides(self, overrides)
|
|
1589
|
+
return self
|
|
1590
|
+
|
|
1591
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1592
|
+
"""Return as nested dictionary.
|
|
1593
|
+
|
|
1594
|
+
This method converts the config instance into a nested dictionary format, which can be useful for serialization or inspection.
|
|
1595
|
+
|
|
1596
|
+
Returns:
|
|
1597
|
+
Dict[str, Any]: The config as a nested dictionary.
|
|
1598
|
+
"""
|
|
1599
|
+
return asdict(self)
|
|
1600
|
+
|
|
1601
|
+
def to_imputer_kwargs(self) -> Dict[str, Any]:
|
|
1602
|
+
"""Map config fields to current ImputeRandomForest ``__init__`` kwargs.
|
|
1603
|
+
|
|
1604
|
+
This method extracts relevant configuration fields and maps them to keyword arguments suitable for initializing the ImputeRandomForest class.
|
|
1605
|
+
|
|
1606
|
+
Returns:
|
|
1607
|
+
Dict[str, Any]: kwargs compatible with ImputeRandomForest(..., \*\*kwargs).
|
|
1608
|
+
"""
|
|
1609
|
+
return {
|
|
1610
|
+
# General
|
|
1611
|
+
"prefix": self.io.prefix,
|
|
1612
|
+
"seed": self.io.seed,
|
|
1613
|
+
"n_jobs": self.io.n_jobs,
|
|
1614
|
+
"verbose": self.io.verbose,
|
|
1615
|
+
"debug": self.io.debug,
|
|
1616
|
+
# Model hyperparameters
|
|
1617
|
+
"model_n_estimators": self.model.n_estimators,
|
|
1618
|
+
"model_max_depth": self.model.max_depth,
|
|
1619
|
+
"model_min_samples_split": self.model.min_samples_split,
|
|
1620
|
+
"model_min_samples_leaf": self.model.min_samples_leaf,
|
|
1621
|
+
"model_max_features": self.model.max_features,
|
|
1622
|
+
"model_criterion": self.model.criterion,
|
|
1623
|
+
"model_validation_split": self.train.validation_split,
|
|
1624
|
+
"model_n_nearest_features": self.imputer.n_nearest_features,
|
|
1625
|
+
"model_max_iter": self.imputer.max_iter,
|
|
1626
|
+
# Simulation
|
|
1627
|
+
"sim_prop_missing": self.sim.prop_missing,
|
|
1628
|
+
"sim_strategy": self.sim.strategy,
|
|
1629
|
+
"sim_het_boost": self.sim.het_boost,
|
|
1630
|
+
# Plotting
|
|
1631
|
+
"plot_format": self.plot.fmt,
|
|
1632
|
+
"plot_fontsize": self.plot.fontsize,
|
|
1633
|
+
"plot_despine": self.plot.despine,
|
|
1634
|
+
"plot_dpi": self.plot.dpi,
|
|
1635
|
+
"plot_show_plots": self.plot.show,
|
|
1636
|
+
}
|
|
1637
|
+
|
|
1638
|
+
|
|
1639
|
+
@dataclass
|
|
1640
|
+
class HGBConfig:
|
|
1641
|
+
"""Configuration for ImputeHistGradientBoosting.
|
|
1642
|
+
|
|
1643
|
+
Mirrors the legacy __init__ signature and provides presets/YAML/overrides.
|
|
1644
|
+
Use `to_imputer_kwargs()` now, or switch the imputer to accept `config: HGBConfig`.
|
|
1645
|
+
|
|
1646
|
+
Attributes:
|
|
1647
|
+
io (IOConfigSupervised): Run identity, logging, and seeds.
|
|
1648
|
+
model (HGBModelConfig): HistGradientBoosting hyperparameters.
|
|
1649
|
+
train (TrainConfigSupervised): Sample split for validation.
|
|
1650
|
+
imputer (ImputerConfigSupervised): IterativeImputer scaffolding (neighbors/iters).
|
|
1651
|
+
sim (SimConfigSupervised): Simulated missingness used during evaluation.
|
|
1652
|
+
plot (PlotConfigSupervised): Plot styling and export options.
|
|
1653
|
+
tune (TuningConfigSupervised): Optuna knobs (not required by HGB itself).
|
|
1654
|
+
"""
|
|
1655
|
+
|
|
1656
|
+
io: IOConfigSupervised = field(default_factory=IOConfigSupervised)
|
|
1657
|
+
model: HGBModelConfig = field(default_factory=HGBModelConfig)
|
|
1658
|
+
train: TrainConfigSupervised = field(default_factory=TrainConfigSupervised)
|
|
1659
|
+
imputer: ImputerConfigSupervised = field(default_factory=ImputerConfigSupervised)
|
|
1660
|
+
sim: SimConfigSupervised = field(default_factory=SimConfigSupervised)
|
|
1661
|
+
plot: PlotConfigSupervised = field(default_factory=PlotConfigSupervised)
|
|
1662
|
+
tune: TuningConfigSupervised = field(default_factory=TuningConfigSupervised)
|
|
1663
|
+
|
|
1664
|
+
@classmethod
|
|
1665
|
+
def from_preset(
|
|
1666
|
+
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1667
|
+
) -> "HGBConfig":
|
|
1668
|
+
"""Build a config from a named preset.
|
|
1669
|
+
|
|
1670
|
+
This class method allows for easy construction of a HGBConfig instance with sensible defaults based on the chosen preset. Presets adjust both model capacity and training/tuning behavior across speed/quality tradeoffs.
|
|
1671
|
+
|
|
1672
|
+
Args:
|
|
1673
|
+
preset (Literal["fast", "balanced", "thorough"]): One of {"fast", "balanced", "thorough"}. fast: Quick baseline; fewer trees; fewer imputer iters. balanced: Balances speed and model performance; moderate trees and imputer iters. thorough: Prioritizes model performance; more trees; more imputer iterations.
|
|
1674
|
+
|
|
1675
|
+
Returns:
|
|
1676
|
+
HGBConfig: Config with preset values applied.
|
|
1677
|
+
"""
|
|
1678
|
+
cfg = cls()
|
|
1679
|
+
if preset == "fast":
|
|
1680
|
+
cfg.model.n_estimators = 50
|
|
1681
|
+
cfg.model.learning_rate = 0.15
|
|
1682
|
+
cfg.model.max_depth = None
|
|
1683
|
+
cfg.imputer.max_iter = 5
|
|
1684
|
+
cfg.io.n_jobs = 1
|
|
1685
|
+
cfg.tune.enabled = False
|
|
1686
|
+
cfg.tune.n_trials = 50
|
|
1687
|
+
elif preset == "balanced":
|
|
1688
|
+
cfg.model.n_estimators = 100
|
|
1689
|
+
cfg.model.learning_rate = 0.1
|
|
1690
|
+
cfg.model.max_depth = None
|
|
1691
|
+
cfg.imputer.max_iter = 10
|
|
1692
|
+
cfg.io.n_jobs = 1
|
|
1693
|
+
cfg.tune.enabled = False
|
|
1694
|
+
cfg.tune.n_trials = 100
|
|
1695
|
+
elif preset == "thorough":
|
|
1696
|
+
cfg.model.n_estimators = 500
|
|
1697
|
+
cfg.model.learning_rate = 0.08
|
|
1698
|
+
cfg.model.max_depth = None
|
|
1699
|
+
cfg.imputer.max_iter = 15
|
|
1700
|
+
cfg.io.n_jobs = 1
|
|
1701
|
+
cfg.tune.enabled = False
|
|
1702
|
+
cfg.tune.n_trials = 250
|
|
1703
|
+
else:
|
|
1704
|
+
raise ValueError(f"Unknown preset: {preset}")
|
|
1705
|
+
return cfg
|
|
1706
|
+
|
|
1707
|
+
@classmethod
|
|
1708
|
+
def from_yaml(cls, path: str) -> "HGBConfig":
|
|
1709
|
+
"""Load from YAML; honors optional top-level 'preset' then merges keys.
|
|
1710
|
+
|
|
1711
|
+
This method allows for easy construction of a HGBConfig instance from a YAML file, with support for presets. If the YAML file specifies a top-level 'preset', the corresponding preset values are applied first, and then any additional keys in the YAML file override those preset values.
|
|
1712
|
+
|
|
1713
|
+
Args:
|
|
1714
|
+
path (str): Path to the YAML configuration file.
|
|
1715
|
+
|
|
1716
|
+
Returns:
|
|
1717
|
+
HGBConfig: Config instance populated from the YAML file.
|
|
1718
|
+
"""
|
|
1719
|
+
return load_yaml_to_dataclass(path, cls, preset_builder=cls.from_preset)
|
|
1720
|
+
|
|
1721
|
+
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "HGBConfig":
|
|
1722
|
+
"""Apply flat dot-key overrides (e.g., {'model.learning_rate': 0.05}).
|
|
1723
|
+
|
|
1724
|
+
This method allows for easy application of overrides to the configuration fields using a flat dot-key notation.
|
|
1725
|
+
|
|
1726
|
+
Args:
|
|
1727
|
+
overrides (Dict[str, Any] | None): Mapping of dot-key paths to values to override.
|
|
1728
|
+
|
|
1729
|
+
Returns:
|
|
1730
|
+
HGBConfig: This instance after applying overrides.
|
|
1731
|
+
"""
|
|
1732
|
+
if overrides:
|
|
1733
|
+
apply_dot_overrides(self, overrides)
|
|
1734
|
+
return self
|
|
1735
|
+
|
|
1736
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1737
|
+
"""Return as nested dict.
|
|
1738
|
+
|
|
1739
|
+
This method converts the configuration instance into a nested dictionary format, which can be useful for serialization or inspection.
|
|
1740
|
+
|
|
1741
|
+
Returns:
|
|
1742
|
+
Dict[str, Any]: The config as a nested dictionary.
|
|
1743
|
+
"""
|
|
1744
|
+
return asdict(self)
|
|
1745
|
+
|
|
1746
|
+
def to_imputer_kwargs(self) -> Dict[str, Any]:
|
|
1747
|
+
"""Map config fields to current ImputeHistGradientBoosting ``__init__`` kwargs.
|
|
1748
|
+
|
|
1749
|
+
This method maps the configuration fields to the keyword arguments expected by the ImputeHistGradientBoosting class.
|
|
1750
|
+
|
|
1751
|
+
Returns:
|
|
1752
|
+
Dict[str, Any]: kwargs compatible with ImputeHistGradientBoosting(..., \*\*kwargs).
|
|
1753
|
+
"""
|
|
1754
|
+
return {
|
|
1755
|
+
# General
|
|
1756
|
+
"prefix": self.io.prefix,
|
|
1757
|
+
"seed": self.io.seed,
|
|
1758
|
+
"n_jobs": self.io.n_jobs,
|
|
1759
|
+
"verbose": self.io.verbose,
|
|
1760
|
+
"debug": self.io.debug,
|
|
1761
|
+
# Model hyperparameters (note the mapping to sklearn's HGB)
|
|
1762
|
+
"model_n_estimators": self.model.n_estimators, # -> max_iter
|
|
1763
|
+
"model_learning_rate": self.model.learning_rate,
|
|
1764
|
+
"model_n_iter_no_change": self.model.n_iter_no_change,
|
|
1765
|
+
"model_tol": self.model.tol,
|
|
1766
|
+
"model_max_depth": self.model.max_depth,
|
|
1767
|
+
"model_min_samples_leaf": self.model.min_samples_leaf,
|
|
1768
|
+
"model_max_features": self.model.max_features,
|
|
1769
|
+
"model_validation_split": self.train.validation_split,
|
|
1770
|
+
"model_n_nearest_features": self.imputer.n_nearest_features,
|
|
1771
|
+
"model_max_iter": self.imputer.max_iter,
|
|
1772
|
+
# Simulation
|
|
1773
|
+
"sim_prop_missing": self.sim.prop_missing,
|
|
1774
|
+
"sim_strategy": self.sim.strategy,
|
|
1775
|
+
"sim_het_boost": self.sim.het_boost,
|
|
1776
|
+
# Plotting
|
|
1777
|
+
"plot_format": self.plot.fmt,
|
|
1778
|
+
"plot_fontsize": self.plot.fontsize,
|
|
1779
|
+
"plot_despine": self.plot.despine,
|
|
1780
|
+
"plot_dpi": self.plot.dpi,
|
|
1781
|
+
"plot_show_plots": self.plot.show,
|
|
1782
|
+
}
|