pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
pgsui/cli.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""PG-SUI Imputation CLI
|
|
5
|
+
|
|
6
|
+
Argument-precedence model:
|
|
7
|
+
code defaults < preset (--preset) < YAML (--config) < explicit CLI flags < --set k=vx
|
|
8
|
+
|
|
9
|
+
Notes
|
|
10
|
+
-----
|
|
11
|
+
- Preset is a CLI-only choice and will be respected unless overridden by YAML or CLI.
|
|
12
|
+
- YAML entries override preset (a 'preset' key in YAML is ignored with a warning).
|
|
13
|
+
- CLI flags only override when explicitly provided (argparse uses SUPPRESS).
|
|
14
|
+
- --set key=value has the highest precedence and applies dot-path overrides.
|
|
15
|
+
|
|
16
|
+
Examples
|
|
17
|
+
--------
|
|
18
|
+
python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix run1
|
|
19
|
+
python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
|
|
20
|
+
python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
|
|
21
|
+
--models ImputeUBP ImputeVAE ImputeMostFrequent --seed deterministic --verbose
|
|
22
|
+
python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
|
|
23
|
+
--include-pops EA GU TT ON --device cpu
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import argparse
|
|
29
|
+
import ast
|
|
30
|
+
import logging
|
|
31
|
+
import sys
|
|
32
|
+
import time
|
|
33
|
+
from functools import wraps
|
|
34
|
+
from pathlib import Path
|
|
35
|
+
from typing import (
|
|
36
|
+
Any,
|
|
37
|
+
Callable,
|
|
38
|
+
Dict,
|
|
39
|
+
Iterable,
|
|
40
|
+
List,
|
|
41
|
+
Literal,
|
|
42
|
+
Optional,
|
|
43
|
+
ParamSpec,
|
|
44
|
+
Tuple,
|
|
45
|
+
TypeVar,
|
|
46
|
+
cast,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
from snpio import GenePopReader, PhylipReader, SNPioMultiQC, VCFReader, TreeParser
|
|
50
|
+
|
|
51
|
+
from pgsui import (
|
|
52
|
+
AutoencoderConfig,
|
|
53
|
+
ImputeAutoencoder,
|
|
54
|
+
ImputeMostFrequent,
|
|
55
|
+
ImputeNLPCA,
|
|
56
|
+
ImputeRefAllele,
|
|
57
|
+
ImputeUBP,
|
|
58
|
+
ImputeVAE,
|
|
59
|
+
MostFrequentConfig,
|
|
60
|
+
NLPCAConfig,
|
|
61
|
+
RefAlleleConfig,
|
|
62
|
+
UBPConfig,
|
|
63
|
+
VAEConfig,
|
|
64
|
+
)
|
|
65
|
+
from pgsui.data_processing.config import (
|
|
66
|
+
apply_dot_overrides,
|
|
67
|
+
dataclass_to_yaml,
|
|
68
|
+
load_yaml_to_dataclass,
|
|
69
|
+
save_dataclass_yaml,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Canonical model order used everywhere (default and subset ordering)
|
|
73
|
+
MODEL_ORDER: Tuple[str, ...] = (
|
|
74
|
+
"ImputeUBP",
|
|
75
|
+
"ImputeVAE",
|
|
76
|
+
"ImputeAutoencoder",
|
|
77
|
+
"ImputeNLPCA",
|
|
78
|
+
"ImputeMostFrequent",
|
|
79
|
+
"ImputeRefAllele",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Strategies supported by SimMissingTransformer + SimConfig.
|
|
83
|
+
SIM_STRATEGY_CHOICES: Tuple[str, ...] = (
|
|
84
|
+
"random",
|
|
85
|
+
"random_weighted",
|
|
86
|
+
"random_weighted_inv",
|
|
87
|
+
"nonrandom",
|
|
88
|
+
"nonrandom_weighted",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
P = ParamSpec("P")
|
|
92
|
+
R = TypeVar("R")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# ----------------------------- CLI Utilities ----------------------------- #
|
|
96
|
+
def _configure_logging(verbose: bool, log_file: Optional[str] = None) -> None:
|
|
97
|
+
"""Configure root logger.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
verbose (bool): If True, INFO; else ERROR.
|
|
101
|
+
log_file (Optional[str]): Optional file to tee logs to.
|
|
102
|
+
"""
|
|
103
|
+
level = logging.INFO if verbose else logging.ERROR
|
|
104
|
+
handlers: List[logging.Handler] = [logging.StreamHandler(sys.stdout)]
|
|
105
|
+
if log_file:
|
|
106
|
+
handlers.append(logging.FileHandler(log_file, mode="w", encoding="utf-8"))
|
|
107
|
+
logging.basicConfig(
|
|
108
|
+
level=level,
|
|
109
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
110
|
+
handlers=handlers,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _parse_seed(seed_arg: str) -> Optional[int]:
|
|
115
|
+
"""Parse --seed argument into an int or None."""
|
|
116
|
+
s = seed_arg.strip().lower()
|
|
117
|
+
if s == "random":
|
|
118
|
+
return None
|
|
119
|
+
if s == "deterministic":
|
|
120
|
+
return 42
|
|
121
|
+
try:
|
|
122
|
+
return int(seed_arg)
|
|
123
|
+
except ValueError as e:
|
|
124
|
+
raise argparse.ArgumentTypeError(
|
|
125
|
+
"Invalid --seed. Use 'random', 'deterministic', or an integer."
|
|
126
|
+
) from e
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _parse_models(models: Iterable[str]) -> Tuple[str, ...]:
|
|
130
|
+
"""Validate and canonicalize model names in a deterministic order.
|
|
131
|
+
|
|
132
|
+
- If no models are provided, returns all in MODEL_ORDER.
|
|
133
|
+
- If a subset is provided via --models, returns them in MODEL_ORDER order.
|
|
134
|
+
"""
|
|
135
|
+
models = tuple(models) # in case it's a generator
|
|
136
|
+
valid = set(MODEL_ORDER)
|
|
137
|
+
|
|
138
|
+
# Validate first
|
|
139
|
+
unknown = [m for m in models if m not in valid]
|
|
140
|
+
if unknown:
|
|
141
|
+
raise argparse.ArgumentTypeError(
|
|
142
|
+
f"Unknown model(s): {unknown}. Valid options: {list(MODEL_ORDER)}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Default: all models in canonical order
|
|
146
|
+
if not models:
|
|
147
|
+
return MODEL_ORDER
|
|
148
|
+
|
|
149
|
+
# Subset: keep only those requested, but in canonical order
|
|
150
|
+
selected = tuple(m for m in MODEL_ORDER if m in models)
|
|
151
|
+
return selected
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _parse_overrides(pairs: list[str]) -> dict:
|
|
155
|
+
"""Parse --set key=value into typed values via literal_eval."""
|
|
156
|
+
out: dict = {}
|
|
157
|
+
for kv in pairs or []:
|
|
158
|
+
if "=" not in kv:
|
|
159
|
+
raise argparse.ArgumentTypeError(f"--set expects key=value, got '{kv}'")
|
|
160
|
+
k, v = kv.split("=", 1)
|
|
161
|
+
v = v.strip()
|
|
162
|
+
try:
|
|
163
|
+
out[k] = ast.literal_eval(v)
|
|
164
|
+
except Exception:
|
|
165
|
+
out[k] = v # raw string fallback
|
|
166
|
+
return out
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
170
|
+
"""Convert explicitly provided CLI flags into config dot-overrides."""
|
|
171
|
+
overrides: dict = {}
|
|
172
|
+
|
|
173
|
+
# IO / top-level controls
|
|
174
|
+
if hasattr(args, "prefix") and args.prefix is not None:
|
|
175
|
+
overrides["io.prefix"] = args.prefix
|
|
176
|
+
else:
|
|
177
|
+
# Note: we don't know input_path here; prefix default is handled later.
|
|
178
|
+
# This fallback is preserved to avoid changing semantics.
|
|
179
|
+
if hasattr(args, "vcf"):
|
|
180
|
+
overrides["io.prefix"] = str(Path(args.vcf).stem)
|
|
181
|
+
|
|
182
|
+
if hasattr(args, "verbose"):
|
|
183
|
+
overrides["io.verbose"] = bool(args.verbose)
|
|
184
|
+
if hasattr(args, "n_jobs"):
|
|
185
|
+
overrides["io.n_jobs"] = int(args.n_jobs)
|
|
186
|
+
if hasattr(args, "seed"):
|
|
187
|
+
overrides["io.seed"] = _parse_seed(args.seed)
|
|
188
|
+
|
|
189
|
+
# Train
|
|
190
|
+
if hasattr(args, "batch_size"):
|
|
191
|
+
overrides["train.batch_size"] = int(args.batch_size)
|
|
192
|
+
if hasattr(args, "device"):
|
|
193
|
+
dev = args.device
|
|
194
|
+
if dev == "cuda":
|
|
195
|
+
dev = "gpu"
|
|
196
|
+
overrides["train.device"] = dev
|
|
197
|
+
|
|
198
|
+
# Plot
|
|
199
|
+
if hasattr(args, "plot_format"):
|
|
200
|
+
overrides["plot.fmt"] = args.plot_format
|
|
201
|
+
|
|
202
|
+
# Simulation overrides (shared across config-driven models)
|
|
203
|
+
if hasattr(args, "sim_strategy"):
|
|
204
|
+
overrides["sim.sim_strategy"] = args.sim_strategy
|
|
205
|
+
if hasattr(args, "sim_prop"):
|
|
206
|
+
overrides["sim.sim_prop"] = float(args.sim_prop)
|
|
207
|
+
if hasattr(args, "simulate_missing"):
|
|
208
|
+
overrides["sim.simulate_missing"] = bool(args.simulate_missing)
|
|
209
|
+
|
|
210
|
+
# Tuning
|
|
211
|
+
if hasattr(args, "tune"):
|
|
212
|
+
overrides["tune.enabled"] = bool(args.tune)
|
|
213
|
+
if hasattr(args, "tune_n_trials"):
|
|
214
|
+
overrides["tune.n_trials"] = int(args.tune_n_trials)
|
|
215
|
+
|
|
216
|
+
return overrides
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _format_seconds(seconds: float) -> str:
|
|
220
|
+
total = int(round(seconds))
|
|
221
|
+
minutes, secs = divmod(total, 60)
|
|
222
|
+
hours, minutes = divmod(minutes, 60)
|
|
223
|
+
if hours:
|
|
224
|
+
return f"{hours:d}:{minutes:02d}:{secs:02d}"
|
|
225
|
+
return f"{minutes:d}:{secs:02d}"
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def log_model_time(fn: Callable[P, R]) -> Callable[P, R]:
|
|
229
|
+
"""Decorator to time run_model_safely; assumes model_name is first arg."""
|
|
230
|
+
|
|
231
|
+
@wraps(fn)
|
|
232
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
233
|
+
model_name = str(args[0]) if args else "<unknown model>"
|
|
234
|
+
start = time.perf_counter()
|
|
235
|
+
try:
|
|
236
|
+
result = fn(*args, **kwargs)
|
|
237
|
+
except Exception:
|
|
238
|
+
elapsed = time.perf_counter() - start
|
|
239
|
+
logging.error(
|
|
240
|
+
f"{model_name} failed after {elapsed:0.2f}s "
|
|
241
|
+
f"({_format_seconds(elapsed)}).",
|
|
242
|
+
exc_info=True,
|
|
243
|
+
)
|
|
244
|
+
raise
|
|
245
|
+
elapsed = time.perf_counter() - start
|
|
246
|
+
logging.info(
|
|
247
|
+
f"{model_name} finished in {elapsed:0.2f}s "
|
|
248
|
+
f"({_format_seconds(elapsed)})."
|
|
249
|
+
)
|
|
250
|
+
return result
|
|
251
|
+
|
|
252
|
+
return cast(Callable[P, R], wrapper)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# ------------------------------ Core Runner ------------------------------ #
|
|
256
|
+
def build_genotype_data(
|
|
257
|
+
input_path: str,
|
|
258
|
+
fmt: Literal["vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"],
|
|
259
|
+
popmap_path: str | None,
|
|
260
|
+
treefile: str | None,
|
|
261
|
+
qmatrix: str | None,
|
|
262
|
+
siterates: str | None,
|
|
263
|
+
force_popmap: bool,
|
|
264
|
+
verbose: bool,
|
|
265
|
+
include_pops: List[str] | None,
|
|
266
|
+
plot_format: Literal["pdf", "png", "jpg", "jpeg"],
|
|
267
|
+
):
|
|
268
|
+
"""Load genotype data from heterogeneous inputs."""
|
|
269
|
+
logging.info(f"Loading {fmt.upper()} and popmap data...")
|
|
270
|
+
|
|
271
|
+
kwargs = {
|
|
272
|
+
"filename": input_path,
|
|
273
|
+
"popmapfile": popmap_path,
|
|
274
|
+
"force_popmap": force_popmap,
|
|
275
|
+
"verbose": verbose,
|
|
276
|
+
"include_pops": include_pops if include_pops else None,
|
|
277
|
+
"prefix": f"snpio_{Path(input_path).stem}",
|
|
278
|
+
"plot_format": plot_format,
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
if fmt == "vcf":
|
|
282
|
+
gd = VCFReader(**kwargs)
|
|
283
|
+
elif fmt == "phylip":
|
|
284
|
+
gd = PhylipReader(**kwargs)
|
|
285
|
+
elif fmt == "genepop":
|
|
286
|
+
gd = GenePopReader(**kwargs)
|
|
287
|
+
else:
|
|
288
|
+
raise ValueError(f"Unsupported genotype data format: {fmt}")
|
|
289
|
+
|
|
290
|
+
tp = None
|
|
291
|
+
if treefile is not None:
|
|
292
|
+
logging.info("Parsing phylogenetic tree...")
|
|
293
|
+
|
|
294
|
+
tp = TreeParser(
|
|
295
|
+
gd, treefile=treefile, qmatrix=qmatrix, siterates=siterates, verbose=True
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
logging.info("Loaded genotype data.")
|
|
299
|
+
return gd, tp
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@log_model_time
|
|
303
|
+
def run_model_safely(model_name: str, builder, *, warn_only: bool = True) -> None:
|
|
304
|
+
"""Run model builder + fit/transform with error isolation."""
|
|
305
|
+
logging.info(f"▶ Running {model_name} ...")
|
|
306
|
+
try:
|
|
307
|
+
model = builder()
|
|
308
|
+
model.fit()
|
|
309
|
+
X_imputed = model.transform()
|
|
310
|
+
logging.info(f"✓ {model_name} completed.")
|
|
311
|
+
return X_imputed
|
|
312
|
+
except Exception as e:
|
|
313
|
+
if warn_only:
|
|
314
|
+
logging.warning(f"⚠ {model_name} failed: {e}", exc_info=True)
|
|
315
|
+
else:
|
|
316
|
+
raise
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# -------------------------- Model Registry ------------------------------- #
|
|
320
|
+
# Add config-driven models here by listing the class and its config dataclass.
|
|
321
|
+
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|
322
|
+
"ImputeUBP": {"cls": ImputeUBP, "config_cls": UBPConfig},
|
|
323
|
+
"ImputeNLPCA": {"cls": ImputeNLPCA, "config_cls": NLPCAConfig},
|
|
324
|
+
"ImputeAutoencoder": {"cls": ImputeAutoencoder, "config_cls": AutoencoderConfig},
|
|
325
|
+
"ImputeVAE": {"cls": ImputeVAE, "config_cls": VAEConfig},
|
|
326
|
+
"ImputeMostFrequent": {"cls": ImputeMostFrequent, "config_cls": MostFrequentConfig},
|
|
327
|
+
"ImputeRefAllele": {"cls": ImputeRefAllele, "config_cls": RefAlleleConfig},
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _build_effective_config_for_model(
|
|
332
|
+
model_name: str, args: argparse.Namespace
|
|
333
|
+
) -> Any | None:
|
|
334
|
+
"""Build the effective config object for a specific model (if it has one).
|
|
335
|
+
|
|
336
|
+
Precedence (lowest → highest):
|
|
337
|
+
defaults < preset (--preset) < YAML (--config) < explicit CLI flags < --set
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Config dataclass instance or None (for models without config dataclasses).
|
|
341
|
+
"""
|
|
342
|
+
reg = MODEL_REGISTRY[model_name]
|
|
343
|
+
cfg_cls = reg.get("config_cls")
|
|
344
|
+
|
|
345
|
+
if cfg_cls is None:
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
# 0) Start from pure dataclass defaults.
|
|
349
|
+
cfg = cfg_cls()
|
|
350
|
+
|
|
351
|
+
# 1) If user explicitly provided a preset, overlay it.
|
|
352
|
+
if hasattr(args, "preset"):
|
|
353
|
+
preset_name = args.preset
|
|
354
|
+
cfg = cfg_cls.from_preset(preset_name)
|
|
355
|
+
logging.info(f"Initialized {model_name} from '{preset_name}' preset.")
|
|
356
|
+
else:
|
|
357
|
+
logging.info(f"Initialized {model_name} from dataclass defaults (no preset).")
|
|
358
|
+
|
|
359
|
+
# 2) YAML overlays preset/defaults (boss). Ignore any 'preset' in YAML.
|
|
360
|
+
yaml_path = getattr(args, "config", None)
|
|
361
|
+
|
|
362
|
+
if yaml_path:
|
|
363
|
+
cfg = load_yaml_to_dataclass(
|
|
364
|
+
yaml_path,
|
|
365
|
+
cfg_cls,
|
|
366
|
+
base=cfg,
|
|
367
|
+
yaml_preset_behavior="ignore", # 'preset' key in YAML ignored with warning
|
|
368
|
+
)
|
|
369
|
+
logging.info(
|
|
370
|
+
f"Loaded YAML config for {model_name} from {yaml_path} (ignored 'preset' in YAML if present)."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# 3) Explicit CLI flags overlay YAML.
|
|
374
|
+
cli_overrides = _args_to_cli_overrides(args)
|
|
375
|
+
if cli_overrides:
|
|
376
|
+
cfg = apply_dot_overrides(cfg, cli_overrides)
|
|
377
|
+
|
|
378
|
+
# 4) --set has highest precedence.
|
|
379
|
+
user_overrides = _parse_overrides(getattr(args, "set", []))
|
|
380
|
+
|
|
381
|
+
if user_overrides:
|
|
382
|
+
try:
|
|
383
|
+
cfg = apply_dot_overrides(cfg, user_overrides)
|
|
384
|
+
except Exception as e:
|
|
385
|
+
if model_name in {
|
|
386
|
+
"ImputeUBP",
|
|
387
|
+
"ImputeNLPCA",
|
|
388
|
+
"ImputeAutoencoder",
|
|
389
|
+
"ImputeVAE",
|
|
390
|
+
}:
|
|
391
|
+
logging.error(
|
|
392
|
+
f"Error applying --set overrides to {model_name} config: {e}"
|
|
393
|
+
)
|
|
394
|
+
raise
|
|
395
|
+
else:
|
|
396
|
+
pass # non-config-driven models ignore --set
|
|
397
|
+
|
|
398
|
+
return cfg
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _maybe_print_or_dump_configs(
|
|
402
|
+
cfgs_by_model: Dict[str, Any], args: argparse.Namespace
|
|
403
|
+
) -> bool:
|
|
404
|
+
"""Handle --print-config / --dump-config for ALL config-driven models selected.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
True if we printed/dumped and should exit; else False.
|
|
408
|
+
"""
|
|
409
|
+
did_io = False
|
|
410
|
+
if getattr(args, "print_config", False):
|
|
411
|
+
for m, cfg in cfgs_by_model.items():
|
|
412
|
+
if cfg is None:
|
|
413
|
+
continue
|
|
414
|
+
print(f"# --- {m} effective config ---")
|
|
415
|
+
print(dataclass_to_yaml(cfg))
|
|
416
|
+
print()
|
|
417
|
+
did_io = True
|
|
418
|
+
|
|
419
|
+
if hasattr(args, "dump_config") and args.dump_config:
|
|
420
|
+
# If multiple models, add suffix per model (before extension if possible)
|
|
421
|
+
dump_base = args.dump_config
|
|
422
|
+
for m, cfg in cfgs_by_model.items():
|
|
423
|
+
if cfg is None:
|
|
424
|
+
continue
|
|
425
|
+
if "." in dump_base:
|
|
426
|
+
stem, ext = dump_base.rsplit(".", 1)
|
|
427
|
+
path = f"{stem}.{m}.{ext}"
|
|
428
|
+
else:
|
|
429
|
+
path = f"{dump_base}.{m}.yaml"
|
|
430
|
+
save_dataclass_yaml(cfg, path)
|
|
431
|
+
logging.info(f"Saved {m} config to {path}")
|
|
432
|
+
did_io = True
|
|
433
|
+
|
|
434
|
+
return did_io
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def main(argv: Optional[List[str]] = None) -> int:
|
|
438
|
+
parser = argparse.ArgumentParser(
|
|
439
|
+
prog="pg-sui",
|
|
440
|
+
description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models.",
|
|
441
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
442
|
+
usage="%(prog)s [options]",
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# ----------------------------- Required I/O ----------------------------- #
|
|
446
|
+
parser.add_argument(
|
|
447
|
+
"--input",
|
|
448
|
+
default=argparse.SUPPRESS,
|
|
449
|
+
help="Path to input file (VCF/PHYLIP/STRUCTURE/GENEPOP).",
|
|
450
|
+
)
|
|
451
|
+
parser.add_argument(
|
|
452
|
+
"--format",
|
|
453
|
+
choices=(
|
|
454
|
+
"infer",
|
|
455
|
+
"vcf",
|
|
456
|
+
"vcf.gz",
|
|
457
|
+
"phy",
|
|
458
|
+
"phylip",
|
|
459
|
+
"str",
|
|
460
|
+
"structure",
|
|
461
|
+
"genepop",
|
|
462
|
+
"gen",
|
|
463
|
+
),
|
|
464
|
+
default=argparse.SUPPRESS,
|
|
465
|
+
help="Input format. If 'infer', deduced from file extension. The default is 'infer'.",
|
|
466
|
+
)
|
|
467
|
+
# Back-compat: --vcf retained; if both provided, --input wins.
|
|
468
|
+
parser.add_argument(
|
|
469
|
+
"--vcf",
|
|
470
|
+
default=argparse.SUPPRESS,
|
|
471
|
+
help="Path to input VCF file. Can be bgzipped or uncompressed.",
|
|
472
|
+
)
|
|
473
|
+
parser.add_argument(
|
|
474
|
+
"--popmap",
|
|
475
|
+
default=argparse.SUPPRESS,
|
|
476
|
+
help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs.",
|
|
477
|
+
)
|
|
478
|
+
parser.add_argument(
|
|
479
|
+
"--treefile",
|
|
480
|
+
default=argparse.SUPPRESS,
|
|
481
|
+
help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format.",
|
|
482
|
+
)
|
|
483
|
+
parser.add_argument(
|
|
484
|
+
"--qmatrix",
|
|
485
|
+
default=argparse.SUPPRESS,
|
|
486
|
+
help="Path to IQ-TREE output file (has .iqtree extension) that contains Rate Matrix Q. Used with --treefile and --siterates.",
|
|
487
|
+
)
|
|
488
|
+
parser.add_argument(
|
|
489
|
+
"--siterates",
|
|
490
|
+
default=argparse.SUPPRESS,
|
|
491
|
+
help="Path to SNP site rates file (has .rate extension). Used with --treefile and --qmatrix.",
|
|
492
|
+
)
|
|
493
|
+
parser.add_argument(
|
|
494
|
+
"--prefix",
|
|
495
|
+
default=argparse.SUPPRESS,
|
|
496
|
+
help="Output file prefix.",
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# ---------------------- Generic Config Inputs -------------------------- #
|
|
500
|
+
parser.add_argument(
|
|
501
|
+
"--config",
|
|
502
|
+
default=argparse.SUPPRESS,
|
|
503
|
+
help="YAML config for config-driven models (NLPCA/UBP/Autoencoder/VAE).",
|
|
504
|
+
)
|
|
505
|
+
parser.add_argument(
|
|
506
|
+
"--preset",
|
|
507
|
+
choices=("fast", "balanced", "thorough"),
|
|
508
|
+
default=argparse.SUPPRESS, # <-- no default; optional
|
|
509
|
+
help="If provided, initialize config(s) from this preset; otherwise start from dataclass defaults.",
|
|
510
|
+
)
|
|
511
|
+
parser.add_argument(
|
|
512
|
+
"--set",
|
|
513
|
+
action="append",
|
|
514
|
+
default=argparse.SUPPRESS,
|
|
515
|
+
help="Dot-key overrides, e.g. --set model.latent_dim=4",
|
|
516
|
+
)
|
|
517
|
+
parser.add_argument(
|
|
518
|
+
"--print-config",
|
|
519
|
+
action="store_true",
|
|
520
|
+
help="Print effective config(s) and exit.",
|
|
521
|
+
)
|
|
522
|
+
parser.add_argument(
|
|
523
|
+
"--dump-config",
|
|
524
|
+
default=argparse.SUPPRESS,
|
|
525
|
+
help="Write effective config(s) YAML to this path (multi-model gets suffixed).",
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# ------------------------------ Toggles -------------------------------- #
|
|
529
|
+
parser.add_argument(
|
|
530
|
+
"--tune",
|
|
531
|
+
action="store_true",
|
|
532
|
+
default=argparse.SUPPRESS,
|
|
533
|
+
help="Enable hyperparameter tuning (if supported).",
|
|
534
|
+
)
|
|
535
|
+
parser.add_argument(
|
|
536
|
+
"--tune-n-trials",
|
|
537
|
+
type=int,
|
|
538
|
+
default=argparse.SUPPRESS,
|
|
539
|
+
help="Optuna trials when --tune is set.",
|
|
540
|
+
)
|
|
541
|
+
parser.add_argument(
|
|
542
|
+
"--batch-size",
|
|
543
|
+
type=int,
|
|
544
|
+
default=argparse.SUPPRESS,
|
|
545
|
+
help="Batch size for NN-based models.",
|
|
546
|
+
)
|
|
547
|
+
parser.add_argument(
|
|
548
|
+
"--device",
|
|
549
|
+
choices=("cpu", "cuda", "mps"),
|
|
550
|
+
default=argparse.SUPPRESS,
|
|
551
|
+
help="Compute device for NN-based models.",
|
|
552
|
+
)
|
|
553
|
+
parser.add_argument(
|
|
554
|
+
"--n-jobs",
|
|
555
|
+
type=int,
|
|
556
|
+
default=argparse.SUPPRESS,
|
|
557
|
+
help="Parallel workers for various steps.",
|
|
558
|
+
)
|
|
559
|
+
parser.add_argument(
|
|
560
|
+
"--plot-format",
|
|
561
|
+
choices=("png", "pdf", "svg", "jpg", "jpeg"),
|
|
562
|
+
default=argparse.SUPPRESS,
|
|
563
|
+
help="Figure format for model plots.",
|
|
564
|
+
)
|
|
565
|
+
# ------------------------- Simulation Controls ------------------------ #
|
|
566
|
+
parser.add_argument(
|
|
567
|
+
"--sim-strategy",
|
|
568
|
+
choices=SIM_STRATEGY_CHOICES,
|
|
569
|
+
default=argparse.SUPPRESS,
|
|
570
|
+
help="Override the missing-data simulation strategy for all config-driven models.",
|
|
571
|
+
)
|
|
572
|
+
parser.add_argument(
|
|
573
|
+
"--sim-prop",
|
|
574
|
+
type=float,
|
|
575
|
+
default=argparse.SUPPRESS,
|
|
576
|
+
help="Override the proportion of observed entries to mask during simulation (0-1).",
|
|
577
|
+
)
|
|
578
|
+
parser.add_argument(
|
|
579
|
+
"--simulate-missing",
|
|
580
|
+
action="store_false",
|
|
581
|
+
default=argparse.SUPPRESS,
|
|
582
|
+
help="Disable missing-data simulation regardless of preset/config (when provided).",
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# --------------------------- Seed & logging ---------------------------- #
|
|
586
|
+
parser.add_argument(
|
|
587
|
+
"--seed",
|
|
588
|
+
default=argparse.SUPPRESS,
|
|
589
|
+
help="Random seed: 'random', 'deterministic', or an integer.",
|
|
590
|
+
)
|
|
591
|
+
parser.add_argument("--verbose", action="store_true", help="Info-level logging.")
|
|
592
|
+
parser.add_argument(
|
|
593
|
+
"--log-file", default=argparse.SUPPRESS, help="Also write logs to a file."
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
# ---------------------------- Data filtering --------------------------- #
|
|
597
|
+
parser.add_argument(
|
|
598
|
+
"--include-pops",
|
|
599
|
+
nargs="+",
|
|
600
|
+
default=argparse.SUPPRESS,
|
|
601
|
+
help="Optional list of population IDs to include.",
|
|
602
|
+
)
|
|
603
|
+
parser.add_argument(
|
|
604
|
+
"--force-popmap",
|
|
605
|
+
action="store_true",
|
|
606
|
+
default=False,
|
|
607
|
+
help="Require popmap (error if absent).",
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# ---------------------------- Model selection -------------------------- #
|
|
611
|
+
parser.add_argument(
|
|
612
|
+
"--models",
|
|
613
|
+
nargs="+",
|
|
614
|
+
default=argparse.SUPPRESS,
|
|
615
|
+
help=(
|
|
616
|
+
"Which models to run. Choices: ImputeUBP ImputeVAE ImputeAutoencoder ImputeNLPCA ImputeMostFrequent ImputeRefAllele. Default is all."
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# -------------------------- MultiQC integration ------------------------ #
|
|
621
|
+
parser.add_argument(
|
|
622
|
+
"--multiqc",
|
|
623
|
+
action="store_true",
|
|
624
|
+
help=(
|
|
625
|
+
"Build a MultiQC HTML report at the end of the run, combining SNPio and PG-SUI plots (requires SNPio's MultiQC module)."
|
|
626
|
+
),
|
|
627
|
+
)
|
|
628
|
+
parser.add_argument(
|
|
629
|
+
"--multiqc-title",
|
|
630
|
+
default=argparse.SUPPRESS,
|
|
631
|
+
help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>').",
|
|
632
|
+
)
|
|
633
|
+
parser.add_argument(
|
|
634
|
+
"--multiqc-output-dir",
|
|
635
|
+
default=argparse.SUPPRESS,
|
|
636
|
+
help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc').",
|
|
637
|
+
)
|
|
638
|
+
parser.add_argument(
|
|
639
|
+
"--multiqc-overwrite",
|
|
640
|
+
action="store_true",
|
|
641
|
+
default=False,
|
|
642
|
+
help="Overwrite an existing MultiQC report if present.",
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# ------------------------------ Safety/UX ------------------------------ #
|
|
646
|
+
parser.add_argument(
|
|
647
|
+
"--dry-run",
|
|
648
|
+
action="store_true",
|
|
649
|
+
help="Parse args and load data, but skip model training.",
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
args = parser.parse_args(argv)
|
|
653
|
+
|
|
654
|
+
# Logging (verbose default is False unless passed)
|
|
655
|
+
_configure_logging(
|
|
656
|
+
verbose=getattr(args, "verbose", False),
|
|
657
|
+
log_file=getattr(args, "log_file", None),
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Models selection (default to all if not explicitly provided)
|
|
661
|
+
try:
|
|
662
|
+
selected_models = _parse_models(getattr(args, "models", ()))
|
|
663
|
+
except argparse.ArgumentTypeError as e:
|
|
664
|
+
parser.error(str(e))
|
|
665
|
+
return 2
|
|
666
|
+
|
|
667
|
+
# Input resolution
|
|
668
|
+
input_path = getattr(args, "input", None)
|
|
669
|
+
if input_path is None and hasattr(args, "vcf"):
|
|
670
|
+
input_path = args.vcf
|
|
671
|
+
if not hasattr(args, "format"):
|
|
672
|
+
setattr(args, "format", "vcf")
|
|
673
|
+
|
|
674
|
+
if input_path is None:
|
|
675
|
+
parser.error("You must provide --input (or legacy --vcf).")
|
|
676
|
+
return 2
|
|
677
|
+
|
|
678
|
+
fmt: Literal["infer", "vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"] = getattr(
|
|
679
|
+
args, "format", "infer"
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
if fmt == "infer":
|
|
683
|
+
if input_path.endswith((".vcf", ".vcf.gz")):
|
|
684
|
+
fmt_final = "vcf"
|
|
685
|
+
elif input_path.endswith((".phy", ".phylip")):
|
|
686
|
+
fmt_final = "phylip"
|
|
687
|
+
elif input_path.endswith((".genepop", ".gen")):
|
|
688
|
+
fmt_final = "genepop"
|
|
689
|
+
else:
|
|
690
|
+
parser.error(
|
|
691
|
+
"Could not infer input format from file extension. Please provide --format."
|
|
692
|
+
)
|
|
693
|
+
return 2
|
|
694
|
+
else:
|
|
695
|
+
fmt_final = fmt
|
|
696
|
+
|
|
697
|
+
popmap_path = getattr(args, "popmap", None)
|
|
698
|
+
include_pops = getattr(args, "include_pops", None)
|
|
699
|
+
verbose_flag = getattr(args, "verbose", False)
|
|
700
|
+
force_popmap = bool(getattr(args, "force_popmap", False))
|
|
701
|
+
|
|
702
|
+
# Canonical prefix for this run (used for outputs and MultiQC)
|
|
703
|
+
prefix: str = getattr(args, "prefix", str(Path(input_path).stem))
|
|
704
|
+
|
|
705
|
+
treefile = getattr(args, "treefile", None)
|
|
706
|
+
qmatrix = getattr(args, "qmatrix", None)
|
|
707
|
+
siterates = getattr(args, "siterates", None)
|
|
708
|
+
|
|
709
|
+
if any(x is not None for x in (treefile, qmatrix, siterates)):
|
|
710
|
+
if not all(x is not None for x in (treefile, qmatrix, siterates)):
|
|
711
|
+
parser.error(
|
|
712
|
+
"--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
|
|
713
|
+
)
|
|
714
|
+
return 2
|
|
715
|
+
|
|
716
|
+
# Load genotype data
|
|
717
|
+
gd, tp = build_genotype_data(
|
|
718
|
+
input_path=input_path,
|
|
719
|
+
fmt=fmt_final,
|
|
720
|
+
popmap_path=popmap_path,
|
|
721
|
+
treefile=treefile,
|
|
722
|
+
qmatrix=qmatrix,
|
|
723
|
+
siterates=siterates,
|
|
724
|
+
force_popmap=force_popmap,
|
|
725
|
+
verbose=verbose_flag,
|
|
726
|
+
include_pops=include_pops,
|
|
727
|
+
plot_format=getattr(args, "plot_format", "pdf"),
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
if getattr(args, "dry_run", False):
|
|
731
|
+
logging.info("Dry run complete. Exiting without training models.")
|
|
732
|
+
return 0
|
|
733
|
+
|
|
734
|
+
# ---------------- Build config(s) per selected model ------------------- #
|
|
735
|
+
cfgs_by_model: Dict[str, Any] = {
|
|
736
|
+
m: _build_effective_config_for_model(m, args) for m in selected_models
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
# Maybe print/dump configs and exit
|
|
740
|
+
if _maybe_print_or_dump_configs(cfgs_by_model, args):
|
|
741
|
+
return 0
|
|
742
|
+
|
|
743
|
+
# ------------------------- Model Builders ------------------------------ #
|
|
744
|
+
def build_impute_ubp():
|
|
745
|
+
cfg = cfgs_by_model.get("ImputeUBP")
|
|
746
|
+
if cfg is None:
|
|
747
|
+
cfg = (
|
|
748
|
+
UBPConfig.from_preset(args.preset)
|
|
749
|
+
if hasattr(args, "preset")
|
|
750
|
+
else UBPConfig()
|
|
751
|
+
)
|
|
752
|
+
return ImputeUBP(
|
|
753
|
+
genotype_data=gd,
|
|
754
|
+
tree_parser=tp,
|
|
755
|
+
config=cfg,
|
|
756
|
+
simulate_missing=cfg.sim.simulate_missing,
|
|
757
|
+
sim_strategy=cfg.sim.sim_strategy,
|
|
758
|
+
sim_prop=cfg.sim.sim_prop,
|
|
759
|
+
sim_kwargs=cfg.sim.sim_kwargs,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
def build_impute_nlpca():
|
|
763
|
+
cfg = cfgs_by_model.get("ImputeNLPCA")
|
|
764
|
+
if cfg is None:
|
|
765
|
+
cfg = (
|
|
766
|
+
NLPCAConfig.from_preset(args.preset)
|
|
767
|
+
if hasattr(args, "preset")
|
|
768
|
+
else NLPCAConfig()
|
|
769
|
+
)
|
|
770
|
+
return ImputeNLPCA(
|
|
771
|
+
genotype_data=gd,
|
|
772
|
+
tree_parser=tp,
|
|
773
|
+
config=cfg,
|
|
774
|
+
simulate_missing=cfg.sim.simulate_missing,
|
|
775
|
+
sim_strategy=cfg.sim.sim_strategy,
|
|
776
|
+
sim_prop=cfg.sim.sim_prop,
|
|
777
|
+
sim_kwargs=cfg.sim.sim_kwargs,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
def build_impute_vae():
|
|
781
|
+
cfg = cfgs_by_model.get("ImputeVAE")
|
|
782
|
+
if cfg is None:
|
|
783
|
+
cfg = (
|
|
784
|
+
VAEConfig.from_preset(args.preset)
|
|
785
|
+
if hasattr(args, "preset")
|
|
786
|
+
else VAEConfig()
|
|
787
|
+
)
|
|
788
|
+
return ImputeVAE(
|
|
789
|
+
genotype_data=gd,
|
|
790
|
+
tree_parser=tp,
|
|
791
|
+
config=cfg,
|
|
792
|
+
simulate_missing=cfg.sim.simulate_missing,
|
|
793
|
+
sim_strategy=cfg.sim.sim_strategy,
|
|
794
|
+
sim_prop=cfg.sim.sim_prop,
|
|
795
|
+
sim_kwargs=cfg.sim.sim_kwargs,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
def build_impute_autoencoder():
|
|
799
|
+
cfg = cfgs_by_model.get("ImputeAutoencoder")
|
|
800
|
+
if cfg is None:
|
|
801
|
+
cfg = (
|
|
802
|
+
AutoencoderConfig.from_preset(args.preset)
|
|
803
|
+
if hasattr(args, "preset")
|
|
804
|
+
else AutoencoderConfig()
|
|
805
|
+
)
|
|
806
|
+
return ImputeAutoencoder(
|
|
807
|
+
genotype_data=gd,
|
|
808
|
+
tree_parser=tp,
|
|
809
|
+
config=cfg,
|
|
810
|
+
simulate_missing=cfg.sim.simulate_missing,
|
|
811
|
+
sim_strategy=cfg.sim.sim_strategy,
|
|
812
|
+
sim_prop=cfg.sim.sim_prop,
|
|
813
|
+
sim_kwargs=cfg.sim.sim_kwargs,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
def build_impute_mostfreq():
|
|
817
|
+
cfg = cfgs_by_model.get("ImputeMostFrequent")
|
|
818
|
+
if cfg is None:
|
|
819
|
+
cfg = (
|
|
820
|
+
MostFrequentConfig.from_preset(args.preset)
|
|
821
|
+
if hasattr(args, "preset")
|
|
822
|
+
else MostFrequentConfig()
|
|
823
|
+
)
|
|
824
|
+
return ImputeMostFrequent(
|
|
825
|
+
gd,
|
|
826
|
+
tree_parser=tp,
|
|
827
|
+
config=cfg,
|
|
828
|
+
simulate_missing=cfg.sim.simulate_missing,
|
|
829
|
+
sim_strategy=cfg.sim.sim_strategy,
|
|
830
|
+
sim_prop=cfg.sim.sim_prop,
|
|
831
|
+
sim_kwargs=cfg.sim.sim_kwargs,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
def build_impute_refallele():
|
|
835
|
+
cfg = cfgs_by_model.get("ImputeRefAllele")
|
|
836
|
+
if cfg is None:
|
|
837
|
+
cfg = (
|
|
838
|
+
RefAlleleConfig.from_preset(args.preset)
|
|
839
|
+
if hasattr(args, "preset")
|
|
840
|
+
else RefAlleleConfig()
|
|
841
|
+
)
|
|
842
|
+
return ImputeRefAllele(
|
|
843
|
+
gd,
|
|
844
|
+
tree_parser=tp,
|
|
845
|
+
config=cfg,
|
|
846
|
+
simulate_missing=cfg.sim.simulate_missing,
|
|
847
|
+
sim_strategy=cfg.sim.sim_strategy,
|
|
848
|
+
sim_prop=cfg.sim.sim_prop,
|
|
849
|
+
sim_kwargs=cfg.sim.sim_kwargs,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
model_builders = {
|
|
853
|
+
"ImputeUBP": build_impute_ubp,
|
|
854
|
+
"ImputeVAE": build_impute_vae,
|
|
855
|
+
"ImputeAutoencoder": build_impute_autoencoder,
|
|
856
|
+
"ImputeNLPCA": build_impute_nlpca,
|
|
857
|
+
"ImputeMostFrequent": build_impute_mostfreq,
|
|
858
|
+
"ImputeRefAllele": build_impute_refallele,
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
logging.info(f"Selected models: {', '.join(selected_models)}")
|
|
862
|
+
for name in selected_models:
|
|
863
|
+
X_imputed = run_model_safely(name, model_builders[name], warn_only=False)
|
|
864
|
+
gd_imp = gd.copy()
|
|
865
|
+
gd_imp.snp_data = X_imputed
|
|
866
|
+
|
|
867
|
+
if name in {"ImputeUBP", "ImputeVAE", "ImputeAutoencoder", "ImputeNLPCA"}:
|
|
868
|
+
family = "Unsupervised"
|
|
869
|
+
elif name in {"ImputeMostFrequent", "ImputeRefAllele"}:
|
|
870
|
+
family = "Deterministic"
|
|
871
|
+
elif name in {"ImputeHistGradientBoosting", "ImputeRandomForest"}:
|
|
872
|
+
family = "Supervised"
|
|
873
|
+
else:
|
|
874
|
+
raise ValueError(f"Unknown model family for {name}")
|
|
875
|
+
|
|
876
|
+
pth = Path(f"{prefix}_output/{family}/imputed/{name}")
|
|
877
|
+
pth.mkdir(parents=True, exist_ok=True)
|
|
878
|
+
|
|
879
|
+
logging.info(f"Writing imputed VCF for {name} to {pth} ...")
|
|
880
|
+
gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
|
|
881
|
+
|
|
882
|
+
logging.info("All requested models processed.")
|
|
883
|
+
|
|
884
|
+
# -------------------------- MultiQC builder ---------------------------- #
|
|
885
|
+
|
|
886
|
+
mqc_output_dir = getattr(args, "multiqc_output_dir", f"{prefix}_output/multiqc")
|
|
887
|
+
mqc_title = getattr(args, "multiqc_title", f"PG-SUI MultiQC Report - {prefix}")
|
|
888
|
+
overwrite = bool(getattr(args, "multiqc_overwrite", False))
|
|
889
|
+
|
|
890
|
+
logging.info(
|
|
891
|
+
f"Building MultiQC report in '{mqc_output_dir}' (title={mqc_title}, overwrite={overwrite})..."
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
try:
|
|
895
|
+
SNPioMultiQC.build(
|
|
896
|
+
prefix=prefix,
|
|
897
|
+
output_dir=mqc_output_dir,
|
|
898
|
+
title=mqc_title,
|
|
899
|
+
overwrite=overwrite,
|
|
900
|
+
)
|
|
901
|
+
logging.info("MultiQC report successfully built.")
|
|
902
|
+
except Exception as exc2: # pragma: no cover
|
|
903
|
+
logging.error(f"Failed to build MultiQC report: {exc2}", exc_info=True)
|
|
904
|
+
|
|
905
|
+
return 0
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
if __name__ == "__main__":
|
|
909
|
+
raise SystemExit(main())
|