pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
pgsui/cli.py CHANGED
@@ -15,19 +15,21 @@ Notes
15
15
 
16
16
  Examples
17
17
  --------
18
- python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix run1
19
- python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
20
- python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
21
- --models ImputeUBP ImputeVAE ImputeMostFrequent --seed deterministic --verbose
22
- python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
23
- --include-pops EA GU TT ON --device cpu
18
+ pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix run1
19
+ pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
20
+ pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
21
+ --models ImputeAutoencoder ImputeVAE ImputeMostFrequent --seed deterministic --verbose
22
+ pg-sui --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
23
+ --include-pops EA GU TT ON --device cpu --sim-prop 0.3 --sim-strategy nonrandom
24
24
  """
25
25
 
26
26
  from __future__ import annotations
27
27
 
28
28
  import argparse
29
29
  import ast
30
+ import json
30
31
  import logging
32
+ import os
31
33
  import sys
32
34
  import time
33
35
  from functools import wraps
@@ -46,20 +48,24 @@ from typing import (
46
48
  cast,
47
49
  )
48
50
 
49
- from snpio import GenePopReader, PhylipReader, SNPioMultiQC, VCFReader, TreeParser
51
+ from snpio import (
52
+ GenePopReader,
53
+ NRemover2,
54
+ PhylipReader,
55
+ SNPioMultiQC,
56
+ StructureReader,
57
+ TreeParser,
58
+ VCFReader,
59
+ )
50
60
 
51
61
  from pgsui import (
52
62
  AutoencoderConfig,
53
63
  ImputeAutoencoder,
54
64
  ImputeMostFrequent,
55
- ImputeNLPCA,
56
65
  ImputeRefAllele,
57
- ImputeUBP,
58
66
  ImputeVAE,
59
67
  MostFrequentConfig,
60
- NLPCAConfig,
61
68
  RefAlleleConfig,
62
- UBPConfig,
63
69
  VAEConfig,
64
70
  )
65
71
  from pgsui.data_processing.config import (
@@ -71,10 +77,8 @@ from pgsui.data_processing.config import (
71
77
 
72
78
  # Canonical model order used everywhere (default and subset ordering)
73
79
  MODEL_ORDER: Tuple[str, ...] = (
74
- "ImputeUBP",
75
80
  "ImputeVAE",
76
81
  "ImputeAutoencoder",
77
- "ImputeNLPCA",
78
82
  "ImputeMostFrequent",
79
83
  "ImputeRefAllele",
80
84
  )
@@ -93,6 +97,142 @@ R = TypeVar("R")
93
97
 
94
98
 
95
99
  # ----------------------------- CLI Utilities ----------------------------- #
100
+ def _print_version() -> None:
101
+ """Print PG-SUI version and exit."""
102
+ from pgsui import __version__ as version
103
+
104
+ logging.info(f"Using PG-SUI version: {version}")
105
+
106
+
107
+ def _model_family(model_name: str) -> str:
108
+ """Return output family folder name used by PG-SUI."""
109
+ if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
110
+ return "Unsupervised"
111
+ if model_name in {"ImputeMostFrequent", "ImputeRefAllele"}:
112
+ return "Deterministic"
113
+ return "Unknown"
114
+
115
+
116
+ def _flatten_dict(d: dict, parent: str = "") -> dict:
117
+ """Flatten a nested dict into dot keys."""
118
+ out: dict = {}
119
+ for k, v in (d or {}).items():
120
+ key = f"{parent}.{k}" if parent else str(k)
121
+ if isinstance(v, dict):
122
+ out.update(_flatten_dict(v, key))
123
+ else:
124
+ out[key] = v
125
+ return out
126
+
127
+
128
+ def _force_tuning_off(cfg: Any, model_name: str) -> Any:
129
+ """Force tuning disabled on a config object (best-effort, but strict for tune-capable models)."""
130
+ # Prefer direct attribute mutation (avoids apply_dot_overrides edge-cases)
131
+ try:
132
+ if hasattr(cfg, "tune") and hasattr(cfg.tune, "enabled"):
133
+ cfg.tune.enabled = False
134
+ return cfg
135
+ except Exception:
136
+ pass
137
+
138
+ # Fallback to dot override
139
+ try:
140
+ return apply_dot_overrides(cfg, {"tune.enabled": False})
141
+ except Exception as e:
142
+ # Only strict for models that actually support tuning
143
+ if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
144
+ raise RuntimeError(
145
+ f"Failed to force tuning off for {model_name}: {e}"
146
+ ) from e
147
+ return cfg
148
+
149
+
150
+ def _find_best_params_json(prefix: str, model_name: str) -> Path | None:
151
+ """Locate best parameter JSON (tuned or final) for a model.
152
+
153
+ Args:
154
+ prefix (str): Output prefix used during the run.
155
+ model_name (str): Model name to look for.
156
+
157
+ Returns:
158
+ Path | None: Path to best_parameters.json / best_tuned_parameters.json if found; else None.
159
+ """
160
+ families = ("Unsupervised", "Deterministic")
161
+ model_dir_candidates = (model_name, model_name.lower())
162
+
163
+ for fam in families:
164
+ for mdir in model_dir_candidates:
165
+ base = Path(f"{prefix}_output") / fam
166
+ candidates = (
167
+ base / "optimize" / mdir / "parameters" / "best_tuned_parameters.json",
168
+ base / "parameters" / mdir / "best_parameters.json",
169
+ )
170
+ for p in candidates:
171
+ if p.exists():
172
+ return p
173
+ return None
174
+
175
+
176
+ def _load_best_params(best_params_path: Path) -> dict:
177
+ """Load best parameters JSON."""
178
+ with best_params_path.open("r", encoding="utf-8") as f:
179
+ data = json.load(f)
180
+ if not isinstance(data, dict):
181
+ raise ValueError(
182
+ f"best_parameters.json must be a JSON object, got {type(data)}"
183
+ )
184
+ return data
185
+
186
+
187
+ def _apply_best_params_to_cfg(cfg: Any, best_params: dict, model_name: str) -> Any:
188
+ """Apply best params into cfg using dot-path keys or inferred dot-paths.
189
+
190
+ - If JSON is nested, flatten to dot keys.
191
+ - If key already contains '.', treat as dot-path and apply directly.
192
+ - If key has no '.', try common sections in order: model., train., sim., tune., io., plot.
193
+ - Unknown keys are ignored with a warning.
194
+ """
195
+ # Flatten if nested (but keep existing dot keys as-is too)
196
+ flat = {}
197
+ for k, v in best_params.items():
198
+ if isinstance(v, dict):
199
+ flat.update(_flatten_dict(v, str(k)))
200
+ else:
201
+ flat[str(k)] = v
202
+
203
+ candidate_prefixes = ("", "model.", "train.", "sim.", "tune.", "io.", "plot.")
204
+
205
+ # Apply one by one so we can try multiple candidate destinations for
206
+ # non-dot keys
207
+ for raw_k, v in flat.items():
208
+ if "." in raw_k:
209
+ try:
210
+ cfg = apply_dot_overrides(cfg, {raw_k: v})
211
+ continue
212
+ except Exception as e:
213
+ logging.warning(
214
+ f"Could not apply best param '{raw_k}' to {model_name} (dot key). Skipping. Error: {e}"
215
+ )
216
+ continue
217
+
218
+ applied = False
219
+ for pref in candidate_prefixes:
220
+ k = f"{pref}{raw_k}" if pref else raw_k
221
+ try:
222
+ cfg = apply_dot_overrides(cfg, {k: v})
223
+ applied = True
224
+ break
225
+ except Exception:
226
+ continue
227
+
228
+ if not applied:
229
+ logging.warning(
230
+ f"Best param '{raw_k}' not recognized for {model_name}; leaving config unchanged for that key."
231
+ )
232
+
233
+ return cfg
234
+
235
+
96
236
  def _configure_logging(verbose: bool, log_file: Optional[str] = None) -> None:
97
237
  """Configure root logger.
98
238
 
@@ -166,6 +306,96 @@ def _parse_overrides(pairs: list[str]) -> dict:
166
306
  return out
167
307
 
168
308
 
309
+ def _parse_allele_encoding(arg: str) -> dict:
310
+ """Parse STRUCTURE allele encoding dict from JSON or Python literal."""
311
+ try:
312
+ payload = json.loads(arg)
313
+ except Exception:
314
+ try:
315
+ payload = ast.literal_eval(arg)
316
+ except Exception as e:
317
+ raise argparse.ArgumentTypeError(
318
+ f"Invalid --structure-allele-encoding; must be a dict. Error: {e}"
319
+ ) from e
320
+
321
+ if not isinstance(payload, dict):
322
+ raise argparse.ArgumentTypeError(
323
+ "--structure-allele-encoding must be a dict-like mapping."
324
+ )
325
+
326
+ out: dict = {}
327
+ for k, v in payload.items():
328
+ key = k
329
+ if isinstance(k, str):
330
+ k_strip = k.strip()
331
+ if k_strip.lstrip("-").isdigit():
332
+ try:
333
+ key = int(k_strip)
334
+ except Exception:
335
+ key = k
336
+ out[key] = str(v)
337
+ return out
338
+
339
+
340
+ def _normalize_input_format(fmt: str) -> str:
341
+ """Normalize format aliases into canonical reader names."""
342
+ fmt = fmt.lower()
343
+ if fmt in {"vcf", "vcf.gz"}:
344
+ return "vcf"
345
+ if fmt in {"phy", "phylip"}:
346
+ return "phylip"
347
+ if fmt in {"gen", "genepop"}:
348
+ return "genepop"
349
+ if fmt in {"str", "structure"}:
350
+ return "structure"
351
+ return fmt
352
+
353
+
354
+ def _normalize_plot_format(fmt: str) -> Literal["pdf", "png", "jpg", "svg"]:
355
+ """Normalize plot format aliases to reader-supported values."""
356
+ fmt = fmt.lower()
357
+ if fmt == "jpeg":
358
+ return "jpg"
359
+ return cast(Literal["pdf", "png", "jpg", "svg"], fmt)
360
+
361
+
362
+ def _expand_path(path: str | None) -> str | None:
363
+ """Expand ~ and env vars in a path-like string."""
364
+ if path is None:
365
+ return None
366
+ raw = str(path).strip()
367
+ if not raw:
368
+ return None
369
+ expanded = os.path.expandvars(raw)
370
+ return str(Path(expanded).expanduser())
371
+
372
+
373
+ def _resolve_tree_paths(
374
+ args: argparse.Namespace,
375
+ ) -> tuple[str | None, str | None, str | None]:
376
+ """Resolve tree-related paths from CLI args."""
377
+ treefile = _expand_path(getattr(args, "treefile", None))
378
+ qmatrix = _expand_path(getattr(args, "qmatrix", None))
379
+ siterates = _expand_path(getattr(args, "siterates", None))
380
+ return treefile, qmatrix, siterates
381
+
382
+
383
+ def _config_needs_tree(cfg: Any | None) -> bool:
384
+ """Return True if config requires a tree parser for simulated missingness."""
385
+ if cfg is None:
386
+ return False
387
+ sim_cfg = getattr(cfg, "sim", None)
388
+ if sim_cfg is None:
389
+ return False
390
+ strategy = getattr(sim_cfg, "sim_strategy", None)
391
+ simulate = bool(getattr(sim_cfg, "simulate_missing", False))
392
+ return (
393
+ simulate
394
+ and isinstance(strategy, str)
395
+ and strategy in {"nonrandom", "nonrandom_weighted"}
396
+ )
397
+
398
+
169
399
  def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
170
400
  """Convert explicitly provided CLI flags into config dot-overrides."""
171
401
  overrides: dict = {}
@@ -174,10 +404,12 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
174
404
  if hasattr(args, "prefix") and args.prefix is not None:
175
405
  overrides["io.prefix"] = args.prefix
176
406
  else:
177
- # Note: we don't know input_path here; prefix default is handled later.
178
- # This fallback is preserved to avoid changing semantics.
179
- if hasattr(args, "vcf"):
180
- overrides["io.prefix"] = str(Path(args.vcf).stem)
407
+ # Prefer --input stem; fallback to legacy --vcf stem
408
+ input_path = getattr(args, "input", None)
409
+ if input_path is None and hasattr(args, "vcf"):
410
+ input_path = getattr(args, "vcf", None)
411
+ if input_path:
412
+ overrides["io.prefix"] = str(Path(input_path).stem)
181
413
 
182
414
  if hasattr(args, "verbose"):
183
415
  overrides["io.verbose"] = bool(args.verbose)
@@ -185,6 +417,8 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
185
417
  overrides["io.n_jobs"] = int(args.n_jobs)
186
418
  if hasattr(args, "seed"):
187
419
  overrides["io.seed"] = _parse_seed(args.seed)
420
+ if hasattr(args, "debug"):
421
+ overrides["io.debug"] = bool(args.debug)
188
422
 
189
423
  # Train
190
424
  if hasattr(args, "batch_size"):
@@ -198,20 +432,34 @@ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
198
432
  # Plot
199
433
  if hasattr(args, "plot_format"):
200
434
  overrides["plot.fmt"] = args.plot_format
435
+ if getattr(args, "disable_plotting", False):
436
+ logging.info(
437
+ "Disabling plotting for all models as per --disable-plotting flag."
438
+ )
439
+ overrides["plot.show"] = False
201
440
 
202
- # Simulation overrides (shared across config-driven models)
441
+ # Simulation overrides
203
442
  if hasattr(args, "sim_strategy"):
204
443
  overrides["sim.sim_strategy"] = args.sim_strategy
205
444
  if hasattr(args, "sim_prop"):
206
445
  overrides["sim.sim_prop"] = float(args.sim_prop)
207
- if hasattr(args, "simulate_missing"):
208
- overrides["sim.simulate_missing"] = bool(args.simulate_missing)
209
446
 
210
447
  # Tuning
211
- if hasattr(args, "tune"):
212
- overrides["tune.enabled"] = bool(args.tune)
213
- if hasattr(args, "tune_n_trials"):
214
- overrides["tune.n_trials"] = int(args.tune_n_trials)
448
+ if getattr(args, "load_best_params", False):
449
+ # Never allow CLI flags to re-enable tuning when loading params
450
+ if hasattr(args, "tune") and bool(getattr(args, "tune", False)):
451
+ logging.warning(
452
+ "--tune was supplied, but --load-best-params is active; ignoring --tune."
453
+ )
454
+ if hasattr(args, "tune_n_trials"):
455
+ logging.warning(
456
+ "--tune-n-trials was supplied, but --load-best-params is active; ignoring it."
457
+ )
458
+ else:
459
+ if hasattr(args, "tune"):
460
+ overrides["tune.enabled"] = bool(args.tune)
461
+ if hasattr(args, "tune_n_trials"):
462
+ overrides["tune.n_trials"] = int(args.tune_n_trials)
215
463
 
216
464
  return overrides
217
465
 
@@ -255,35 +503,77 @@ def log_model_time(fn: Callable[P, R]) -> Callable[P, R]:
255
503
  # ------------------------------ Core Runner ------------------------------ #
256
504
  def build_genotype_data(
257
505
  input_path: str,
258
- fmt: Literal["vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"],
506
+ fmt: Literal[
507
+ "vcf",
508
+ "vcf.gz",
509
+ "phy",
510
+ "phylip",
511
+ "genepop",
512
+ "gen",
513
+ "structure",
514
+ "str",
515
+ ],
259
516
  popmap_path: str | None,
260
517
  treefile: str | None,
261
518
  qmatrix: str | None,
262
519
  siterates: str | None,
263
520
  force_popmap: bool,
264
- verbose: bool,
521
+ debug: bool,
265
522
  include_pops: List[str] | None,
266
- plot_format: Literal["pdf", "png", "jpg", "jpeg"],
523
+ plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"],
524
+ structure_has_popids: bool = False,
525
+ structure_has_marker_names: bool = False,
526
+ structure_allele_start_col: int | None = None,
527
+ structure_allele_encoding: dict | None = None,
267
528
  ):
268
- """Load genotype data from heterogeneous inputs."""
269
- logging.info(f"Loading {fmt.upper()} and popmap data...")
529
+ """Load genotype data from heterogeneous inputs.
530
+
531
+ Args:
532
+ input_path (str): Path to genotype data file.
533
+ fmt (Literal): Format of genotype data file.
534
+ popmap_path (str | None): Optional path to population map file.
535
+ treefile (str | None): Optional path to phylogenetic tree file.
536
+ qmatrix (str | None): Optional path to IQ-TREE Q matrix file.
537
+ siterates (str | None): Optional path to SNP site rates file.
538
+ force_popmap (bool): Whether to force use of popmap even if samples don't match exactly.
539
+ debug (bool): Whether to enable debug-level logging in SNPio readers.
540
+ include_pops (List[str] | None): Optional list of population IDs to include.
541
+ plot_format (Literal): Figure format for SNPio plots.
542
+ structure_has_popids (bool): STRUCTURE only; whether pop IDs are present.
543
+ structure_has_marker_names (bool): STRUCTURE only; whether the first line has marker names.
544
+ structure_allele_start_col (int | None): STRUCTURE only; zero-based allele start column.
545
+ structure_allele_encoding (dict | None): STRUCTURE only; allele encoding map.
546
+ """
547
+ fmt_norm = _normalize_input_format(fmt)
548
+ plot_format = _normalize_plot_format(cast(str, plot_format))
549
+ logging.info(f"Loading {fmt_norm.upper()} and popmap data...")
270
550
 
271
551
  kwargs = {
272
552
  "filename": input_path,
273
553
  "popmapfile": popmap_path,
274
554
  "force_popmap": force_popmap,
275
- "verbose": verbose,
555
+ "verbose": debug,
276
556
  "include_pops": include_pops if include_pops else None,
277
557
  "prefix": f"snpio_{Path(input_path).stem}",
278
558
  "plot_format": plot_format,
279
559
  }
280
560
 
281
- if fmt == "vcf":
561
+ if fmt_norm == "vcf":
282
562
  gd = VCFReader(**kwargs)
283
- elif fmt == "phylip":
563
+ elif fmt_norm == "phylip":
284
564
  gd = PhylipReader(**kwargs)
285
- elif fmt == "genepop":
565
+ elif fmt_norm == "genepop":
286
566
  gd = GenePopReader(**kwargs)
567
+ elif fmt_norm == "structure":
568
+ kwargs.update(
569
+ {
570
+ "has_popids": structure_has_popids,
571
+ "has_marker_names": structure_has_marker_names,
572
+ "allele_start_col": structure_allele_start_col,
573
+ "allele_encoding": structure_allele_encoding,
574
+ }
575
+ )
576
+ gd = StructureReader(**kwargs)
287
577
  else:
288
578
  raise ValueError(f"Unsupported genotype data format: {fmt}")
289
579
 
@@ -319,8 +609,6 @@ def run_model_safely(model_name: str, builder, *, warn_only: bool = True) -> Non
319
609
  # -------------------------- Model Registry ------------------------------- #
320
610
  # Add config-driven models here by listing the class and its config dataclass.
321
611
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
322
- "ImputeUBP": {"cls": ImputeUBP, "config_cls": UBPConfig},
323
- "ImputeNLPCA": {"cls": ImputeNLPCA, "config_cls": NLPCAConfig},
324
612
  "ImputeAutoencoder": {"cls": ImputeAutoencoder, "config_cls": AutoencoderConfig},
325
613
  "ImputeVAE": {"cls": ImputeVAE, "config_cls": VAEConfig},
326
614
  "ImputeMostFrequent": {"cls": ImputeMostFrequent, "config_cls": MostFrequentConfig},
@@ -370,24 +658,65 @@ def _build_effective_config_for_model(
370
658
  f"Loaded YAML config for {model_name} from {yaml_path} (ignored 'preset' in YAML if present)."
371
659
  )
372
660
 
373
- # 3) Explicit CLI flags overlay YAML.
661
+ # 3) Optional: load best parameters from a previous run and force tuning OFF.
662
+ if getattr(args, "load_best_params", False):
663
+ # Determine which prefix to look under for *_output
664
+ src_prefix = getattr(args, "best_params_prefix", None)
665
+ if src_prefix is None:
666
+ # Use the resolved prefix if provided; otherwise fall back to input
667
+ # stem behavior
668
+ src_prefix = getattr(args, "prefix", None)
669
+
670
+ if src_prefix is None and hasattr(args, "vcf"):
671
+ src_prefix = str(Path(args.vcf).stem)
672
+
673
+ if src_prefix is None:
674
+ # As a last resort, use current effective io.prefix if it exists in cfg
675
+ src_prefix = getattr(getattr(cfg, "io", object()), "prefix", None)
676
+
677
+ if getattr(args, "tune", False):
678
+ logging.warning(
679
+ "--tune was supplied, but --load-best-params is active; forcing tuning OFF."
680
+ )
681
+
682
+ # Force tuning disabled in config (even if CLI/YAML enabled it)
683
+ cfg = _force_tuning_off(cfg, model_name)
684
+
685
+ best_path = _find_best_params_json(str(src_prefix), model_name)
686
+ if best_path is None:
687
+ # For tune-capable (unsupervised) models, treat as an error; deterministic models warn only.
688
+ fam = _model_family(model_name)
689
+ msg = (
690
+ "Requested --load-best-params, but could not find a best parameters JSON "
691
+ f"for {model_name}. Looked under '.../optimize/<model>/parameters/best_tuned_parameters.json' and '{src_prefix}_output/{fam}/parameters/{model_name}/best_parameters.json'"
692
+ )
693
+ if model_name in {"ImputeVAE", "ImputeAutoencoder"}:
694
+ logging.error(msg)
695
+ raise FileNotFoundError(msg)
696
+ logging.warning(msg)
697
+ else:
698
+ logging.info(f"Loading best parameters for {model_name} from: {best_path}")
699
+ best_params = _load_best_params(best_path)
700
+ cfg = _apply_best_params_to_cfg(cfg, best_params, model_name)
701
+ cfg = _force_tuning_off(cfg, model_name)
702
+
703
+ # 4) Explicit CLI flags overlay YAML/best-params layers.
374
704
  cli_overrides = _args_to_cli_overrides(args)
375
705
  if cli_overrides:
376
706
  cfg = apply_dot_overrides(cfg, cli_overrides)
377
707
 
378
- # 4) --set has highest precedence.
708
+ # Keep tuning disabled if --load-best-params was requested, even if CLI flags tried to re-enable it.
709
+ if getattr(args, "load_best_params", False):
710
+ cfg = _force_tuning_off(cfg, model_name)
711
+
712
+ # 5) --set has highest precedence.
379
713
  user_overrides = _parse_overrides(getattr(args, "set", []))
380
714
 
381
715
  if user_overrides:
382
716
  try:
383
717
  cfg = apply_dot_overrides(cfg, user_overrides)
384
718
  except Exception as e:
385
- if model_name in {
386
- "ImputeUBP",
387
- "ImputeNLPCA",
388
- "ImputeAutoencoder",
389
- "ImputeVAE",
390
- }:
719
+ if model_name in {"ImputeAutoencoder", "ImputeVAE"}:
391
720
  logging.error(
392
721
  f"Error applying --set overrides to {model_name} config: {e}"
393
722
  )
@@ -395,6 +724,18 @@ def _build_effective_config_for_model(
395
724
  else:
396
725
  pass # non-config-driven models ignore --set
397
726
 
727
+ # FINAL GUARANTEE:
728
+ # --load-best-params always wins over
729
+ # --set, YAML, preset, and CLI flags.
730
+ if getattr(args, "load_best_params", False):
731
+ # If user explicitly tried to set tune.* via --set, warn and override.
732
+ if any(str(k).startswith("tune.") for k in (user_overrides or {}).keys()):
733
+ logging.warning(
734
+ f"{model_name}: '--set tune.*=...' was provided, but --load-best-params forces tuning OFF. "
735
+ "Ignoring any tune.* overrides."
736
+ )
737
+ cfg = _force_tuning_off(cfg, model_name)
738
+
398
739
  return cfg
399
740
 
400
741
 
@@ -435,9 +776,19 @@ def _maybe_print_or_dump_configs(
435
776
 
436
777
 
437
778
  def main(argv: Optional[List[str]] = None) -> int:
779
+ """PG-SUI CLI main entry point.
780
+
781
+ The CLI supports running multiple imputation models on a single input file, with configuration handled via presets, YAML files, and CLI flags.
782
+
783
+ Args:
784
+ argv (Optional[List[str]]): List of CLI args (default: sys.argv[1:]).
785
+
786
+ Returns:
787
+ int: Exit code (0=success, 2=argparse error, 1=other error).
788
+ """
438
789
  parser = argparse.ArgumentParser(
439
790
  prog="pg-sui",
440
- description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models.",
791
+ description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models. The input file can be in VCF, PHYLIP, or GENEPOP format. Outputs include imputed genotype files and performance summaries.",
441
792
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
442
793
  usage="%(prog)s [options]",
443
794
  )
@@ -446,7 +797,7 @@ def main(argv: Optional[List[str]] = None) -> int:
446
797
  parser.add_argument(
447
798
  "--input",
448
799
  default=argparse.SUPPRESS,
449
- help="Path to input file (VCF/PHYLIP/STRUCTURE/GENEPOP).",
800
+ help="Path to input file (VCF/PHYLIP/GENEPOP). VCF file can be bgzipped or uncompressed.",
450
801
  )
451
802
  parser.add_argument(
452
803
  "--format",
@@ -462,23 +813,23 @@ def main(argv: Optional[List[str]] = None) -> int:
462
813
  "gen",
463
814
  ),
464
815
  default=argparse.SUPPRESS,
465
- help="Input format. If 'infer', deduced from file extension. The default is 'infer'.",
816
+ help="Input format. If 'infer', deduced from file extension. The default is 'infer'. Supported formats: VCF ('.vcf', '.vcf.gz'), PHYLIP ('.phy', '.phylip'), GENEPOP ('.genepop', '.gen').",
466
817
  )
467
818
  # Back-compat: --vcf retained; if both provided, --input wins.
468
819
  parser.add_argument(
469
820
  "--vcf",
470
821
  default=argparse.SUPPRESS,
471
- help="Path to input VCF file. Can be bgzipped or uncompressed.",
822
+ help="Path to input VCF file. Can be bgzipped or uncompressed. (Deprecated; use --input instead.)",
472
823
  )
473
824
  parser.add_argument(
474
825
  "--popmap",
475
826
  default=argparse.SUPPRESS,
476
- help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs.",
827
+ help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs. If not provided, no population info is used.",
477
828
  )
478
829
  parser.add_argument(
479
830
  "--treefile",
480
831
  default=argparse.SUPPRESS,
481
- help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format.",
832
+ help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format. Used with --qmatrix and --siterates.",
482
833
  )
483
834
  parser.add_argument(
484
835
  "--qmatrix",
@@ -488,19 +839,19 @@ def main(argv: Optional[List[str]] = None) -> int:
488
839
  parser.add_argument(
489
840
  "--siterates",
490
841
  default=argparse.SUPPRESS,
491
- help="Path to SNP site rates file (has .rate extension). Used with --treefile and --qmatrix.",
842
+ help="Path to SNP site rates file (has .rate extension and can be produced with IQ-TREE). Used with --treefile and --qmatrix.",
492
843
  )
493
844
  parser.add_argument(
494
845
  "--prefix",
495
846
  default=argparse.SUPPRESS,
496
- help="Output file prefix.",
847
+ help="Output file prefix. If not provided, defaults to the input file stem.",
497
848
  )
498
849
 
499
850
  # ---------------------- Generic Config Inputs -------------------------- #
500
851
  parser.add_argument(
501
852
  "--config",
502
853
  default=argparse.SUPPRESS,
503
- help="YAML config for config-driven models (NLPCA/UBP/Autoencoder/VAE).",
854
+ help="YAML config for config-driven models (Autoencoder, VAE). Overrides preset and defaults.",
504
855
  )
505
856
  parser.add_argument(
506
857
  "--preset",
@@ -512,7 +863,7 @@ def main(argv: Optional[List[str]] = None) -> int:
512
863
  "--set",
513
864
  action="append",
514
865
  default=argparse.SUPPRESS,
515
- help="Dot-key overrides, e.g. --set model.latent_dim=4",
866
+ help="Dot-key overrides, e.g. --set model.latent_dim=4 --set train.epochs=100. Applies to all models.",
516
867
  )
517
868
  parser.add_argument(
518
869
  "--print-config",
@@ -530,7 +881,7 @@ def main(argv: Optional[List[str]] = None) -> int:
530
881
  "--tune",
531
882
  action="store_true",
532
883
  default=argparse.SUPPRESS,
533
- help="Enable hyperparameter tuning (if supported).",
884
+ help="Enable hyperparameter tuning (if supported by model). Uses Optuna to optimize hyperparameters.",
534
885
  )
535
886
  parser.add_argument(
536
887
  "--tune-n-trials",
@@ -560,8 +911,31 @@ def main(argv: Optional[List[str]] = None) -> int:
560
911
  "--plot-format",
561
912
  choices=("png", "pdf", "svg", "jpg", "jpeg"),
562
913
  default=argparse.SUPPRESS,
563
- help="Figure format for model plots.",
914
+ help="Figure format for model plots. Choices: png, pdf, svg, jpg, jpeg.",
915
+ )
916
+ parser.add_argument(
917
+ "--disable-plotting",
918
+ action="store_true",
919
+ default=False,
920
+ help="Disable plotting for all models. Overrides any config settings enabling plotting.",
564
921
  )
922
+
923
+ parser.add_argument(
924
+ "--load-best-params",
925
+ action="store_true",
926
+ default=False,
927
+ help=(
928
+ "Load best hyperparameters from a previous run's best_parameters.json (or tuning best_tuned_parameters.json) for each selected model and apply them to the model configs. This forces tuning OFF."
929
+ ),
930
+ )
931
+ parser.add_argument(
932
+ "--best-params-prefix",
933
+ default=argparse.SUPPRESS,
934
+ help=(
935
+ "Prefix of the PREVIOUS run to load best parameters from. If omitted, uses the current --prefix (or input stem)."
936
+ ),
937
+ )
938
+
565
939
  # ------------------------- Simulation Controls ------------------------ #
566
940
  parser.add_argument(
567
941
  "--sim-strategy",
@@ -575,20 +949,17 @@ def main(argv: Optional[List[str]] = None) -> int:
575
949
  default=argparse.SUPPRESS,
576
950
  help="Override the proportion of observed entries to mask during simulation (0-1).",
577
951
  )
578
- parser.add_argument(
579
- "--simulate-missing",
580
- action="store_false",
581
- default=argparse.SUPPRESS,
582
- help="Disable missing-data simulation regardless of preset/config (when provided).",
583
- )
584
952
 
585
953
  # --------------------------- Seed & logging ---------------------------- #
586
954
  parser.add_argument(
587
955
  "--seed",
588
956
  default=argparse.SUPPRESS,
589
- help="Random seed: 'random', 'deterministic', or an integer.",
957
+ help="Random seed: 'random', 'deterministic', or an integer. Default is 'random'.",
590
958
  )
959
+
960
+ # ----------------------------- Logging --------------------------------- #
591
961
  parser.add_argument("--verbose", action="store_true", help="Info-level logging.")
962
+ parser.add_argument("--debug", action="store_true", help="Debug-level logging.")
592
963
  parser.add_argument(
593
964
  "--log-file", default=argparse.SUPPRESS, help="Also write logs to a file."
594
965
  )
@@ -604,7 +975,33 @@ def main(argv: Optional[List[str]] = None) -> int:
604
975
  "--force-popmap",
605
976
  action="store_true",
606
977
  default=False,
607
- help="Require popmap (error if absent).",
978
+ help="Force use of provided popmap even if samples don't match exactly. This will drop samples not in the popmap and vice versa.",
979
+ )
980
+
981
+ # -------------------------- STRUCTURE options ------------------------- #
982
+ parser.add_argument(
983
+ "--structure-has-popids",
984
+ action="store_true",
985
+ default=False,
986
+ help="STRUCTURE only: second column contains population IDs.",
987
+ )
988
+ parser.add_argument(
989
+ "--structure-has-marker-names",
990
+ action="store_true",
991
+ default=False,
992
+ help="STRUCTURE only: first row contains marker names.",
993
+ )
994
+ parser.add_argument(
995
+ "--structure-allele-start-col",
996
+ type=int,
997
+ default=argparse.SUPPRESS,
998
+ help="STRUCTURE only: zero-based column index where alleles begin.",
999
+ )
1000
+ parser.add_argument(
1001
+ "--structure-allele-encoding",
1002
+ type=_parse_allele_encoding,
1003
+ default=argparse.SUPPRESS,
1004
+ help="STRUCTURE only: allele encoding mapping as JSON or Python dict.",
608
1005
  )
609
1006
 
610
1007
  # ---------------------------- Model selection -------------------------- #
@@ -613,54 +1010,66 @@ def main(argv: Optional[List[str]] = None) -> int:
613
1010
  nargs="+",
614
1011
  default=argparse.SUPPRESS,
615
1012
  help=(
616
- "Which models to run. Choices: ImputeUBP ImputeVAE ImputeAutoencoder ImputeNLPCA ImputeMostFrequent ImputeRefAllele. Default is all."
1013
+ "Which models to run. Specify each model separated by a space. Choices: ImputeVAE ImputeAutoencoder ImputeMostFrequent ImputeRefAllele (Default is all models)."
617
1014
  ),
618
1015
  )
619
1016
 
620
1017
  # -------------------------- MultiQC integration ------------------------ #
621
1018
  parser.add_argument(
622
- "--multiqc",
1019
+ "--disable-multiqc",
623
1020
  action="store_true",
1021
+ default=False,
624
1022
  help=(
625
- "Build a MultiQC HTML report at the end of the run, combining SNPio and PG-SUI plots (requires SNPio's MultiQC module)."
1023
+ "Disable MultiQC report generation after imputation. By default, a MultiQC report is generated unless this flag is set."
626
1024
  ),
627
1025
  )
628
1026
  parser.add_argument(
629
1027
  "--multiqc-title",
630
1028
  default=argparse.SUPPRESS,
631
- help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>').",
1029
+ help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>'). ",
632
1030
  )
633
1031
  parser.add_argument(
634
1032
  "--multiqc-output-dir",
635
1033
  default=argparse.SUPPRESS,
636
- help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc').",
1034
+ help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc'). This directory will be created if it does not exist.",
637
1035
  )
638
1036
  parser.add_argument(
639
1037
  "--multiqc-overwrite",
640
1038
  action="store_true",
641
1039
  default=False,
642
- help="Overwrite an existing MultiQC report if present.",
1040
+ help="Overwrite an existing MultiQC report if present. If not set and a report exists, an integer suffix will be added to avoid overwriting. NOTE: if running multiple times with this flag, it may append multiple suffixes to avoid overwriting previous reports.",
643
1041
  )
644
1042
 
645
1043
  # ------------------------------ Safety/UX ------------------------------ #
646
1044
  parser.add_argument(
647
1045
  "--dry-run",
648
1046
  action="store_true",
649
- help="Parse args and load data, but skip model training.",
1047
+ help="Parse args and load data, but skip model training. Useful for testing I/O and configs.",
1048
+ )
1049
+ parser.add_argument(
1050
+ "--version", action="store_true", help="Print PG-SUI version and exit."
650
1051
  )
651
1052
 
652
1053
  args = parser.parse_args(argv)
653
1054
 
1055
+ if getattr(args, "version", False):
1056
+ _print_version()
1057
+ return 0
1058
+
654
1059
  # Logging (verbose default is False unless passed)
655
1060
  _configure_logging(
656
1061
  verbose=getattr(args, "verbose", False),
657
1062
  log_file=getattr(args, "log_file", None),
658
1063
  )
659
1064
 
1065
+ logging.info("Starting PG-SUI imputation...")
1066
+ _print_version()
1067
+
660
1068
  # Models selection (default to all if not explicitly provided)
661
1069
  try:
662
1070
  selected_models = _parse_models(getattr(args, "models", ()))
663
1071
  except argparse.ArgumentTypeError as e:
1072
+ logging.error(str(e))
664
1073
  parser.error(str(e))
665
1074
  return 2
666
1075
 
@@ -672,12 +1081,11 @@ def main(argv: Optional[List[str]] = None) -> int:
672
1081
  setattr(args, "format", "vcf")
673
1082
 
674
1083
  if input_path is None:
1084
+ logging.error("You must provide --input (or legacy --vcf).")
675
1085
  parser.error("You must provide --input (or legacy --vcf).")
676
1086
  return 2
677
1087
 
678
- fmt: Literal["infer", "vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"] = getattr(
679
- args, "format", "infer"
680
- )
1088
+ fmt = getattr(args, "format", "infer")
681
1089
 
682
1090
  if fmt == "infer":
683
1091
  if input_path.endswith((".vcf", ".vcf.gz")):
@@ -686,28 +1094,59 @@ def main(argv: Optional[List[str]] = None) -> int:
686
1094
  fmt_final = "phylip"
687
1095
  elif input_path.endswith((".genepop", ".gen")):
688
1096
  fmt_final = "genepop"
1097
+ elif input_path.endswith((".str", ".stru", ".structure")):
1098
+ fmt_final = "structure"
689
1099
  else:
1100
+ logging.error(
1101
+ "Could not infer input format from file extension. Please provide --format."
1102
+ )
690
1103
  parser.error(
691
1104
  "Could not infer input format from file extension. Please provide --format."
692
1105
  )
693
1106
  return 2
694
1107
  else:
695
- fmt_final = fmt
1108
+ fmt_final = cast(
1109
+ Literal[
1110
+ "vcf",
1111
+ "vcf.gz",
1112
+ "phy",
1113
+ "phylip",
1114
+ "genepop",
1115
+ "gen",
1116
+ "structure",
1117
+ "str",
1118
+ ],
1119
+ fmt,
1120
+ )
1121
+
1122
+ fmt_final = _normalize_input_format(fmt_final)
696
1123
 
697
1124
  popmap_path = getattr(args, "popmap", None)
698
1125
  include_pops = getattr(args, "include_pops", None)
699
- verbose_flag = getattr(args, "verbose", False)
700
1126
  force_popmap = bool(getattr(args, "force_popmap", False))
1127
+ structure_has_popids = bool(getattr(args, "structure_has_popids", False))
1128
+ structure_has_marker_names = bool(
1129
+ getattr(args, "structure_has_marker_names", False)
1130
+ )
1131
+ structure_allele_start_col = getattr(args, "structure_allele_start_col", None)
1132
+ structure_allele_encoding = getattr(args, "structure_allele_encoding", None)
701
1133
 
702
1134
  # Canonical prefix for this run (used for outputs and MultiQC)
703
1135
  prefix: str = getattr(args, "prefix", str(Path(input_path).stem))
1136
+ # Ensure downstream config building sees the resolved prefix even if
1137
+ # --prefix was not provided.
1138
+ setattr(args, "prefix", prefix)
704
1139
 
705
- treefile = getattr(args, "treefile", None)
706
- qmatrix = getattr(args, "qmatrix", None)
707
- siterates = getattr(args, "siterates", None)
1140
+ treefile, qmatrix, siterates = _resolve_tree_paths(args)
1141
+ setattr(args, "treefile", treefile)
1142
+ setattr(args, "qmatrix", qmatrix)
1143
+ setattr(args, "siterates", siterates)
708
1144
 
709
1145
  if any(x is not None for x in (treefile, qmatrix, siterates)):
710
1146
  if not all(x is not None for x in (treefile, qmatrix, siterates)):
1147
+ logging.error(
1148
+ "--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
1149
+ )
711
1150
  parser.error(
712
1151
  "--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
713
1152
  )
@@ -716,15 +1155,31 @@ def main(argv: Optional[List[str]] = None) -> int:
716
1155
  # Load genotype data
717
1156
  gd, tp = build_genotype_data(
718
1157
  input_path=input_path,
719
- fmt=fmt_final,
1158
+ fmt=cast(
1159
+ Literal[
1160
+ "vcf",
1161
+ "vcf.gz",
1162
+ "phy",
1163
+ "phylip",
1164
+ "genepop",
1165
+ "gen",
1166
+ "structure",
1167
+ "str",
1168
+ ],
1169
+ fmt_final,
1170
+ ),
720
1171
  popmap_path=popmap_path,
721
1172
  treefile=treefile,
722
1173
  qmatrix=qmatrix,
723
1174
  siterates=siterates,
724
1175
  force_popmap=force_popmap,
725
- verbose=verbose_flag,
726
1176
  include_pops=include_pops,
1177
+ debug=getattr(args, "debug", False),
727
1178
  plot_format=getattr(args, "plot_format", "pdf"),
1179
+ structure_has_popids=structure_has_popids,
1180
+ structure_has_marker_names=structure_has_marker_names,
1181
+ structure_allele_start_col=structure_allele_start_col,
1182
+ structure_allele_encoding=structure_allele_encoding,
728
1183
  )
729
1184
 
730
1185
  if getattr(args, "dry_run", False):
@@ -736,47 +1191,33 @@ def main(argv: Optional[List[str]] = None) -> int:
736
1191
  m: _build_effective_config_for_model(m, args) for m in selected_models
737
1192
  }
738
1193
 
1194
+ needs_tree = any(
1195
+ _config_needs_tree(cfg) for cfg in cfgs_by_model.values() if cfg is not None
1196
+ )
1197
+ if needs_tree and not all(x is not None for x in (treefile, qmatrix, siterates)):
1198
+ logging.error(
1199
+ "Nonrandom simulated missingness requires --treefile, --qmatrix, and --siterates."
1200
+ )
1201
+ parser.error(
1202
+ "Nonrandom simulated missingness requires --treefile, --qmatrix, and --siterates."
1203
+ )
1204
+ return 2
1205
+ if needs_tree and tp is None:
1206
+ logging.error(
1207
+ "Tree parser was not initialized for nonrandom simulation. "
1208
+ "Please verify --treefile, --qmatrix, and --siterates."
1209
+ )
1210
+ parser.error(
1211
+ "Tree parser was not initialized for nonrandom simulation. "
1212
+ "Please verify --treefile, --qmatrix, and --siterates."
1213
+ )
1214
+ return 2
1215
+
739
1216
  # Maybe print/dump configs and exit
740
1217
  if _maybe_print_or_dump_configs(cfgs_by_model, args):
741
1218
  return 0
742
1219
 
743
1220
  # ------------------------- Model Builders ------------------------------ #
744
- def build_impute_ubp():
745
- cfg = cfgs_by_model.get("ImputeUBP")
746
- if cfg is None:
747
- cfg = (
748
- UBPConfig.from_preset(args.preset)
749
- if hasattr(args, "preset")
750
- else UBPConfig()
751
- )
752
- return ImputeUBP(
753
- genotype_data=gd,
754
- tree_parser=tp,
755
- config=cfg,
756
- simulate_missing=cfg.sim.simulate_missing,
757
- sim_strategy=cfg.sim.sim_strategy,
758
- sim_prop=cfg.sim.sim_prop,
759
- sim_kwargs=cfg.sim.sim_kwargs,
760
- )
761
-
762
- def build_impute_nlpca():
763
- cfg = cfgs_by_model.get("ImputeNLPCA")
764
- if cfg is None:
765
- cfg = (
766
- NLPCAConfig.from_preset(args.preset)
767
- if hasattr(args, "preset")
768
- else NLPCAConfig()
769
- )
770
- return ImputeNLPCA(
771
- genotype_data=gd,
772
- tree_parser=tp,
773
- config=cfg,
774
- simulate_missing=cfg.sim.simulate_missing,
775
- sim_strategy=cfg.sim.sim_strategy,
776
- sim_prop=cfg.sim.sim_prop,
777
- sim_kwargs=cfg.sim.sim_kwargs,
778
- )
779
-
780
1221
  def build_impute_vae():
781
1222
  cfg = cfgs_by_model.get("ImputeVAE")
782
1223
  if cfg is None:
@@ -789,7 +1230,6 @@ def main(argv: Optional[List[str]] = None) -> int:
789
1230
  genotype_data=gd,
790
1231
  tree_parser=tp,
791
1232
  config=cfg,
792
- simulate_missing=cfg.sim.simulate_missing,
793
1233
  sim_strategy=cfg.sim.sim_strategy,
794
1234
  sim_prop=cfg.sim.sim_prop,
795
1235
  sim_kwargs=cfg.sim.sim_kwargs,
@@ -807,7 +1247,6 @@ def main(argv: Optional[List[str]] = None) -> int:
807
1247
  genotype_data=gd,
808
1248
  tree_parser=tp,
809
1249
  config=cfg,
810
- simulate_missing=cfg.sim.simulate_missing,
811
1250
  sim_strategy=cfg.sim.sim_strategy,
812
1251
  sim_prop=cfg.sim.sim_prop,
813
1252
  sim_kwargs=cfg.sim.sim_kwargs,
@@ -825,7 +1264,7 @@ def main(argv: Optional[List[str]] = None) -> int:
825
1264
  gd,
826
1265
  tree_parser=tp,
827
1266
  config=cfg,
828
- simulate_missing=cfg.sim.simulate_missing,
1267
+ simulate_missing=True,
829
1268
  sim_strategy=cfg.sim.sim_strategy,
830
1269
  sim_prop=cfg.sim.sim_prop,
831
1270
  sim_kwargs=cfg.sim.sim_kwargs,
@@ -843,43 +1282,68 @@ def main(argv: Optional[List[str]] = None) -> int:
843
1282
  gd,
844
1283
  tree_parser=tp,
845
1284
  config=cfg,
846
- simulate_missing=cfg.sim.simulate_missing,
1285
+ simulate_missing=True,
847
1286
  sim_strategy=cfg.sim.sim_strategy,
848
1287
  sim_prop=cfg.sim.sim_prop,
849
1288
  sim_kwargs=cfg.sim.sim_kwargs,
850
1289
  )
851
1290
 
852
1291
  model_builders = {
853
- "ImputeUBP": build_impute_ubp,
854
1292
  "ImputeVAE": build_impute_vae,
855
1293
  "ImputeAutoencoder": build_impute_autoencoder,
856
- "ImputeNLPCA": build_impute_nlpca,
857
1294
  "ImputeMostFrequent": build_impute_mostfreq,
858
1295
  "ImputeRefAllele": build_impute_refallele,
859
1296
  }
860
1297
 
861
1298
  logging.info(f"Selected models: {', '.join(selected_models)}")
862
1299
  for name in selected_models:
1300
+ logging.info("")
1301
+ logging.info("=" * 60)
1302
+ logging.info("")
1303
+ logging.info(f"Processing model: {name} ...")
863
1304
  X_imputed = run_model_safely(name, model_builders[name], warn_only=False)
864
1305
  gd_imp = gd.copy()
865
1306
  gd_imp.snp_data = X_imputed
866
1307
 
867
- if name in {"ImputeUBP", "ImputeVAE", "ImputeAutoencoder", "ImputeNLPCA"}:
1308
+ if name in {"ImputeVAE", "ImputeAutoencoder"}:
868
1309
  family = "Unsupervised"
869
1310
  elif name in {"ImputeMostFrequent", "ImputeRefAllele"}:
870
1311
  family = "Deterministic"
871
1312
  elif name in {"ImputeHistGradientBoosting", "ImputeRandomForest"}:
872
1313
  family = "Supervised"
873
1314
  else:
1315
+ logging.error(f"Unknown model family for {name}")
874
1316
  raise ValueError(f"Unknown model family for {name}")
875
1317
 
876
1318
  pth = Path(f"{prefix}_output/{family}/imputed/{name}")
877
1319
  pth.mkdir(parents=True, exist_ok=True)
878
1320
 
879
1321
  logging.info(f"Writing imputed VCF for {name} to {pth} ...")
880
- gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
881
1322
 
882
- logging.info("All requested models processed.")
1323
+ if fmt_final == "vcf":
1324
+ gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
1325
+ elif fmt_final == "phylip":
1326
+ gd_imp.write_phylip(pth / f"{name.lower()}_imputed.phy")
1327
+ elif fmt_final == "genepop":
1328
+ gd_imp.write_genepop(pth / f"{name.lower()}_imputed.gen")
1329
+ else:
1330
+ logging.warning(
1331
+ f"Output format {fmt_final} not supported for imputed data export."
1332
+ )
1333
+
1334
+ logging.info("")
1335
+ logging.info(f"Successfully finished imputation for model: {name}!")
1336
+ logging.info("")
1337
+ logging.info("=" * 60)
1338
+
1339
+ logging.info(f"All requested models processed for input: {input_path}")
1340
+
1341
+ disable_mqc = bool(getattr(args, "disable_multiqc", False))
1342
+
1343
+ if disable_mqc:
1344
+ logging.info("MultiQC report generation disabled via --disable-multiqc.")
1345
+ logging.info("PG-SUI imputation run complete!")
1346
+ return 0
883
1347
 
884
1348
  # -------------------------- MultiQC builder ---------------------------- #
885
1349
 
@@ -899,9 +1363,10 @@ def main(argv: Optional[List[str]] = None) -> int:
899
1363
  overwrite=overwrite,
900
1364
  )
901
1365
  logging.info("MultiQC report successfully built.")
902
- except Exception as exc2: # pragma: no cover
1366
+ except Exception as exc2:
903
1367
  logging.error(f"Failed to build MultiQC report: {exc2}", exc_info=True)
904
1368
 
1369
+ logging.info("PG-SUI imputation run complete!")
905
1370
  return 0
906
1371
 
907
1372