pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/RECORD +0 -75
  83. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
pgsui/cli.py ADDED
@@ -0,0 +1,909 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """PG-SUI Imputation CLI
5
+
6
+ Argument-precedence model:
7
+ code defaults < preset (--preset) < YAML (--config) < explicit CLI flags < --set k=vx
8
+
9
+ Notes
10
+ -----
11
+ - Preset is a CLI-only choice and will be respected unless overridden by YAML or CLI.
12
+ - YAML entries override preset (a 'preset' key in YAML is ignored with a warning).
13
+ - CLI flags only override when explicitly provided (argparse uses SUPPRESS).
14
+ - --set key=value has the highest precedence and applies dot-path overrides.
15
+
16
+ Examples
17
+ --------
18
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix run1
19
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
20
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
21
+ --models ImputeUBP ImputeVAE ImputeMostFrequent --seed deterministic --verbose
22
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
23
+ --include-pops EA GU TT ON --device cpu
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import ast
30
+ import logging
31
+ import sys
32
+ import time
33
+ from functools import wraps
34
+ from pathlib import Path
35
+ from typing import (
36
+ Any,
37
+ Callable,
38
+ Dict,
39
+ Iterable,
40
+ List,
41
+ Literal,
42
+ Optional,
43
+ ParamSpec,
44
+ Tuple,
45
+ TypeVar,
46
+ cast,
47
+ )
48
+
49
+ from snpio import GenePopReader, PhylipReader, SNPioMultiQC, VCFReader, TreeParser
50
+
51
+ from pgsui import (
52
+ AutoencoderConfig,
53
+ ImputeAutoencoder,
54
+ ImputeMostFrequent,
55
+ ImputeNLPCA,
56
+ ImputeRefAllele,
57
+ ImputeUBP,
58
+ ImputeVAE,
59
+ MostFrequentConfig,
60
+ NLPCAConfig,
61
+ RefAlleleConfig,
62
+ UBPConfig,
63
+ VAEConfig,
64
+ )
65
+ from pgsui.data_processing.config import (
66
+ apply_dot_overrides,
67
+ dataclass_to_yaml,
68
+ load_yaml_to_dataclass,
69
+ save_dataclass_yaml,
70
+ )
71
+
72
+ # Canonical model order used everywhere (default and subset ordering)
73
+ MODEL_ORDER: Tuple[str, ...] = (
74
+ "ImputeUBP",
75
+ "ImputeVAE",
76
+ "ImputeAutoencoder",
77
+ "ImputeNLPCA",
78
+ "ImputeMostFrequent",
79
+ "ImputeRefAllele",
80
+ )
81
+
82
+ # Strategies supported by SimMissingTransformer + SimConfig.
83
+ SIM_STRATEGY_CHOICES: Tuple[str, ...] = (
84
+ "random",
85
+ "random_weighted",
86
+ "random_weighted_inv",
87
+ "nonrandom",
88
+ "nonrandom_weighted",
89
+ )
90
+
91
+ P = ParamSpec("P")
92
+ R = TypeVar("R")
93
+
94
+
95
+ # ----------------------------- CLI Utilities ----------------------------- #
96
+ def _configure_logging(verbose: bool, log_file: Optional[str] = None) -> None:
97
+ """Configure root logger.
98
+
99
+ Args:
100
+ verbose (bool): If True, INFO; else ERROR.
101
+ log_file (Optional[str]): Optional file to tee logs to.
102
+ """
103
+ level = logging.INFO if verbose else logging.ERROR
104
+ handlers: List[logging.Handler] = [logging.StreamHandler(sys.stdout)]
105
+ if log_file:
106
+ handlers.append(logging.FileHandler(log_file, mode="w", encoding="utf-8"))
107
+ logging.basicConfig(
108
+ level=level,
109
+ format="%(asctime)s - %(levelname)s - %(message)s",
110
+ handlers=handlers,
111
+ )
112
+
113
+
114
+ def _parse_seed(seed_arg: str) -> Optional[int]:
115
+ """Parse --seed argument into an int or None."""
116
+ s = seed_arg.strip().lower()
117
+ if s == "random":
118
+ return None
119
+ if s == "deterministic":
120
+ return 42
121
+ try:
122
+ return int(seed_arg)
123
+ except ValueError as e:
124
+ raise argparse.ArgumentTypeError(
125
+ "Invalid --seed. Use 'random', 'deterministic', or an integer."
126
+ ) from e
127
+
128
+
129
+ def _parse_models(models: Iterable[str]) -> Tuple[str, ...]:
130
+ """Validate and canonicalize model names in a deterministic order.
131
+
132
+ - If no models are provided, returns all in MODEL_ORDER.
133
+ - If a subset is provided via --models, returns them in MODEL_ORDER order.
134
+ """
135
+ models = tuple(models) # in case it's a generator
136
+ valid = set(MODEL_ORDER)
137
+
138
+ # Validate first
139
+ unknown = [m for m in models if m not in valid]
140
+ if unknown:
141
+ raise argparse.ArgumentTypeError(
142
+ f"Unknown model(s): {unknown}. Valid options: {list(MODEL_ORDER)}"
143
+ )
144
+
145
+ # Default: all models in canonical order
146
+ if not models:
147
+ return MODEL_ORDER
148
+
149
+ # Subset: keep only those requested, but in canonical order
150
+ selected = tuple(m for m in MODEL_ORDER if m in models)
151
+ return selected
152
+
153
+
154
+ def _parse_overrides(pairs: list[str]) -> dict:
155
+ """Parse --set key=value into typed values via literal_eval."""
156
+ out: dict = {}
157
+ for kv in pairs or []:
158
+ if "=" not in kv:
159
+ raise argparse.ArgumentTypeError(f"--set expects key=value, got '{kv}'")
160
+ k, v = kv.split("=", 1)
161
+ v = v.strip()
162
+ try:
163
+ out[k] = ast.literal_eval(v)
164
+ except Exception:
165
+ out[k] = v # raw string fallback
166
+ return out
167
+
168
+
169
+ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
170
+ """Convert explicitly provided CLI flags into config dot-overrides."""
171
+ overrides: dict = {}
172
+
173
+ # IO / top-level controls
174
+ if hasattr(args, "prefix") and args.prefix is not None:
175
+ overrides["io.prefix"] = args.prefix
176
+ else:
177
+ # Note: we don't know input_path here; prefix default is handled later.
178
+ # This fallback is preserved to avoid changing semantics.
179
+ if hasattr(args, "vcf"):
180
+ overrides["io.prefix"] = str(Path(args.vcf).stem)
181
+
182
+ if hasattr(args, "verbose"):
183
+ overrides["io.verbose"] = bool(args.verbose)
184
+ if hasattr(args, "n_jobs"):
185
+ overrides["io.n_jobs"] = int(args.n_jobs)
186
+ if hasattr(args, "seed"):
187
+ overrides["io.seed"] = _parse_seed(args.seed)
188
+
189
+ # Train
190
+ if hasattr(args, "batch_size"):
191
+ overrides["train.batch_size"] = int(args.batch_size)
192
+ if hasattr(args, "device"):
193
+ dev = args.device
194
+ if dev == "cuda":
195
+ dev = "gpu"
196
+ overrides["train.device"] = dev
197
+
198
+ # Plot
199
+ if hasattr(args, "plot_format"):
200
+ overrides["plot.fmt"] = args.plot_format
201
+
202
+ # Simulation overrides (shared across config-driven models)
203
+ if hasattr(args, "sim_strategy"):
204
+ overrides["sim.sim_strategy"] = args.sim_strategy
205
+ if hasattr(args, "sim_prop"):
206
+ overrides["sim.sim_prop"] = float(args.sim_prop)
207
+ if hasattr(args, "simulate_missing"):
208
+ overrides["sim.simulate_missing"] = bool(args.simulate_missing)
209
+
210
+ # Tuning
211
+ if hasattr(args, "tune"):
212
+ overrides["tune.enabled"] = bool(args.tune)
213
+ if hasattr(args, "tune_n_trials"):
214
+ overrides["tune.n_trials"] = int(args.tune_n_trials)
215
+
216
+ return overrides
217
+
218
+
219
+ def _format_seconds(seconds: float) -> str:
220
+ total = int(round(seconds))
221
+ minutes, secs = divmod(total, 60)
222
+ hours, minutes = divmod(minutes, 60)
223
+ if hours:
224
+ return f"{hours:d}:{minutes:02d}:{secs:02d}"
225
+ return f"{minutes:d}:{secs:02d}"
226
+
227
+
228
+ def log_model_time(fn: Callable[P, R]) -> Callable[P, R]:
229
+ """Decorator to time run_model_safely; assumes model_name is first arg."""
230
+
231
+ @wraps(fn)
232
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
233
+ model_name = str(args[0]) if args else "<unknown model>"
234
+ start = time.perf_counter()
235
+ try:
236
+ result = fn(*args, **kwargs)
237
+ except Exception:
238
+ elapsed = time.perf_counter() - start
239
+ logging.error(
240
+ f"{model_name} failed after {elapsed:0.2f}s "
241
+ f"({_format_seconds(elapsed)}).",
242
+ exc_info=True,
243
+ )
244
+ raise
245
+ elapsed = time.perf_counter() - start
246
+ logging.info(
247
+ f"{model_name} finished in {elapsed:0.2f}s "
248
+ f"({_format_seconds(elapsed)})."
249
+ )
250
+ return result
251
+
252
+ return cast(Callable[P, R], wrapper)
253
+
254
+
255
+ # ------------------------------ Core Runner ------------------------------ #
256
+ def build_genotype_data(
257
+ input_path: str,
258
+ fmt: Literal["vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"],
259
+ popmap_path: str | None,
260
+ treefile: str | None,
261
+ qmatrix: str | None,
262
+ siterates: str | None,
263
+ force_popmap: bool,
264
+ verbose: bool,
265
+ include_pops: List[str] | None,
266
+ plot_format: Literal["pdf", "png", "jpg", "jpeg"],
267
+ ):
268
+ """Load genotype data from heterogeneous inputs."""
269
+ logging.info(f"Loading {fmt.upper()} and popmap data...")
270
+
271
+ kwargs = {
272
+ "filename": input_path,
273
+ "popmapfile": popmap_path,
274
+ "force_popmap": force_popmap,
275
+ "verbose": verbose,
276
+ "include_pops": include_pops if include_pops else None,
277
+ "prefix": f"snpio_{Path(input_path).stem}",
278
+ "plot_format": plot_format,
279
+ }
280
+
281
+ if fmt == "vcf":
282
+ gd = VCFReader(**kwargs)
283
+ elif fmt == "phylip":
284
+ gd = PhylipReader(**kwargs)
285
+ elif fmt == "genepop":
286
+ gd = GenePopReader(**kwargs)
287
+ else:
288
+ raise ValueError(f"Unsupported genotype data format: {fmt}")
289
+
290
+ tp = None
291
+ if treefile is not None:
292
+ logging.info("Parsing phylogenetic tree...")
293
+
294
+ tp = TreeParser(
295
+ gd, treefile=treefile, qmatrix=qmatrix, siterates=siterates, verbose=True
296
+ )
297
+
298
+ logging.info("Loaded genotype data.")
299
+ return gd, tp
300
+
301
+
302
+ @log_model_time
303
+ def run_model_safely(model_name: str, builder, *, warn_only: bool = True) -> None:
304
+ """Run model builder + fit/transform with error isolation."""
305
+ logging.info(f"▶ Running {model_name} ...")
306
+ try:
307
+ model = builder()
308
+ model.fit()
309
+ X_imputed = model.transform()
310
+ logging.info(f"✓ {model_name} completed.")
311
+ return X_imputed
312
+ except Exception as e:
313
+ if warn_only:
314
+ logging.warning(f"⚠ {model_name} failed: {e}", exc_info=True)
315
+ else:
316
+ raise
317
+
318
+
319
+ # -------------------------- Model Registry ------------------------------- #
320
+ # Add config-driven models here by listing the class and its config dataclass.
321
+ MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
322
+ "ImputeUBP": {"cls": ImputeUBP, "config_cls": UBPConfig},
323
+ "ImputeNLPCA": {"cls": ImputeNLPCA, "config_cls": NLPCAConfig},
324
+ "ImputeAutoencoder": {"cls": ImputeAutoencoder, "config_cls": AutoencoderConfig},
325
+ "ImputeVAE": {"cls": ImputeVAE, "config_cls": VAEConfig},
326
+ "ImputeMostFrequent": {"cls": ImputeMostFrequent, "config_cls": MostFrequentConfig},
327
+ "ImputeRefAllele": {"cls": ImputeRefAllele, "config_cls": RefAlleleConfig},
328
+ }
329
+
330
+
331
+ def _build_effective_config_for_model(
332
+ model_name: str, args: argparse.Namespace
333
+ ) -> Any | None:
334
+ """Build the effective config object for a specific model (if it has one).
335
+
336
+ Precedence (lowest → highest):
337
+ defaults < preset (--preset) < YAML (--config) < explicit CLI flags < --set
338
+
339
+ Returns:
340
+ Config dataclass instance or None (for models without config dataclasses).
341
+ """
342
+ reg = MODEL_REGISTRY[model_name]
343
+ cfg_cls = reg.get("config_cls")
344
+
345
+ if cfg_cls is None:
346
+ return None
347
+
348
+ # 0) Start from pure dataclass defaults.
349
+ cfg = cfg_cls()
350
+
351
+ # 1) If user explicitly provided a preset, overlay it.
352
+ if hasattr(args, "preset"):
353
+ preset_name = args.preset
354
+ cfg = cfg_cls.from_preset(preset_name)
355
+ logging.info(f"Initialized {model_name} from '{preset_name}' preset.")
356
+ else:
357
+ logging.info(f"Initialized {model_name} from dataclass defaults (no preset).")
358
+
359
+ # 2) YAML overlays preset/defaults (boss). Ignore any 'preset' in YAML.
360
+ yaml_path = getattr(args, "config", None)
361
+
362
+ if yaml_path:
363
+ cfg = load_yaml_to_dataclass(
364
+ yaml_path,
365
+ cfg_cls,
366
+ base=cfg,
367
+ yaml_preset_behavior="ignore", # 'preset' key in YAML ignored with warning
368
+ )
369
+ logging.info(
370
+ f"Loaded YAML config for {model_name} from {yaml_path} (ignored 'preset' in YAML if present)."
371
+ )
372
+
373
+ # 3) Explicit CLI flags overlay YAML.
374
+ cli_overrides = _args_to_cli_overrides(args)
375
+ if cli_overrides:
376
+ cfg = apply_dot_overrides(cfg, cli_overrides)
377
+
378
+ # 4) --set has highest precedence.
379
+ user_overrides = _parse_overrides(getattr(args, "set", []))
380
+
381
+ if user_overrides:
382
+ try:
383
+ cfg = apply_dot_overrides(cfg, user_overrides)
384
+ except Exception as e:
385
+ if model_name in {
386
+ "ImputeUBP",
387
+ "ImputeNLPCA",
388
+ "ImputeAutoencoder",
389
+ "ImputeVAE",
390
+ }:
391
+ logging.error(
392
+ f"Error applying --set overrides to {model_name} config: {e}"
393
+ )
394
+ raise
395
+ else:
396
+ pass # non-config-driven models ignore --set
397
+
398
+ return cfg
399
+
400
+
401
+ def _maybe_print_or_dump_configs(
402
+ cfgs_by_model: Dict[str, Any], args: argparse.Namespace
403
+ ) -> bool:
404
+ """Handle --print-config / --dump-config for ALL config-driven models selected.
405
+
406
+ Returns:
407
+ True if we printed/dumped and should exit; else False.
408
+ """
409
+ did_io = False
410
+ if getattr(args, "print_config", False):
411
+ for m, cfg in cfgs_by_model.items():
412
+ if cfg is None:
413
+ continue
414
+ print(f"# --- {m} effective config ---")
415
+ print(dataclass_to_yaml(cfg))
416
+ print()
417
+ did_io = True
418
+
419
+ if hasattr(args, "dump_config") and args.dump_config:
420
+ # If multiple models, add suffix per model (before extension if possible)
421
+ dump_base = args.dump_config
422
+ for m, cfg in cfgs_by_model.items():
423
+ if cfg is None:
424
+ continue
425
+ if "." in dump_base:
426
+ stem, ext = dump_base.rsplit(".", 1)
427
+ path = f"{stem}.{m}.{ext}"
428
+ else:
429
+ path = f"{dump_base}.{m}.yaml"
430
+ save_dataclass_yaml(cfg, path)
431
+ logging.info(f"Saved {m} config to {path}")
432
+ did_io = True
433
+
434
+ return did_io
435
+
436
+
437
+ def main(argv: Optional[List[str]] = None) -> int:
438
+ parser = argparse.ArgumentParser(
439
+ prog="pg-sui",
440
+ description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models.",
441
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
442
+ usage="%(prog)s [options]",
443
+ )
444
+
445
+ # ----------------------------- Required I/O ----------------------------- #
446
+ parser.add_argument(
447
+ "--input",
448
+ default=argparse.SUPPRESS,
449
+ help="Path to input file (VCF/PHYLIP/STRUCTURE/GENEPOP).",
450
+ )
451
+ parser.add_argument(
452
+ "--format",
453
+ choices=(
454
+ "infer",
455
+ "vcf",
456
+ "vcf.gz",
457
+ "phy",
458
+ "phylip",
459
+ "str",
460
+ "structure",
461
+ "genepop",
462
+ "gen",
463
+ ),
464
+ default=argparse.SUPPRESS,
465
+ help="Input format. If 'infer', deduced from file extension. The default is 'infer'.",
466
+ )
467
+ # Back-compat: --vcf retained; if both provided, --input wins.
468
+ parser.add_argument(
469
+ "--vcf",
470
+ default=argparse.SUPPRESS,
471
+ help="Path to input VCF file. Can be bgzipped or uncompressed.",
472
+ )
473
+ parser.add_argument(
474
+ "--popmap",
475
+ default=argparse.SUPPRESS,
476
+ help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs.",
477
+ )
478
+ parser.add_argument(
479
+ "--treefile",
480
+ default=argparse.SUPPRESS,
481
+ help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format.",
482
+ )
483
+ parser.add_argument(
484
+ "--qmatrix",
485
+ default=argparse.SUPPRESS,
486
+ help="Path to IQ-TREE output file (has .iqtree extension) that contains Rate Matrix Q. Used with --treefile and --siterates.",
487
+ )
488
+ parser.add_argument(
489
+ "--siterates",
490
+ default=argparse.SUPPRESS,
491
+ help="Path to SNP site rates file (has .rate extension). Used with --treefile and --qmatrix.",
492
+ )
493
+ parser.add_argument(
494
+ "--prefix",
495
+ default=argparse.SUPPRESS,
496
+ help="Output file prefix.",
497
+ )
498
+
499
+ # ---------------------- Generic Config Inputs -------------------------- #
500
+ parser.add_argument(
501
+ "--config",
502
+ default=argparse.SUPPRESS,
503
+ help="YAML config for config-driven models (NLPCA/UBP/Autoencoder/VAE).",
504
+ )
505
+ parser.add_argument(
506
+ "--preset",
507
+ choices=("fast", "balanced", "thorough"),
508
+ default=argparse.SUPPRESS, # <-- no default; optional
509
+ help="If provided, initialize config(s) from this preset; otherwise start from dataclass defaults.",
510
+ )
511
+ parser.add_argument(
512
+ "--set",
513
+ action="append",
514
+ default=argparse.SUPPRESS,
515
+ help="Dot-key overrides, e.g. --set model.latent_dim=4",
516
+ )
517
+ parser.add_argument(
518
+ "--print-config",
519
+ action="store_true",
520
+ help="Print effective config(s) and exit.",
521
+ )
522
+ parser.add_argument(
523
+ "--dump-config",
524
+ default=argparse.SUPPRESS,
525
+ help="Write effective config(s) YAML to this path (multi-model gets suffixed).",
526
+ )
527
+
528
+ # ------------------------------ Toggles -------------------------------- #
529
+ parser.add_argument(
530
+ "--tune",
531
+ action="store_true",
532
+ default=argparse.SUPPRESS,
533
+ help="Enable hyperparameter tuning (if supported).",
534
+ )
535
+ parser.add_argument(
536
+ "--tune-n-trials",
537
+ type=int,
538
+ default=argparse.SUPPRESS,
539
+ help="Optuna trials when --tune is set.",
540
+ )
541
+ parser.add_argument(
542
+ "--batch-size",
543
+ type=int,
544
+ default=argparse.SUPPRESS,
545
+ help="Batch size for NN-based models.",
546
+ )
547
+ parser.add_argument(
548
+ "--device",
549
+ choices=("cpu", "cuda", "mps"),
550
+ default=argparse.SUPPRESS,
551
+ help="Compute device for NN-based models.",
552
+ )
553
+ parser.add_argument(
554
+ "--n-jobs",
555
+ type=int,
556
+ default=argparse.SUPPRESS,
557
+ help="Parallel workers for various steps.",
558
+ )
559
+ parser.add_argument(
560
+ "--plot-format",
561
+ choices=("png", "pdf", "svg", "jpg", "jpeg"),
562
+ default=argparse.SUPPRESS,
563
+ help="Figure format for model plots.",
564
+ )
565
+ # ------------------------- Simulation Controls ------------------------ #
566
+ parser.add_argument(
567
+ "--sim-strategy",
568
+ choices=SIM_STRATEGY_CHOICES,
569
+ default=argparse.SUPPRESS,
570
+ help="Override the missing-data simulation strategy for all config-driven models.",
571
+ )
572
+ parser.add_argument(
573
+ "--sim-prop",
574
+ type=float,
575
+ default=argparse.SUPPRESS,
576
+ help="Override the proportion of observed entries to mask during simulation (0-1).",
577
+ )
578
+ parser.add_argument(
579
+ "--simulate-missing",
580
+ action="store_false",
581
+ default=argparse.SUPPRESS,
582
+ help="Disable missing-data simulation regardless of preset/config (when provided).",
583
+ )
584
+
585
+ # --------------------------- Seed & logging ---------------------------- #
586
+ parser.add_argument(
587
+ "--seed",
588
+ default=argparse.SUPPRESS,
589
+ help="Random seed: 'random', 'deterministic', or an integer.",
590
+ )
591
+ parser.add_argument("--verbose", action="store_true", help="Info-level logging.")
592
+ parser.add_argument(
593
+ "--log-file", default=argparse.SUPPRESS, help="Also write logs to a file."
594
+ )
595
+
596
+ # ---------------------------- Data filtering --------------------------- #
597
+ parser.add_argument(
598
+ "--include-pops",
599
+ nargs="+",
600
+ default=argparse.SUPPRESS,
601
+ help="Optional list of population IDs to include.",
602
+ )
603
+ parser.add_argument(
604
+ "--force-popmap",
605
+ action="store_true",
606
+ default=False,
607
+ help="Require popmap (error if absent).",
608
+ )
609
+
610
+ # ---------------------------- Model selection -------------------------- #
611
+ parser.add_argument(
612
+ "--models",
613
+ nargs="+",
614
+ default=argparse.SUPPRESS,
615
+ help=(
616
+ "Which models to run. Choices: ImputeUBP ImputeVAE ImputeAutoencoder ImputeNLPCA ImputeMostFrequent ImputeRefAllele. Default is all."
617
+ ),
618
+ )
619
+
620
+ # -------------------------- MultiQC integration ------------------------ #
621
+ parser.add_argument(
622
+ "--multiqc",
623
+ action="store_true",
624
+ help=(
625
+ "Build a MultiQC HTML report at the end of the run, combining SNPio and PG-SUI plots (requires SNPio's MultiQC module)."
626
+ ),
627
+ )
628
+ parser.add_argument(
629
+ "--multiqc-title",
630
+ default=argparse.SUPPRESS,
631
+ help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>').",
632
+ )
633
+ parser.add_argument(
634
+ "--multiqc-output-dir",
635
+ default=argparse.SUPPRESS,
636
+ help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc').",
637
+ )
638
+ parser.add_argument(
639
+ "--multiqc-overwrite",
640
+ action="store_true",
641
+ default=False,
642
+ help="Overwrite an existing MultiQC report if present.",
643
+ )
644
+
645
+ # ------------------------------ Safety/UX ------------------------------ #
646
+ parser.add_argument(
647
+ "--dry-run",
648
+ action="store_true",
649
+ help="Parse args and load data, but skip model training.",
650
+ )
651
+
652
+ args = parser.parse_args(argv)
653
+
654
+ # Logging (verbose default is False unless passed)
655
+ _configure_logging(
656
+ verbose=getattr(args, "verbose", False),
657
+ log_file=getattr(args, "log_file", None),
658
+ )
659
+
660
+ # Models selection (default to all if not explicitly provided)
661
+ try:
662
+ selected_models = _parse_models(getattr(args, "models", ()))
663
+ except argparse.ArgumentTypeError as e:
664
+ parser.error(str(e))
665
+ return 2
666
+
667
+ # Input resolution
668
+ input_path = getattr(args, "input", None)
669
+ if input_path is None and hasattr(args, "vcf"):
670
+ input_path = args.vcf
671
+ if not hasattr(args, "format"):
672
+ setattr(args, "format", "vcf")
673
+
674
+ if input_path is None:
675
+ parser.error("You must provide --input (or legacy --vcf).")
676
+ return 2
677
+
678
+ fmt: Literal["infer", "vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"] = getattr(
679
+ args, "format", "infer"
680
+ )
681
+
682
+ if fmt == "infer":
683
+ if input_path.endswith((".vcf", ".vcf.gz")):
684
+ fmt_final = "vcf"
685
+ elif input_path.endswith((".phy", ".phylip")):
686
+ fmt_final = "phylip"
687
+ elif input_path.endswith((".genepop", ".gen")):
688
+ fmt_final = "genepop"
689
+ else:
690
+ parser.error(
691
+ "Could not infer input format from file extension. Please provide --format."
692
+ )
693
+ return 2
694
+ else:
695
+ fmt_final = fmt
696
+
697
+ popmap_path = getattr(args, "popmap", None)
698
+ include_pops = getattr(args, "include_pops", None)
699
+ verbose_flag = getattr(args, "verbose", False)
700
+ force_popmap = bool(getattr(args, "force_popmap", False))
701
+
702
+ # Canonical prefix for this run (used for outputs and MultiQC)
703
+ prefix: str = getattr(args, "prefix", str(Path(input_path).stem))
704
+
705
+ treefile = getattr(args, "treefile", None)
706
+ qmatrix = getattr(args, "qmatrix", None)
707
+ siterates = getattr(args, "siterates", None)
708
+
709
+ if any(x is not None for x in (treefile, qmatrix, siterates)):
710
+ if not all(x is not None for x in (treefile, qmatrix, siterates)):
711
+ parser.error(
712
+ "--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
713
+ )
714
+ return 2
715
+
716
+ # Load genotype data
717
+ gd, tp = build_genotype_data(
718
+ input_path=input_path,
719
+ fmt=fmt_final,
720
+ popmap_path=popmap_path,
721
+ treefile=treefile,
722
+ qmatrix=qmatrix,
723
+ siterates=siterates,
724
+ force_popmap=force_popmap,
725
+ verbose=verbose_flag,
726
+ include_pops=include_pops,
727
+ plot_format=getattr(args, "plot_format", "pdf"),
728
+ )
729
+
730
+ if getattr(args, "dry_run", False):
731
+ logging.info("Dry run complete. Exiting without training models.")
732
+ return 0
733
+
734
+ # ---------------- Build config(s) per selected model ------------------- #
735
+ cfgs_by_model: Dict[str, Any] = {
736
+ m: _build_effective_config_for_model(m, args) for m in selected_models
737
+ }
738
+
739
+ # Maybe print/dump configs and exit
740
+ if _maybe_print_or_dump_configs(cfgs_by_model, args):
741
+ return 0
742
+
743
+ # ------------------------- Model Builders ------------------------------ #
744
+ def build_impute_ubp():
745
+ cfg = cfgs_by_model.get("ImputeUBP")
746
+ if cfg is None:
747
+ cfg = (
748
+ UBPConfig.from_preset(args.preset)
749
+ if hasattr(args, "preset")
750
+ else UBPConfig()
751
+ )
752
+ return ImputeUBP(
753
+ genotype_data=gd,
754
+ tree_parser=tp,
755
+ config=cfg,
756
+ simulate_missing=cfg.sim.simulate_missing,
757
+ sim_strategy=cfg.sim.sim_strategy,
758
+ sim_prop=cfg.sim.sim_prop,
759
+ sim_kwargs=cfg.sim.sim_kwargs,
760
+ )
761
+
762
+ def build_impute_nlpca():
763
+ cfg = cfgs_by_model.get("ImputeNLPCA")
764
+ if cfg is None:
765
+ cfg = (
766
+ NLPCAConfig.from_preset(args.preset)
767
+ if hasattr(args, "preset")
768
+ else NLPCAConfig()
769
+ )
770
+ return ImputeNLPCA(
771
+ genotype_data=gd,
772
+ tree_parser=tp,
773
+ config=cfg,
774
+ simulate_missing=cfg.sim.simulate_missing,
775
+ sim_strategy=cfg.sim.sim_strategy,
776
+ sim_prop=cfg.sim.sim_prop,
777
+ sim_kwargs=cfg.sim.sim_kwargs,
778
+ )
779
+
780
+ def build_impute_vae():
781
+ cfg = cfgs_by_model.get("ImputeVAE")
782
+ if cfg is None:
783
+ cfg = (
784
+ VAEConfig.from_preset(args.preset)
785
+ if hasattr(args, "preset")
786
+ else VAEConfig()
787
+ )
788
+ return ImputeVAE(
789
+ genotype_data=gd,
790
+ tree_parser=tp,
791
+ config=cfg,
792
+ simulate_missing=cfg.sim.simulate_missing,
793
+ sim_strategy=cfg.sim.sim_strategy,
794
+ sim_prop=cfg.sim.sim_prop,
795
+ sim_kwargs=cfg.sim.sim_kwargs,
796
+ )
797
+
798
+ def build_impute_autoencoder():
799
+ cfg = cfgs_by_model.get("ImputeAutoencoder")
800
+ if cfg is None:
801
+ cfg = (
802
+ AutoencoderConfig.from_preset(args.preset)
803
+ if hasattr(args, "preset")
804
+ else AutoencoderConfig()
805
+ )
806
+ return ImputeAutoencoder(
807
+ genotype_data=gd,
808
+ tree_parser=tp,
809
+ config=cfg,
810
+ simulate_missing=cfg.sim.simulate_missing,
811
+ sim_strategy=cfg.sim.sim_strategy,
812
+ sim_prop=cfg.sim.sim_prop,
813
+ sim_kwargs=cfg.sim.sim_kwargs,
814
+ )
815
+
816
+ def build_impute_mostfreq():
817
+ cfg = cfgs_by_model.get("ImputeMostFrequent")
818
+ if cfg is None:
819
+ cfg = (
820
+ MostFrequentConfig.from_preset(args.preset)
821
+ if hasattr(args, "preset")
822
+ else MostFrequentConfig()
823
+ )
824
+ return ImputeMostFrequent(
825
+ gd,
826
+ tree_parser=tp,
827
+ config=cfg,
828
+ simulate_missing=cfg.sim.simulate_missing,
829
+ sim_strategy=cfg.sim.sim_strategy,
830
+ sim_prop=cfg.sim.sim_prop,
831
+ sim_kwargs=cfg.sim.sim_kwargs,
832
+ )
833
+
834
+ def build_impute_refallele():
835
+ cfg = cfgs_by_model.get("ImputeRefAllele")
836
+ if cfg is None:
837
+ cfg = (
838
+ RefAlleleConfig.from_preset(args.preset)
839
+ if hasattr(args, "preset")
840
+ else RefAlleleConfig()
841
+ )
842
+ return ImputeRefAllele(
843
+ gd,
844
+ tree_parser=tp,
845
+ config=cfg,
846
+ simulate_missing=cfg.sim.simulate_missing,
847
+ sim_strategy=cfg.sim.sim_strategy,
848
+ sim_prop=cfg.sim.sim_prop,
849
+ sim_kwargs=cfg.sim.sim_kwargs,
850
+ )
851
+
852
+ model_builders = {
853
+ "ImputeUBP": build_impute_ubp,
854
+ "ImputeVAE": build_impute_vae,
855
+ "ImputeAutoencoder": build_impute_autoencoder,
856
+ "ImputeNLPCA": build_impute_nlpca,
857
+ "ImputeMostFrequent": build_impute_mostfreq,
858
+ "ImputeRefAllele": build_impute_refallele,
859
+ }
860
+
861
+ logging.info(f"Selected models: {', '.join(selected_models)}")
862
+ for name in selected_models:
863
+ X_imputed = run_model_safely(name, model_builders[name], warn_only=False)
864
+ gd_imp = gd.copy()
865
+ gd_imp.snp_data = X_imputed
866
+
867
+ if name in {"ImputeUBP", "ImputeVAE", "ImputeAutoencoder", "ImputeNLPCA"}:
868
+ family = "Unsupervised"
869
+ elif name in {"ImputeMostFrequent", "ImputeRefAllele"}:
870
+ family = "Deterministic"
871
+ elif name in {"ImputeHistGradientBoosting", "ImputeRandomForest"}:
872
+ family = "Supervised"
873
+ else:
874
+ raise ValueError(f"Unknown model family for {name}")
875
+
876
+ pth = Path(f"{prefix}_output/{family}/imputed/{name}")
877
+ pth.mkdir(parents=True, exist_ok=True)
878
+
879
+ logging.info(f"Writing imputed VCF for {name} to {pth} ...")
880
+ gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
881
+
882
+ logging.info("All requested models processed.")
883
+
884
+ # -------------------------- MultiQC builder ---------------------------- #
885
+
886
+ mqc_output_dir = getattr(args, "multiqc_output_dir", f"{prefix}_output/multiqc")
887
+ mqc_title = getattr(args, "multiqc_title", f"PG-SUI MultiQC Report - {prefix}")
888
+ overwrite = bool(getattr(args, "multiqc_overwrite", False))
889
+
890
+ logging.info(
891
+ f"Building MultiQC report in '{mqc_output_dir}' (title={mqc_title}, overwrite={overwrite})..."
892
+ )
893
+
894
+ try:
895
+ SNPioMultiQC.build(
896
+ prefix=prefix,
897
+ output_dir=mqc_output_dir,
898
+ title=mqc_title,
899
+ overwrite=overwrite,
900
+ )
901
+ logging.info("MultiQC report successfully built.")
902
+ except Exception as exc2: # pragma: no cover
903
+ logging.error(f"Failed to build MultiQC report: {exc2}", exc_info=True)
904
+
905
+ return 0
906
+
907
+
908
+ if __name__ == "__main__":
909
+ raise SystemExit(main())