pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -121,20 +121,18 @@ class ModelConfig:
|
|
|
121
121
|
latent_dim (int): Dimensionality of the latent space.
|
|
122
122
|
dropout_rate (float): Dropout rate for regularization.
|
|
123
123
|
num_hidden_layers (int): Number of hidden layers in the neural network.
|
|
124
|
-
|
|
124
|
+
activation (Literal["relu", "elu", "selu", "leaky_relu"]): Activation function.
|
|
125
125
|
layer_scaling_factor (float): Scaling factor for the number of neurons in hidden layers.
|
|
126
|
-
layer_schedule (Literal["pyramid", "
|
|
127
|
-
gamma (float): Parameter for the focal loss function.
|
|
126
|
+
layer_schedule (Literal["pyramid", "linear"]): Schedule for scaling hidden layer sizes.
|
|
128
127
|
"""
|
|
129
128
|
|
|
130
129
|
latent_init: Literal["random", "pca"] = "random"
|
|
131
130
|
latent_dim: int = 2
|
|
132
131
|
dropout_rate: float = 0.2
|
|
133
132
|
num_hidden_layers: int = 2
|
|
134
|
-
|
|
133
|
+
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu"
|
|
135
134
|
layer_scaling_factor: float = 5.0
|
|
136
|
-
layer_schedule: Literal["pyramid", "
|
|
137
|
-
gamma: float = 2.0
|
|
135
|
+
layer_schedule: Literal["pyramid", "linear"] = "pyramid"
|
|
138
136
|
|
|
139
137
|
|
|
140
138
|
@dataclass
|
|
@@ -144,28 +142,39 @@ class TrainConfig:
|
|
|
144
142
|
Attributes:
|
|
145
143
|
batch_size (int): Number of samples per training batch.
|
|
146
144
|
learning_rate (float): Learning rate for the optimizer.
|
|
147
|
-
lr_input_factor (float): Factor to scale the learning rate for input layer.
|
|
148
145
|
l1_penalty (float): L1 regularization penalty.
|
|
149
146
|
early_stop_gen (int): Number of generations with no improvement to wait before early stopping.
|
|
150
147
|
min_epochs (int): Minimum number of epochs to train.
|
|
151
148
|
max_epochs (int): Maximum number of epochs to train.
|
|
152
149
|
validation_split (float): Proportion of data to use for validation.
|
|
153
|
-
|
|
154
|
-
|
|
150
|
+
weights_max_ratio (float | None): Maximum ratio for class weights to prevent extreme values.
|
|
151
|
+
gamma (float): Focusing parameter for focal loss.
|
|
155
152
|
device (Literal["gpu", "cpu", "mps"]): Device to use for computation.
|
|
156
153
|
"""
|
|
157
154
|
|
|
158
|
-
batch_size: int =
|
|
155
|
+
batch_size: int = 64
|
|
159
156
|
learning_rate: float = 1e-3
|
|
160
|
-
lr_input_factor: float = 1.0
|
|
161
157
|
l1_penalty: float = 0.0
|
|
162
|
-
early_stop_gen: int =
|
|
158
|
+
early_stop_gen: int = 25
|
|
163
159
|
min_epochs: int = 100
|
|
164
|
-
max_epochs: int =
|
|
160
|
+
max_epochs: int = 2000
|
|
165
161
|
validation_split: float = 0.2
|
|
166
|
-
weights_beta: float = 0.9999
|
|
167
|
-
weights_max_ratio: float = 1.0
|
|
168
162
|
device: Literal["gpu", "cpu", "mps"] = "cpu"
|
|
163
|
+
weights_max_ratio: Optional[float] = None
|
|
164
|
+
weights_power: float = 1.0
|
|
165
|
+
weights_normalize: bool = True
|
|
166
|
+
weights_inverse: bool = False
|
|
167
|
+
gamma: float = 0.0
|
|
168
|
+
gamma_schedule: bool = False
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _default_train_config() -> TrainConfig:
|
|
172
|
+
"""Typed default factory for TrainConfig (helps some type checkers).
|
|
173
|
+
|
|
174
|
+
Using the class object directly (default_factory=TrainConfig) is valid at runtime but certain type checkers can fail to match dataclasses.field overloads.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
return TrainConfig()
|
|
169
178
|
|
|
170
179
|
|
|
171
180
|
@dataclass
|
|
@@ -174,19 +183,13 @@ class TuneConfig:
|
|
|
174
183
|
|
|
175
184
|
Attributes:
|
|
176
185
|
enabled (bool): If True, enables hyperparameter tuning.
|
|
177
|
-
metric (Literal["f1", "accuracy", "pr_macro"]): Metric to optimize during tuning.
|
|
186
|
+
metric (Literal["f1", "accuracy", "pr_macro", "average_precision", "roc_auc", "precision", "recall", "mcc", "jaccard"]): Metric to optimize during tuning.
|
|
178
187
|
n_trials (int): Number of hyperparameter trials to run.
|
|
179
188
|
resume (bool): If True, resumes tuning from a previous state.
|
|
180
189
|
save_db (bool): If True, saves the tuning results to a database.
|
|
181
|
-
fast (bool): If True, uses a faster but less thorough tuning approach.
|
|
182
|
-
max_samples (int): Maximum number of samples to use for tuning. 0 means all samples.
|
|
183
|
-
max_loci (int): Maximum number of loci to use for tuning. 0 means all loci.
|
|
184
190
|
epochs (int): Number of epochs to train each trial.
|
|
185
191
|
batch_size (int): Batch size for training during tuning.
|
|
186
|
-
eval_interval (int): Interval (in epochs) at which to evaluate the model during tuning.
|
|
187
|
-
infer_epochs (int): Number of epochs for inference during tuning.
|
|
188
192
|
patience (int): Number of evaluations with no improvement before stopping early.
|
|
189
|
-
proxy_metric_batch (int): If > 0, uses a subset of data for proxy metric evaluation.
|
|
190
193
|
"""
|
|
191
194
|
|
|
192
195
|
enabled: bool = False
|
|
@@ -198,34 +201,15 @@ class TuneConfig:
|
|
|
198
201
|
"roc_auc",
|
|
199
202
|
"precision",
|
|
200
203
|
"recall",
|
|
204
|
+
"mcc",
|
|
205
|
+
"jaccard",
|
|
201
206
|
] = "f1"
|
|
202
207
|
n_trials: int = 100
|
|
203
208
|
resume: bool = False
|
|
204
209
|
save_db: bool = False
|
|
205
|
-
fast: bool = True
|
|
206
|
-
max_samples: int = 512
|
|
207
|
-
max_loci: int = 0 # 0 = all
|
|
208
210
|
epochs: int = 500
|
|
209
211
|
batch_size: int = 64
|
|
210
|
-
eval_interval: int = 20
|
|
211
|
-
infer_epochs: int = 100
|
|
212
212
|
patience: int = 10
|
|
213
|
-
proxy_metric_batch: int = 0
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
@dataclass
|
|
217
|
-
class EvalConfig:
|
|
218
|
-
"""Evaluation configuration.
|
|
219
|
-
|
|
220
|
-
Attributes:
|
|
221
|
-
eval_latent_steps (int): Number of optimization steps for latent space evaluation.
|
|
222
|
-
eval_latent_lr (float): Learning rate for latent space optimization.
|
|
223
|
-
eval_latent_weight_decay (float): Weight decay for latent space optimization.
|
|
224
|
-
"""
|
|
225
|
-
|
|
226
|
-
eval_latent_steps: int = 50
|
|
227
|
-
eval_latent_lr: float = 1e-2
|
|
228
|
-
eval_latent_weight_decay: float = 0.0
|
|
229
213
|
|
|
230
214
|
|
|
231
215
|
@dataclass
|
|
@@ -244,15 +228,18 @@ class PlotConfig:
|
|
|
244
228
|
dpi: int = 300
|
|
245
229
|
fontsize: int = 18
|
|
246
230
|
despine: bool = True
|
|
247
|
-
show: bool =
|
|
231
|
+
show: bool = True
|
|
248
232
|
|
|
249
233
|
|
|
250
234
|
@dataclass
|
|
251
235
|
class IOConfig:
|
|
252
236
|
"""I/O configuration.
|
|
253
237
|
|
|
238
|
+
Dataclass that includes configuration settings for file naming, logging verbosity, random seed, and parallelism.
|
|
239
|
+
|
|
254
240
|
Attributes:
|
|
255
241
|
prefix (str): Prefix for output files. Default is "pgsui".
|
|
242
|
+
ploidy (int): Ploidy level of the organism. Default is 2.
|
|
256
243
|
verbose (bool): If True, enables verbose logging. Default is False.
|
|
257
244
|
debug (bool): If True, enables debug mode. Default is False.
|
|
258
245
|
seed (int | None): Random seed for reproducibility. Default is None.
|
|
@@ -261,6 +248,7 @@ class IOConfig:
|
|
|
261
248
|
"""
|
|
262
249
|
|
|
263
250
|
prefix: str = "pgsui"
|
|
251
|
+
ploidy: int = 2
|
|
264
252
|
verbose: bool = False
|
|
265
253
|
debug: bool = False
|
|
266
254
|
seed: int | None = None
|
|
@@ -287,37 +275,46 @@ class SimConfig:
|
|
|
287
275
|
"nonrandom",
|
|
288
276
|
"nonrandom_weighted",
|
|
289
277
|
] = "random"
|
|
290
|
-
sim_prop: float = 0.
|
|
278
|
+
sim_prop: float = 0.20
|
|
291
279
|
sim_kwargs: dict | None = None
|
|
292
280
|
|
|
293
281
|
|
|
294
282
|
@dataclass
|
|
295
|
-
class
|
|
296
|
-
"""Top-level configuration for
|
|
283
|
+
class AutoencoderConfig:
|
|
284
|
+
"""Top-level configuration for ImputeAutoencoder.
|
|
285
|
+
|
|
286
|
+
This configuration class encapsulates all settings required for the
|
|
287
|
+
ImputeAutoencoder model, including I/O, model architecture, training,
|
|
288
|
+
hyperparameter tuning, plotting, and simulated-missing configuration.
|
|
297
289
|
|
|
298
290
|
Attributes:
|
|
299
291
|
io (IOConfig): I/O configuration.
|
|
300
292
|
model (ModelConfig): Model architecture configuration.
|
|
301
293
|
train (TrainConfig): Training procedure configuration.
|
|
302
294
|
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
303
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
304
295
|
plot (PlotConfig): Plotting configuration.
|
|
305
|
-
sim (SimConfig):
|
|
296
|
+
sim (SimConfig): Simulated-missing configuration.
|
|
306
297
|
"""
|
|
307
298
|
|
|
308
299
|
io: IOConfig = field(default_factory=IOConfig)
|
|
309
300
|
model: ModelConfig = field(default_factory=ModelConfig)
|
|
310
|
-
train: TrainConfig = field(default_factory=
|
|
301
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
311
302
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
312
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
313
303
|
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
314
304
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
315
305
|
|
|
316
306
|
@classmethod
|
|
317
307
|
def from_preset(
|
|
318
308
|
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
319
|
-
) -> "
|
|
320
|
-
"""Build a
|
|
309
|
+
) -> "AutoencoderConfig":
|
|
310
|
+
"""Build a AutoencoderConfig from a named preset.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
AutoencoderConfig: Configuration instance corresponding to the preset.
|
|
317
|
+
"""
|
|
321
318
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
322
319
|
raise ValueError(f"Unknown preset: {preset}")
|
|
323
320
|
|
|
@@ -325,414 +322,87 @@ class NLPCAConfig:
|
|
|
325
322
|
|
|
326
323
|
# Common baselines
|
|
327
324
|
cfg.io.verbose = False
|
|
328
|
-
cfg.
|
|
329
|
-
cfg.
|
|
325
|
+
cfg.io.ploidy = 2
|
|
326
|
+
cfg.train.validation_split = 0.2
|
|
327
|
+
cfg.model.activation = "relu"
|
|
330
328
|
cfg.model.layer_schedule = "pyramid"
|
|
331
|
-
cfg.model.
|
|
332
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
333
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
334
|
-
cfg.sim.simulate_missing = True
|
|
329
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
335
330
|
cfg.sim.sim_strategy = "random"
|
|
336
331
|
cfg.sim.sim_prop = 0.2
|
|
332
|
+
cfg.plot.show = True
|
|
333
|
+
|
|
334
|
+
# Train settings
|
|
335
|
+
cfg.train.weights_max_ratio = None
|
|
336
|
+
cfg.train.weights_power = 1.0
|
|
337
|
+
cfg.train.weights_normalize = True
|
|
338
|
+
cfg.train.weights_inverse = False
|
|
339
|
+
cfg.train.gamma = 0.0
|
|
340
|
+
cfg.train.gamma_schedule = False
|
|
341
|
+
cfg.train.min_epochs = 100
|
|
342
|
+
|
|
343
|
+
# Tune
|
|
344
|
+
cfg.tune.enabled = False
|
|
345
|
+
cfg.tune.n_trials = 100
|
|
337
346
|
|
|
338
347
|
if preset == "fast":
|
|
339
348
|
# Model
|
|
340
349
|
cfg.model.latent_dim = 4
|
|
341
350
|
cfg.model.num_hidden_layers = 1
|
|
342
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
343
351
|
cfg.model.dropout_rate = 0.10
|
|
344
|
-
cfg.model.gamma = 1.5
|
|
345
|
-
# Train
|
|
346
|
-
cfg.train.batch_size = 128
|
|
347
|
-
cfg.train.learning_rate = 1e-3
|
|
348
|
-
cfg.train.early_stop_gen = 5
|
|
349
|
-
cfg.train.min_epochs = 10
|
|
350
|
-
cfg.train.max_epochs = 120
|
|
351
|
-
cfg.train.weights_beta = 0.9999
|
|
352
|
-
cfg.train.weights_max_ratio = 2.0
|
|
353
|
-
# Tuning
|
|
354
|
-
cfg.tune.enabled = True
|
|
355
|
-
cfg.tune.fast = True
|
|
356
|
-
cfg.tune.n_trials = 25
|
|
357
|
-
cfg.tune.epochs = 120
|
|
358
|
-
cfg.tune.batch_size = 128
|
|
359
|
-
cfg.tune.max_samples = 512
|
|
360
|
-
cfg.tune.max_loci = 0
|
|
361
|
-
cfg.tune.eval_interval = 20
|
|
362
|
-
cfg.tune.infer_epochs = 20
|
|
363
|
-
cfg.tune.patience = 5
|
|
364
|
-
cfg.tune.proxy_metric_batch = 0
|
|
365
|
-
# Eval
|
|
366
|
-
cfg.evaluate.eval_latent_steps = 20
|
|
367
352
|
|
|
368
|
-
elif preset == "balanced":
|
|
369
|
-
# Model
|
|
370
|
-
cfg.model.latent_dim = 8
|
|
371
|
-
cfg.model.num_hidden_layers = 2
|
|
372
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
373
|
-
cfg.model.dropout_rate = 0.20
|
|
374
|
-
cfg.model.gamma = 2.0
|
|
375
353
|
# Train
|
|
376
354
|
cfg.train.batch_size = 128
|
|
377
|
-
cfg.train.learning_rate =
|
|
355
|
+
cfg.train.learning_rate = 2e-3
|
|
378
356
|
cfg.train.early_stop_gen = 15
|
|
379
|
-
cfg.train.
|
|
380
|
-
cfg.train.
|
|
381
|
-
cfg.train.weights_beta = 0.9999
|
|
382
|
-
cfg.train.weights_max_ratio = 2.0
|
|
383
|
-
# Tuning
|
|
384
|
-
cfg.tune.enabled = True
|
|
385
|
-
cfg.tune.fast = True
|
|
386
|
-
cfg.tune.n_trials = 75
|
|
387
|
-
cfg.tune.epochs = 300
|
|
388
|
-
cfg.tune.batch_size = 128
|
|
389
|
-
cfg.tune.max_samples = 2048
|
|
390
|
-
cfg.tune.max_loci = 0
|
|
391
|
-
cfg.tune.eval_interval = 20
|
|
392
|
-
cfg.tune.infer_epochs = 40
|
|
393
|
-
cfg.tune.patience = 10
|
|
394
|
-
cfg.tune.proxy_metric_batch = 0
|
|
395
|
-
# Eval
|
|
396
|
-
cfg.evaluate.eval_latent_steps = 30
|
|
357
|
+
cfg.train.max_epochs = 200
|
|
358
|
+
cfg.train.weights_max_ratio = None
|
|
397
359
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
cfg.model.latent_dim = 16
|
|
401
|
-
cfg.model.num_hidden_layers = 3
|
|
402
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
403
|
-
cfg.model.dropout_rate = 0.30
|
|
404
|
-
cfg.model.gamma = 2.5
|
|
405
|
-
# Train
|
|
406
|
-
cfg.train.batch_size = 64
|
|
407
|
-
cfg.train.learning_rate = 6e-4
|
|
408
|
-
cfg.train.early_stop_gen = 20 # Reduced from 30
|
|
409
|
-
cfg.train.min_epochs = 100
|
|
410
|
-
cfg.train.max_epochs = 800 # Reduced from 1200
|
|
411
|
-
cfg.train.weights_beta = 0.9999
|
|
412
|
-
cfg.train.weights_max_ratio = 2.0
|
|
413
|
-
# Tuning
|
|
414
|
-
cfg.tune.enabled = True
|
|
415
|
-
cfg.tune.fast = False
|
|
416
|
-
cfg.tune.n_trials = 150
|
|
417
|
-
cfg.tune.epochs = 600
|
|
418
|
-
cfg.tune.batch_size = 64
|
|
419
|
-
cfg.tune.max_samples = 5000 # Capped from 0
|
|
420
|
-
cfg.tune.max_loci = 0
|
|
421
|
-
cfg.tune.eval_interval = 10
|
|
422
|
-
cfg.tune.infer_epochs = 80
|
|
423
|
-
cfg.tune.patience = 15 # Reduced from 20
|
|
424
|
-
cfg.tune.proxy_metric_batch = 0
|
|
425
|
-
# Eval
|
|
426
|
-
cfg.evaluate.eval_latent_steps = 50
|
|
427
|
-
|
|
428
|
-
return cfg
|
|
429
|
-
|
|
430
|
-
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "NLPCAConfig":
|
|
431
|
-
"""Apply flat dot-key overrides."""
|
|
432
|
-
if not overrides:
|
|
433
|
-
return self
|
|
434
|
-
for k, v in overrides.items():
|
|
435
|
-
node = self
|
|
436
|
-
parts = k.split(".")
|
|
437
|
-
for p in parts[:-1]:
|
|
438
|
-
node = getattr(node, p)
|
|
439
|
-
last = parts[-1]
|
|
440
|
-
if hasattr(node, last):
|
|
441
|
-
setattr(node, last, v)
|
|
442
|
-
else:
|
|
443
|
-
raise KeyError(f"Unknown config key: {k}")
|
|
444
|
-
return self
|
|
445
|
-
|
|
446
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
447
|
-
return asdict(self)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
@dataclass
|
|
451
|
-
class UBPConfig:
|
|
452
|
-
"""Top-level configuration for ImputeUBP.
|
|
453
|
-
|
|
454
|
-
Attributes:
|
|
455
|
-
io (IOConfig): I/O configuration.
|
|
456
|
-
model (ModelConfig): Model architecture configuration.
|
|
457
|
-
train (TrainConfig): Training procedure configuration.
|
|
458
|
-
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
459
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
460
|
-
plot (PlotConfig): Plotting configuration.
|
|
461
|
-
sim (SimConfig): Simulated-missing configuration.
|
|
462
|
-
"""
|
|
463
|
-
|
|
464
|
-
io: IOConfig = field(default_factory=IOConfig)
|
|
465
|
-
model: ModelConfig = field(default_factory=ModelConfig)
|
|
466
|
-
train: TrainConfig = field(default_factory=TrainConfig)
|
|
467
|
-
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
468
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
469
|
-
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
470
|
-
sim: SimConfig = field(default_factory=SimConfig)
|
|
471
|
-
|
|
472
|
-
@classmethod
|
|
473
|
-
def from_preset(
|
|
474
|
-
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
475
|
-
) -> "UBPConfig":
|
|
476
|
-
"""Build a UBPConfig from a named preset."""
|
|
477
|
-
if preset not in {"fast", "balanced", "thorough"}:
|
|
478
|
-
raise ValueError(f"Unknown preset: {preset}")
|
|
479
|
-
|
|
480
|
-
cfg = cls()
|
|
481
|
-
|
|
482
|
-
# Common baselines
|
|
483
|
-
cfg.io.verbose = False
|
|
484
|
-
cfg.model.hidden_activation = "relu"
|
|
485
|
-
cfg.model.layer_schedule = "pyramid"
|
|
486
|
-
cfg.model.latent_init = "random"
|
|
487
|
-
cfg.sim.simulate_missing = True
|
|
488
|
-
cfg.sim.sim_strategy = "random"
|
|
489
|
-
cfg.sim.sim_prop = 0.2
|
|
490
|
-
|
|
491
|
-
if preset == "fast":
|
|
492
|
-
# Model
|
|
493
|
-
cfg.model.latent_dim = 4
|
|
494
|
-
cfg.model.num_hidden_layers = 1
|
|
495
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
496
|
-
cfg.model.dropout_rate = 0.10
|
|
497
|
-
cfg.model.gamma = 1.5
|
|
498
|
-
# Train
|
|
499
|
-
cfg.train.batch_size = 128
|
|
500
|
-
cfg.train.learning_rate = 1e-3
|
|
501
|
-
cfg.train.early_stop_gen = 5
|
|
502
|
-
cfg.train.min_epochs = 10
|
|
503
|
-
cfg.train.max_epochs = 120
|
|
504
|
-
cfg.train.weights_beta = 0.9999
|
|
505
|
-
cfg.train.weights_max_ratio = 2.0
|
|
506
|
-
# Tuning
|
|
507
|
-
cfg.tune.enabled = True
|
|
508
|
-
cfg.tune.fast = True
|
|
509
|
-
cfg.tune.n_trials = 25
|
|
510
|
-
cfg.tune.epochs = 120
|
|
511
|
-
cfg.tune.batch_size = 128
|
|
512
|
-
cfg.tune.max_samples = 512
|
|
513
|
-
cfg.tune.max_loci = 0
|
|
514
|
-
cfg.tune.eval_interval = 20
|
|
515
|
-
cfg.tune.infer_epochs = 20
|
|
516
|
-
cfg.tune.patience = 5
|
|
517
|
-
cfg.tune.proxy_metric_batch = 0
|
|
518
|
-
# Eval
|
|
519
|
-
cfg.evaluate.eval_latent_steps = 20
|
|
520
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
521
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
360
|
+
# Tune
|
|
361
|
+
cfg.tune.patience = 15
|
|
522
362
|
|
|
523
363
|
elif preset == "balanced":
|
|
524
364
|
# Model
|
|
525
365
|
cfg.model.latent_dim = 8
|
|
526
366
|
cfg.model.num_hidden_layers = 2
|
|
527
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
528
367
|
cfg.model.dropout_rate = 0.20
|
|
529
|
-
cfg.model.gamma = 2.0
|
|
530
|
-
# Train
|
|
531
|
-
cfg.train.batch_size = 128
|
|
532
|
-
cfg.train.learning_rate = 8e-4
|
|
533
|
-
cfg.train.early_stop_gen = 15
|
|
534
|
-
cfg.train.min_epochs = 50
|
|
535
|
-
cfg.train.max_epochs = 600
|
|
536
|
-
cfg.train.weights_beta = 0.9999
|
|
537
|
-
cfg.train.weights_max_ratio = 2.0
|
|
538
|
-
# Tuning
|
|
539
|
-
cfg.tune.enabled = True
|
|
540
|
-
cfg.tune.fast = True
|
|
541
|
-
cfg.tune.n_trials = 75
|
|
542
|
-
cfg.tune.epochs = 300
|
|
543
|
-
cfg.tune.batch_size = 128
|
|
544
|
-
cfg.tune.max_samples = 2048
|
|
545
|
-
cfg.tune.max_loci = 0
|
|
546
|
-
cfg.tune.eval_interval = 20
|
|
547
|
-
cfg.tune.infer_epochs = 40
|
|
548
|
-
cfg.tune.patience = 10
|
|
549
|
-
cfg.tune.proxy_metric_batch = 0
|
|
550
|
-
# Eval
|
|
551
|
-
cfg.evaluate.eval_latent_steps = 30
|
|
552
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
553
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
554
368
|
|
|
555
|
-
else: # thorough
|
|
556
|
-
# Model
|
|
557
|
-
cfg.model.latent_dim = 16
|
|
558
|
-
cfg.model.num_hidden_layers = 3
|
|
559
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
560
|
-
cfg.model.dropout_rate = 0.30
|
|
561
|
-
cfg.model.gamma = 2.5
|
|
562
369
|
# Train
|
|
563
370
|
cfg.train.batch_size = 64
|
|
564
|
-
cfg.train.learning_rate = 6e-4
|
|
565
|
-
cfg.train.early_stop_gen = 20 # Reduced from 30
|
|
566
|
-
cfg.train.min_epochs = 100
|
|
567
|
-
cfg.train.max_epochs = 800 # Reduced from 1200
|
|
568
|
-
cfg.train.weights_beta = 0.9999
|
|
569
|
-
cfg.train.weights_max_ratio = 2.0
|
|
570
|
-
# Tuning
|
|
571
|
-
cfg.tune.enabled = True
|
|
572
|
-
cfg.tune.fast = False
|
|
573
|
-
cfg.tune.n_trials = 150
|
|
574
|
-
cfg.tune.epochs = 600
|
|
575
|
-
cfg.tune.batch_size = 64
|
|
576
|
-
cfg.tune.max_samples = 5000 # Capped from 0
|
|
577
|
-
cfg.tune.max_loci = 0
|
|
578
|
-
cfg.tune.eval_interval = 10
|
|
579
|
-
cfg.tune.infer_epochs = 80
|
|
580
|
-
cfg.tune.patience = 15 # Reduced from 20
|
|
581
|
-
cfg.tune.proxy_metric_batch = 0
|
|
582
|
-
# Eval
|
|
583
|
-
cfg.evaluate.eval_latent_steps = 50
|
|
584
|
-
cfg.evaluate.eval_latent_lr = 1e-2
|
|
585
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
586
|
-
|
|
587
|
-
return cfg
|
|
588
|
-
|
|
589
|
-
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "UBPConfig":
|
|
590
|
-
"""Apply flat dot-key overrides."""
|
|
591
|
-
if not overrides:
|
|
592
|
-
return self
|
|
593
|
-
|
|
594
|
-
for k, v in overrides.items():
|
|
595
|
-
node = self
|
|
596
|
-
parts = k.split(".")
|
|
597
|
-
for p in parts[:-1]:
|
|
598
|
-
node = getattr(node, p)
|
|
599
|
-
last = parts[-1]
|
|
600
|
-
if hasattr(node, last):
|
|
601
|
-
setattr(node, last, v)
|
|
602
|
-
else:
|
|
603
|
-
raise KeyError(f"Unknown config key: {k}")
|
|
604
|
-
return self
|
|
605
|
-
|
|
606
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
607
|
-
return asdict(self)
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
@dataclass
|
|
611
|
-
class AutoencoderConfig:
|
|
612
|
-
"""Top-level configuration for ImputeAutoencoder.
|
|
613
|
-
|
|
614
|
-
Attributes:
|
|
615
|
-
io (IOConfig): I/O configuration.
|
|
616
|
-
model (ModelConfig): Model architecture configuration.
|
|
617
|
-
train (TrainConfig): Training procedure configuration.
|
|
618
|
-
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
619
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
620
|
-
plot (PlotConfig): Plotting configuration.
|
|
621
|
-
sim (SimConfig): Simulated-missing configuration.
|
|
622
|
-
"""
|
|
623
|
-
|
|
624
|
-
io: IOConfig = field(default_factory=IOConfig)
|
|
625
|
-
model: ModelConfig = field(default_factory=ModelConfig)
|
|
626
|
-
train: TrainConfig = field(default_factory=TrainConfig)
|
|
627
|
-
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
628
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
629
|
-
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
630
|
-
sim: SimConfig = field(default_factory=SimConfig)
|
|
631
|
-
|
|
632
|
-
@classmethod
|
|
633
|
-
def from_preset(
|
|
634
|
-
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
635
|
-
) -> "AutoencoderConfig":
|
|
636
|
-
"""Build a AutoencoderConfig from a named preset."""
|
|
637
|
-
if preset not in {"fast", "balanced", "thorough"}:
|
|
638
|
-
raise ValueError(f"Unknown preset: {preset}")
|
|
639
|
-
|
|
640
|
-
cfg = cls()
|
|
641
|
-
|
|
642
|
-
# Common baselines (no latent refinement at eval)
|
|
643
|
-
cfg.io.verbose = False
|
|
644
|
-
cfg.train.validation_split = 0.20
|
|
645
|
-
cfg.model.hidden_activation = "relu"
|
|
646
|
-
cfg.model.layer_schedule = "pyramid"
|
|
647
|
-
cfg.evaluate.eval_latent_steps = 0
|
|
648
|
-
cfg.evaluate.eval_latent_lr = 0.0
|
|
649
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
650
|
-
cfg.sim.simulate_missing = True
|
|
651
|
-
cfg.sim.sim_strategy = "random"
|
|
652
|
-
cfg.sim.sim_prop = 0.2
|
|
653
|
-
|
|
654
|
-
if preset == "fast":
|
|
655
|
-
cfg.model.latent_dim = 4
|
|
656
|
-
cfg.model.num_hidden_layers = 1
|
|
657
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
658
|
-
cfg.model.dropout_rate = 0.10
|
|
659
|
-
cfg.model.gamma = 1.5
|
|
660
|
-
cfg.train.batch_size = 128
|
|
661
371
|
cfg.train.learning_rate = 1e-3
|
|
662
|
-
cfg.train.early_stop_gen =
|
|
663
|
-
cfg.train.
|
|
664
|
-
cfg.train.
|
|
665
|
-
cfg.train.weights_beta = 0.9999
|
|
666
|
-
cfg.train.weights_max_ratio = 2.0
|
|
667
|
-
cfg.tune.enabled = True
|
|
668
|
-
cfg.tune.fast = True
|
|
669
|
-
cfg.tune.n_trials = 25
|
|
670
|
-
cfg.tune.epochs = 120
|
|
671
|
-
cfg.tune.batch_size = 128
|
|
672
|
-
cfg.tune.max_samples = 512
|
|
673
|
-
cfg.tune.max_loci = 0
|
|
674
|
-
cfg.tune.eval_interval = 20
|
|
675
|
-
cfg.tune.patience = 5
|
|
676
|
-
cfg.tune.proxy_metric_batch = 0
|
|
677
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
678
|
-
cfg.tune.infer_epochs = 0
|
|
372
|
+
cfg.train.early_stop_gen = 25
|
|
373
|
+
cfg.train.max_epochs = 500
|
|
374
|
+
cfg.train.weights_max_ratio = None
|
|
679
375
|
|
|
680
|
-
|
|
681
|
-
cfg.
|
|
682
|
-
cfg.model.num_hidden_layers = 2
|
|
683
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
684
|
-
cfg.model.dropout_rate = 0.20
|
|
685
|
-
cfg.model.gamma = 2.0
|
|
686
|
-
cfg.train.batch_size = 128
|
|
687
|
-
cfg.train.learning_rate = 8e-4
|
|
688
|
-
cfg.train.early_stop_gen = 15
|
|
689
|
-
cfg.train.min_epochs = 50
|
|
690
|
-
cfg.train.max_epochs = 600
|
|
691
|
-
cfg.train.weights_beta = 0.9999
|
|
692
|
-
cfg.train.weights_max_ratio = 2.0
|
|
693
|
-
cfg.tune.enabled = True
|
|
694
|
-
cfg.tune.fast = True
|
|
695
|
-
cfg.tune.n_trials = 75
|
|
696
|
-
cfg.tune.epochs = 300
|
|
697
|
-
cfg.tune.batch_size = 128
|
|
698
|
-
cfg.tune.max_samples = 2048
|
|
699
|
-
cfg.tune.max_loci = 0
|
|
700
|
-
cfg.tune.eval_interval = 20
|
|
701
|
-
cfg.tune.patience = 10
|
|
702
|
-
cfg.tune.proxy_metric_batch = 0
|
|
703
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
704
|
-
cfg.tune.infer_epochs = 0
|
|
376
|
+
# Tune
|
|
377
|
+
cfg.tune.patience = 25
|
|
705
378
|
|
|
706
379
|
else: # thorough
|
|
380
|
+
# Model
|
|
707
381
|
cfg.model.latent_dim = 16
|
|
708
382
|
cfg.model.num_hidden_layers = 3
|
|
709
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
710
383
|
cfg.model.dropout_rate = 0.30
|
|
711
|
-
|
|
384
|
+
|
|
385
|
+
# Train
|
|
712
386
|
cfg.train.batch_size = 64
|
|
713
|
-
cfg.train.learning_rate =
|
|
714
|
-
cfg.train.early_stop_gen =
|
|
715
|
-
cfg.train.
|
|
716
|
-
cfg.train.
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
cfg.tune.
|
|
720
|
-
cfg.tune.fast = False
|
|
721
|
-
cfg.tune.n_trials = 150
|
|
722
|
-
cfg.tune.epochs = 600
|
|
723
|
-
cfg.tune.batch_size = 64
|
|
724
|
-
cfg.tune.max_samples = 5000 # Capped from 0
|
|
725
|
-
cfg.tune.max_loci = 0
|
|
726
|
-
cfg.tune.eval_interval = 10
|
|
727
|
-
cfg.tune.patience = 15 # Reduced from 20
|
|
728
|
-
cfg.tune.proxy_metric_batch = 0
|
|
729
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
730
|
-
cfg.tune.infer_epochs = 0
|
|
387
|
+
cfg.train.learning_rate = 5e-4
|
|
388
|
+
cfg.train.early_stop_gen = 50
|
|
389
|
+
cfg.train.max_epochs = 1000
|
|
390
|
+
cfg.train.weights_max_ratio = None
|
|
391
|
+
|
|
392
|
+
# Tune
|
|
393
|
+
cfg.tune.patience = 50
|
|
731
394
|
|
|
732
395
|
return cfg
|
|
733
396
|
|
|
734
397
|
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "AutoencoderConfig":
|
|
735
|
-
"""Apply flat dot-key overrides.
|
|
398
|
+
"""Apply flat dot-key overrides.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
overrides (Dict[str, Any] | None): Dictionary of overrides with dot-separated keys.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
AutoencoderConfig: New configuration instance with overrides applied.
|
|
405
|
+
"""
|
|
736
406
|
if not overrides:
|
|
737
407
|
return self
|
|
738
408
|
for k, v in overrides.items():
|
|
@@ -753,29 +423,22 @@ class AutoencoderConfig:
|
|
|
753
423
|
|
|
754
424
|
@dataclass
|
|
755
425
|
class VAEExtraConfig:
|
|
756
|
-
"""VAE-specific knobs.
|
|
757
|
-
|
|
758
|
-
Attributes:
|
|
759
|
-
kl_beta (float): Final β for KL divergence term.
|
|
760
|
-
kl_warmup (int): Number of epochs with β=0 (warm-up period).
|
|
761
|
-
kl_ramp (int): Number of epochs for linear ramp to final β.
|
|
762
|
-
"""
|
|
763
|
-
|
|
764
426
|
kl_beta: float = 1.0
|
|
765
|
-
|
|
766
|
-
kl_ramp: int = 200
|
|
427
|
+
kl_beta_schedule: bool = False
|
|
767
428
|
|
|
768
429
|
|
|
769
430
|
@dataclass
|
|
770
431
|
class VAEConfig:
|
|
771
432
|
"""Top-level configuration for ImputeVAE (AE-parity + VAE extras).
|
|
772
433
|
|
|
434
|
+
Mirrors AutoencoderConfig sections and adds a ``vae`` block with KL-beta
|
|
435
|
+
controls for the VAE loss.
|
|
436
|
+
|
|
773
437
|
Attributes:
|
|
774
438
|
io (IOConfig): I/O configuration.
|
|
775
439
|
model (ModelConfig): Model architecture configuration.
|
|
776
440
|
train (TrainConfig): Training procedure configuration.
|
|
777
441
|
tune (TuneConfig): Hyperparameter tuning configuration.
|
|
778
|
-
evaluate (EvalConfig): Evaluation configuration.
|
|
779
442
|
plot (PlotConfig): Plotting configuration.
|
|
780
443
|
vae (VAEExtraConfig): VAE-specific configuration.
|
|
781
444
|
sim (SimConfig): Simulated-missing configuration.
|
|
@@ -783,9 +446,8 @@ class VAEConfig:
|
|
|
783
446
|
|
|
784
447
|
io: IOConfig = field(default_factory=IOConfig)
|
|
785
448
|
model: ModelConfig = field(default_factory=ModelConfig)
|
|
786
|
-
train: TrainConfig = field(default_factory=
|
|
449
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
787
450
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
788
|
-
evaluate: EvalConfig = field(default_factory=EvalConfig)
|
|
789
451
|
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
790
452
|
vae: VAEExtraConfig = field(default_factory=VAEExtraConfig)
|
|
791
453
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
@@ -794,107 +456,92 @@ class VAEConfig:
|
|
|
794
456
|
def from_preset(
|
|
795
457
|
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
796
458
|
) -> "VAEConfig":
|
|
797
|
-
"""Build a VAEConfig from a named preset.
|
|
459
|
+
"""Build a VAEConfig from a named preset.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
VAEConfig: Configuration instance corresponding to the preset.
|
|
466
|
+
"""
|
|
798
467
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
799
468
|
raise ValueError(f"Unknown preset: {preset}")
|
|
800
469
|
|
|
801
470
|
cfg = cls()
|
|
802
471
|
|
|
803
|
-
#
|
|
472
|
+
# General settings
|
|
804
473
|
cfg.io.verbose = False
|
|
805
|
-
cfg.
|
|
806
|
-
cfg.
|
|
474
|
+
cfg.io.ploidy = 2
|
|
475
|
+
cfg.train.validation_split = 0.2
|
|
476
|
+
cfg.model.activation = "relu"
|
|
807
477
|
cfg.model.layer_schedule = "pyramid"
|
|
808
|
-
cfg.
|
|
809
|
-
cfg.evaluate.eval_latent_lr = 0.0
|
|
810
|
-
cfg.evaluate.eval_latent_weight_decay = 0.0
|
|
478
|
+
cfg.model.layer_scaling_factor = 2.0
|
|
811
479
|
cfg.sim.simulate_missing = True
|
|
812
480
|
cfg.sim.sim_strategy = "random"
|
|
813
481
|
cfg.sim.sim_prop = 0.2
|
|
814
|
-
|
|
815
|
-
|
|
482
|
+
cfg.plot.show = True
|
|
483
|
+
|
|
484
|
+
# Train settings
|
|
485
|
+
cfg.train.weights_max_ratio = None
|
|
486
|
+
cfg.train.weights_power = 1.0
|
|
487
|
+
cfg.train.weights_normalize = True
|
|
488
|
+
cfg.train.weights_inverse = False
|
|
489
|
+
cfg.train.gamma = 0.0
|
|
490
|
+
cfg.train.gamma_schedule = False
|
|
491
|
+
cfg.train.min_epochs = 100
|
|
492
|
+
|
|
493
|
+
# VAE-specific
|
|
816
494
|
cfg.vae.kl_beta = 1.0
|
|
817
|
-
cfg.vae.
|
|
818
|
-
|
|
495
|
+
cfg.vae.kl_beta_schedule = False
|
|
496
|
+
|
|
497
|
+
# Tune
|
|
498
|
+
cfg.tune.enabled = False
|
|
499
|
+
cfg.tune.n_trials = 100
|
|
819
500
|
|
|
820
501
|
if preset == "fast":
|
|
502
|
+
# Model
|
|
821
503
|
cfg.model.latent_dim = 4
|
|
822
|
-
cfg.model.num_hidden_layers =
|
|
823
|
-
cfg.model.layer_scaling_factor = 2.0
|
|
504
|
+
cfg.model.num_hidden_layers = 2
|
|
824
505
|
cfg.model.dropout_rate = 0.10
|
|
825
|
-
|
|
826
|
-
|
|
506
|
+
|
|
507
|
+
# Train
|
|
827
508
|
cfg.train.batch_size = 128
|
|
828
|
-
cfg.train.learning_rate =
|
|
829
|
-
cfg.train.early_stop_gen =
|
|
830
|
-
cfg.train.
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
cfg.
|
|
834
|
-
cfg.tune.enabled = True
|
|
835
|
-
cfg.tune.fast = True
|
|
836
|
-
cfg.tune.n_trials = 25
|
|
837
|
-
cfg.tune.epochs = 120
|
|
838
|
-
cfg.tune.batch_size = 128
|
|
839
|
-
cfg.tune.max_samples = 512
|
|
840
|
-
cfg.tune.max_loci = 0
|
|
841
|
-
cfg.tune.eval_interval = 20
|
|
842
|
-
cfg.tune.patience = 5
|
|
843
|
-
cfg.tune.proxy_metric_batch = 0
|
|
844
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
845
|
-
cfg.tune.infer_epochs = 0
|
|
509
|
+
cfg.train.learning_rate = 2e-3
|
|
510
|
+
cfg.train.early_stop_gen = 15
|
|
511
|
+
cfg.train.max_epochs = 200
|
|
512
|
+
|
|
513
|
+
# Tune
|
|
514
|
+
cfg.tune.patience = 15
|
|
846
515
|
|
|
847
516
|
elif preset == "balanced":
|
|
517
|
+
# Model
|
|
848
518
|
cfg.model.latent_dim = 8
|
|
849
|
-
cfg.model.num_hidden_layers =
|
|
850
|
-
cfg.model.layer_scaling_factor = 3.0
|
|
519
|
+
cfg.model.num_hidden_layers = 4
|
|
851
520
|
cfg.model.dropout_rate = 0.20
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
cfg.train.
|
|
855
|
-
cfg.train.
|
|
856
|
-
cfg.train.
|
|
857
|
-
cfg.train.max_epochs =
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
cfg.tune.
|
|
861
|
-
cfg.tune.fast = True
|
|
862
|
-
cfg.tune.n_trials = 75
|
|
863
|
-
cfg.tune.epochs = 300
|
|
864
|
-
cfg.tune.batch_size = 128
|
|
865
|
-
cfg.tune.max_samples = 2048
|
|
866
|
-
cfg.tune.max_loci = 0
|
|
867
|
-
cfg.tune.eval_interval = 20
|
|
868
|
-
cfg.tune.patience = 10
|
|
869
|
-
cfg.tune.proxy_metric_batch = 0
|
|
870
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
871
|
-
cfg.tune.infer_epochs = 0
|
|
521
|
+
|
|
522
|
+
# Train
|
|
523
|
+
cfg.train.batch_size = 64
|
|
524
|
+
cfg.train.learning_rate = 1e-3
|
|
525
|
+
cfg.train.early_stop_gen = 25
|
|
526
|
+
cfg.train.max_epochs = 500
|
|
527
|
+
|
|
528
|
+
# Tune
|
|
529
|
+
cfg.tune.patience = 25
|
|
872
530
|
|
|
873
531
|
else: # thorough
|
|
532
|
+
# Model
|
|
874
533
|
cfg.model.latent_dim = 16
|
|
875
|
-
cfg.model.num_hidden_layers =
|
|
876
|
-
cfg.model.layer_scaling_factor = 5.0
|
|
534
|
+
cfg.model.num_hidden_layers = 8
|
|
877
535
|
cfg.model.dropout_rate = 0.30
|
|
878
|
-
|
|
536
|
+
|
|
537
|
+
# Train
|
|
879
538
|
cfg.train.batch_size = 64
|
|
880
|
-
cfg.train.learning_rate =
|
|
881
|
-
cfg.train.early_stop_gen =
|
|
882
|
-
cfg.train.
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
cfg.
|
|
886
|
-
cfg.tune.enabled = True
|
|
887
|
-
cfg.tune.fast = False
|
|
888
|
-
cfg.tune.n_trials = 150
|
|
889
|
-
cfg.tune.epochs = 600
|
|
890
|
-
cfg.tune.batch_size = 64
|
|
891
|
-
cfg.tune.max_samples = 5000 # Capped from 0
|
|
892
|
-
cfg.tune.max_loci = 0
|
|
893
|
-
cfg.tune.eval_interval = 10
|
|
894
|
-
cfg.tune.patience = 15 # Reduced from 20
|
|
895
|
-
cfg.tune.proxy_metric_batch = 0
|
|
896
|
-
if hasattr(cfg.tune, "infer_epochs"):
|
|
897
|
-
cfg.tune.infer_epochs = 0
|
|
539
|
+
cfg.train.learning_rate = 5e-4
|
|
540
|
+
cfg.train.early_stop_gen = 50
|
|
541
|
+
cfg.train.max_epochs = 1000
|
|
542
|
+
|
|
543
|
+
# Tune
|
|
544
|
+
cfg.tune.patience = 50
|
|
898
545
|
|
|
899
546
|
return cfg
|
|
900
547
|
|
|
@@ -923,9 +570,9 @@ class MostFrequentAlgoConfig:
|
|
|
923
570
|
"""Algorithmic knobs for ImputeMostFrequent.
|
|
924
571
|
|
|
925
572
|
Attributes:
|
|
926
|
-
by_populations (bool): Whether to compute per-population modes.
|
|
927
|
-
default (int): Fallback mode if no valid entries in a locus.
|
|
928
|
-
missing (int): Code for missing genotypes in 0/1/2.
|
|
573
|
+
by_populations (bool): Whether to compute per-population modes. Default is False.
|
|
574
|
+
default (int): Fallback mode if no valid entries in a locus. Default is 0.
|
|
575
|
+
missing (int): Code for missing genotypes in 0/1/2. Default is -1.
|
|
929
576
|
"""
|
|
930
577
|
|
|
931
578
|
by_populations: bool = False
|
|
@@ -938,8 +585,8 @@ class DeterministicSplitConfig:
|
|
|
938
585
|
"""Evaluation split configuration shared by deterministic imputers.
|
|
939
586
|
|
|
940
587
|
Attributes:
|
|
941
|
-
test_size (float): Proportion of data to use as the test set.
|
|
942
|
-
test_indices (Optional[Sequence[int]]): Specific indices to use as the test set.
|
|
588
|
+
test_size (float): Proportion of data to use as the test set. Default is 0.2.
|
|
589
|
+
test_indices (Optional[Sequence[int]]): Specific indices to use as the test set. Default is None.
|
|
943
590
|
"""
|
|
944
591
|
|
|
945
592
|
test_size: float = 0.2
|
|
@@ -950,6 +597,10 @@ class DeterministicSplitConfig:
|
|
|
950
597
|
class MostFrequentConfig:
|
|
951
598
|
"""Top-level configuration for ImputeMostFrequent.
|
|
952
599
|
|
|
600
|
+
Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
|
|
601
|
+
and ``sim``. The ``train`` and ``tune`` sections are retained for schema
|
|
602
|
+
parity with NN models but are not currently used by ImputeMostFrequent.
|
|
603
|
+
|
|
953
604
|
Attributes:
|
|
954
605
|
io (IOConfig): I/O configuration.
|
|
955
606
|
plot (PlotConfig): Plotting configuration.
|
|
@@ -966,19 +617,27 @@ class MostFrequentConfig:
|
|
|
966
617
|
algo: MostFrequentAlgoConfig = field(default_factory=MostFrequentAlgoConfig)
|
|
967
618
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
968
619
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
969
|
-
train: TrainConfig = field(default_factory=
|
|
620
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
970
621
|
|
|
971
622
|
@classmethod
|
|
972
623
|
def from_preset(
|
|
973
624
|
cls,
|
|
974
625
|
preset: Literal["fast", "balanced", "thorough"] = "balanced",
|
|
975
626
|
) -> "MostFrequentConfig":
|
|
976
|
-
"""Construct a preset configuration.
|
|
627
|
+
"""Construct a preset configuration.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
MostFrequentConfig: Configuration instance corresponding to the preset.
|
|
634
|
+
"""
|
|
977
635
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
978
636
|
raise ValueError(f"Unknown preset: {preset}")
|
|
979
637
|
|
|
980
638
|
cfg = cls()
|
|
981
639
|
cfg.io.verbose = False
|
|
640
|
+
cfg.io.ploidy = 2
|
|
982
641
|
cfg.split.test_size = 0.2
|
|
983
642
|
cfg.sim.simulate_missing = True
|
|
984
643
|
cfg.sim.sim_strategy = "random"
|
|
@@ -1021,6 +680,10 @@ class RefAlleleAlgoConfig:
|
|
|
1021
680
|
class RefAlleleConfig:
|
|
1022
681
|
"""Top-level configuration for ImputeRefAllele.
|
|
1023
682
|
|
|
683
|
+
Deterministic imputers primarily use ``io``, ``plot``, ``split``, ``algo``,
|
|
684
|
+
and ``sim``. The ``train`` and ``tune`` sections are retained for schema
|
|
685
|
+
parity with NN models but are not currently used by ImputeRefAllele.
|
|
686
|
+
|
|
1024
687
|
Attributes:
|
|
1025
688
|
io (IOConfig): I/O configuration.
|
|
1026
689
|
plot (PlotConfig): Plotting configuration.
|
|
@@ -1037,18 +700,26 @@ class RefAlleleConfig:
|
|
|
1037
700
|
algo: RefAlleleAlgoConfig = field(default_factory=RefAlleleAlgoConfig)
|
|
1038
701
|
sim: SimConfig = field(default_factory=SimConfig)
|
|
1039
702
|
tune: TuneConfig = field(default_factory=TuneConfig)
|
|
1040
|
-
train: TrainConfig = field(default_factory=
|
|
703
|
+
train: TrainConfig = field(default_factory=_default_train_config)
|
|
1041
704
|
|
|
1042
705
|
@classmethod
|
|
1043
706
|
def from_preset(
|
|
1044
707
|
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
|
|
1045
708
|
) -> "RefAlleleConfig":
|
|
1046
|
-
"""Presets mainly keep parity with logging/IO and split test_size.
|
|
709
|
+
"""Presets mainly keep parity with logging/IO and split test_size.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
preset (Literal["fast", "balanced", "thorough"]): Preset name.
|
|
713
|
+
|
|
714
|
+
Returns:
|
|
715
|
+
RefAlleleConfig: Configuration instance corresponding to the preset.
|
|
716
|
+
"""
|
|
1047
717
|
if preset not in {"fast", "balanced", "thorough"}:
|
|
1048
718
|
raise ValueError(f"Unknown preset: {preset}")
|
|
1049
719
|
|
|
1050
720
|
cfg = cls()
|
|
1051
721
|
cfg.io.verbose = False
|
|
722
|
+
cfg.io.ploidy = 2
|
|
1052
723
|
cfg.split.test_size = 0.2
|
|
1053
724
|
cfg.sim.simulate_missing = True
|
|
1054
725
|
cfg.sim.sim_strategy = "random"
|
|
@@ -1261,16 +932,23 @@ class RFConfig:
|
|
|
1261
932
|
|
|
1262
933
|
@classmethod
|
|
1263
934
|
def from_preset(cls, preset: str = "balanced") -> "RFConfig":
|
|
1264
|
-
"""Build a config from a named preset.
|
|
935
|
+
"""Build a config from a named preset.
|
|
936
|
+
|
|
937
|
+
Args:
|
|
938
|
+
preset (str): Preset name.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
RFConfig: Configuration instance corresponding to the preset.
|
|
942
|
+
"""
|
|
1265
943
|
cfg = cls()
|
|
1266
944
|
if preset == "fast":
|
|
1267
|
-
cfg.model.n_estimators =
|
|
945
|
+
cfg.model.n_estimators = 50
|
|
1268
946
|
cfg.model.max_depth = None
|
|
1269
947
|
cfg.imputer.max_iter = 5
|
|
1270
948
|
cfg.io.n_jobs = 1
|
|
1271
949
|
cfg.tune.enabled = False
|
|
1272
950
|
elif preset == "balanced":
|
|
1273
|
-
cfg.model.n_estimators = 200
|
|
951
|
+
cfg.model.n_estimators = 200
|
|
1274
952
|
cfg.model.max_depth = None
|
|
1275
953
|
cfg.imputer.max_iter = 10
|
|
1276
954
|
cfg.io.n_jobs = 1
|
|
@@ -1279,7 +957,7 @@ class RFConfig:
|
|
|
1279
957
|
elif preset == "thorough":
|
|
1280
958
|
cfg.model.n_estimators = 500
|
|
1281
959
|
cfg.model.max_depth = 50 # Added safety cap
|
|
1282
|
-
cfg.imputer.max_iter =
|
|
960
|
+
cfg.imputer.max_iter = 20
|
|
1283
961
|
cfg.io.n_jobs = 1
|
|
1284
962
|
cfg.tune.enabled = False
|
|
1285
963
|
cfg.tune.n_trials = 250
|
|
@@ -1353,18 +1031,25 @@ class HGBConfig:
|
|
|
1353
1031
|
|
|
1354
1032
|
@classmethod
|
|
1355
1033
|
def from_preset(cls, preset: str = "balanced") -> "HGBConfig":
|
|
1356
|
-
"""Build a config from a named preset.
|
|
1034
|
+
"""Build a config from a named preset.
|
|
1035
|
+
|
|
1036
|
+
Args:
|
|
1037
|
+
preset (str): Preset name.
|
|
1038
|
+
|
|
1039
|
+
Returns:
|
|
1040
|
+
HGBConfig: Configuration instance corresponding to the preset.
|
|
1041
|
+
"""
|
|
1357
1042
|
cfg = cls()
|
|
1358
1043
|
if preset == "fast":
|
|
1359
1044
|
cfg.model.n_estimators = 50
|
|
1360
|
-
cfg.model.learning_rate = 0.
|
|
1045
|
+
cfg.model.learning_rate = 0.2
|
|
1361
1046
|
cfg.model.max_depth = None
|
|
1362
1047
|
cfg.imputer.max_iter = 5
|
|
1363
1048
|
cfg.io.n_jobs = 1
|
|
1364
1049
|
cfg.tune.enabled = False
|
|
1365
1050
|
cfg.tune.n_trials = 50
|
|
1366
1051
|
elif preset == "balanced":
|
|
1367
|
-
cfg.model.n_estimators =
|
|
1052
|
+
cfg.model.n_estimators = 150
|
|
1368
1053
|
cfg.model.learning_rate = 0.1
|
|
1369
1054
|
cfg.model.max_depth = None
|
|
1370
1055
|
cfg.imputer.max_iter = 10
|
|
@@ -1373,10 +1058,10 @@ class HGBConfig:
|
|
|
1373
1058
|
cfg.tune.n_trials = 100
|
|
1374
1059
|
elif preset == "thorough":
|
|
1375
1060
|
cfg.model.n_estimators = 500
|
|
1376
|
-
cfg.model.learning_rate = 0.05
|
|
1061
|
+
cfg.model.learning_rate = 0.05
|
|
1377
1062
|
cfg.model.n_iter_no_change = 20 # Increased patience
|
|
1378
1063
|
cfg.model.max_depth = None
|
|
1379
|
-
cfg.imputer.max_iter =
|
|
1064
|
+
cfg.imputer.max_iter = 20
|
|
1380
1065
|
cfg.io.n_jobs = 1
|
|
1381
1066
|
cfg.tune.enabled = False
|
|
1382
1067
|
cfg.tune.n_trials = 250
|