pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
pgsui/cli.py
CHANGED
|
@@ -15,19 +15,21 @@ Notes
|
|
|
15
15
|
|
|
16
16
|
Examples
|
|
17
17
|
--------
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
--models
|
|
22
|
-
|
|
23
|
-
--include-pops EA GU TT ON --device cpu
|
|
18
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix run1
|
|
19
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
|
|
20
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
|
|
21
|
+
--models ImputeAutoencoder ImputeVAE ImputeMostFrequent --seed deterministic --verbose
|
|
22
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
|
|
23
|
+
--include-pops EA GU TT ON --device cpu --sim-prop 0.3 --sim-strategy nonrandom
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
26
|
from __future__ import annotations
|
|
27
27
|
|
|
28
28
|
import argparse
|
|
29
29
|
import ast
|
|
30
|
+
import json
|
|
30
31
|
import logging
|
|
32
|
+
import os
|
|
31
33
|
import sys
|
|
32
34
|
import time
|
|
33
35
|
from functools import wraps
|
|
@@ -46,20 +48,24 @@ from typing import (
|
|
|
46
48
|
cast,
|
|
47
49
|
)
|
|
48
50
|
|
|
49
|
-
from snpio import
|
|
51
|
+
from snpio import (
|
|
52
|
+
GenePopReader,
|
|
53
|
+
NRemover2,
|
|
54
|
+
PhylipReader,
|
|
55
|
+
SNPioMultiQC,
|
|
56
|
+
StructureReader,
|
|
57
|
+
TreeParser,
|
|
58
|
+
VCFReader,
|
|
59
|
+
)
|
|
50
60
|
|
|
51
61
|
from pgsui import (
|
|
52
62
|
AutoencoderConfig,
|
|
53
63
|
ImputeAutoencoder,
|
|
54
64
|
ImputeMostFrequent,
|
|
55
|
-
ImputeNLPCA,
|
|
56
65
|
ImputeRefAllele,
|
|
57
|
-
ImputeUBP,
|
|
58
66
|
ImputeVAE,
|
|
59
67
|
MostFrequentConfig,
|
|
60
|
-
NLPCAConfig,
|
|
61
68
|
RefAlleleConfig,
|
|
62
|
-
UBPConfig,
|
|
63
69
|
VAEConfig,
|
|
64
70
|
)
|
|
65
71
|
from pgsui.data_processing.config import (
|
|
@@ -71,10 +77,8 @@ from pgsui.data_processing.config import (
|
|
|
71
77
|
|
|
72
78
|
# Canonical model order used everywhere (default and subset ordering)
|
|
73
79
|
MODEL_ORDER: Tuple[str, ...] = (
|
|
74
|
-
"ImputeUBP",
|
|
75
80
|
"ImputeVAE",
|
|
76
81
|
"ImputeAutoencoder",
|
|
77
|
-
"ImputeNLPCA",
|
|
78
82
|
"ImputeMostFrequent",
|
|
79
83
|
"ImputeRefAllele",
|
|
80
84
|
)
|
|
@@ -93,6 +97,142 @@ R = TypeVar("R")
|
|
|
93
97
|
|
|
94
98
|
|
|
95
99
|
# ----------------------------- CLI Utilities ----------------------------- #
|
|
100
|
+
def _print_version() -> None:
|
|
101
|
+
"""Print PG-SUI version and exit."""
|
|
102
|
+
from pgsui import __version__ as version
|
|
103
|
+
|
|
104
|
+
logging.info(f"Using PG-SUI version: {version}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _model_family(model_name: str) -> str:
|
|
108
|
+
"""Return output family folder name used by PG-SUI."""
|
|
109
|
+
if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
110
|
+
return "Unsupervised"
|
|
111
|
+
if model_name in {"ImputeMostFrequent", "ImputeRefAllele"}:
|
|
112
|
+
return "Deterministic"
|
|
113
|
+
return "Unknown"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _flatten_dict(d: dict, parent: str = "") -> dict:
|
|
117
|
+
"""Flatten a nested dict into dot keys."""
|
|
118
|
+
out: dict = {}
|
|
119
|
+
for k, v in (d or {}).items():
|
|
120
|
+
key = f"{parent}.{k}" if parent else str(k)
|
|
121
|
+
if isinstance(v, dict):
|
|
122
|
+
out.update(_flatten_dict(v, key))
|
|
123
|
+
else:
|
|
124
|
+
out[key] = v
|
|
125
|
+
return out
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _force_tuning_off(cfg: Any, model_name: str) -> Any:
|
|
129
|
+
"""Force tuning disabled on a config object (best-effort, but strict for tune-capable models)."""
|
|
130
|
+
# Prefer direct attribute mutation (avoids apply_dot_overrides edge-cases)
|
|
131
|
+
try:
|
|
132
|
+
if hasattr(cfg, "tune") and hasattr(cfg.tune, "enabled"):
|
|
133
|
+
cfg.tune.enabled = False
|
|
134
|
+
return cfg
|
|
135
|
+
except Exception:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
# Fallback to dot override
|
|
139
|
+
try:
|
|
140
|
+
return apply_dot_overrides(cfg, {"tune.enabled": False})
|
|
141
|
+
except Exception as e:
|
|
142
|
+
# Only strict for models that actually support tuning
|
|
143
|
+
if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
144
|
+
raise RuntimeError(
|
|
145
|
+
f"Failed to force tuning off for {model_name}: {e}"
|
|
146
|
+
) from e
|
|
147
|
+
return cfg
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _find_best_params_json(prefix: str, model_name: str) -> Path | None:
|
|
151
|
+
"""Locate best parameter JSON (tuned or final) for a model.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
prefix (str): Output prefix used during the run.
|
|
155
|
+
model_name (str): Model name to look for.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Path | None: Path to best_parameters.json / best_tuned_parameters.json if found; else None.
|
|
159
|
+
"""
|
|
160
|
+
families = ("Unsupervised", "Deterministic")
|
|
161
|
+
model_dir_candidates = (model_name, model_name.lower())
|
|
162
|
+
|
|
163
|
+
for fam in families:
|
|
164
|
+
for mdir in model_dir_candidates:
|
|
165
|
+
base = Path(f"{prefix}_output") / fam
|
|
166
|
+
candidates = (
|
|
167
|
+
base / "optimize" / mdir / "parameters" / "best_tuned_parameters.json",
|
|
168
|
+
base / "parameters" / mdir / "best_parameters.json",
|
|
169
|
+
)
|
|
170
|
+
for p in candidates:
|
|
171
|
+
if p.exists():
|
|
172
|
+
return p
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _load_best_params(best_params_path: Path) -> dict:
|
|
177
|
+
"""Load best parameters JSON."""
|
|
178
|
+
with best_params_path.open("r", encoding="utf-8") as f:
|
|
179
|
+
data = json.load(f)
|
|
180
|
+
if not isinstance(data, dict):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"best_parameters.json must be a JSON object, got {type(data)}"
|
|
183
|
+
)
|
|
184
|
+
return data
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _apply_best_params_to_cfg(cfg: Any, best_params: dict, model_name: str) -> Any:
|
|
188
|
+
"""Apply best params into cfg using dot-path keys or inferred dot-paths.
|
|
189
|
+
|
|
190
|
+
- If JSON is nested, flatten to dot keys.
|
|
191
|
+
- If key already contains '.', treat as dot-path and apply directly.
|
|
192
|
+
- If key has no '.', try common sections in order: model., train., sim., tune., io., plot.
|
|
193
|
+
- Unknown keys are ignored with a warning.
|
|
194
|
+
"""
|
|
195
|
+
# Flatten if nested (but keep existing dot keys as-is too)
|
|
196
|
+
flat = {}
|
|
197
|
+
for k, v in best_params.items():
|
|
198
|
+
if isinstance(v, dict):
|
|
199
|
+
flat.update(_flatten_dict(v, str(k)))
|
|
200
|
+
else:
|
|
201
|
+
flat[str(k)] = v
|
|
202
|
+
|
|
203
|
+
candidate_prefixes = ("", "model.", "train.", "sim.", "tune.", "io.", "plot.")
|
|
204
|
+
|
|
205
|
+
# Apply one by one so we can try multiple candidate destinations for
|
|
206
|
+
# non-dot keys
|
|
207
|
+
for raw_k, v in flat.items():
|
|
208
|
+
if "." in raw_k:
|
|
209
|
+
try:
|
|
210
|
+
cfg = apply_dot_overrides(cfg, {raw_k: v})
|
|
211
|
+
continue
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logging.warning(
|
|
214
|
+
f"Could not apply best param '{raw_k}' to {model_name} (dot key). Skipping. Error: {e}"
|
|
215
|
+
)
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
applied = False
|
|
219
|
+
for pref in candidate_prefixes:
|
|
220
|
+
k = f"{pref}{raw_k}" if pref else raw_k
|
|
221
|
+
try:
|
|
222
|
+
cfg = apply_dot_overrides(cfg, {k: v})
|
|
223
|
+
applied = True
|
|
224
|
+
break
|
|
225
|
+
except Exception:
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
if not applied:
|
|
229
|
+
logging.warning(
|
|
230
|
+
f"Best param '{raw_k}' not recognized for {model_name}; leaving config unchanged for that key."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return cfg
|
|
234
|
+
|
|
235
|
+
|
|
96
236
|
def _configure_logging(verbose: bool, log_file: Optional[str] = None) -> None:
|
|
97
237
|
"""Configure root logger.
|
|
98
238
|
|
|
@@ -166,6 +306,96 @@ def _parse_overrides(pairs: list[str]) -> dict:
|
|
|
166
306
|
return out
|
|
167
307
|
|
|
168
308
|
|
|
309
|
+
def _parse_allele_encoding(arg: str) -> dict:
|
|
310
|
+
"""Parse STRUCTURE allele encoding dict from JSON or Python literal."""
|
|
311
|
+
try:
|
|
312
|
+
payload = json.loads(arg)
|
|
313
|
+
except Exception:
|
|
314
|
+
try:
|
|
315
|
+
payload = ast.literal_eval(arg)
|
|
316
|
+
except Exception as e:
|
|
317
|
+
raise argparse.ArgumentTypeError(
|
|
318
|
+
f"Invalid --structure-allele-encoding; must be a dict. Error: {e}"
|
|
319
|
+
) from e
|
|
320
|
+
|
|
321
|
+
if not isinstance(payload, dict):
|
|
322
|
+
raise argparse.ArgumentTypeError(
|
|
323
|
+
"--structure-allele-encoding must be a dict-like mapping."
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
out: dict = {}
|
|
327
|
+
for k, v in payload.items():
|
|
328
|
+
key = k
|
|
329
|
+
if isinstance(k, str):
|
|
330
|
+
k_strip = k.strip()
|
|
331
|
+
if k_strip.lstrip("-").isdigit():
|
|
332
|
+
try:
|
|
333
|
+
key = int(k_strip)
|
|
334
|
+
except Exception:
|
|
335
|
+
key = k
|
|
336
|
+
out[key] = str(v)
|
|
337
|
+
return out
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _normalize_input_format(fmt: str) -> str:
|
|
341
|
+
"""Normalize format aliases into canonical reader names."""
|
|
342
|
+
fmt = fmt.lower()
|
|
343
|
+
if fmt in {"vcf", "vcf.gz"}:
|
|
344
|
+
return "vcf"
|
|
345
|
+
if fmt in {"phy", "phylip"}:
|
|
346
|
+
return "phylip"
|
|
347
|
+
if fmt in {"gen", "genepop"}:
|
|
348
|
+
return "genepop"
|
|
349
|
+
if fmt in {"str", "structure"}:
|
|
350
|
+
return "structure"
|
|
351
|
+
return fmt
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _normalize_plot_format(fmt: str) -> Literal["pdf", "png", "jpg", "svg"]:
|
|
355
|
+
"""Normalize plot format aliases to reader-supported values."""
|
|
356
|
+
fmt = fmt.lower()
|
|
357
|
+
if fmt == "jpeg":
|
|
358
|
+
return "jpg"
|
|
359
|
+
return cast(Literal["pdf", "png", "jpg", "svg"], fmt)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _expand_path(path: str | None) -> str | None:
|
|
363
|
+
"""Expand ~ and env vars in a path-like string."""
|
|
364
|
+
if path is None:
|
|
365
|
+
return None
|
|
366
|
+
raw = str(path).strip()
|
|
367
|
+
if not raw:
|
|
368
|
+
return None
|
|
369
|
+
expanded = os.path.expandvars(raw)
|
|
370
|
+
return str(Path(expanded).expanduser())
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _resolve_tree_paths(
|
|
374
|
+
args: argparse.Namespace,
|
|
375
|
+
) -> tuple[str | None, str | None, str | None]:
|
|
376
|
+
"""Resolve tree-related paths from CLI args."""
|
|
377
|
+
treefile = _expand_path(getattr(args, "treefile", None))
|
|
378
|
+
qmatrix = _expand_path(getattr(args, "qmatrix", None))
|
|
379
|
+
siterates = _expand_path(getattr(args, "siterates", None))
|
|
380
|
+
return treefile, qmatrix, siterates
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _config_needs_tree(cfg: Any | None) -> bool:
|
|
384
|
+
"""Return True if config requires a tree parser for simulated missingness."""
|
|
385
|
+
if cfg is None:
|
|
386
|
+
return False
|
|
387
|
+
sim_cfg = getattr(cfg, "sim", None)
|
|
388
|
+
if sim_cfg is None:
|
|
389
|
+
return False
|
|
390
|
+
strategy = getattr(sim_cfg, "sim_strategy", None)
|
|
391
|
+
simulate = bool(getattr(sim_cfg, "simulate_missing", False))
|
|
392
|
+
return (
|
|
393
|
+
simulate
|
|
394
|
+
and isinstance(strategy, str)
|
|
395
|
+
and strategy in {"nonrandom", "nonrandom_weighted"}
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
|
|
169
399
|
def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
170
400
|
"""Convert explicitly provided CLI flags into config dot-overrides."""
|
|
171
401
|
overrides: dict = {}
|
|
@@ -174,10 +404,12 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
|
174
404
|
if hasattr(args, "prefix") and args.prefix is not None:
|
|
175
405
|
overrides["io.prefix"] = args.prefix
|
|
176
406
|
else:
|
|
177
|
-
#
|
|
178
|
-
|
|
179
|
-
if hasattr(args, "vcf"):
|
|
180
|
-
|
|
407
|
+
# Prefer --input stem; fallback to legacy --vcf stem
|
|
408
|
+
input_path = getattr(args, "input", None)
|
|
409
|
+
if input_path is None and hasattr(args, "vcf"):
|
|
410
|
+
input_path = getattr(args, "vcf", None)
|
|
411
|
+
if input_path:
|
|
412
|
+
overrides["io.prefix"] = str(Path(input_path).stem)
|
|
181
413
|
|
|
182
414
|
if hasattr(args, "verbose"):
|
|
183
415
|
overrides["io.verbose"] = bool(args.verbose)
|
|
@@ -200,20 +432,34 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
|
200
432
|
# Plot
|
|
201
433
|
if hasattr(args, "plot_format"):
|
|
202
434
|
overrides["plot.fmt"] = args.plot_format
|
|
435
|
+
if getattr(args, "disable_plotting", False):
|
|
436
|
+
logging.info(
|
|
437
|
+
"Disabling plotting for all models as per --disable-plotting flag."
|
|
438
|
+
)
|
|
439
|
+
overrides["plot.show"] = False
|
|
203
440
|
|
|
204
|
-
# Simulation overrides
|
|
441
|
+
# Simulation overrides
|
|
205
442
|
if hasattr(args, "sim_strategy"):
|
|
206
443
|
overrides["sim.sim_strategy"] = args.sim_strategy
|
|
207
444
|
if hasattr(args, "sim_prop"):
|
|
208
445
|
overrides["sim.sim_prop"] = float(args.sim_prop)
|
|
209
|
-
if hasattr(args, "simulate_missing"):
|
|
210
|
-
overrides["sim.simulate_missing"] = bool(args.simulate_missing)
|
|
211
446
|
|
|
212
447
|
# Tuning
|
|
213
|
-
if
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
448
|
+
if getattr(args, "load_best_params", False):
|
|
449
|
+
# Never allow CLI flags to re-enable tuning when loading params
|
|
450
|
+
if hasattr(args, "tune") and bool(getattr(args, "tune", False)):
|
|
451
|
+
logging.warning(
|
|
452
|
+
"--tune was supplied, but --load-best-params is active; ignoring --tune."
|
|
453
|
+
)
|
|
454
|
+
if hasattr(args, "tune_n_trials"):
|
|
455
|
+
logging.warning(
|
|
456
|
+
"--tune-n-trials was supplied, but --load-best-params is active; ignoring it."
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
if hasattr(args, "tune"):
|
|
460
|
+
overrides["tune.enabled"] = bool(args.tune)
|
|
461
|
+
if hasattr(args, "tune_n_trials"):
|
|
462
|
+
overrides["tune.n_trials"] = int(args.tune_n_trials)
|
|
217
463
|
|
|
218
464
|
return overrides
|
|
219
465
|
|
|
@@ -257,35 +503,77 @@ def log_model_time(fn: Callable[P, R]) -> Callable[P, R]:
|
|
|
257
503
|
# ------------------------------ Core Runner ------------------------------ #
|
|
258
504
|
def build_genotype_data(
|
|
259
505
|
input_path: str,
|
|
260
|
-
fmt: Literal[
|
|
506
|
+
fmt: Literal[
|
|
507
|
+
"vcf",
|
|
508
|
+
"vcf.gz",
|
|
509
|
+
"phy",
|
|
510
|
+
"phylip",
|
|
511
|
+
"genepop",
|
|
512
|
+
"gen",
|
|
513
|
+
"structure",
|
|
514
|
+
"str",
|
|
515
|
+
],
|
|
261
516
|
popmap_path: str | None,
|
|
262
517
|
treefile: str | None,
|
|
263
518
|
qmatrix: str | None,
|
|
264
519
|
siterates: str | None,
|
|
265
520
|
force_popmap: bool,
|
|
266
|
-
|
|
521
|
+
debug: bool,
|
|
267
522
|
include_pops: List[str] | None,
|
|
268
|
-
plot_format: Literal["pdf", "png", "jpg", "jpeg"],
|
|
523
|
+
plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"],
|
|
524
|
+
structure_has_popids: bool = False,
|
|
525
|
+
structure_has_marker_names: bool = False,
|
|
526
|
+
structure_allele_start_col: int | None = None,
|
|
527
|
+
structure_allele_encoding: dict | None = None,
|
|
269
528
|
):
|
|
270
|
-
"""Load genotype data from heterogeneous inputs.
|
|
271
|
-
|
|
529
|
+
"""Load genotype data from heterogeneous inputs.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
input_path (str): Path to genotype data file.
|
|
533
|
+
fmt (Literal): Format of genotype data file.
|
|
534
|
+
popmap_path (str | None): Optional path to population map file.
|
|
535
|
+
treefile (str | None): Optional path to phylogenetic tree file.
|
|
536
|
+
qmatrix (str | None): Optional path to IQ-TREE Q matrix file.
|
|
537
|
+
siterates (str | None): Optional path to SNP site rates file.
|
|
538
|
+
force_popmap (bool): Whether to force use of popmap even if samples don't match exactly.
|
|
539
|
+
debug (bool): Whether to enable debug-level logging in SNPio readers.
|
|
540
|
+
include_pops (List[str] | None): Optional list of population IDs to include.
|
|
541
|
+
plot_format (Literal): Figure format for SNPio plots.
|
|
542
|
+
structure_has_popids (bool): STRUCTURE only; whether pop IDs are present.
|
|
543
|
+
structure_has_marker_names (bool): STRUCTURE only; whether the first line has marker names.
|
|
544
|
+
structure_allele_start_col (int | None): STRUCTURE only; zero-based allele start column.
|
|
545
|
+
structure_allele_encoding (dict | None): STRUCTURE only; allele encoding map.
|
|
546
|
+
"""
|
|
547
|
+
fmt_norm = _normalize_input_format(fmt)
|
|
548
|
+
plot_format = _normalize_plot_format(cast(str, plot_format))
|
|
549
|
+
logging.info(f"Loading {fmt_norm.upper()} and popmap data...")
|
|
272
550
|
|
|
273
551
|
kwargs = {
|
|
274
552
|
"filename": input_path,
|
|
275
553
|
"popmapfile": popmap_path,
|
|
276
554
|
"force_popmap": force_popmap,
|
|
277
|
-
"verbose":
|
|
555
|
+
"verbose": debug,
|
|
278
556
|
"include_pops": include_pops if include_pops else None,
|
|
279
557
|
"prefix": f"snpio_{Path(input_path).stem}",
|
|
280
558
|
"plot_format": plot_format,
|
|
281
559
|
}
|
|
282
560
|
|
|
283
|
-
if
|
|
561
|
+
if fmt_norm == "vcf":
|
|
284
562
|
gd = VCFReader(**kwargs)
|
|
285
|
-
elif
|
|
563
|
+
elif fmt_norm == "phylip":
|
|
286
564
|
gd = PhylipReader(**kwargs)
|
|
287
|
-
elif
|
|
565
|
+
elif fmt_norm == "genepop":
|
|
288
566
|
gd = GenePopReader(**kwargs)
|
|
567
|
+
elif fmt_norm == "structure":
|
|
568
|
+
kwargs.update(
|
|
569
|
+
{
|
|
570
|
+
"has_popids": structure_has_popids,
|
|
571
|
+
"has_marker_names": structure_has_marker_names,
|
|
572
|
+
"allele_start_col": structure_allele_start_col,
|
|
573
|
+
"allele_encoding": structure_allele_encoding,
|
|
574
|
+
}
|
|
575
|
+
)
|
|
576
|
+
gd = StructureReader(**kwargs)
|
|
289
577
|
else:
|
|
290
578
|
raise ValueError(f"Unsupported genotype data format: {fmt}")
|
|
291
579
|
|
|
@@ -321,8 +609,6 @@ def run_model_safely(model_name: str, builder, *, warn_only: bool = True) -> Non
|
|
|
321
609
|
# -------------------------- Model Registry ------------------------------- #
|
|
322
610
|
# Add config-driven models here by listing the class and its config dataclass.
|
|
323
611
|
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|
324
|
-
"ImputeUBP": {"cls": ImputeUBP, "config_cls": UBPConfig},
|
|
325
|
-
"ImputeNLPCA": {"cls": ImputeNLPCA, "config_cls": NLPCAConfig},
|
|
326
612
|
"ImputeAutoencoder": {"cls": ImputeAutoencoder, "config_cls": AutoencoderConfig},
|
|
327
613
|
"ImputeVAE": {"cls": ImputeVAE, "config_cls": VAEConfig},
|
|
328
614
|
"ImputeMostFrequent": {"cls": ImputeMostFrequent, "config_cls": MostFrequentConfig},
|
|
@@ -372,24 +658,65 @@ def _build_effective_config_for_model(
|
|
|
372
658
|
f"Loaded YAML config for {model_name} from {yaml_path} (ignored 'preset' in YAML if present)."
|
|
373
659
|
)
|
|
374
660
|
|
|
375
|
-
# 3)
|
|
661
|
+
# 3) Optional: load best parameters from a previous run and force tuning OFF.
|
|
662
|
+
if getattr(args, "load_best_params", False):
|
|
663
|
+
# Determine which prefix to look under for *_output
|
|
664
|
+
src_prefix = getattr(args, "best_params_prefix", None)
|
|
665
|
+
if src_prefix is None:
|
|
666
|
+
# Use the resolved prefix if provided; otherwise fall back to input
|
|
667
|
+
# stem behavior
|
|
668
|
+
src_prefix = getattr(args, "prefix", None)
|
|
669
|
+
|
|
670
|
+
if src_prefix is None and hasattr(args, "vcf"):
|
|
671
|
+
src_prefix = str(Path(args.vcf).stem)
|
|
672
|
+
|
|
673
|
+
if src_prefix is None:
|
|
674
|
+
# As a last resort, use current effective io.prefix if it exists in cfg
|
|
675
|
+
src_prefix = getattr(getattr(cfg, "io", object()), "prefix", None)
|
|
676
|
+
|
|
677
|
+
if getattr(args, "tune", False):
|
|
678
|
+
logging.warning(
|
|
679
|
+
"--tune was supplied, but --load-best-params is active; forcing tuning OFF."
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Force tuning disabled in config (even if CLI/YAML enabled it)
|
|
683
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
684
|
+
|
|
685
|
+
best_path = _find_best_params_json(str(src_prefix), model_name)
|
|
686
|
+
if best_path is None:
|
|
687
|
+
# For tune-capable (unsupervised) models, treat as an error; deterministic models warn only.
|
|
688
|
+
fam = _model_family(model_name)
|
|
689
|
+
msg = (
|
|
690
|
+
"Requested --load-best-params, but could not find a best parameters JSON "
|
|
691
|
+
f"for {model_name}. Looked under '.../optimize/<model>/parameters/best_tuned_parameters.json' and '{src_prefix}_output/{fam}/parameters/{model_name}/best_parameters.json'"
|
|
692
|
+
)
|
|
693
|
+
if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
694
|
+
logging.error(msg)
|
|
695
|
+
raise FileNotFoundError(msg)
|
|
696
|
+
logging.warning(msg)
|
|
697
|
+
else:
|
|
698
|
+
logging.info(f"Loading best parameters for {model_name} from: {best_path}")
|
|
699
|
+
best_params = _load_best_params(best_path)
|
|
700
|
+
cfg = _apply_best_params_to_cfg(cfg, best_params, model_name)
|
|
701
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
702
|
+
|
|
703
|
+
# 4) Explicit CLI flags overlay YAML/best-params layers.
|
|
376
704
|
cli_overrides = _args_to_cli_overrides(args)
|
|
377
705
|
if cli_overrides:
|
|
378
706
|
cfg = apply_dot_overrides(cfg, cli_overrides)
|
|
379
707
|
|
|
380
|
-
#
|
|
708
|
+
# Keep tuning disabled if --load-best-params was requested, even if CLI flags tried to re-enable it.
|
|
709
|
+
if getattr(args, "load_best_params", False):
|
|
710
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
711
|
+
|
|
712
|
+
# 5) --set has highest precedence.
|
|
381
713
|
user_overrides = _parse_overrides(getattr(args, "set", []))
|
|
382
714
|
|
|
383
715
|
if user_overrides:
|
|
384
716
|
try:
|
|
385
717
|
cfg = apply_dot_overrides(cfg, user_overrides)
|
|
386
718
|
except Exception as e:
|
|
387
|
-
if model_name in {
|
|
388
|
-
"ImputeUBP",
|
|
389
|
-
"ImputeNLPCA",
|
|
390
|
-
"ImputeAutoencoder",
|
|
391
|
-
"ImputeVAE",
|
|
392
|
-
}:
|
|
719
|
+
if model_name in {"ImputeAutoencoder", "ImputeVAE"}:
|
|
393
720
|
logging.error(
|
|
394
721
|
f"Error applying --set overrides to {model_name} config: {e}"
|
|
395
722
|
)
|
|
@@ -397,6 +724,18 @@ def _build_effective_config_for_model(
|
|
|
397
724
|
else:
|
|
398
725
|
pass # non-config-driven models ignore --set
|
|
399
726
|
|
|
727
|
+
# FINAL GUARANTEE:
|
|
728
|
+
# --load-best-params always wins over
|
|
729
|
+
# --set, YAML, preset, and CLI flags.
|
|
730
|
+
if getattr(args, "load_best_params", False):
|
|
731
|
+
# If user explicitly tried to set tune.* via --set, warn and override.
|
|
732
|
+
if any(str(k).startswith("tune.") for k in (user_overrides or {}).keys()):
|
|
733
|
+
logging.warning(
|
|
734
|
+
f"{model_name}: '--set tune.*=...' was provided, but --load-best-params forces tuning OFF. "
|
|
735
|
+
"Ignoring any tune.* overrides."
|
|
736
|
+
)
|
|
737
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
738
|
+
|
|
400
739
|
return cfg
|
|
401
740
|
|
|
402
741
|
|
|
@@ -437,9 +776,19 @@ def _maybe_print_or_dump_configs(
|
|
|
437
776
|
|
|
438
777
|
|
|
439
778
|
def main(argv: Optional[List[str]] = None) -> int:
|
|
779
|
+
"""PG-SUI CLI main entry point.
|
|
780
|
+
|
|
781
|
+
The CLI supports running multiple imputation models on a single input file, with configuration handled via presets, YAML files, and CLI flags.
|
|
782
|
+
|
|
783
|
+
Args:
|
|
784
|
+
argv (Optional[List[str]]): List of CLI args (default: sys.argv[1:]).
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
int: Exit code (0=success, 2=argparse error, 1=other error).
|
|
788
|
+
"""
|
|
440
789
|
parser = argparse.ArgumentParser(
|
|
441
790
|
prog="pg-sui",
|
|
442
|
-
description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models.",
|
|
791
|
+
description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models. The input file can be in VCF, PHYLIP, or GENEPOP format. Outputs include imputed genotype files and performance summaries.",
|
|
443
792
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
444
793
|
usage="%(prog)s [options]",
|
|
445
794
|
)
|
|
@@ -448,7 +797,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
448
797
|
parser.add_argument(
|
|
449
798
|
"--input",
|
|
450
799
|
default=argparse.SUPPRESS,
|
|
451
|
-
help="Path to input file (VCF/PHYLIP/
|
|
800
|
+
help="Path to input file (VCF/PHYLIP/GENEPOP). VCF file can be bgzipped or uncompressed.",
|
|
452
801
|
)
|
|
453
802
|
parser.add_argument(
|
|
454
803
|
"--format",
|
|
@@ -464,23 +813,23 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
464
813
|
"gen",
|
|
465
814
|
),
|
|
466
815
|
default=argparse.SUPPRESS,
|
|
467
|
-
help="Input format. If 'infer', deduced from file extension. The default is 'infer'.",
|
|
816
|
+
help="Input format. If 'infer', deduced from file extension. The default is 'infer'. Supported formats: VCF ('.vcf', '.vcf.gz'), PHYLIP ('.phy', '.phylip'), GENEPOP ('.genepop', '.gen').",
|
|
468
817
|
)
|
|
469
818
|
# Back-compat: --vcf retained; if both provided, --input wins.
|
|
470
819
|
parser.add_argument(
|
|
471
820
|
"--vcf",
|
|
472
821
|
default=argparse.SUPPRESS,
|
|
473
|
-
help="Path to input VCF file. Can be bgzipped or uncompressed.",
|
|
822
|
+
help="Path to input VCF file. Can be bgzipped or uncompressed. (Deprecated; use --input instead.)",
|
|
474
823
|
)
|
|
475
824
|
parser.add_argument(
|
|
476
825
|
"--popmap",
|
|
477
826
|
default=argparse.SUPPRESS,
|
|
478
|
-
help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs.",
|
|
827
|
+
help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs. If not provided, no population info is used.",
|
|
479
828
|
)
|
|
480
829
|
parser.add_argument(
|
|
481
830
|
"--treefile",
|
|
482
831
|
default=argparse.SUPPRESS,
|
|
483
|
-
help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format.",
|
|
832
|
+
help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format. Used with --qmatrix and --siterates.",
|
|
484
833
|
)
|
|
485
834
|
parser.add_argument(
|
|
486
835
|
"--qmatrix",
|
|
@@ -490,19 +839,19 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
490
839
|
parser.add_argument(
|
|
491
840
|
"--siterates",
|
|
492
841
|
default=argparse.SUPPRESS,
|
|
493
|
-
help="Path to SNP site rates file (has .rate extension). Used with --treefile and --qmatrix.",
|
|
842
|
+
help="Path to SNP site rates file (has .rate extension and can be produced with IQ-TREE). Used with --treefile and --qmatrix.",
|
|
494
843
|
)
|
|
495
844
|
parser.add_argument(
|
|
496
845
|
"--prefix",
|
|
497
846
|
default=argparse.SUPPRESS,
|
|
498
|
-
help="Output file prefix.",
|
|
847
|
+
help="Output file prefix. If not provided, defaults to the input file stem.",
|
|
499
848
|
)
|
|
500
849
|
|
|
501
850
|
# ---------------------- Generic Config Inputs -------------------------- #
|
|
502
851
|
parser.add_argument(
|
|
503
852
|
"--config",
|
|
504
853
|
default=argparse.SUPPRESS,
|
|
505
|
-
help="YAML config for config-driven models (
|
|
854
|
+
help="YAML config for config-driven models (Autoencoder, VAE). Overrides preset and defaults.",
|
|
506
855
|
)
|
|
507
856
|
parser.add_argument(
|
|
508
857
|
"--preset",
|
|
@@ -514,7 +863,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
514
863
|
"--set",
|
|
515
864
|
action="append",
|
|
516
865
|
default=argparse.SUPPRESS,
|
|
517
|
-
help="Dot-key overrides, e.g. --set model.latent_dim=4",
|
|
866
|
+
help="Dot-key overrides, e.g. --set model.latent_dim=4 --set train.epochs=100. Applies to all models.",
|
|
518
867
|
)
|
|
519
868
|
parser.add_argument(
|
|
520
869
|
"--print-config",
|
|
@@ -532,7 +881,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
532
881
|
"--tune",
|
|
533
882
|
action="store_true",
|
|
534
883
|
default=argparse.SUPPRESS,
|
|
535
|
-
help="Enable hyperparameter tuning (if supported).",
|
|
884
|
+
help="Enable hyperparameter tuning (if supported by model). Uses Optuna to optimize hyperparameters.",
|
|
536
885
|
)
|
|
537
886
|
parser.add_argument(
|
|
538
887
|
"--tune-n-trials",
|
|
@@ -562,8 +911,31 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
562
911
|
"--plot-format",
|
|
563
912
|
choices=("png", "pdf", "svg", "jpg", "jpeg"),
|
|
564
913
|
default=argparse.SUPPRESS,
|
|
565
|
-
help="Figure format for model plots.",
|
|
914
|
+
help="Figure format for model plots. Choices: png, pdf, svg, jpg, jpeg.",
|
|
915
|
+
)
|
|
916
|
+
parser.add_argument(
|
|
917
|
+
"--disable-plotting",
|
|
918
|
+
action="store_true",
|
|
919
|
+
default=False,
|
|
920
|
+
help="Disable plotting for all models. Overrides any config settings enabling plotting.",
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
parser.add_argument(
|
|
924
|
+
"--load-best-params",
|
|
925
|
+
action="store_true",
|
|
926
|
+
default=False,
|
|
927
|
+
help=(
|
|
928
|
+
"Load best hyperparameters from a previous run's best_parameters.json (or tuning best_tuned_parameters.json) for each selected model and apply them to the model configs. This forces tuning OFF."
|
|
929
|
+
),
|
|
566
930
|
)
|
|
931
|
+
parser.add_argument(
|
|
932
|
+
"--best-params-prefix",
|
|
933
|
+
default=argparse.SUPPRESS,
|
|
934
|
+
help=(
|
|
935
|
+
"Prefix of the PREVIOUS run to load best parameters from. If omitted, uses the current --prefix (or input stem)."
|
|
936
|
+
),
|
|
937
|
+
)
|
|
938
|
+
|
|
567
939
|
# ------------------------- Simulation Controls ------------------------ #
|
|
568
940
|
parser.add_argument(
|
|
569
941
|
"--sim-strategy",
|
|
@@ -577,19 +949,15 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
577
949
|
default=argparse.SUPPRESS,
|
|
578
950
|
help="Override the proportion of observed entries to mask during simulation (0-1).",
|
|
579
951
|
)
|
|
580
|
-
parser.add_argument(
|
|
581
|
-
"--simulate-missing",
|
|
582
|
-
action="store_false",
|
|
583
|
-
default=argparse.SUPPRESS,
|
|
584
|
-
help="Disable missing-data simulation regardless of preset/config (when provided).",
|
|
585
|
-
)
|
|
586
952
|
|
|
587
953
|
# --------------------------- Seed & logging ---------------------------- #
|
|
588
954
|
parser.add_argument(
|
|
589
955
|
"--seed",
|
|
590
956
|
default=argparse.SUPPRESS,
|
|
591
|
-
help="Random seed: 'random', 'deterministic', or an integer.",
|
|
957
|
+
help="Random seed: 'random', 'deterministic', or an integer. Default is 'random'.",
|
|
592
958
|
)
|
|
959
|
+
|
|
960
|
+
# ----------------------------- Logging --------------------------------- #
|
|
593
961
|
parser.add_argument("--verbose", action="store_true", help="Info-level logging.")
|
|
594
962
|
parser.add_argument("--debug", action="store_true", help="Debug-level logging.")
|
|
595
963
|
parser.add_argument(
|
|
@@ -607,7 +975,33 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
607
975
|
"--force-popmap",
|
|
608
976
|
action="store_true",
|
|
609
977
|
default=False,
|
|
610
|
-
help="
|
|
978
|
+
help="Force use of provided popmap even if samples don't match exactly. This will drop samples not in the popmap and vice versa.",
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# -------------------------- STRUCTURE options ------------------------- #
|
|
982
|
+
parser.add_argument(
|
|
983
|
+
"--structure-has-popids",
|
|
984
|
+
action="store_true",
|
|
985
|
+
default=False,
|
|
986
|
+
help="STRUCTURE only: second column contains population IDs.",
|
|
987
|
+
)
|
|
988
|
+
parser.add_argument(
|
|
989
|
+
"--structure-has-marker-names",
|
|
990
|
+
action="store_true",
|
|
991
|
+
default=False,
|
|
992
|
+
help="STRUCTURE only: first row contains marker names.",
|
|
993
|
+
)
|
|
994
|
+
parser.add_argument(
|
|
995
|
+
"--structure-allele-start-col",
|
|
996
|
+
type=int,
|
|
997
|
+
default=argparse.SUPPRESS,
|
|
998
|
+
help="STRUCTURE only: zero-based column index where alleles begin.",
|
|
999
|
+
)
|
|
1000
|
+
parser.add_argument(
|
|
1001
|
+
"--structure-allele-encoding",
|
|
1002
|
+
type=_parse_allele_encoding,
|
|
1003
|
+
default=argparse.SUPPRESS,
|
|
1004
|
+
help="STRUCTURE only: allele encoding mapping as JSON or Python dict.",
|
|
611
1005
|
)
|
|
612
1006
|
|
|
613
1007
|
# ---------------------------- Model selection -------------------------- #
|
|
@@ -616,54 +1010,66 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
616
1010
|
nargs="+",
|
|
617
1011
|
default=argparse.SUPPRESS,
|
|
618
1012
|
help=(
|
|
619
|
-
"Which models to run. Choices:
|
|
1013
|
+
"Which models to run. Specify each model separated by a space. Choices: ImputeVAE ImputeAutoencoder ImputeMostFrequent ImputeRefAllele (Default is all models)."
|
|
620
1014
|
),
|
|
621
1015
|
)
|
|
622
1016
|
|
|
623
1017
|
# -------------------------- MultiQC integration ------------------------ #
|
|
624
1018
|
parser.add_argument(
|
|
625
|
-
"--multiqc",
|
|
1019
|
+
"--disable-multiqc",
|
|
626
1020
|
action="store_true",
|
|
1021
|
+
default=False,
|
|
627
1022
|
help=(
|
|
628
|
-
"
|
|
1023
|
+
"Disable MultiQC report generation after imputation. By default, a MultiQC report is generated unless this flag is set."
|
|
629
1024
|
),
|
|
630
1025
|
)
|
|
631
1026
|
parser.add_argument(
|
|
632
1027
|
"--multiqc-title",
|
|
633
1028
|
default=argparse.SUPPRESS,
|
|
634
|
-
help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>').",
|
|
1029
|
+
help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>'). ",
|
|
635
1030
|
)
|
|
636
1031
|
parser.add_argument(
|
|
637
1032
|
"--multiqc-output-dir",
|
|
638
1033
|
default=argparse.SUPPRESS,
|
|
639
|
-
help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc').",
|
|
1034
|
+
help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc'). This directory will be created if it does not exist.",
|
|
640
1035
|
)
|
|
641
1036
|
parser.add_argument(
|
|
642
1037
|
"--multiqc-overwrite",
|
|
643
1038
|
action="store_true",
|
|
644
1039
|
default=False,
|
|
645
|
-
help="Overwrite an existing MultiQC report if present.",
|
|
1040
|
+
help="Overwrite an existing MultiQC report if present. If not set and a report exists, an integer suffix will be added to avoid overwriting. NOTE: if running multiple times with this flag, it may append multiple suffixes to avoid overwriting previous reports.",
|
|
646
1041
|
)
|
|
647
1042
|
|
|
648
1043
|
# ------------------------------ Safety/UX ------------------------------ #
|
|
649
1044
|
parser.add_argument(
|
|
650
1045
|
"--dry-run",
|
|
651
1046
|
action="store_true",
|
|
652
|
-
help="Parse args and load data, but skip model training.",
|
|
1047
|
+
help="Parse args and load data, but skip model training. Useful for testing I/O and configs.",
|
|
1048
|
+
)
|
|
1049
|
+
parser.add_argument(
|
|
1050
|
+
"--version", action="store_true", help="Print PG-SUI version and exit."
|
|
653
1051
|
)
|
|
654
1052
|
|
|
655
1053
|
args = parser.parse_args(argv)
|
|
656
1054
|
|
|
1055
|
+
if getattr(args, "version", False):
|
|
1056
|
+
_print_version()
|
|
1057
|
+
return 0
|
|
1058
|
+
|
|
657
1059
|
# Logging (verbose default is False unless passed)
|
|
658
1060
|
_configure_logging(
|
|
659
1061
|
verbose=getattr(args, "verbose", False),
|
|
660
1062
|
log_file=getattr(args, "log_file", None),
|
|
661
1063
|
)
|
|
662
1064
|
|
|
1065
|
+
logging.info("Starting PG-SUI imputation...")
|
|
1066
|
+
_print_version()
|
|
1067
|
+
|
|
663
1068
|
# Models selection (default to all if not explicitly provided)
|
|
664
1069
|
try:
|
|
665
1070
|
selected_models = _parse_models(getattr(args, "models", ()))
|
|
666
1071
|
except argparse.ArgumentTypeError as e:
|
|
1072
|
+
logging.error(str(e))
|
|
667
1073
|
parser.error(str(e))
|
|
668
1074
|
return 2
|
|
669
1075
|
|
|
@@ -675,12 +1081,11 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
675
1081
|
setattr(args, "format", "vcf")
|
|
676
1082
|
|
|
677
1083
|
if input_path is None:
|
|
1084
|
+
logging.error("You must provide --input (or legacy --vcf).")
|
|
678
1085
|
parser.error("You must provide --input (or legacy --vcf).")
|
|
679
1086
|
return 2
|
|
680
1087
|
|
|
681
|
-
fmt
|
|
682
|
-
args, "format", "infer"
|
|
683
|
-
)
|
|
1088
|
+
fmt = getattr(args, "format", "infer")
|
|
684
1089
|
|
|
685
1090
|
if fmt == "infer":
|
|
686
1091
|
if input_path.endswith((".vcf", ".vcf.gz")):
|
|
@@ -689,28 +1094,59 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
689
1094
|
fmt_final = "phylip"
|
|
690
1095
|
elif input_path.endswith((".genepop", ".gen")):
|
|
691
1096
|
fmt_final = "genepop"
|
|
1097
|
+
elif input_path.endswith((".str", ".stru", ".structure")):
|
|
1098
|
+
fmt_final = "structure"
|
|
692
1099
|
else:
|
|
1100
|
+
logging.error(
|
|
1101
|
+
"Could not infer input format from file extension. Please provide --format."
|
|
1102
|
+
)
|
|
693
1103
|
parser.error(
|
|
694
1104
|
"Could not infer input format from file extension. Please provide --format."
|
|
695
1105
|
)
|
|
696
1106
|
return 2
|
|
697
1107
|
else:
|
|
698
|
-
fmt_final =
|
|
1108
|
+
fmt_final = cast(
|
|
1109
|
+
Literal[
|
|
1110
|
+
"vcf",
|
|
1111
|
+
"vcf.gz",
|
|
1112
|
+
"phy",
|
|
1113
|
+
"phylip",
|
|
1114
|
+
"genepop",
|
|
1115
|
+
"gen",
|
|
1116
|
+
"structure",
|
|
1117
|
+
"str",
|
|
1118
|
+
],
|
|
1119
|
+
fmt,
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
fmt_final = _normalize_input_format(fmt_final)
|
|
699
1123
|
|
|
700
1124
|
popmap_path = getattr(args, "popmap", None)
|
|
701
1125
|
include_pops = getattr(args, "include_pops", None)
|
|
702
|
-
verbose_flag = getattr(args, "verbose", False)
|
|
703
1126
|
force_popmap = bool(getattr(args, "force_popmap", False))
|
|
1127
|
+
structure_has_popids = bool(getattr(args, "structure_has_popids", False))
|
|
1128
|
+
structure_has_marker_names = bool(
|
|
1129
|
+
getattr(args, "structure_has_marker_names", False)
|
|
1130
|
+
)
|
|
1131
|
+
structure_allele_start_col = getattr(args, "structure_allele_start_col", None)
|
|
1132
|
+
structure_allele_encoding = getattr(args, "structure_allele_encoding", None)
|
|
704
1133
|
|
|
705
1134
|
# Canonical prefix for this run (used for outputs and MultiQC)
|
|
706
1135
|
prefix: str = getattr(args, "prefix", str(Path(input_path).stem))
|
|
1136
|
+
# Ensure downstream config building sees the resolved prefix even if
|
|
1137
|
+
# --prefix was not provided.
|
|
1138
|
+
setattr(args, "prefix", prefix)
|
|
707
1139
|
|
|
708
|
-
treefile =
|
|
709
|
-
|
|
710
|
-
|
|
1140
|
+
treefile, qmatrix, siterates = _resolve_tree_paths(args)
|
|
1141
|
+
setattr(args, "treefile", treefile)
|
|
1142
|
+
setattr(args, "qmatrix", qmatrix)
|
|
1143
|
+
setattr(args, "siterates", siterates)
|
|
711
1144
|
|
|
712
1145
|
if any(x is not None for x in (treefile, qmatrix, siterates)):
|
|
713
1146
|
if not all(x is not None for x in (treefile, qmatrix, siterates)):
|
|
1147
|
+
logging.error(
|
|
1148
|
+
"--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
|
|
1149
|
+
)
|
|
714
1150
|
parser.error(
|
|
715
1151
|
"--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
|
|
716
1152
|
)
|
|
@@ -719,15 +1155,31 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
719
1155
|
# Load genotype data
|
|
720
1156
|
gd, tp = build_genotype_data(
|
|
721
1157
|
input_path=input_path,
|
|
722
|
-
fmt=
|
|
1158
|
+
fmt=cast(
|
|
1159
|
+
Literal[
|
|
1160
|
+
"vcf",
|
|
1161
|
+
"vcf.gz",
|
|
1162
|
+
"phy",
|
|
1163
|
+
"phylip",
|
|
1164
|
+
"genepop",
|
|
1165
|
+
"gen",
|
|
1166
|
+
"structure",
|
|
1167
|
+
"str",
|
|
1168
|
+
],
|
|
1169
|
+
fmt_final,
|
|
1170
|
+
),
|
|
723
1171
|
popmap_path=popmap_path,
|
|
724
1172
|
treefile=treefile,
|
|
725
1173
|
qmatrix=qmatrix,
|
|
726
1174
|
siterates=siterates,
|
|
727
1175
|
force_popmap=force_popmap,
|
|
728
|
-
verbose=verbose_flag,
|
|
729
1176
|
include_pops=include_pops,
|
|
1177
|
+
debug=getattr(args, "debug", False),
|
|
730
1178
|
plot_format=getattr(args, "plot_format", "pdf"),
|
|
1179
|
+
structure_has_popids=structure_has_popids,
|
|
1180
|
+
structure_has_marker_names=structure_has_marker_names,
|
|
1181
|
+
structure_allele_start_col=structure_allele_start_col,
|
|
1182
|
+
structure_allele_encoding=structure_allele_encoding,
|
|
731
1183
|
)
|
|
732
1184
|
|
|
733
1185
|
if getattr(args, "dry_run", False):
|
|
@@ -739,47 +1191,33 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
739
1191
|
m: _build_effective_config_for_model(m, args) for m in selected_models
|
|
740
1192
|
}
|
|
741
1193
|
|
|
1194
|
+
needs_tree = any(
|
|
1195
|
+
_config_needs_tree(cfg) for cfg in cfgs_by_model.values() if cfg is not None
|
|
1196
|
+
)
|
|
1197
|
+
if needs_tree and not all(x is not None for x in (treefile, qmatrix, siterates)):
|
|
1198
|
+
logging.error(
|
|
1199
|
+
"Nonrandom simulated missingness requires --treefile, --qmatrix, and --siterates."
|
|
1200
|
+
)
|
|
1201
|
+
parser.error(
|
|
1202
|
+
"Nonrandom simulated missingness requires --treefile, --qmatrix, and --siterates."
|
|
1203
|
+
)
|
|
1204
|
+
return 2
|
|
1205
|
+
if needs_tree and tp is None:
|
|
1206
|
+
logging.error(
|
|
1207
|
+
"Tree parser was not initialized for nonrandom simulation. "
|
|
1208
|
+
"Please verify --treefile, --qmatrix, and --siterates."
|
|
1209
|
+
)
|
|
1210
|
+
parser.error(
|
|
1211
|
+
"Tree parser was not initialized for nonrandom simulation. "
|
|
1212
|
+
"Please verify --treefile, --qmatrix, and --siterates."
|
|
1213
|
+
)
|
|
1214
|
+
return 2
|
|
1215
|
+
|
|
742
1216
|
# Maybe print/dump configs and exit
|
|
743
1217
|
if _maybe_print_or_dump_configs(cfgs_by_model, args):
|
|
744
1218
|
return 0
|
|
745
1219
|
|
|
746
1220
|
# ------------------------- Model Builders ------------------------------ #
|
|
747
|
-
def build_impute_ubp():
|
|
748
|
-
cfg = cfgs_by_model.get("ImputeUBP")
|
|
749
|
-
if cfg is None:
|
|
750
|
-
cfg = (
|
|
751
|
-
UBPConfig.from_preset(args.preset)
|
|
752
|
-
if hasattr(args, "preset")
|
|
753
|
-
else UBPConfig()
|
|
754
|
-
)
|
|
755
|
-
return ImputeUBP(
|
|
756
|
-
genotype_data=gd,
|
|
757
|
-
tree_parser=tp,
|
|
758
|
-
config=cfg,
|
|
759
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
760
|
-
sim_strategy=cfg.sim.sim_strategy,
|
|
761
|
-
sim_prop=cfg.sim.sim_prop,
|
|
762
|
-
sim_kwargs=cfg.sim.sim_kwargs,
|
|
763
|
-
)
|
|
764
|
-
|
|
765
|
-
def build_impute_nlpca():
|
|
766
|
-
cfg = cfgs_by_model.get("ImputeNLPCA")
|
|
767
|
-
if cfg is None:
|
|
768
|
-
cfg = (
|
|
769
|
-
NLPCAConfig.from_preset(args.preset)
|
|
770
|
-
if hasattr(args, "preset")
|
|
771
|
-
else NLPCAConfig()
|
|
772
|
-
)
|
|
773
|
-
return ImputeNLPCA(
|
|
774
|
-
genotype_data=gd,
|
|
775
|
-
tree_parser=tp,
|
|
776
|
-
config=cfg,
|
|
777
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
778
|
-
sim_strategy=cfg.sim.sim_strategy,
|
|
779
|
-
sim_prop=cfg.sim.sim_prop,
|
|
780
|
-
sim_kwargs=cfg.sim.sim_kwargs,
|
|
781
|
-
)
|
|
782
|
-
|
|
783
1221
|
def build_impute_vae():
|
|
784
1222
|
cfg = cfgs_by_model.get("ImputeVAE")
|
|
785
1223
|
if cfg is None:
|
|
@@ -792,7 +1230,6 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
792
1230
|
genotype_data=gd,
|
|
793
1231
|
tree_parser=tp,
|
|
794
1232
|
config=cfg,
|
|
795
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
796
1233
|
sim_strategy=cfg.sim.sim_strategy,
|
|
797
1234
|
sim_prop=cfg.sim.sim_prop,
|
|
798
1235
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
@@ -810,7 +1247,6 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
810
1247
|
genotype_data=gd,
|
|
811
1248
|
tree_parser=tp,
|
|
812
1249
|
config=cfg,
|
|
813
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
814
1250
|
sim_strategy=cfg.sim.sim_strategy,
|
|
815
1251
|
sim_prop=cfg.sim.sim_prop,
|
|
816
1252
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
@@ -828,7 +1264,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
828
1264
|
gd,
|
|
829
1265
|
tree_parser=tp,
|
|
830
1266
|
config=cfg,
|
|
831
|
-
simulate_missing=
|
|
1267
|
+
simulate_missing=True,
|
|
832
1268
|
sim_strategy=cfg.sim.sim_strategy,
|
|
833
1269
|
sim_prop=cfg.sim.sim_prop,
|
|
834
1270
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
@@ -846,34 +1282,37 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
846
1282
|
gd,
|
|
847
1283
|
tree_parser=tp,
|
|
848
1284
|
config=cfg,
|
|
849
|
-
simulate_missing=
|
|
1285
|
+
simulate_missing=True,
|
|
850
1286
|
sim_strategy=cfg.sim.sim_strategy,
|
|
851
1287
|
sim_prop=cfg.sim.sim_prop,
|
|
852
1288
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
853
1289
|
)
|
|
854
1290
|
|
|
855
1291
|
model_builders = {
|
|
856
|
-
"ImputeUBP": build_impute_ubp,
|
|
857
1292
|
"ImputeVAE": build_impute_vae,
|
|
858
1293
|
"ImputeAutoencoder": build_impute_autoencoder,
|
|
859
|
-
"ImputeNLPCA": build_impute_nlpca,
|
|
860
1294
|
"ImputeMostFrequent": build_impute_mostfreq,
|
|
861
1295
|
"ImputeRefAllele": build_impute_refallele,
|
|
862
1296
|
}
|
|
863
1297
|
|
|
864
1298
|
logging.info(f"Selected models: {', '.join(selected_models)}")
|
|
865
1299
|
for name in selected_models:
|
|
1300
|
+
logging.info("")
|
|
1301
|
+
logging.info("=" * 60)
|
|
1302
|
+
logging.info("")
|
|
1303
|
+
logging.info(f"Processing model: {name} ...")
|
|
866
1304
|
X_imputed = run_model_safely(name, model_builders[name], warn_only=False)
|
|
867
1305
|
gd_imp = gd.copy()
|
|
868
1306
|
gd_imp.snp_data = X_imputed
|
|
869
1307
|
|
|
870
|
-
if name in {"
|
|
1308
|
+
if name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
871
1309
|
family = "Unsupervised"
|
|
872
1310
|
elif name in {"ImputeMostFrequent", "ImputeRefAllele"}:
|
|
873
1311
|
family = "Deterministic"
|
|
874
1312
|
elif name in {"ImputeHistGradientBoosting", "ImputeRandomForest"}:
|
|
875
1313
|
family = "Supervised"
|
|
876
1314
|
else:
|
|
1315
|
+
logging.error(f"Unknown model family for {name}")
|
|
877
1316
|
raise ValueError(f"Unknown model family for {name}")
|
|
878
1317
|
|
|
879
1318
|
pth = Path(f"{prefix}_output/{family}/imputed/{name}")
|
|
@@ -892,7 +1331,19 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
892
1331
|
f"Output format {fmt_final} not supported for imputed data export."
|
|
893
1332
|
)
|
|
894
1333
|
|
|
895
|
-
|
|
1334
|
+
logging.info("")
|
|
1335
|
+
logging.info(f"Successfully finished imputation for model: {name}!")
|
|
1336
|
+
logging.info("")
|
|
1337
|
+
logging.info("=" * 60)
|
|
1338
|
+
|
|
1339
|
+
logging.info(f"All requested models processed for input: {input_path}")
|
|
1340
|
+
|
|
1341
|
+
disable_mqc = bool(getattr(args, "disable_multiqc", False))
|
|
1342
|
+
|
|
1343
|
+
if disable_mqc:
|
|
1344
|
+
logging.info("MultiQC report generation disabled via --disable-multiqc.")
|
|
1345
|
+
logging.info("PG-SUI imputation run complete!")
|
|
1346
|
+
return 0
|
|
896
1347
|
|
|
897
1348
|
# -------------------------- MultiQC builder ---------------------------- #
|
|
898
1349
|
|
|
@@ -912,9 +1363,10 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
912
1363
|
overwrite=overwrite,
|
|
913
1364
|
)
|
|
914
1365
|
logging.info("MultiQC report successfully built.")
|
|
915
|
-
except Exception as exc2:
|
|
1366
|
+
except Exception as exc2:
|
|
916
1367
|
logging.error(f"Failed to build MultiQC report: {exc2}", exc_info=True)
|
|
917
1368
|
|
|
1369
|
+
logging.info("PG-SUI imputation run complete!")
|
|
918
1370
|
return 0
|
|
919
1371
|
|
|
920
1372
|
|