pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
pgsui/cli.py
CHANGED
|
@@ -15,19 +15,21 @@ Notes
|
|
|
15
15
|
|
|
16
16
|
Examples
|
|
17
17
|
--------
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
--models
|
|
22
|
-
|
|
23
|
-
--include-pops EA GU TT ON --device cpu
|
|
18
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix run1
|
|
19
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
|
|
20
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
|
|
21
|
+
--models ImputeAutoencoder ImputeVAE ImputeMostFrequent --seed deterministic --verbose
|
|
22
|
+
pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
|
|
23
|
+
--include-pops EA GU TT ON --device cpu --sim-prop 0.3 --sim-strategy nonrandom
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
26
|
from __future__ import annotations
|
|
27
27
|
|
|
28
28
|
import argparse
|
|
29
29
|
import ast
|
|
30
|
+
import json
|
|
30
31
|
import logging
|
|
32
|
+
import os
|
|
31
33
|
import sys
|
|
32
34
|
import time
|
|
33
35
|
from functools import wraps
|
|
@@ -46,20 +48,24 @@ from typing import (
|
|
|
46
48
|
cast,
|
|
47
49
|
)
|
|
48
50
|
|
|
49
|
-
from snpio import
|
|
51
|
+
from snpio import (
|
|
52
|
+
GenePopReader,
|
|
53
|
+
NRemover2,
|
|
54
|
+
PhylipReader,
|
|
55
|
+
SNPioMultiQC,
|
|
56
|
+
StructureReader,
|
|
57
|
+
TreeParser,
|
|
58
|
+
VCFReader,
|
|
59
|
+
)
|
|
50
60
|
|
|
51
61
|
from pgsui import (
|
|
52
62
|
AutoencoderConfig,
|
|
53
63
|
ImputeAutoencoder,
|
|
54
64
|
ImputeMostFrequent,
|
|
55
|
-
ImputeNLPCA,
|
|
56
65
|
ImputeRefAllele,
|
|
57
|
-
ImputeUBP,
|
|
58
66
|
ImputeVAE,
|
|
59
67
|
MostFrequentConfig,
|
|
60
|
-
NLPCAConfig,
|
|
61
68
|
RefAlleleConfig,
|
|
62
|
-
UBPConfig,
|
|
63
69
|
VAEConfig,
|
|
64
70
|
)
|
|
65
71
|
from pgsui.data_processing.config import (
|
|
@@ -71,10 +77,8 @@ from pgsui.data_processing.config import (
|
|
|
71
77
|
|
|
72
78
|
# Canonical model order used everywhere (default and subset ordering)
|
|
73
79
|
MODEL_ORDER: Tuple[str, ...] = (
|
|
74
|
-
"ImputeUBP",
|
|
75
80
|
"ImputeVAE",
|
|
76
81
|
"ImputeAutoencoder",
|
|
77
|
-
"ImputeNLPCA",
|
|
78
82
|
"ImputeMostFrequent",
|
|
79
83
|
"ImputeRefAllele",
|
|
80
84
|
)
|
|
@@ -93,6 +97,142 @@ R = TypeVar("R")
|
|
|
93
97
|
|
|
94
98
|
|
|
95
99
|
# ----------------------------- CLI Utilities ----------------------------- #
|
|
100
|
+
def _print_version() -> None:
|
|
101
|
+
"""Print PG-SUI version and exit."""
|
|
102
|
+
from pgsui import __version__ as version
|
|
103
|
+
|
|
104
|
+
logging.info(f"Using PG-SUI version: {version}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _model_family(model_name: str) -> str:
|
|
108
|
+
"""Return output family folder name used by PG-SUI."""
|
|
109
|
+
if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
110
|
+
return "Unsupervised"
|
|
111
|
+
if model_name in {"ImputeMostFrequent", "ImputeRefAllele"}:
|
|
112
|
+
return "Deterministic"
|
|
113
|
+
return "Unknown"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _flatten_dict(d: dict, parent: str = "") -> dict:
|
|
117
|
+
"""Flatten a nested dict into dot keys."""
|
|
118
|
+
out: dict = {}
|
|
119
|
+
for k, v in (d or {}).items():
|
|
120
|
+
key = f"{parent}.{k}" if parent else str(k)
|
|
121
|
+
if isinstance(v, dict):
|
|
122
|
+
out.update(_flatten_dict(v, key))
|
|
123
|
+
else:
|
|
124
|
+
out[key] = v
|
|
125
|
+
return out
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _force_tuning_off(cfg: Any, model_name: str) -> Any:
|
|
129
|
+
"""Force tuning disabled on a config object (best-effort, but strict for tune-capable models)."""
|
|
130
|
+
# Prefer direct attribute mutation (avoids apply_dot_overrides edge-cases)
|
|
131
|
+
try:
|
|
132
|
+
if hasattr(cfg, "tune") and hasattr(cfg.tune, "enabled"):
|
|
133
|
+
cfg.tune.enabled = False
|
|
134
|
+
return cfg
|
|
135
|
+
except Exception:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
# Fallback to dot override
|
|
139
|
+
try:
|
|
140
|
+
return apply_dot_overrides(cfg, {"tune.enabled": False})
|
|
141
|
+
except Exception as e:
|
|
142
|
+
# Only strict for models that actually support tuning
|
|
143
|
+
if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
144
|
+
raise RuntimeError(
|
|
145
|
+
f"Failed to force tuning off for {model_name}: {e}"
|
|
146
|
+
) from e
|
|
147
|
+
return cfg
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _find_best_params_json(prefix: str, model_name: str) -> Path | None:
|
|
151
|
+
"""Locate best parameter JSON (tuned or final) for a model.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
prefix (str): Output prefix used during the run.
|
|
155
|
+
model_name (str): Model name to look for.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Path | None: Path to best_parameters.json / best_tuned_parameters.json if found; else None.
|
|
159
|
+
"""
|
|
160
|
+
families = ("Unsupervised", "Deterministic")
|
|
161
|
+
model_dir_candidates = (model_name, model_name.lower())
|
|
162
|
+
|
|
163
|
+
for fam in families:
|
|
164
|
+
for mdir in model_dir_candidates:
|
|
165
|
+
base = Path(f"{prefix}_output") / fam
|
|
166
|
+
candidates = (
|
|
167
|
+
base / "optimize" / mdir / "parameters" / "best_tuned_parameters.json",
|
|
168
|
+
base / "parameters" / mdir / "best_parameters.json",
|
|
169
|
+
)
|
|
170
|
+
for p in candidates:
|
|
171
|
+
if p.exists():
|
|
172
|
+
return p
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _load_best_params(best_params_path: Path) -> dict:
|
|
177
|
+
"""Load best parameters JSON."""
|
|
178
|
+
with best_params_path.open("r", encoding="utf-8") as f:
|
|
179
|
+
data = json.load(f)
|
|
180
|
+
if not isinstance(data, dict):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"best_parameters.json must be a JSON object, got {type(data)}"
|
|
183
|
+
)
|
|
184
|
+
return data
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _apply_best_params_to_cfg(cfg: Any, best_params: dict, model_name: str) -> Any:
|
|
188
|
+
"""Apply best params into cfg using dot-path keys or inferred dot-paths.
|
|
189
|
+
|
|
190
|
+
- If JSON is nested, flatten to dot keys.
|
|
191
|
+
- If key already contains '.', treat as dot-path and apply directly.
|
|
192
|
+
- If key has no '.', try common sections in order: model., train., sim., tune., io., plot.
|
|
193
|
+
- Unknown keys are ignored with a warning.
|
|
194
|
+
"""
|
|
195
|
+
# Flatten if nested (but keep existing dot keys as-is too)
|
|
196
|
+
flat = {}
|
|
197
|
+
for k, v in best_params.items():
|
|
198
|
+
if isinstance(v, dict):
|
|
199
|
+
flat.update(_flatten_dict(v, str(k)))
|
|
200
|
+
else:
|
|
201
|
+
flat[str(k)] = v
|
|
202
|
+
|
|
203
|
+
candidate_prefixes = ("", "model.", "train.", "sim.", "tune.", "io.", "plot.")
|
|
204
|
+
|
|
205
|
+
# Apply one by one so we can try multiple candidate destinations for
|
|
206
|
+
# non-dot keys
|
|
207
|
+
for raw_k, v in flat.items():
|
|
208
|
+
if "." in raw_k:
|
|
209
|
+
try:
|
|
210
|
+
cfg = apply_dot_overrides(cfg, {raw_k: v})
|
|
211
|
+
continue
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logging.warning(
|
|
214
|
+
f"Could not apply best param '{raw_k}' to {model_name} (dot key). Skipping. Error: {e}"
|
|
215
|
+
)
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
applied = False
|
|
219
|
+
for pref in candidate_prefixes:
|
|
220
|
+
k = f"{pref}{raw_k}" if pref else raw_k
|
|
221
|
+
try:
|
|
222
|
+
cfg = apply_dot_overrides(cfg, {k: v})
|
|
223
|
+
applied = True
|
|
224
|
+
break
|
|
225
|
+
except Exception:
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
if not applied:
|
|
229
|
+
logging.warning(
|
|
230
|
+
f"Best param '{raw_k}' not recognized for {model_name}; leaving config unchanged for that key."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return cfg
|
|
234
|
+
|
|
235
|
+
|
|
96
236
|
def _configure_logging(verbose: bool, log_file: Optional[str] = None) -> None:
|
|
97
237
|
"""Configure root logger.
|
|
98
238
|
|
|
@@ -166,6 +306,96 @@ def _parse_overrides(pairs: list[str]) -> dict:
|
|
|
166
306
|
return out
|
|
167
307
|
|
|
168
308
|
|
|
309
|
+
def _parse_allele_encoding(arg: str) -> dict:
|
|
310
|
+
"""Parse STRUCTURE allele encoding dict from JSON or Python literal."""
|
|
311
|
+
try:
|
|
312
|
+
payload = json.loads(arg)
|
|
313
|
+
except Exception:
|
|
314
|
+
try:
|
|
315
|
+
payload = ast.literal_eval(arg)
|
|
316
|
+
except Exception as e:
|
|
317
|
+
raise argparse.ArgumentTypeError(
|
|
318
|
+
f"Invalid --structure-allele-encoding; must be a dict. Error: {e}"
|
|
319
|
+
) from e
|
|
320
|
+
|
|
321
|
+
if not isinstance(payload, dict):
|
|
322
|
+
raise argparse.ArgumentTypeError(
|
|
323
|
+
"--structure-allele-encoding must be a dict-like mapping."
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
out: dict = {}
|
|
327
|
+
for k, v in payload.items():
|
|
328
|
+
key = k
|
|
329
|
+
if isinstance(k, str):
|
|
330
|
+
k_strip = k.strip()
|
|
331
|
+
if k_strip.lstrip("-").isdigit():
|
|
332
|
+
try:
|
|
333
|
+
key = int(k_strip)
|
|
334
|
+
except Exception:
|
|
335
|
+
key = k
|
|
336
|
+
out[key] = str(v)
|
|
337
|
+
return out
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _normalize_input_format(fmt: str) -> str:
|
|
341
|
+
"""Normalize format aliases into canonical reader names."""
|
|
342
|
+
fmt = fmt.lower()
|
|
343
|
+
if fmt in {"vcf", "vcf.gz"}:
|
|
344
|
+
return "vcf"
|
|
345
|
+
if fmt in {"phy", "phylip"}:
|
|
346
|
+
return "phylip"
|
|
347
|
+
if fmt in {"gen", "genepop"}:
|
|
348
|
+
return "genepop"
|
|
349
|
+
if fmt in {"str", "structure"}:
|
|
350
|
+
return "structure"
|
|
351
|
+
return fmt
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _normalize_plot_format(fmt: str) -> Literal["pdf", "png", "jpg", "svg"]:
|
|
355
|
+
"""Normalize plot format aliases to reader-supported values."""
|
|
356
|
+
fmt = fmt.lower()
|
|
357
|
+
if fmt == "jpeg":
|
|
358
|
+
return "jpg"
|
|
359
|
+
return cast(Literal["pdf", "png", "jpg", "svg"], fmt)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _expand_path(path: str | None) -> str | None:
|
|
363
|
+
"""Expand ~ and env vars in a path-like string."""
|
|
364
|
+
if path is None:
|
|
365
|
+
return None
|
|
366
|
+
raw = str(path).strip()
|
|
367
|
+
if not raw:
|
|
368
|
+
return None
|
|
369
|
+
expanded = os.path.expandvars(raw)
|
|
370
|
+
return str(Path(expanded).expanduser())
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _resolve_tree_paths(
|
|
374
|
+
args: argparse.Namespace,
|
|
375
|
+
) -> tuple[str | None, str | None, str | None]:
|
|
376
|
+
"""Resolve tree-related paths from CLI args."""
|
|
377
|
+
treefile = _expand_path(getattr(args, "treefile", None))
|
|
378
|
+
qmatrix = _expand_path(getattr(args, "qmatrix", None))
|
|
379
|
+
siterates = _expand_path(getattr(args, "siterates", None))
|
|
380
|
+
return treefile, qmatrix, siterates
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _config_needs_tree(cfg: Any | None) -> bool:
|
|
384
|
+
"""Return True if config requires a tree parser for simulated missingness."""
|
|
385
|
+
if cfg is None:
|
|
386
|
+
return False
|
|
387
|
+
sim_cfg = getattr(cfg, "sim", None)
|
|
388
|
+
if sim_cfg is None:
|
|
389
|
+
return False
|
|
390
|
+
strategy = getattr(sim_cfg, "sim_strategy", None)
|
|
391
|
+
simulate = bool(getattr(sim_cfg, "simulate_missing", False))
|
|
392
|
+
return (
|
|
393
|
+
simulate
|
|
394
|
+
and isinstance(strategy, str)
|
|
395
|
+
and strategy in {"nonrandom", "nonrandom_weighted"}
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
|
|
169
399
|
def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
170
400
|
"""Convert explicitly provided CLI flags into config dot-overrides."""
|
|
171
401
|
overrides: dict = {}
|
|
@@ -174,10 +404,12 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
|
174
404
|
if hasattr(args, "prefix") and args.prefix is not None:
|
|
175
405
|
overrides["io.prefix"] = args.prefix
|
|
176
406
|
else:
|
|
177
|
-
#
|
|
178
|
-
|
|
179
|
-
if hasattr(args, "vcf"):
|
|
180
|
-
|
|
407
|
+
# Prefer --input stem; fallback to legacy --vcf stem
|
|
408
|
+
input_path = getattr(args, "input", None)
|
|
409
|
+
if input_path is None and hasattr(args, "vcf"):
|
|
410
|
+
input_path = getattr(args, "vcf", None)
|
|
411
|
+
if input_path:
|
|
412
|
+
overrides["io.prefix"] = str(Path(input_path).stem)
|
|
181
413
|
|
|
182
414
|
if hasattr(args, "verbose"):
|
|
183
415
|
overrides["io.verbose"] = bool(args.verbose)
|
|
@@ -185,6 +417,8 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
|
185
417
|
overrides["io.n_jobs"] = int(args.n_jobs)
|
|
186
418
|
if hasattr(args, "seed"):
|
|
187
419
|
overrides["io.seed"] = _parse_seed(args.seed)
|
|
420
|
+
if hasattr(args, "debug"):
|
|
421
|
+
overrides["io.debug"] = bool(args.debug)
|
|
188
422
|
|
|
189
423
|
# Train
|
|
190
424
|
if hasattr(args, "batch_size"):
|
|
@@ -198,20 +432,34 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
|
|
|
198
432
|
# Plot
|
|
199
433
|
if hasattr(args, "plot_format"):
|
|
200
434
|
overrides["plot.fmt"] = args.plot_format
|
|
435
|
+
if getattr(args, "disable_plotting", False):
|
|
436
|
+
logging.info(
|
|
437
|
+
"Disabling plotting for all models as per --disable-plotting flag."
|
|
438
|
+
)
|
|
439
|
+
overrides["plot.show"] = False
|
|
201
440
|
|
|
202
|
-
# Simulation overrides
|
|
441
|
+
# Simulation overrides
|
|
203
442
|
if hasattr(args, "sim_strategy"):
|
|
204
443
|
overrides["sim.sim_strategy"] = args.sim_strategy
|
|
205
444
|
if hasattr(args, "sim_prop"):
|
|
206
445
|
overrides["sim.sim_prop"] = float(args.sim_prop)
|
|
207
|
-
if hasattr(args, "simulate_missing"):
|
|
208
|
-
overrides["sim.simulate_missing"] = bool(args.simulate_missing)
|
|
209
446
|
|
|
210
447
|
# Tuning
|
|
211
|
-
if
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
448
|
+
if getattr(args, "load_best_params", False):
|
|
449
|
+
# Never allow CLI flags to re-enable tuning when loading params
|
|
450
|
+
if hasattr(args, "tune") and bool(getattr(args, "tune", False)):
|
|
451
|
+
logging.warning(
|
|
452
|
+
"--tune was supplied, but --load-best-params is active; ignoring --tune."
|
|
453
|
+
)
|
|
454
|
+
if hasattr(args, "tune_n_trials"):
|
|
455
|
+
logging.warning(
|
|
456
|
+
"--tune-n-trials was supplied, but --load-best-params is active; ignoring it."
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
if hasattr(args, "tune"):
|
|
460
|
+
overrides["tune.enabled"] = bool(args.tune)
|
|
461
|
+
if hasattr(args, "tune_n_trials"):
|
|
462
|
+
overrides["tune.n_trials"] = int(args.tune_n_trials)
|
|
215
463
|
|
|
216
464
|
return overrides
|
|
217
465
|
|
|
@@ -255,35 +503,77 @@ def log_model_time(fn: Callable[P, R]) -> Callable[P, R]:
|
|
|
255
503
|
# ------------------------------ Core Runner ------------------------------ #
|
|
256
504
|
def build_genotype_data(
|
|
257
505
|
input_path: str,
|
|
258
|
-
fmt: Literal[
|
|
506
|
+
fmt: Literal[
|
|
507
|
+
"vcf",
|
|
508
|
+
"vcf.gz",
|
|
509
|
+
"phy",
|
|
510
|
+
"phylip",
|
|
511
|
+
"genepop",
|
|
512
|
+
"gen",
|
|
513
|
+
"structure",
|
|
514
|
+
"str",
|
|
515
|
+
],
|
|
259
516
|
popmap_path: str | None,
|
|
260
517
|
treefile: str | None,
|
|
261
518
|
qmatrix: str | None,
|
|
262
519
|
siterates: str | None,
|
|
263
520
|
force_popmap: bool,
|
|
264
|
-
|
|
521
|
+
debug: bool,
|
|
265
522
|
include_pops: List[str] | None,
|
|
266
|
-
plot_format: Literal["pdf", "png", "jpg", "jpeg"],
|
|
523
|
+
plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"],
|
|
524
|
+
structure_has_popids: bool = False,
|
|
525
|
+
structure_has_marker_names: bool = False,
|
|
526
|
+
structure_allele_start_col: int | None = None,
|
|
527
|
+
structure_allele_encoding: dict | None = None,
|
|
267
528
|
):
|
|
268
|
-
"""Load genotype data from heterogeneous inputs.
|
|
269
|
-
|
|
529
|
+
"""Load genotype data from heterogeneous inputs.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
input_path (str): Path to genotype data file.
|
|
533
|
+
fmt (Literal): Format of genotype data file.
|
|
534
|
+
popmap_path (str | None): Optional path to population map file.
|
|
535
|
+
treefile (str | None): Optional path to phylogenetic tree file.
|
|
536
|
+
qmatrix (str | None): Optional path to IQ-TREE Q matrix file.
|
|
537
|
+
siterates (str | None): Optional path to SNP site rates file.
|
|
538
|
+
force_popmap (bool): Whether to force use of popmap even if samples don't match exactly.
|
|
539
|
+
debug (bool): Whether to enable debug-level logging in SNPio readers.
|
|
540
|
+
include_pops (List[str] | None): Optional list of population IDs to include.
|
|
541
|
+
plot_format (Literal): Figure format for SNPio plots.
|
|
542
|
+
structure_has_popids (bool): STRUCTURE only; whether pop IDs are present.
|
|
543
|
+
structure_has_marker_names (bool): STRUCTURE only; whether the first line has marker names.
|
|
544
|
+
structure_allele_start_col (int | None): STRUCTURE only; zero-based allele start column.
|
|
545
|
+
structure_allele_encoding (dict | None): STRUCTURE only; allele encoding map.
|
|
546
|
+
"""
|
|
547
|
+
fmt_norm = _normalize_input_format(fmt)
|
|
548
|
+
plot_format = _normalize_plot_format(cast(str, plot_format))
|
|
549
|
+
logging.info(f"Loading {fmt_norm.upper()} and popmap data...")
|
|
270
550
|
|
|
271
551
|
kwargs = {
|
|
272
552
|
"filename": input_path,
|
|
273
553
|
"popmapfile": popmap_path,
|
|
274
554
|
"force_popmap": force_popmap,
|
|
275
|
-
"verbose":
|
|
555
|
+
"verbose": debug,
|
|
276
556
|
"include_pops": include_pops if include_pops else None,
|
|
277
557
|
"prefix": f"snpio_{Path(input_path).stem}",
|
|
278
558
|
"plot_format": plot_format,
|
|
279
559
|
}
|
|
280
560
|
|
|
281
|
-
if
|
|
561
|
+
if fmt_norm == "vcf":
|
|
282
562
|
gd = VCFReader(**kwargs)
|
|
283
|
-
elif
|
|
563
|
+
elif fmt_norm == "phylip":
|
|
284
564
|
gd = PhylipReader(**kwargs)
|
|
285
|
-
elif
|
|
565
|
+
elif fmt_norm == "genepop":
|
|
286
566
|
gd = GenePopReader(**kwargs)
|
|
567
|
+
elif fmt_norm == "structure":
|
|
568
|
+
kwargs.update(
|
|
569
|
+
{
|
|
570
|
+
"has_popids": structure_has_popids,
|
|
571
|
+
"has_marker_names": structure_has_marker_names,
|
|
572
|
+
"allele_start_col": structure_allele_start_col,
|
|
573
|
+
"allele_encoding": structure_allele_encoding,
|
|
574
|
+
}
|
|
575
|
+
)
|
|
576
|
+
gd = StructureReader(**kwargs)
|
|
287
577
|
else:
|
|
288
578
|
raise ValueError(f"Unsupported genotype data format: {fmt}")
|
|
289
579
|
|
|
@@ -319,8 +609,6 @@ def run_model_safely(model_name: str, builder, *, warn_only: bool = True) -> Non
|
|
|
319
609
|
# -------------------------- Model Registry ------------------------------- #
|
|
320
610
|
# Add config-driven models here by listing the class and its config dataclass.
|
|
321
611
|
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|
322
|
-
"ImputeUBP": {"cls": ImputeUBP, "config_cls": UBPConfig},
|
|
323
|
-
"ImputeNLPCA": {"cls": ImputeNLPCA, "config_cls": NLPCAConfig},
|
|
324
612
|
"ImputeAutoencoder": {"cls": ImputeAutoencoder, "config_cls": AutoencoderConfig},
|
|
325
613
|
"ImputeVAE": {"cls": ImputeVAE, "config_cls": VAEConfig},
|
|
326
614
|
"ImputeMostFrequent": {"cls": ImputeMostFrequent, "config_cls": MostFrequentConfig},
|
|
@@ -370,24 +658,65 @@ def _build_effective_config_for_model(
|
|
|
370
658
|
f"Loaded YAML config for {model_name} from {yaml_path} (ignored 'preset' in YAML if present)."
|
|
371
659
|
)
|
|
372
660
|
|
|
373
|
-
# 3)
|
|
661
|
+
# 3) Optional: load best parameters from a previous run and force tuning OFF.
|
|
662
|
+
if getattr(args, "load_best_params", False):
|
|
663
|
+
# Determine which prefix to look under for *_output
|
|
664
|
+
src_prefix = getattr(args, "best_params_prefix", None)
|
|
665
|
+
if src_prefix is None:
|
|
666
|
+
# Use the resolved prefix if provided; otherwise fall back to input
|
|
667
|
+
# stem behavior
|
|
668
|
+
src_prefix = getattr(args, "prefix", None)
|
|
669
|
+
|
|
670
|
+
if src_prefix is None and hasattr(args, "vcf"):
|
|
671
|
+
src_prefix = str(Path(args.vcf).stem)
|
|
672
|
+
|
|
673
|
+
if src_prefix is None:
|
|
674
|
+
# As a last resort, use current effective io.prefix if it exists in cfg
|
|
675
|
+
src_prefix = getattr(getattr(cfg, "io", object()), "prefix", None)
|
|
676
|
+
|
|
677
|
+
if getattr(args, "tune", False):
|
|
678
|
+
logging.warning(
|
|
679
|
+
"--tune was supplied, but --load-best-params is active; forcing tuning OFF."
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Force tuning disabled in config (even if CLI/YAML enabled it)
|
|
683
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
684
|
+
|
|
685
|
+
best_path = _find_best_params_json(str(src_prefix), model_name)
|
|
686
|
+
if best_path is None:
|
|
687
|
+
# For tune-capable (unsupervised) models, treat as an error; deterministic models warn only.
|
|
688
|
+
fam = _model_family(model_name)
|
|
689
|
+
msg = (
|
|
690
|
+
"Requested --load-best-params, but could not find a best parameters JSON "
|
|
691
|
+
f"for {model_name}. Looked under '.../optimize/<model>/parameters/best_tuned_parameters.json' and '{src_prefix}_output/{fam}/parameters/{model_name}/best_parameters.json'"
|
|
692
|
+
)
|
|
693
|
+
if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
694
|
+
logging.error(msg)
|
|
695
|
+
raise FileNotFoundError(msg)
|
|
696
|
+
logging.warning(msg)
|
|
697
|
+
else:
|
|
698
|
+
logging.info(f"Loading best parameters for {model_name} from: {best_path}")
|
|
699
|
+
best_params = _load_best_params(best_path)
|
|
700
|
+
cfg = _apply_best_params_to_cfg(cfg, best_params, model_name)
|
|
701
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
702
|
+
|
|
703
|
+
# 4) Explicit CLI flags overlay YAML/best-params layers.
|
|
374
704
|
cli_overrides = _args_to_cli_overrides(args)
|
|
375
705
|
if cli_overrides:
|
|
376
706
|
cfg = apply_dot_overrides(cfg, cli_overrides)
|
|
377
707
|
|
|
378
|
-
#
|
|
708
|
+
# Keep tuning disabled if --load-best-params was requested, even if CLI flags tried to re-enable it.
|
|
709
|
+
if getattr(args, "load_best_params", False):
|
|
710
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
711
|
+
|
|
712
|
+
# 5) --set has highest precedence.
|
|
379
713
|
user_overrides = _parse_overrides(getattr(args, "set", []))
|
|
380
714
|
|
|
381
715
|
if user_overrides:
|
|
382
716
|
try:
|
|
383
717
|
cfg = apply_dot_overrides(cfg, user_overrides)
|
|
384
718
|
except Exception as e:
|
|
385
|
-
if model_name in {
|
|
386
|
-
"ImputeUBP",
|
|
387
|
-
"ImputeNLPCA",
|
|
388
|
-
"ImputeAutoencoder",
|
|
389
|
-
"ImputeVAE",
|
|
390
|
-
}:
|
|
719
|
+
if model_name in {"ImputeAutoencoder", "ImputeVAE"}:
|
|
391
720
|
logging.error(
|
|
392
721
|
f"Error applying --set overrides to {model_name} config: {e}"
|
|
393
722
|
)
|
|
@@ -395,6 +724,18 @@ def _build_effective_config_for_model(
|
|
|
395
724
|
else:
|
|
396
725
|
pass # non-config-driven models ignore --set
|
|
397
726
|
|
|
727
|
+
# FINAL GUARANTEE:
|
|
728
|
+
# --load-best-params always wins over
|
|
729
|
+
# --set, YAML, preset, and CLI flags.
|
|
730
|
+
if getattr(args, "load_best_params", False):
|
|
731
|
+
# If user explicitly tried to set tune.* via --set, warn and override.
|
|
732
|
+
if any(str(k).startswith("tune.") for k in (user_overrides or {}).keys()):
|
|
733
|
+
logging.warning(
|
|
734
|
+
f"{model_name}: '--set tune.*=...' was provided, but --load-best-params forces tuning OFF. "
|
|
735
|
+
"Ignoring any tune.* overrides."
|
|
736
|
+
)
|
|
737
|
+
cfg = _force_tuning_off(cfg, model_name)
|
|
738
|
+
|
|
398
739
|
return cfg
|
|
399
740
|
|
|
400
741
|
|
|
@@ -435,9 +776,19 @@ def _maybe_print_or_dump_configs(
|
|
|
435
776
|
|
|
436
777
|
|
|
437
778
|
def main(argv: Optional[List[str]] = None) -> int:
|
|
779
|
+
"""PG-SUI CLI main entry point.
|
|
780
|
+
|
|
781
|
+
The CLI supports running multiple imputation models on a single input file, with configuration handled via presets, YAML files, and CLI flags.
|
|
782
|
+
|
|
783
|
+
Args:
|
|
784
|
+
argv (Optional[List[str]]): List of CLI args (default: sys.argv[1:]).
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
int: Exit code (0=success, 2=argparse error, 1=other error).
|
|
788
|
+
"""
|
|
438
789
|
parser = argparse.ArgumentParser(
|
|
439
790
|
prog="pg-sui",
|
|
440
|
-
description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models.",
|
|
791
|
+
description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models. The input file can be in VCF, PHYLIP, or GENEPOP format. Outputs include imputed genotype files and performance summaries.",
|
|
441
792
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
442
793
|
usage="%(prog)s [options]",
|
|
443
794
|
)
|
|
@@ -446,7 +797,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
446
797
|
parser.add_argument(
|
|
447
798
|
"--input",
|
|
448
799
|
default=argparse.SUPPRESS,
|
|
449
|
-
help="Path to input file (VCF/PHYLIP/
|
|
800
|
+
help="Path to input file (VCF/PHYLIP/GENEPOP). VCF file can be bgzipped or uncompressed.",
|
|
450
801
|
)
|
|
451
802
|
parser.add_argument(
|
|
452
803
|
"--format",
|
|
@@ -462,23 +813,23 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
462
813
|
"gen",
|
|
463
814
|
),
|
|
464
815
|
default=argparse.SUPPRESS,
|
|
465
|
-
help="Input format. If 'infer', deduced from file extension. The default is 'infer'.",
|
|
816
|
+
help="Input format. If 'infer', deduced from file extension. The default is 'infer'. Supported formats: VCF ('.vcf', '.vcf.gz'), PHYLIP ('.phy', '.phylip'), GENEPOP ('.genepop', '.gen').",
|
|
466
817
|
)
|
|
467
818
|
# Back-compat: --vcf retained; if both provided, --input wins.
|
|
468
819
|
parser.add_argument(
|
|
469
820
|
"--vcf",
|
|
470
821
|
default=argparse.SUPPRESS,
|
|
471
|
-
help="Path to input VCF file. Can be bgzipped or uncompressed.",
|
|
822
|
+
help="Path to input VCF file. Can be bgzipped or uncompressed. (Deprecated; use --input instead.)",
|
|
472
823
|
)
|
|
473
824
|
parser.add_argument(
|
|
474
825
|
"--popmap",
|
|
475
826
|
default=argparse.SUPPRESS,
|
|
476
|
-
help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs.",
|
|
827
|
+
help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs. If not provided, no population info is used.",
|
|
477
828
|
)
|
|
478
829
|
parser.add_argument(
|
|
479
830
|
"--treefile",
|
|
480
831
|
default=argparse.SUPPRESS,
|
|
481
|
-
help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format.",
|
|
832
|
+
help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format. Used with --qmatrix and --siterates.",
|
|
482
833
|
)
|
|
483
834
|
parser.add_argument(
|
|
484
835
|
"--qmatrix",
|
|
@@ -488,19 +839,19 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
488
839
|
parser.add_argument(
|
|
489
840
|
"--siterates",
|
|
490
841
|
default=argparse.SUPPRESS,
|
|
491
|
-
help="Path to SNP site rates file (has .rate extension). Used with --treefile and --qmatrix.",
|
|
842
|
+
help="Path to SNP site rates file (has .rate extension and can be produced with IQ-TREE). Used with --treefile and --qmatrix.",
|
|
492
843
|
)
|
|
493
844
|
parser.add_argument(
|
|
494
845
|
"--prefix",
|
|
495
846
|
default=argparse.SUPPRESS,
|
|
496
|
-
help="Output file prefix.",
|
|
847
|
+
help="Output file prefix. If not provided, defaults to the input file stem.",
|
|
497
848
|
)
|
|
498
849
|
|
|
499
850
|
# ---------------------- Generic Config Inputs -------------------------- #
|
|
500
851
|
parser.add_argument(
|
|
501
852
|
"--config",
|
|
502
853
|
default=argparse.SUPPRESS,
|
|
503
|
-
help="YAML config for config-driven models (
|
|
854
|
+
help="YAML config for config-driven models (Autoencoder, VAE). Overrides preset and defaults.",
|
|
504
855
|
)
|
|
505
856
|
parser.add_argument(
|
|
506
857
|
"--preset",
|
|
@@ -512,7 +863,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
512
863
|
"--set",
|
|
513
864
|
action="append",
|
|
514
865
|
default=argparse.SUPPRESS,
|
|
515
|
-
help="Dot-key overrides, e.g. --set model.latent_dim=4",
|
|
866
|
+
help="Dot-key overrides, e.g. --set model.latent_dim=4 --set train.epochs=100. Applies to all models.",
|
|
516
867
|
)
|
|
517
868
|
parser.add_argument(
|
|
518
869
|
"--print-config",
|
|
@@ -530,7 +881,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
530
881
|
"--tune",
|
|
531
882
|
action="store_true",
|
|
532
883
|
default=argparse.SUPPRESS,
|
|
533
|
-
help="Enable hyperparameter tuning (if supported).",
|
|
884
|
+
help="Enable hyperparameter tuning (if supported by model). Uses Optuna to optimize hyperparameters.",
|
|
534
885
|
)
|
|
535
886
|
parser.add_argument(
|
|
536
887
|
"--tune-n-trials",
|
|
@@ -560,8 +911,31 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
560
911
|
"--plot-format",
|
|
561
912
|
choices=("png", "pdf", "svg", "jpg", "jpeg"),
|
|
562
913
|
default=argparse.SUPPRESS,
|
|
563
|
-
help="Figure format for model plots.",
|
|
914
|
+
help="Figure format for model plots. Choices: png, pdf, svg, jpg, jpeg.",
|
|
915
|
+
)
|
|
916
|
+
parser.add_argument(
|
|
917
|
+
"--disable-plotting",
|
|
918
|
+
action="store_true",
|
|
919
|
+
default=False,
|
|
920
|
+
help="Disable plotting for all models. Overrides any config settings enabling plotting.",
|
|
564
921
|
)
|
|
922
|
+
|
|
923
|
+
parser.add_argument(
|
|
924
|
+
"--load-best-params",
|
|
925
|
+
action="store_true",
|
|
926
|
+
default=False,
|
|
927
|
+
help=(
|
|
928
|
+
"Load best hyperparameters from a previous run's best_parameters.json (or tuning best_tuned_parameters.json) for each selected model and apply them to the model configs. This forces tuning OFF."
|
|
929
|
+
),
|
|
930
|
+
)
|
|
931
|
+
parser.add_argument(
|
|
932
|
+
"--best-params-prefix",
|
|
933
|
+
default=argparse.SUPPRESS,
|
|
934
|
+
help=(
|
|
935
|
+
"Prefix of the PREVIOUS run to load best parameters from. If omitted, uses the current --prefix (or input stem)."
|
|
936
|
+
),
|
|
937
|
+
)
|
|
938
|
+
|
|
565
939
|
# ------------------------- Simulation Controls ------------------------ #
|
|
566
940
|
parser.add_argument(
|
|
567
941
|
"--sim-strategy",
|
|
@@ -575,20 +949,17 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
575
949
|
default=argparse.SUPPRESS,
|
|
576
950
|
help="Override the proportion of observed entries to mask during simulation (0-1).",
|
|
577
951
|
)
|
|
578
|
-
parser.add_argument(
|
|
579
|
-
"--simulate-missing",
|
|
580
|
-
action="store_false",
|
|
581
|
-
default=argparse.SUPPRESS,
|
|
582
|
-
help="Disable missing-data simulation regardless of preset/config (when provided).",
|
|
583
|
-
)
|
|
584
952
|
|
|
585
953
|
# --------------------------- Seed & logging ---------------------------- #
|
|
586
954
|
parser.add_argument(
|
|
587
955
|
"--seed",
|
|
588
956
|
default=argparse.SUPPRESS,
|
|
589
|
-
help="Random seed: 'random', 'deterministic', or an integer.",
|
|
957
|
+
help="Random seed: 'random', 'deterministic', or an integer. Default is 'random'.",
|
|
590
958
|
)
|
|
959
|
+
|
|
960
|
+
# ----------------------------- Logging --------------------------------- #
|
|
591
961
|
parser.add_argument("--verbose", action="store_true", help="Info-level logging.")
|
|
962
|
+
parser.add_argument("--debug", action="store_true", help="Debug-level logging.")
|
|
592
963
|
parser.add_argument(
|
|
593
964
|
"--log-file", default=argparse.SUPPRESS, help="Also write logs to a file."
|
|
594
965
|
)
|
|
@@ -604,7 +975,33 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
604
975
|
"--force-popmap",
|
|
605
976
|
action="store_true",
|
|
606
977
|
default=False,
|
|
607
|
-
help="
|
|
978
|
+
help="Force use of provided popmap even if samples don't match exactly. This will drop samples not in the popmap and vice versa.",
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# -------------------------- STRUCTURE options ------------------------- #
|
|
982
|
+
parser.add_argument(
|
|
983
|
+
"--structure-has-popids",
|
|
984
|
+
action="store_true",
|
|
985
|
+
default=False,
|
|
986
|
+
help="STRUCTURE only: second column contains population IDs.",
|
|
987
|
+
)
|
|
988
|
+
parser.add_argument(
|
|
989
|
+
"--structure-has-marker-names",
|
|
990
|
+
action="store_true",
|
|
991
|
+
default=False,
|
|
992
|
+
help="STRUCTURE only: first row contains marker names.",
|
|
993
|
+
)
|
|
994
|
+
parser.add_argument(
|
|
995
|
+
"--structure-allele-start-col",
|
|
996
|
+
type=int,
|
|
997
|
+
default=argparse.SUPPRESS,
|
|
998
|
+
help="STRUCTURE only: zero-based column index where alleles begin.",
|
|
999
|
+
)
|
|
1000
|
+
parser.add_argument(
|
|
1001
|
+
"--structure-allele-encoding",
|
|
1002
|
+
type=_parse_allele_encoding,
|
|
1003
|
+
default=argparse.SUPPRESS,
|
|
1004
|
+
help="STRUCTURE only: allele encoding mapping as JSON or Python dict.",
|
|
608
1005
|
)
|
|
609
1006
|
|
|
610
1007
|
# ---------------------------- Model selection -------------------------- #
|
|
@@ -613,54 +1010,66 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
613
1010
|
nargs="+",
|
|
614
1011
|
default=argparse.SUPPRESS,
|
|
615
1012
|
help=(
|
|
616
|
-
"Which models to run. Choices:
|
|
1013
|
+
"Which models to run. Specify each model separated by a space. Choices: ImputeVAE ImputeAutoencoder ImputeMostFrequent ImputeRefAllele (Default is all models)."
|
|
617
1014
|
),
|
|
618
1015
|
)
|
|
619
1016
|
|
|
620
1017
|
# -------------------------- MultiQC integration ------------------------ #
|
|
621
1018
|
parser.add_argument(
|
|
622
|
-
"--multiqc",
|
|
1019
|
+
"--disable-multiqc",
|
|
623
1020
|
action="store_true",
|
|
1021
|
+
default=False,
|
|
624
1022
|
help=(
|
|
625
|
-
"
|
|
1023
|
+
"Disable MultiQC report generation after imputation. By default, a MultiQC report is generated unless this flag is set."
|
|
626
1024
|
),
|
|
627
1025
|
)
|
|
628
1026
|
parser.add_argument(
|
|
629
1027
|
"--multiqc-title",
|
|
630
1028
|
default=argparse.SUPPRESS,
|
|
631
|
-
help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>').",
|
|
1029
|
+
help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>'). ",
|
|
632
1030
|
)
|
|
633
1031
|
parser.add_argument(
|
|
634
1032
|
"--multiqc-output-dir",
|
|
635
1033
|
default=argparse.SUPPRESS,
|
|
636
|
-
help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc').",
|
|
1034
|
+
help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc'). This directory will be created if it does not exist.",
|
|
637
1035
|
)
|
|
638
1036
|
parser.add_argument(
|
|
639
1037
|
"--multiqc-overwrite",
|
|
640
1038
|
action="store_true",
|
|
641
1039
|
default=False,
|
|
642
|
-
help="Overwrite an existing MultiQC report if present.",
|
|
1040
|
+
help="Overwrite an existing MultiQC report if present. If not set and a report exists, an integer suffix will be added to avoid overwriting. NOTE: if running multiple times with this flag, it may append multiple suffixes to avoid overwriting previous reports.",
|
|
643
1041
|
)
|
|
644
1042
|
|
|
645
1043
|
# ------------------------------ Safety/UX ------------------------------ #
|
|
646
1044
|
parser.add_argument(
|
|
647
1045
|
"--dry-run",
|
|
648
1046
|
action="store_true",
|
|
649
|
-
help="Parse args and load data, but skip model training.",
|
|
1047
|
+
help="Parse args and load data, but skip model training. Useful for testing I/O and configs.",
|
|
1048
|
+
)
|
|
1049
|
+
parser.add_argument(
|
|
1050
|
+
"--version", action="store_true", help="Print PG-SUI version and exit."
|
|
650
1051
|
)
|
|
651
1052
|
|
|
652
1053
|
args = parser.parse_args(argv)
|
|
653
1054
|
|
|
1055
|
+
if getattr(args, "version", False):
|
|
1056
|
+
_print_version()
|
|
1057
|
+
return 0
|
|
1058
|
+
|
|
654
1059
|
# Logging (verbose default is False unless passed)
|
|
655
1060
|
_configure_logging(
|
|
656
1061
|
verbose=getattr(args, "verbose", False),
|
|
657
1062
|
log_file=getattr(args, "log_file", None),
|
|
658
1063
|
)
|
|
659
1064
|
|
|
1065
|
+
logging.info("Starting PG-SUI imputation...")
|
|
1066
|
+
_print_version()
|
|
1067
|
+
|
|
660
1068
|
# Models selection (default to all if not explicitly provided)
|
|
661
1069
|
try:
|
|
662
1070
|
selected_models = _parse_models(getattr(args, "models", ()))
|
|
663
1071
|
except argparse.ArgumentTypeError as e:
|
|
1072
|
+
logging.error(str(e))
|
|
664
1073
|
parser.error(str(e))
|
|
665
1074
|
return 2
|
|
666
1075
|
|
|
@@ -672,12 +1081,11 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
672
1081
|
setattr(args, "format", "vcf")
|
|
673
1082
|
|
|
674
1083
|
if input_path is None:
|
|
1084
|
+
logging.error("You must provide --input (or legacy --vcf).")
|
|
675
1085
|
parser.error("You must provide --input (or legacy --vcf).")
|
|
676
1086
|
return 2
|
|
677
1087
|
|
|
678
|
-
fmt
|
|
679
|
-
args, "format", "infer"
|
|
680
|
-
)
|
|
1088
|
+
fmt = getattr(args, "format", "infer")
|
|
681
1089
|
|
|
682
1090
|
if fmt == "infer":
|
|
683
1091
|
if input_path.endswith((".vcf", ".vcf.gz")):
|
|
@@ -686,28 +1094,59 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
686
1094
|
fmt_final = "phylip"
|
|
687
1095
|
elif input_path.endswith((".genepop", ".gen")):
|
|
688
1096
|
fmt_final = "genepop"
|
|
1097
|
+
elif input_path.endswith((".str", ".stru", ".structure")):
|
|
1098
|
+
fmt_final = "structure"
|
|
689
1099
|
else:
|
|
1100
|
+
logging.error(
|
|
1101
|
+
"Could not infer input format from file extension. Please provide --format."
|
|
1102
|
+
)
|
|
690
1103
|
parser.error(
|
|
691
1104
|
"Could not infer input format from file extension. Please provide --format."
|
|
692
1105
|
)
|
|
693
1106
|
return 2
|
|
694
1107
|
else:
|
|
695
|
-
fmt_final =
|
|
1108
|
+
fmt_final = cast(
|
|
1109
|
+
Literal[
|
|
1110
|
+
"vcf",
|
|
1111
|
+
"vcf.gz",
|
|
1112
|
+
"phy",
|
|
1113
|
+
"phylip",
|
|
1114
|
+
"genepop",
|
|
1115
|
+
"gen",
|
|
1116
|
+
"structure",
|
|
1117
|
+
"str",
|
|
1118
|
+
],
|
|
1119
|
+
fmt,
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
fmt_final = _normalize_input_format(fmt_final)
|
|
696
1123
|
|
|
697
1124
|
popmap_path = getattr(args, "popmap", None)
|
|
698
1125
|
include_pops = getattr(args, "include_pops", None)
|
|
699
|
-
verbose_flag = getattr(args, "verbose", False)
|
|
700
1126
|
force_popmap = bool(getattr(args, "force_popmap", False))
|
|
1127
|
+
structure_has_popids = bool(getattr(args, "structure_has_popids", False))
|
|
1128
|
+
structure_has_marker_names = bool(
|
|
1129
|
+
getattr(args, "structure_has_marker_names", False)
|
|
1130
|
+
)
|
|
1131
|
+
structure_allele_start_col = getattr(args, "structure_allele_start_col", None)
|
|
1132
|
+
structure_allele_encoding = getattr(args, "structure_allele_encoding", None)
|
|
701
1133
|
|
|
702
1134
|
# Canonical prefix for this run (used for outputs and MultiQC)
|
|
703
1135
|
prefix: str = getattr(args, "prefix", str(Path(input_path).stem))
|
|
1136
|
+
# Ensure downstream config building sees the resolved prefix even if
|
|
1137
|
+
# --prefix was not provided.
|
|
1138
|
+
setattr(args, "prefix", prefix)
|
|
704
1139
|
|
|
705
|
-
treefile =
|
|
706
|
-
|
|
707
|
-
|
|
1140
|
+
treefile, qmatrix, siterates = _resolve_tree_paths(args)
|
|
1141
|
+
setattr(args, "treefile", treefile)
|
|
1142
|
+
setattr(args, "qmatrix", qmatrix)
|
|
1143
|
+
setattr(args, "siterates", siterates)
|
|
708
1144
|
|
|
709
1145
|
if any(x is not None for x in (treefile, qmatrix, siterates)):
|
|
710
1146
|
if not all(x is not None for x in (treefile, qmatrix, siterates)):
|
|
1147
|
+
logging.error(
|
|
1148
|
+
"--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
|
|
1149
|
+
)
|
|
711
1150
|
parser.error(
|
|
712
1151
|
"--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
|
|
713
1152
|
)
|
|
@@ -716,15 +1155,31 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
716
1155
|
# Load genotype data
|
|
717
1156
|
gd, tp = build_genotype_data(
|
|
718
1157
|
input_path=input_path,
|
|
719
|
-
fmt=
|
|
1158
|
+
fmt=cast(
|
|
1159
|
+
Literal[
|
|
1160
|
+
"vcf",
|
|
1161
|
+
"vcf.gz",
|
|
1162
|
+
"phy",
|
|
1163
|
+
"phylip",
|
|
1164
|
+
"genepop",
|
|
1165
|
+
"gen",
|
|
1166
|
+
"structure",
|
|
1167
|
+
"str",
|
|
1168
|
+
],
|
|
1169
|
+
fmt_final,
|
|
1170
|
+
),
|
|
720
1171
|
popmap_path=popmap_path,
|
|
721
1172
|
treefile=treefile,
|
|
722
1173
|
qmatrix=qmatrix,
|
|
723
1174
|
siterates=siterates,
|
|
724
1175
|
force_popmap=force_popmap,
|
|
725
|
-
verbose=verbose_flag,
|
|
726
1176
|
include_pops=include_pops,
|
|
1177
|
+
debug=getattr(args, "debug", False),
|
|
727
1178
|
plot_format=getattr(args, "plot_format", "pdf"),
|
|
1179
|
+
structure_has_popids=structure_has_popids,
|
|
1180
|
+
structure_has_marker_names=structure_has_marker_names,
|
|
1181
|
+
structure_allele_start_col=structure_allele_start_col,
|
|
1182
|
+
structure_allele_encoding=structure_allele_encoding,
|
|
728
1183
|
)
|
|
729
1184
|
|
|
730
1185
|
if getattr(args, "dry_run", False):
|
|
@@ -736,47 +1191,33 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
736
1191
|
m: _build_effective_config_for_model(m, args) for m in selected_models
|
|
737
1192
|
}
|
|
738
1193
|
|
|
1194
|
+
needs_tree = any(
|
|
1195
|
+
_config_needs_tree(cfg) for cfg in cfgs_by_model.values() if cfg is not None
|
|
1196
|
+
)
|
|
1197
|
+
if needs_tree and not all(x is not None for x in (treefile, qmatrix, siterates)):
|
|
1198
|
+
logging.error(
|
|
1199
|
+
"Nonrandom simulated missingness requires --treefile, --qmatrix, and --siterates."
|
|
1200
|
+
)
|
|
1201
|
+
parser.error(
|
|
1202
|
+
"Nonrandom simulated missingness requires --treefile, --qmatrix, and --siterates."
|
|
1203
|
+
)
|
|
1204
|
+
return 2
|
|
1205
|
+
if needs_tree and tp is None:
|
|
1206
|
+
logging.error(
|
|
1207
|
+
"Tree parser was not initialized for nonrandom simulation. "
|
|
1208
|
+
"Please verify --treefile, --qmatrix, and --siterates."
|
|
1209
|
+
)
|
|
1210
|
+
parser.error(
|
|
1211
|
+
"Tree parser was not initialized for nonrandom simulation. "
|
|
1212
|
+
"Please verify --treefile, --qmatrix, and --siterates."
|
|
1213
|
+
)
|
|
1214
|
+
return 2
|
|
1215
|
+
|
|
739
1216
|
# Maybe print/dump configs and exit
|
|
740
1217
|
if _maybe_print_or_dump_configs(cfgs_by_model, args):
|
|
741
1218
|
return 0
|
|
742
1219
|
|
|
743
1220
|
# ------------------------- Model Builders ------------------------------ #
|
|
744
|
-
def build_impute_ubp():
|
|
745
|
-
cfg = cfgs_by_model.get("ImputeUBP")
|
|
746
|
-
if cfg is None:
|
|
747
|
-
cfg = (
|
|
748
|
-
UBPConfig.from_preset(args.preset)
|
|
749
|
-
if hasattr(args, "preset")
|
|
750
|
-
else UBPConfig()
|
|
751
|
-
)
|
|
752
|
-
return ImputeUBP(
|
|
753
|
-
genotype_data=gd,
|
|
754
|
-
tree_parser=tp,
|
|
755
|
-
config=cfg,
|
|
756
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
757
|
-
sim_strategy=cfg.sim.sim_strategy,
|
|
758
|
-
sim_prop=cfg.sim.sim_prop,
|
|
759
|
-
sim_kwargs=cfg.sim.sim_kwargs,
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
def build_impute_nlpca():
|
|
763
|
-
cfg = cfgs_by_model.get("ImputeNLPCA")
|
|
764
|
-
if cfg is None:
|
|
765
|
-
cfg = (
|
|
766
|
-
NLPCAConfig.from_preset(args.preset)
|
|
767
|
-
if hasattr(args, "preset")
|
|
768
|
-
else NLPCAConfig()
|
|
769
|
-
)
|
|
770
|
-
return ImputeNLPCA(
|
|
771
|
-
genotype_data=gd,
|
|
772
|
-
tree_parser=tp,
|
|
773
|
-
config=cfg,
|
|
774
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
775
|
-
sim_strategy=cfg.sim.sim_strategy,
|
|
776
|
-
sim_prop=cfg.sim.sim_prop,
|
|
777
|
-
sim_kwargs=cfg.sim.sim_kwargs,
|
|
778
|
-
)
|
|
779
|
-
|
|
780
1221
|
def build_impute_vae():
|
|
781
1222
|
cfg = cfgs_by_model.get("ImputeVAE")
|
|
782
1223
|
if cfg is None:
|
|
@@ -789,7 +1230,6 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
789
1230
|
genotype_data=gd,
|
|
790
1231
|
tree_parser=tp,
|
|
791
1232
|
config=cfg,
|
|
792
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
793
1233
|
sim_strategy=cfg.sim.sim_strategy,
|
|
794
1234
|
sim_prop=cfg.sim.sim_prop,
|
|
795
1235
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
@@ -807,7 +1247,6 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
807
1247
|
genotype_data=gd,
|
|
808
1248
|
tree_parser=tp,
|
|
809
1249
|
config=cfg,
|
|
810
|
-
simulate_missing=cfg.sim.simulate_missing,
|
|
811
1250
|
sim_strategy=cfg.sim.sim_strategy,
|
|
812
1251
|
sim_prop=cfg.sim.sim_prop,
|
|
813
1252
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
@@ -825,7 +1264,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
825
1264
|
gd,
|
|
826
1265
|
tree_parser=tp,
|
|
827
1266
|
config=cfg,
|
|
828
|
-
simulate_missing=
|
|
1267
|
+
simulate_missing=True,
|
|
829
1268
|
sim_strategy=cfg.sim.sim_strategy,
|
|
830
1269
|
sim_prop=cfg.sim.sim_prop,
|
|
831
1270
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
@@ -843,43 +1282,68 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
843
1282
|
gd,
|
|
844
1283
|
tree_parser=tp,
|
|
845
1284
|
config=cfg,
|
|
846
|
-
simulate_missing=
|
|
1285
|
+
simulate_missing=True,
|
|
847
1286
|
sim_strategy=cfg.sim.sim_strategy,
|
|
848
1287
|
sim_prop=cfg.sim.sim_prop,
|
|
849
1288
|
sim_kwargs=cfg.sim.sim_kwargs,
|
|
850
1289
|
)
|
|
851
1290
|
|
|
852
1291
|
model_builders = {
|
|
853
|
-
"ImputeUBP": build_impute_ubp,
|
|
854
1292
|
"ImputeVAE": build_impute_vae,
|
|
855
1293
|
"ImputeAutoencoder": build_impute_autoencoder,
|
|
856
|
-
"ImputeNLPCA": build_impute_nlpca,
|
|
857
1294
|
"ImputeMostFrequent": build_impute_mostfreq,
|
|
858
1295
|
"ImputeRefAllele": build_impute_refallele,
|
|
859
1296
|
}
|
|
860
1297
|
|
|
861
1298
|
logging.info(f"Selected models: {', '.join(selected_models)}")
|
|
862
1299
|
for name in selected_models:
|
|
1300
|
+
logging.info("")
|
|
1301
|
+
logging.info("=" * 60)
|
|
1302
|
+
logging.info("")
|
|
1303
|
+
logging.info(f"Processing model: {name} ...")
|
|
863
1304
|
X_imputed = run_model_safely(name, model_builders[name], warn_only=False)
|
|
864
1305
|
gd_imp = gd.copy()
|
|
865
1306
|
gd_imp.snp_data = X_imputed
|
|
866
1307
|
|
|
867
|
-
if name in {"
|
|
1308
|
+
if name in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
868
1309
|
family = "Unsupervised"
|
|
869
1310
|
elif name in {"ImputeMostFrequent", "ImputeRefAllele"}:
|
|
870
1311
|
family = "Deterministic"
|
|
871
1312
|
elif name in {"ImputeHistGradientBoosting", "ImputeRandomForest"}:
|
|
872
1313
|
family = "Supervised"
|
|
873
1314
|
else:
|
|
1315
|
+
logging.error(f"Unknown model family for {name}")
|
|
874
1316
|
raise ValueError(f"Unknown model family for {name}")
|
|
875
1317
|
|
|
876
1318
|
pth = Path(f"{prefix}_output/{family}/imputed/{name}")
|
|
877
1319
|
pth.mkdir(parents=True, exist_ok=True)
|
|
878
1320
|
|
|
879
1321
|
logging.info(f"Writing imputed VCF for {name} to {pth} ...")
|
|
880
|
-
gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
|
|
881
1322
|
|
|
882
|
-
|
|
1323
|
+
if fmt_final == "vcf":
|
|
1324
|
+
gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
|
|
1325
|
+
elif fmt_final == "phylip":
|
|
1326
|
+
gd_imp.write_phylip(pth / f"{name.lower()}_imputed.phy")
|
|
1327
|
+
elif fmt_final == "genepop":
|
|
1328
|
+
gd_imp.write_genepop(pth / f"{name.lower()}_imputed.gen")
|
|
1329
|
+
else:
|
|
1330
|
+
logging.warning(
|
|
1331
|
+
f"Output format {fmt_final} not supported for imputed data export."
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
logging.info("")
|
|
1335
|
+
logging.info(f"Successfully finished imputation for model: {name}!")
|
|
1336
|
+
logging.info("")
|
|
1337
|
+
logging.info("=" * 60)
|
|
1338
|
+
|
|
1339
|
+
logging.info(f"All requested models processed for input: {input_path}")
|
|
1340
|
+
|
|
1341
|
+
disable_mqc = bool(getattr(args, "disable_multiqc", False))
|
|
1342
|
+
|
|
1343
|
+
if disable_mqc:
|
|
1344
|
+
logging.info("MultiQC report generation disabled via --disable-multiqc.")
|
|
1345
|
+
logging.info("PG-SUI imputation run complete!")
|
|
1346
|
+
return 0
|
|
883
1347
|
|
|
884
1348
|
# -------------------------- MultiQC builder ---------------------------- #
|
|
885
1349
|
|
|
@@ -899,9 +1363,10 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|
|
899
1363
|
overwrite=overwrite,
|
|
900
1364
|
)
|
|
901
1365
|
logging.info("MultiQC report successfully built.")
|
|
902
|
-
except Exception as exc2:
|
|
1366
|
+
except Exception as exc2:
|
|
903
1367
|
logging.error(f"Failed to build MultiQC report: {exc2}", exc_info=True)
|
|
904
1368
|
|
|
1369
|
+
logging.info("PG-SUI imputation run complete!")
|
|
905
1370
|
return 0
|
|
906
1371
|
|
|
907
1372
|
|