pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -121,20 +121,18 @@ class ModelConfig:
|
|
|
121
121
|
latent_dim (int): Dimensionality of the latent space.
|
|
122
122
|
dropout_rate (float): Dropout rate for regularization.
|
|
123
123
|
num_hidden_layers (int): Number of hidden layers in the neural network.
|
|
124
|
-
|
|
124
|
+
activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
|
|
125
125
|
layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
|
|
126
|
-
layer_schedule (Literal["pyramid", "
|
|
127
|
-
gamma (float): Parameter for the focal loss function.
|
|
126
|
+
layer_schedule (Literal["pyramid", "linear"]): Schedule for scaling hidden layer sizes.
|
|
128
127
|
"""
|
|
129
128
|
|
|
130
129
|
latent_init: Literal["random", "pca"] = "random"
|
|
131
130
|
latent_dim: int = 2
|
|
132
131
|
dropout_rate: float = 0.2
|
|
133
132
|
num_hidden_layers: int = 2
|
|
134
|
-
|
|
133
|
+
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
|
|
135
134
|
layer_scaling_factor: float = 5.0
|
|
136
|
-
layer_schedule: Literal["pyramid", "
|
|
137
|
-
gamma: float = 2.0
|
|
135
|
+
layer_schedule: Literal["pyramid", "linear"] = "pyramid"
|
|
138
136
|
|
|
139
137
|
|
|
140
138
|
@dataclass
|
|
@@ -144,28 +142,39 @@ class TrainConfig:
|
|
|
144
142
|
Attributes:
|
|
145
143
|
batch_size (int): Number of samples per training batch.
|
|
146
144
|
learning_rate (float): Learning rate for the optimizer.
|
|
147
|
-
lr_input_factor (float): Factor to scale the learning rate for input layer.
|
|
148
145
|
l1_penalty (float): L1 regularization penalty.
|
|
149
146
|
early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
|
|
150
147
|
min_epochs (int): Minimum number of epochs to train.
|
|
151
148
|
max_epochs (int): Maximum number of epochs to train.
|
|
152
149
|
validation_split (float): Proportion of data to use for validation.
|
|
153
|
-
|
|
154
|
-
|
|
150
|
+
weights_max_ratio (float | None): Maximum ratio for class weights to prevent extreme values.
|
|
151
|
+
gamma (float): Focusing parameter for focal loss.
|
|
155
152
|
device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
|
|
156
153
|
"""
|
|
157
154
|
|
|
158
|
-
batch_size: int =
|
|
155
|
+
batch_size: int = 64
|
|
159
156
|
learning_rate: float = 1e-3
|
|
160
|
-
lr_input_factor: float = 1.0
|
|
161
157
|
l1_penalty: float = 0.0
|
|
162
|
-
early_stop_gen: int =
|
|
158
|
+
early_stop_gen: int = 25
|
|
163
159
|
min_epochs: int = 100
|
|
164
|
-
max_epochs: int =
|
|
160
|
+
max_epochs: int = 2000
|
|
165
161
|
validation_split: float = 0.2
|
|
166
|
-
weights_beta: float = 0.9999
|
|
167
|
-
weights_max_ratio: float = 1.0
|
|
168
162
|
device: Literal["gpu", "cpu", "mps"] = "cpu"
|
|
163
|
+
weights_max_ratio: Optional[float] = None
|
|
164
|
+
weights_power: float = 1.0
|
|
165
|
+
weights_normalize: bool = True
|
|
166
|
+
weights_inverse: bool = False
|
|
167
|
+
gamma: float = 0.0
|
|
168
|
+
gamma_schedule: bool = False
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _default_train_config() -> TrainConfig:
|
|
172
|
+
"""Typed default factory for TrainConfig (helps some type checkers).
|
|
173
|
+
|
|
174
|
+
Using the class object directly (default_factory=TrainConfig) is valid at runtime but certain type checkers can fail to match dataclasses.field overloads.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
return TrainConfig()
|
|
169
178
|
|
|
170
179
|
|
|
171
180
|
@dataclass
|
|
@@ -174,19 +183,13 @@ class TuneConfig:
|
|
|
174
183
|
|
|
175
184
|
Attributes:
|
|
176
185
|
enabled (bool): If True, enables hyperparameter tuning.
|
|
177
|
-
metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
|
|
186
|
+
metric (Literal["f1", "accuracy", "pr_macro", "average_precision", "roc_auc", "precision", "recall", "mcc", "jaccard"]): Metric to optimize during tuning.
|
|
178
187
|
n_trials (int): Number of hyperparameter trials to run.
|
|
179
188
|
resume (bool): If True, resumes tuning from a previous state.
|
|
180
189
|
save_db (bool): If True, saves the tuning results to a database.
|
|
181
|
-
fast (bool): If True, uses a faster but less thorough tuning approach.
|
|
182
|
-
max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
|
|
183
|
-
max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
|
|
184
190
|
epochs (int): Number of epochs to train each trial.
|
|
185
191
|
batch_size (int): Batch size for training during tuning.
|
|
186
|
-
eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
|
|
187
|
-
infer_epochs (int): Number of epochs for inference during tuning.
|
|
188
192
|
patience (int): Number of evaluations with no improvement before stopping early.
|
|
189
|
-
proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
|
|
190
193
|
"""
|
|
191
194
|
|
|
192
195
|
enabled: bool = False
|
|
@@ -198,34 +201,15 @@ class TuneConfig:
|
|
|
198
201
|
"roc_auc",
|
|
199
202
|
"precision",
|
|
200
203
|
"recall",
|
|
204
|
+
"mcc",
|
|
205
|
+
"jaccard",
|
|
201
206
|
] = "f1"
|
|
202
207
|
n_trials: int = 100
|
|
203
208
|
resume: bool = False
|
|
204
209
|
save_db: bool = False
|
|
205
|
-
fast: bool = True
|
|
206
|
-
max_samples: int = 512
|
|
207
|
-
max_loci: int = 0 # 0 = all
|
|
208
210
|
epochs: int = 500
|
|
209
211
|
batch_size: int = 64
|
|
210
|
-
eval_interval: int = 20
|
|
211
|
-
infer_epochs: int = 100
|
|
212
212
|
patience: int = 10
|
|
213
|
-
proxy_metric_batch: int = 0
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
@dataclass
|
|
217
|
-
class EvalConfig:
|
|
218
|
-
"""Evaluation configuration.
|
|
219
|
-
|
|
220
|
-
Attributes:
|
|
221
|
-
eval_latent_steps (int): Number of optimization steps for latent space evaluation.
|
|
222
|
-
eval_latent_lr (float): Learning rate for latent space optimization.
|
|
223
|
-
eval_latent_weight_decay (float): Weight decay for latent space optimization.
|
|
224
|
-
"""
|
|
225
|
-
|
|
226
|
-
eval_latent_steps: int = 50
|
|
227
|
-
eval_latent_lr: float = 1e-2
|
|
228
|
-
eval_latent_weight_decay: float = 0.0
|
|
229
213
|
|
|
230
214
|
|
|
231
215
|
@dataclass
|
|
@@ -244,15 +228,18 @@ class PlotConfig:
|
|
|
244
228
|
dpi: int = 300
|
|
245
229
|
fontsize: int = 18
|
|
246
230
|
despine: bool = True
|
|
247
|
-
show: bool =
|
|
231
|
+
show: bool = True
|
|
248
232
|
|
|
249
233
|
|
|
250
234
|
@dataclass
|
|
251
235
|
class IOConfig:
|
|
252
236
|
"""I/O configuration.
|
|
253
237
|
|
|
238
|
+
Dataclass that includes configuration settings for file naming, logging verbosity, random seed, and parallelism.
|
|
239
|
+
|
|
254
240
|
Attributes:
|
|
255
241
|
prefix (str): Prefix for output files. Default is "pgsui".
|
|
242
|
+
ploidy (int): Ploidy level of the organism. Default is 2.
|
|
256
243
|
verbose (bool): If True, enables verbose logging. Default is False.
|
|
257
244
|
debug (bool): If True, enables debug mode. Default is False.
|
|
258
245
|
seed (int | None): Random seed for reproducibility. Default is None.
|
|
@@ -261,6 +248,7 @@ class IOConfig:
|
|
|
261
248
|
"""
|
|
262
249
|
|
|
263
250
|
prefix: str = "pgsui"
|
|
251
|
+
ploidy: int = 2
|
|
264
252
|
verbose: bool = False
|
|
265
253
|
debug: bool = False
|
|
266
254
|
seed: int | None = None
|
|
@@ -287,37 +275,46 @@ class SimConfig:
|
|
|
287
275
|
"nonrandom",
|
|
288
276
|
"nonrandom_weighted",
|
|
289
277
|
] = "random"
|
|
290
|
-
sim_prop: float = 0.
|
|
278
|
+
sim_prop: float = 0.20
|
|
291
279
|
sim_kwargs: dict | None = None
|
|
292
280
|
|
|
293
281
|
|
|
294
282
|
@dataclass
|
|
295
|
-
class
|
|
296
|
-
"""Top-level configuration for
|
|
283
|
+
class AutoencoderConfig:
|
|
284
|
+
"""Top-level configuration for ImputeAutoencoder.
|
|
285
|
+
|
|
286
|
+
This configuration class encapsulates all settings required for the
|
|
287
|
+
ImputeAutoencoder model, including I/O, model architecture, training,
|
|
288
|
+
hyperparameter tuning, plotting, and simulated-missing configuration.
|
|
297
289
|
|
|
298
290
|
Attributes:
|
|
299
291
|
io (IOConfig): I/O configuration.
|
|
300
292
|
model (ModelConfig): Model architecture configuration.
|
|
301
293
|
train (TrainConfig): Training procedure configuration.
|
|
302
294
|
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
303
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
304
295
|
plot (PlotConfig): Plotting configuration.
|
|
305
|
-
sim (SimConfig):
|
|
296
|
+
sim (SimConfig): Simulated-missing configuration.
|
|
306
297
|
"""
|
|
307
298
|
|
|
308
299
|
io: IOConfig = field(default_factory=IOConfig)
|
|
309
300
|
model: ModelConfig = field(default_factory=ModelConfig)
|
|
310
|
-
train: TrainConfig = field(default_factory=
|
|
301
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
311
302
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
312
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
313
303
|
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
314
304
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
315
305
|
|
|
316
306
|
@classmethod
|
|
317
307
|
def from_preset(
|
|
318
308
|
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
319
|
-
) -> "
|
|
320
|
-
"""Build a
|
|
309
|
+
) -> "AutoencoderConfig":
|
|
310
|
+
"""Build a AutoencoderConfig from a named preset.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
AutoencoderConfig: Configuration instance corresponding to the preset.
|
|
317
|
+
"""
|
|
321
318
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
322
319
|
raise ValueError(f"Unknown preset: {preset}")
|
|
323
320
|
|
|
@@ -325,414 +322,87 @@ class NLPCAConfig:
|
|
|
325
322
|
|
|
326
323
|
# Common baselines
|
|
327
324
|
cfg.io.verbose = False
|
|
328
|
-
cfg.
|
|
329
|
-
cfg.
|
|
325
|
+
cfg.io.ploidy = 2
|
|
326
|
+
cfg.train.validation_split = 0.2
|
|
327
|
+
cfg.model.activation = "relu"
|
|
330
328
|
cfg.model.layer_schedule = "pyramid"
|
|
331
|
-
cfg.model.
|
|
332
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
333
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
334
|
-
cfg.sim.simulate_missing = True
|
|
329
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
335
330
|
cfg.sim.sim_strategy = "random"
|
|
336
331
|
cfg.sim.sim_prop = 0.2
|
|
332
|
+
cfg.plot.show = True
|
|
333
|
+
|
|
334
|
+
# Train settings
|
|
335
|
+
cfg.train.weights_max_ratio = None
|
|
336
|
+
cfg.train.weights_power = 1.0
|
|
337
|
+
cfg.train.weights_normalize = True
|
|
338
|
+
cfg.train.weights_inverse = False
|
|
339
|
+
cfg.train.gamma = 0.0
|
|
340
|
+
cfg.train.gamma_schedule = False
|
|
341
|
+
cfg.train.min_epochs = 100
|
|
342
|
+
|
|
343
|
+
# Tune
|
|
344
|
+
cfg.tune.enabled = False
|
|
345
|
+
cfg.tune.n_trials = 100
|
|
337
346
|
|
|
338
347
|
if preset == "fast":
|
|
339
348
|
# Model
|
|
340
349
|
cfg.model.latent_dim = 4
|
|
341
350
|
cfg.model.num_hidden_layers = 1
|
|
342
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
343
351
|
cfg.model.dropout_rate = 0.10
|
|
344
|
-
cfg.model.gamma = 1.5
|
|
345
|
-
# Train
|
|
346
|
-
cfg.train.batch_size = 256
|
|
347
|
-
cfg.train.learning_rate = 2e-3
|
|
348
|
-
cfg.train.early_stop_gen = 5
|
|
349
|
-
cfg.train.min_epochs = 10
|
|
350
|
-
cfg.train.max_epochs = 150
|
|
351
|
-
cfg.train.weights_beta = 0.999
|
|
352
|
-
cfg.train.weights_max_ratio = 5.0
|
|
353
|
-
# Tuning
|
|
354
|
-
cfg.tune.enabled = True
|
|
355
|
-
cfg.tune.fast = True
|
|
356
|
-
cfg.tune.n_trials = 20
|
|
357
|
-
cfg.tune.epochs = 150
|
|
358
|
-
cfg.tune.batch_size = 256
|
|
359
|
-
cfg.tune.max_samples = 512
|
|
360
|
-
cfg.tune.max_loci = 0
|
|
361
|
-
cfg.tune.eval_interval = 20
|
|
362
|
-
cfg.tune.infer_epochs = 20
|
|
363
|
-
cfg.tune.patience = 5
|
|
364
|
-
cfg.tune.proxy_metric_batch = 0
|
|
365
|
-
# Eval
|
|
366
|
-
cfg.evaluate.eval_latent_steps = 20
|
|
367
352
|
|
|
368
|
-
elif preset == "balanced":
|
|
369
|
-
# Model
|
|
370
|
-
cfg.model.latent_dim = 8
|
|
371
|
-
cfg.model.num_hidden_layers = 2
|
|
372
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
373
|
-
cfg.model.dropout_rate = 0.20
|
|
374
|
-
cfg.model.gamma = 2.0
|
|
375
353
|
# Train
|
|
376
354
|
cfg.train.batch_size = 128
|
|
377
|
-
cfg.train.learning_rate =
|
|
355
|
+
cfg.train.learning_rate = 2e-3
|
|
378
356
|
cfg.train.early_stop_gen = 15
|
|
379
|
-
cfg.train.
|
|
380
|
-
cfg.train.
|
|
381
|
-
cfg.train.weights_beta = 0.9999
|
|
382
|
-
cfg.train.weights_max_ratio = 5.0
|
|
383
|
-
# Tuning
|
|
384
|
-
cfg.tune.enabled = True
|
|
385
|
-
cfg.tune.fast = False
|
|
386
|
-
cfg.tune.n_trials = 60
|
|
387
|
-
cfg.tune.epochs = 200
|
|
388
|
-
cfg.tune.batch_size = 128
|
|
389
|
-
cfg.tune.max_samples = 2048
|
|
390
|
-
cfg.tune.max_loci = 0
|
|
391
|
-
cfg.tune.eval_interval = 10
|
|
392
|
-
cfg.tune.infer_epochs = 50
|
|
393
|
-
cfg.tune.patience = 10
|
|
394
|
-
cfg.tune.proxy_metric_batch = 0
|
|
395
|
-
# Eval
|
|
396
|
-
cfg.evaluate.eval_latent_steps = 40
|
|
397
|
-
|
|
398
|
-
else: # thorough
|
|
399
|
-
# Model
|
|
400
|
-
cfg.model.latent_dim = 16
|
|
401
|
-
cfg.model.num_hidden_layers = 3
|
|
402
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
403
|
-
cfg.model.dropout_rate = 0.30
|
|
404
|
-
cfg.model.gamma = 2.5
|
|
405
|
-
# Train
|
|
406
|
-
cfg.train.batch_size = 64
|
|
407
|
-
cfg.train.learning_rate = 5e-4
|
|
408
|
-
cfg.train.early_stop_gen = 30
|
|
409
|
-
cfg.train.min_epochs = 100
|
|
410
|
-
cfg.train.max_epochs = 2000
|
|
411
|
-
cfg.train.weights_beta = 0.9999
|
|
412
|
-
cfg.train.weights_max_ratio = 5.0
|
|
413
|
-
# Tuning
|
|
414
|
-
cfg.tune.enabled = True
|
|
415
|
-
cfg.tune.fast = False # Full search
|
|
416
|
-
cfg.tune.n_trials = 100
|
|
417
|
-
cfg.tune.epochs = 600
|
|
418
|
-
cfg.tune.batch_size = 64
|
|
419
|
-
cfg.tune.max_samples = 0 # No limit
|
|
420
|
-
cfg.tune.max_loci = 0
|
|
421
|
-
cfg.tune.eval_interval = 10
|
|
422
|
-
cfg.tune.infer_epochs = 80
|
|
423
|
-
cfg.tune.patience = 20
|
|
424
|
-
cfg.tune.proxy_metric_batch = 0
|
|
425
|
-
# Eval
|
|
426
|
-
cfg.evaluate.eval_latent_steps = 100
|
|
427
|
-
|
|
428
|
-
return cfg
|
|
429
|
-
|
|
430
|
-
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
|
|
431
|
-
"""Apply flat dot-key overrides."""
|
|
432
|
-
if not overrides:
|
|
433
|
-
return self
|
|
434
|
-
for k, v in overrides.items():
|
|
435
|
-
node = self
|
|
436
|
-
parts = k.split(".")
|
|
437
|
-
for p in parts[:-1]:
|
|
438
|
-
node = getattr(node, p)
|
|
439
|
-
last = parts[-1]
|
|
440
|
-
if hasattr(node, last):
|
|
441
|
-
setattr(node, last, v)
|
|
442
|
-
else:
|
|
443
|
-
raise KeyError(f"Unknown config key: {k}")
|
|
444
|
-
return self
|
|
445
|
-
|
|
446
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
447
|
-
return asdict(self)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
@dataclass
|
|
451
|
-
class UBPConfig:
|
|
452
|
-
"""Top-level configuration for ImputeUBP.
|
|
453
|
-
|
|
454
|
-
Attributes:
|
|
455
|
-
io (IOConfig): I/O configuration.
|
|
456
|
-
model (ModelConfig): Model architecture configuration.
|
|
457
|
-
train (TrainConfig): Training procedure configuration.
|
|
458
|
-
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
459
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
460
|
-
plot (PlotConfig): Plotting configuration.
|
|
461
|
-
sim (SimConfig): Simulated-missing configuration.
|
|
462
|
-
"""
|
|
357
|
+
cfg.train.max_epochs = 200
|
|
358
|
+
cfg.train.weights_max_ratio = None
|
|
463
359
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
train: TrainConfig = field(default_factory=TrainConfig)
|
|
467
|
-
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
468
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
469
|
-
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
470
|
-
sim: SimConfig = field(default_factory=SimConfig)
|
|
471
|
-
|
|
472
|
-
@classmethod
|
|
473
|
-
def from_preset(
|
|
474
|
-
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
475
|
-
) -> "UBPConfig":
|
|
476
|
-
"""Build a UBPConfig from a named preset."""
|
|
477
|
-
if preset not in {"fast", "balanced", "thorough"}:
|
|
478
|
-
raise ValueError(f"Unknown preset: {preset}")
|
|
479
|
-
|
|
480
|
-
cfg = cls()
|
|
481
|
-
|
|
482
|
-
# Common baselines
|
|
483
|
-
cfg.io.verbose = False
|
|
484
|
-
cfg.model.hidden_activation = "relu"
|
|
485
|
-
cfg.model.layer_schedule = "pyramid"
|
|
486
|
-
cfg.model.latent_init = "random"
|
|
487
|
-
cfg.sim.simulate_missing = True
|
|
488
|
-
cfg.sim.sim_strategy = "random"
|
|
489
|
-
cfg.sim.sim_prop = 0.2
|
|
490
|
-
|
|
491
|
-
if preset == "fast":
|
|
492
|
-
# Model
|
|
493
|
-
cfg.model.latent_dim = 4
|
|
494
|
-
cfg.model.num_hidden_layers = 1
|
|
495
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
496
|
-
cfg.model.dropout_rate = 0.10
|
|
497
|
-
cfg.model.gamma = 1.5
|
|
498
|
-
# Train
|
|
499
|
-
cfg.train.batch_size = 256
|
|
500
|
-
cfg.train.learning_rate = 2e-3
|
|
501
|
-
cfg.train.early_stop_gen = 5
|
|
502
|
-
cfg.train.min_epochs = 10
|
|
503
|
-
cfg.train.max_epochs = 150
|
|
504
|
-
cfg.train.weights_beta = 0.999
|
|
505
|
-
cfg.train.weights_max_ratio = 5.0
|
|
506
|
-
# Tuning
|
|
507
|
-
cfg.tune.enabled = True
|
|
508
|
-
cfg.tune.fast = True
|
|
509
|
-
cfg.tune.n_trials = 20
|
|
510
|
-
cfg.tune.epochs = 150
|
|
511
|
-
cfg.tune.batch_size = 256
|
|
512
|
-
cfg.tune.max_samples = 512
|
|
513
|
-
cfg.tune.max_loci = 0
|
|
514
|
-
cfg.tune.eval_interval = 20
|
|
515
|
-
cfg.tune.infer_epochs = 20
|
|
516
|
-
cfg.tune.patience = 5
|
|
517
|
-
cfg.tune.proxy_metric_batch = 0
|
|
518
|
-
# Eval
|
|
519
|
-
cfg.evaluate.eval_latent_steps = 20
|
|
520
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
521
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
360
|
+
# Tune
|
|
361
|
+
cfg.tune.patience = 15
|
|
522
362
|
|
|
523
363
|
elif preset == "balanced":
|
|
524
364
|
# Model
|
|
525
365
|
cfg.model.latent_dim = 8
|
|
526
366
|
cfg.model.num_hidden_layers = 2
|
|
527
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
528
367
|
cfg.model.dropout_rate = 0.20
|
|
529
|
-
|
|
368
|
+
|
|
530
369
|
# Train
|
|
531
|
-
cfg.train.batch_size =
|
|
370
|
+
cfg.train.batch_size = 64
|
|
532
371
|
cfg.train.learning_rate = 1e-3
|
|
533
|
-
cfg.train.early_stop_gen =
|
|
534
|
-
cfg.train.
|
|
535
|
-
cfg.train.
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
cfg.tune.enabled = True
|
|
540
|
-
cfg.tune.fast = False
|
|
541
|
-
cfg.tune.n_trials = 60
|
|
542
|
-
cfg.tune.epochs = 200
|
|
543
|
-
cfg.tune.batch_size = 128
|
|
544
|
-
cfg.tune.max_samples = 2048
|
|
545
|
-
cfg.tune.max_loci = 0
|
|
546
|
-
cfg.tune.eval_interval = 10
|
|
547
|
-
cfg.tune.infer_epochs = 50
|
|
548
|
-
cfg.tune.patience = 10
|
|
549
|
-
cfg.tune.proxy_metric_batch = 0
|
|
550
|
-
# Eval
|
|
551
|
-
cfg.evaluate.eval_latent_steps = 40
|
|
552
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
553
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
372
|
+
cfg.train.early_stop_gen = 25
|
|
373
|
+
cfg.train.max_epochs = 500
|
|
374
|
+
cfg.train.weights_max_ratio = None
|
|
375
|
+
|
|
376
|
+
# Tune
|
|
377
|
+
cfg.tune.patience = 25
|
|
554
378
|
|
|
555
379
|
else: # thorough
|
|
556
380
|
# Model
|
|
557
381
|
cfg.model.latent_dim = 16
|
|
558
382
|
cfg.model.num_hidden_layers = 3
|
|
559
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
560
383
|
cfg.model.dropout_rate = 0.30
|
|
561
|
-
|
|
384
|
+
|
|
562
385
|
# Train
|
|
563
386
|
cfg.train.batch_size = 64
|
|
564
387
|
cfg.train.learning_rate = 5e-4
|
|
565
|
-
cfg.train.early_stop_gen =
|
|
566
|
-
cfg.train.
|
|
567
|
-
cfg.train.
|
|
568
|
-
cfg.train.weights_beta = 0.9999
|
|
569
|
-
cfg.train.weights_max_ratio = 5.0
|
|
570
|
-
# Tuning
|
|
571
|
-
cfg.tune.enabled = True
|
|
572
|
-
cfg.tune.fast = False
|
|
573
|
-
cfg.tune.n_trials = 100
|
|
574
|
-
cfg.tune.epochs = 600
|
|
575
|
-
cfg.tune.batch_size = 64
|
|
576
|
-
cfg.tune.max_samples = 0
|
|
577
|
-
cfg.tune.max_loci = 0
|
|
578
|
-
cfg.tune.eval_interval = 10
|
|
579
|
-
cfg.tune.infer_epochs = 80
|
|
580
|
-
cfg.tune.patience = 20
|
|
581
|
-
cfg.tune.proxy_metric_batch = 0
|
|
582
|
-
# Eval
|
|
583
|
-
cfg.evaluate.eval_latent_steps = 100
|
|
584
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
585
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
586
|
-
|
|
587
|
-
return cfg
|
|
588
|
-
|
|
589
|
-
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
|
|
590
|
-
"""Apply flat dot-key overrides."""
|
|
591
|
-
if not overrides:
|
|
592
|
-
return self
|
|
593
|
-
|
|
594
|
-
for k, v in overrides.items():
|
|
595
|
-
node = self
|
|
596
|
-
parts = k.split(".")
|
|
597
|
-
for p in parts[:-1]:
|
|
598
|
-
node = getattr(node, p)
|
|
599
|
-
last = parts[-1]
|
|
600
|
-
if hasattr(node, last):
|
|
601
|
-
setattr(node, last, v)
|
|
602
|
-
else:
|
|
603
|
-
raise KeyError(f"Unknown config key: {k}")
|
|
604
|
-
return self
|
|
605
|
-
|
|
606
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
607
|
-
return asdict(self)
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
@dataclass
|
|
611
|
-
class AutoencoderConfig:
|
|
612
|
-
"""Top-level configuration for ImputeAutoencoder.
|
|
613
|
-
|
|
614
|
-
Attributes:
|
|
615
|
-
io (IOConfig): I/O configuration.
|
|
616
|
-
model (ModelConfig): Model architecture configuration.
|
|
617
|
-
train (TrainConfig): Training procedure configuration.
|
|
618
|
-
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
619
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
620
|
-
plot (PlotConfig): Plotting configuration.
|
|
621
|
-
sim (SimConfig): Simulated-missing configuration.
|
|
622
|
-
"""
|
|
623
|
-
|
|
624
|
-
io: IOConfig = field(default_factory=IOConfig)
|
|
625
|
-
model: ModelConfig = field(default_factory=ModelConfig)
|
|
626
|
-
train: TrainConfig = field(default_factory=TrainConfig)
|
|
627
|
-
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
628
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
629
|
-
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
630
|
-
sim: SimConfig = field(default_factory=SimConfig)
|
|
631
|
-
|
|
632
|
-
@classmethod
|
|
633
|
-
def from_preset(
|
|
634
|
-
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
635
|
-
) -> "AutoencoderConfig":
|
|
636
|
-
"""Build a AutoencoderConfig from a named preset."""
|
|
637
|
-
if preset not in {"fast", "balanced", "thorough"}:
|
|
638
|
-
raise ValueError(f"Unknown preset: {preset}")
|
|
639
|
-
|
|
640
|
-
cfg = cls()
|
|
641
|
-
|
|
642
|
-
# Common baselines (no latent refinement at eval)
|
|
643
|
-
cfg.io.verbose = False
|
|
644
|
-
cfg.train.validation_split = 0.20
|
|
645
|
-
cfg.model.hidden_activation = "relu"
|
|
646
|
-
cfg.model.layer_schedule = "pyramid"
|
|
647
|
-
cfg.evaluate.eval_latent_steps = 0
|
|
648
|
-
cfg.evaluate.eval_latent_lr = 0.0
|
|
649
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
650
|
-
cfg.sim.simulate_missing = True
|
|
651
|
-
cfg.sim.sim_strategy = "random"
|
|
652
|
-
cfg.sim.sim_prop = 0.2
|
|
653
|
-
|
|
654
|
-
if preset == "fast":
|
|
655
|
-
cfg.model.latent_dim = 4
|
|
656
|
-
cfg.model.num_hidden_layers = 1
|
|
657
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
658
|
-
cfg.model.dropout_rate = 0.10
|
|
659
|
-
cfg.model.gamma = 1.5
|
|
660
|
-
cfg.train.batch_size = 256
|
|
661
|
-
cfg.train.learning_rate = 2e-3
|
|
662
|
-
cfg.train.early_stop_gen = 5
|
|
663
|
-
cfg.train.min_epochs = 10
|
|
664
|
-
cfg.train.max_epochs = 150
|
|
665
|
-
cfg.train.weights_beta = 0.999
|
|
666
|
-
cfg.train.weights_max_ratio = 5.0
|
|
667
|
-
cfg.tune.enabled = True
|
|
668
|
-
cfg.tune.fast = True
|
|
669
|
-
cfg.tune.n_trials = 20
|
|
670
|
-
cfg.tune.epochs = 150
|
|
671
|
-
cfg.tune.batch_size = 256
|
|
672
|
-
cfg.tune.max_samples = 512
|
|
673
|
-
cfg.tune.max_loci = 0
|
|
674
|
-
cfg.tune.eval_interval = 20
|
|
675
|
-
cfg.tune.patience = 5
|
|
676
|
-
cfg.tune.proxy_metric_batch = 0
|
|
677
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
678
|
-
cfg.tune.infer_epochs = 0
|
|
388
|
+
cfg.train.early_stop_gen = 50
|
|
389
|
+
cfg.train.max_epochs = 1000
|
|
390
|
+
cfg.train.weights_max_ratio = None
|
|
679
391
|
|
|
680
|
-
|
|
681
|
-
cfg.
|
|
682
|
-
cfg.model.num_hidden_layers = 2
|
|
683
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
684
|
-
cfg.model.dropout_rate = 0.20
|
|
685
|
-
cfg.model.gamma = 2.0
|
|
686
|
-
cfg.train.batch_size = 128
|
|
687
|
-
cfg.train.learning_rate = 1e-3
|
|
688
|
-
cfg.train.early_stop_gen = 15
|
|
689
|
-
cfg.train.min_epochs = 50
|
|
690
|
-
cfg.train.max_epochs = 600
|
|
691
|
-
cfg.train.weights_beta = 0.9999
|
|
692
|
-
cfg.train.weights_max_ratio = 5.0
|
|
693
|
-
cfg.tune.enabled = True
|
|
694
|
-
cfg.tune.fast = False
|
|
695
|
-
cfg.tune.n_trials = 60
|
|
696
|
-
cfg.tune.epochs = 200
|
|
697
|
-
cfg.tune.batch_size = 128
|
|
698
|
-
cfg.tune.max_samples = 2048
|
|
699
|
-
cfg.tune.max_loci = 0
|
|
700
|
-
cfg.tune.eval_interval = 10
|
|
701
|
-
cfg.tune.patience = 10
|
|
702
|
-
cfg.tune.proxy_metric_batch = 0
|
|
703
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
704
|
-
cfg.tune.infer_epochs = 0
|
|
705
|
-
|
|
706
|
-
else: # thorough
|
|
707
|
-
cfg.model.latent_dim = 16
|
|
708
|
-
cfg.model.num_hidden_layers = 3
|
|
709
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
710
|
-
cfg.model.dropout_rate = 0.30
|
|
711
|
-
cfg.model.gamma = 2.5
|
|
712
|
-
cfg.train.batch_size = 64
|
|
713
|
-
cfg.train.learning_rate = 5e-4
|
|
714
|
-
cfg.train.early_stop_gen = 30
|
|
715
|
-
cfg.train.min_epochs = 100
|
|
716
|
-
cfg.train.max_epochs = 2000
|
|
717
|
-
cfg.train.weights_beta = 0.9999
|
|
718
|
-
cfg.train.weights_max_ratio = 5.0
|
|
719
|
-
cfg.tune.enabled = True
|
|
720
|
-
cfg.tune.fast = False
|
|
721
|
-
cfg.tune.n_trials = 100
|
|
722
|
-
cfg.tune.epochs = 600
|
|
723
|
-
cfg.tune.batch_size = 64
|
|
724
|
-
cfg.tune.max_samples = 0
|
|
725
|
-
cfg.tune.max_loci = 0
|
|
726
|
-
cfg.tune.eval_interval = 10
|
|
727
|
-
cfg.tune.patience = 20
|
|
728
|
-
cfg.tune.proxy_metric_batch = 0
|
|
729
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
730
|
-
cfg.tune.infer_epochs = 0
|
|
392
|
+
# Tune
|
|
393
|
+
cfg.tune.patience = 50
|
|
731
394
|
|
|
732
395
|
return cfg
|
|
733
396
|
|
|
734
397
|
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
|
|
735
|
-
"""Apply flat dot-key overrides.
|
|
398
|
+
"""Apply flat dot-key overrides.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
overrides (Dict[str, Any] | None): Dictionary of overrides with dot-separated keys.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
AutoencoderConfig: New configuration instance with overrides applied.
|
|
405
|
+
"""
|
|
736
406
|
if not overrides:
|
|
737
407
|
return self
|
|
738
408
|
for k, v in overrides.items():
|
|
@@ -753,29 +423,22 @@ class AutoencoderConfig:
|
|
|
753
423
|
|
|
754
424
|
@dataclass
|
|
755
425
|
class VAEExtraConfig:
|
|
756
|
-
"""VAE-specific knobs.
|
|
757
|
-
|
|
758
|
-
Attributes:
|
|
759
|
-
kl_beta (float): Final β for KL divergence term.
|
|
760
|
-
kl_warmup (int): Number of epochs with β=0 (warm-up period).
|
|
761
|
-
kl_ramp (int): Number of epochs for linear ramp to final β.
|
|
762
|
-
"""
|
|
763
|
-
|
|
764
426
|
kl_beta: float = 1.0
|
|
765
|
-
|
|
766
|
-
kl_ramp: int = 200
|
|
427
|
+
kl_beta_schedule: bool = False
|
|
767
428
|
|
|
768
429
|
|
|
769
430
|
@dataclass
|
|
770
431
|
class VAEConfig:
|
|
771
432
|
"""Top-level configuration for ImputeVAE (AE-parity + VAE extras).
|
|
772
433
|
|
|
434
|
+
Mirrors AutoencoderConfig sections and adds a ``vae`` block with KL-beta
|
|
435
|
+
controls for the VAE loss.
|
|
436
|
+
|
|
773
437
|
Attributes:
|
|
774
438
|
io (IOConfig): I/O configuration.
|
|
775
439
|
model (ModelConfig): Model architecture configuration.
|
|
776
440
|
train (TrainConfig): Training procedure configuration.
|
|
777
441
|
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
778
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
779
442
|
plot (PlotConfig): Plotting configuration.
|
|
780
443
|
vae (VAEExtraConfig): VAE-specific configuration.
|
|
781
444
|
sim (SimConfig): Simulated-missing configuration.
|
|
@@ -783,9 +446,8 @@ class VAEConfig:
|
|
|
783
446
|
|
|
784
447
|
io: IOConfig = field(default_factory=IOConfig)
|
|
785
448
|
model: ModelConfig = field(default_factory=ModelConfig)
|
|
786
|
-
train: TrainConfig = field(default_factory=
|
|
449
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
787
450
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
788
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
789
451
|
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
790
452
|
vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
|
|
791
453
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
@@ -794,119 +456,92 @@ class VAEConfig:
|
|
|
794
456
|
def from_preset(
|
|
795
457
|
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
796
458
|
) -> "VAEConfig":
|
|
797
|
-
"""Build a VAEConfig from a named preset.
|
|
459
|
+
"""Build a VAEConfig from a named preset.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
VAEConfig: Configuration instance corresponding to the preset.
|
|
466
|
+
"""
|
|
798
467
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
799
468
|
raise ValueError(f"Unknown preset: {preset}")
|
|
800
469
|
|
|
801
470
|
cfg = cls()
|
|
802
471
|
|
|
803
|
-
#
|
|
472
|
+
# General settings
|
|
804
473
|
cfg.io.verbose = False
|
|
805
|
-
cfg.
|
|
806
|
-
cfg.
|
|
474
|
+
cfg.io.ploidy = 2
|
|
475
|
+
cfg.train.validation_split = 0.2
|
|
476
|
+
cfg.model.activation = "relu"
|
|
807
477
|
cfg.model.layer_schedule = "pyramid"
|
|
808
|
-
cfg.
|
|
809
|
-
cfg.evaluate.eval_latent_lr = 0.0
|
|
810
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
478
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
811
479
|
cfg.sim.simulate_missing = True
|
|
812
480
|
cfg.sim.sim_strategy = "random"
|
|
813
481
|
cfg.sim.sim_prop = 0.2
|
|
482
|
+
cfg.plot.show = True
|
|
483
|
+
|
|
484
|
+
# Train settings
|
|
485
|
+
cfg.train.weights_max_ratio = None
|
|
486
|
+
cfg.train.weights_power = 1.0
|
|
487
|
+
cfg.train.weights_normalize = True
|
|
488
|
+
cfg.train.weights_inverse = False
|
|
489
|
+
cfg.train.gamma = 0.0
|
|
490
|
+
cfg.train.gamma_schedule = False
|
|
491
|
+
cfg.train.min_epochs = 100
|
|
492
|
+
|
|
493
|
+
# VAE-specific
|
|
494
|
+
cfg.vae.kl_beta = 1.0
|
|
495
|
+
cfg.vae.kl_beta_schedule = False
|
|
496
|
+
|
|
497
|
+
# Tune
|
|
498
|
+
cfg.tune.enabled = False
|
|
499
|
+
cfg.tune.n_trials = 100
|
|
814
500
|
|
|
815
501
|
if preset == "fast":
|
|
502
|
+
# Model
|
|
816
503
|
cfg.model.latent_dim = 4
|
|
817
|
-
cfg.model.num_hidden_layers =
|
|
818
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
504
|
+
cfg.model.num_hidden_layers = 2
|
|
819
505
|
cfg.model.dropout_rate = 0.10
|
|
820
|
-
|
|
821
|
-
# VAE specifics
|
|
822
|
-
cfg.vae.kl_beta = 0.5
|
|
823
|
-
cfg.vae.kl_warmup = 10
|
|
824
|
-
cfg.vae.kl_ramp = 40
|
|
506
|
+
|
|
825
507
|
# Train
|
|
826
|
-
cfg.train.batch_size =
|
|
508
|
+
cfg.train.batch_size = 128
|
|
827
509
|
cfg.train.learning_rate = 2e-3
|
|
828
|
-
cfg.train.early_stop_gen =
|
|
829
|
-
cfg.train.
|
|
830
|
-
|
|
831
|
-
cfg.train.weights_beta = 0.999
|
|
832
|
-
cfg.train.weights_max_ratio = 5.0
|
|
510
|
+
cfg.train.early_stop_gen = 15
|
|
511
|
+
cfg.train.max_epochs = 200
|
|
512
|
+
|
|
833
513
|
# Tune
|
|
834
|
-
cfg.tune.
|
|
835
|
-
cfg.tune.fast = True
|
|
836
|
-
cfg.tune.n_trials = 20
|
|
837
|
-
cfg.tune.epochs = 150
|
|
838
|
-
cfg.tune.batch_size = 256
|
|
839
|
-
cfg.tune.max_samples = 512
|
|
840
|
-
cfg.tune.max_loci = 0
|
|
841
|
-
cfg.tune.eval_interval = 20
|
|
842
|
-
cfg.tune.patience = 5
|
|
843
|
-
cfg.tune.proxy_metric_batch = 0
|
|
844
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
845
|
-
cfg.tune.infer_epochs = 0
|
|
514
|
+
cfg.tune.patience = 15
|
|
846
515
|
|
|
847
516
|
elif preset == "balanced":
|
|
517
|
+
# Model
|
|
848
518
|
cfg.model.latent_dim = 8
|
|
849
|
-
cfg.model.num_hidden_layers =
|
|
850
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
519
|
+
cfg.model.num_hidden_layers = 4
|
|
851
520
|
cfg.model.dropout_rate = 0.20
|
|
852
|
-
|
|
853
|
-
# VAE specifics
|
|
854
|
-
cfg.vae.kl_beta = 1.0
|
|
855
|
-
cfg.vae.kl_warmup = 50
|
|
856
|
-
cfg.vae.kl_ramp = 150
|
|
521
|
+
|
|
857
522
|
# Train
|
|
858
|
-
cfg.train.batch_size =
|
|
523
|
+
cfg.train.batch_size = 64
|
|
859
524
|
cfg.train.learning_rate = 1e-3
|
|
860
|
-
cfg.train.early_stop_gen =
|
|
861
|
-
cfg.train.
|
|
862
|
-
|
|
863
|
-
cfg.train.weights_beta = 0.9999
|
|
864
|
-
cfg.train.weights_max_ratio = 5.0
|
|
525
|
+
cfg.train.early_stop_gen = 25
|
|
526
|
+
cfg.train.max_epochs = 500
|
|
527
|
+
|
|
865
528
|
# Tune
|
|
866
|
-
cfg.tune.
|
|
867
|
-
cfg.tune.fast = False
|
|
868
|
-
cfg.tune.n_trials = 60
|
|
869
|
-
cfg.tune.epochs = 200
|
|
870
|
-
cfg.tune.batch_size = 128
|
|
871
|
-
cfg.tune.max_samples = 2048
|
|
872
|
-
cfg.tune.max_loci = 0
|
|
873
|
-
cfg.tune.eval_interval = 10
|
|
874
|
-
cfg.tune.patience = 10
|
|
875
|
-
cfg.tune.proxy_metric_batch = 0
|
|
876
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
877
|
-
cfg.tune.infer_epochs = 0
|
|
529
|
+
cfg.tune.patience = 25
|
|
878
530
|
|
|
879
531
|
else: # thorough
|
|
532
|
+
# Model
|
|
880
533
|
cfg.model.latent_dim = 16
|
|
881
|
-
cfg.model.num_hidden_layers =
|
|
882
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
534
|
+
cfg.model.num_hidden_layers = 8
|
|
883
535
|
cfg.model.dropout_rate = 0.30
|
|
884
|
-
|
|
885
|
-
# VAE specifics
|
|
886
|
-
cfg.vae.kl_beta = 1.0
|
|
887
|
-
cfg.vae.kl_warmup = 100
|
|
888
|
-
cfg.vae.kl_ramp = 400
|
|
536
|
+
|
|
889
537
|
# Train
|
|
890
538
|
cfg.train.batch_size = 64
|
|
891
539
|
cfg.train.learning_rate = 5e-4
|
|
892
|
-
cfg.train.early_stop_gen =
|
|
893
|
-
cfg.train.
|
|
894
|
-
|
|
895
|
-
cfg.train.weights_beta = 0.9999
|
|
896
|
-
cfg.train.weights_max_ratio = 5.0
|
|
540
|
+
cfg.train.early_stop_gen = 50
|
|
541
|
+
cfg.train.max_epochs = 1000
|
|
542
|
+
|
|
897
543
|
# Tune
|
|
898
|
-
cfg.tune.
|
|
899
|
-
cfg.tune.fast = False
|
|
900
|
-
cfg.tune.n_trials = 100
|
|
901
|
-
cfg.tune.epochs = 600
|
|
902
|
-
cfg.tune.batch_size = 64
|
|
903
|
-
cfg.tune.max_samples = 0
|
|
904
|
-
cfg.tune.max_loci = 0
|
|
905
|
-
cfg.tune.eval_interval = 10
|
|
906
|
-
cfg.tune.patience = 20
|
|
907
|
-
cfg.tune.proxy_metric_batch = 0
|
|
908
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
909
|
-
cfg.tune.infer_epochs = 0
|
|
544
|
+
cfg.tune.patience = 50
|
|
910
545
|
|
|
911
546
|
return cfg
|
|
912
547
|
|
|
@@ -935,9 +570,9 @@ class MostFrequentAlgoConfig:
|
|
|
935
570
|
"""Algorithmic knobs for ImputeMostFrequent.
|
|
936
571
|
|
|
937
572
|
Attributes:
|
|
938
|
-
by_populations (bool): Whether to compute per-population modes.
|
|
939
|
-
default (int): Fallback mode if no valid entries in a locus.
|
|
940
|
-
missing (int): Code for missing genotypes in 0/1/2.
|
|
573
|
+
by_populations (bool): Whether to compute per-population modes. Default is False.
|
|
574
|
+
default (int): Fallback mode if no valid entries in a locus. Default is 0.
|
|
575
|
+
missing (int): Code for missing genotypes in 0/1/2. Default is -1.
|
|
941
576
|
"""
|
|
942
577
|
|
|
943
578
|
by_populations: bool = False
|
|
@@ -950,8 +585,8 @@ class DeterministicSplitConfig:
|
|
|
950
585
|
"""Evaluation split configuration shared by deterministic imputers.
|
|
951
586
|
|
|
952
587
|
Attributes:
|
|
953
|
-
test_size (float): Proportion of data to use as the test set.
|
|
954
|
-
test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
|
|
588
|
+
test_size (float): Proportion of data to use as the test set. Default is 0.2.
|
|
589
|
+
test_indices (Optional[Sequence[int]]): Specific indices to use as the test set. Default is None.
|
|
955
590
|
"""
|
|
956
591
|
|
|
957
592
|
test_size: float = 0.2
|
|
@@ -962,6 +597,10 @@ class DeterministicSplitConfig:
|
|
|
962
597
|
class MostFrequentConfig:
|
|
963
598
|
"""Top-level configuration for ImputeMostFrequent.
|
|
964
599
|
|
|
600
|
+
Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
|
|
601
|
+
and ``sim``. The ``train`` and ``tune`` sections are retained for schema
|
|
602
|
+
parity with NN models but are not currently used by ImputeMostFrequent.
|
|
603
|
+
|
|
965
604
|
Attributes:
|
|
966
605
|
io (IOConfig): I/O configuration.
|
|
967
606
|
plot (PlotConfig): Plotting configuration.
|
|
@@ -978,19 +617,27 @@ class MostFrequentConfig:
|
|
|
978
617
|
algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
|
|
979
618
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
980
619
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
981
|
-
train: TrainConfig = field(default_factory=
|
|
620
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
982
621
|
|
|
983
622
|
@classmethod
|
|
984
623
|
def from_preset(
|
|
985
624
|
cls,
|
|
986
625
|
preset: Literal["fast", "balanced", "thorough"] = "balanced",
|
|
987
626
|
) -> "MostFrequentConfig":
|
|
988
|
-
"""Construct a preset configuration.
|
|
627
|
+
"""Construct a preset configuration.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
MostFrequentConfig: Configuration instance corresponding to the preset.
|
|
634
|
+
"""
|
|
989
635
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
990
636
|
raise ValueError(f"Unknown preset: {preset}")
|
|
991
637
|
|
|
992
638
|
cfg = cls()
|
|
993
639
|
cfg.io.verbose = False
|
|
640
|
+
cfg.io.ploidy = 2
|
|
994
641
|
cfg.split.test_size = 0.2
|
|
995
642
|
cfg.sim.simulate_missing = True
|
|
996
643
|
cfg.sim.sim_strategy = "random"
|
|
@@ -1033,6 +680,10 @@ class RefAlleleAlgoConfig:
|
|
|
1033
680
|
class RefAlleleConfig:
|
|
1034
681
|
"""Top-level configuration for ImputeRefAllele.
|
|
1035
682
|
|
|
683
|
+
Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
|
|
684
|
+
and ``sim``. The ``train`` and ``tune`` sections are retained for schema
|
|
685
|
+
parity with NN models but are not currently used by ImputeRefAllele.
|
|
686
|
+
|
|
1036
687
|
Attributes:
|
|
1037
688
|
io (IOConfig): I/O configuration.
|
|
1038
689
|
plot (PlotConfig): Plotting configuration.
|
|
@@ -1049,18 +700,26 @@ class RefAlleleConfig:
|
|
|
1049
700
|
algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
|
|
1050
701
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
1051
702
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
1052
|
-
train: TrainConfig = field(default_factory=
|
|
703
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
1053
704
|
|
|
1054
705
|
@classmethod
|
|
1055
706
|
def from_preset(
|
|
1056
707
|
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1057
708
|
) -> "RefAlleleConfig":
|
|
1058
|
-
"""Presets mainly keep parity with logging/IO and split test_size.
|
|
709
|
+
"""Presets mainly keep parity with logging/IO and split test_size.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
713
|
+
|
|
714
|
+
Returns:
|
|
715
|
+
RefAlleleConfig: Configuration instance corresponding to the preset.
|
|
716
|
+
"""
|
|
1059
717
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
1060
718
|
raise ValueError(f"Unknown preset: {preset}")
|
|
1061
719
|
|
|
1062
720
|
cfg = cls()
|
|
1063
721
|
cfg.io.verbose = False
|
|
722
|
+
cfg.io.ploidy = 2
|
|
1064
723
|
cfg.split.test_size = 0.2
|
|
1065
724
|
cfg.sim.simulate_missing = True
|
|
1066
725
|
cfg.sim.sim_strategy = "random"
|
|
@@ -1273,7 +932,14 @@ class RFConfig:
|
|
|
1273
932
|
|
|
1274
933
|
@classmethod
|
|
1275
934
|
def from_preset(cls, preset: str = "balanced") -> "RFConfig":
|
|
1276
|
-
"""Build a config from a named preset.
|
|
935
|
+
"""Build a config from a named preset.
|
|
936
|
+
|
|
937
|
+
Args:
|
|
938
|
+
preset (str): Preset name.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
RFConfig: Configuration instance corresponding to the preset.
|
|
942
|
+
"""
|
|
1277
943
|
cfg = cls()
|
|
1278
944
|
if preset == "fast":
|
|
1279
945
|
cfg.model.n_estimators = 50
|
|
@@ -1365,7 +1031,14 @@ class HGBConfig:
|
|
|
1365
1031
|
|
|
1366
1032
|
@classmethod
|
|
1367
1033
|
def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
|
|
1368
|
-
"""Build a config from a named preset.
|
|
1034
|
+
"""Build a config from a named preset.
|
|
1035
|
+
|
|
1036
|
+
Args:
|
|
1037
|
+
preset (str): Preset name.
|
|
1038
|
+
|
|
1039
|
+
Returns:
|
|
1040
|
+
HGBConfig: Configuration instance corresponding to the preset.
|
|
1041
|
+
"""
|
|
1369
1042
|
cfg = cls()
|
|
1370
1043
|
if preset == "fast":
|
|
1371
1044
|
cfg.model.n_estimators = 50
|