pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
pgsui/cli.py ADDED
@@ -0,0 +1,922 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """PG-SUI Imputation CLI
5
+
6
+ Argument-precedence model:
7
+ code defaults < preset (--preset) < YAML (--config) < explicit CLI flags < --set k=vx
8
+
9
+ Notes
10
+ -----
11
+ - Preset is a CLI-only choice and will be respected unless overridden by YAML or CLI.
12
+ - YAML entries override preset (a 'preset' key in YAML is ignored with a warning).
13
+ - CLI flags only override when explicitly provided (argparse uses SUPPRESS).
14
+ - --set key=value has the highest precedence and applies dot-path overrides.
15
+
16
+ Examples
17
+ --------
18
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix run1
19
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix tuned --tune
20
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix demo \
21
+ --models ImputeUBP ImputeVAE ImputeMostFrequent --seed deterministic --verbose
22
+ python cli.py --vcf data.vcf.gz --popmap pops.popmap --prefix subset \
23
+ --include-pops EA GU TT ON --device cpu
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import ast
30
+ import logging
31
+ import sys
32
+ import time
33
+ from functools import wraps
34
+ from pathlib import Path
35
+ from typing import (
36
+ Any,
37
+ Callable,
38
+ Dict,
39
+ Iterable,
40
+ List,
41
+ Literal,
42
+ Optional,
43
+ ParamSpec,
44
+ Tuple,
45
+ TypeVar,
46
+ cast,
47
+ )
48
+
49
+ from snpio import GenePopReader, PhylipReader, SNPioMultiQC, VCFReader, TreeParser
50
+
51
+ from pgsui import (
52
+ AutoencoderConfig,
53
+ ImputeAutoencoder,
54
+ ImputeMostFrequent,
55
+ ImputeNLPCA,
56
+ ImputeRefAllele,
57
+ ImputeUBP,
58
+ ImputeVAE,
59
+ MostFrequentConfig,
60
+ NLPCAConfig,
61
+ RefAlleleConfig,
62
+ UBPConfig,
63
+ VAEConfig,
64
+ )
65
+ from pgsui.data_processing.config import (
66
+ apply_dot_overrides,
67
+ dataclass_to_yaml,
68
+ load_yaml_to_dataclass,
69
+ save_dataclass_yaml,
70
+ )
71
+
72
+ # Canonical model order used everywhere (default and subset ordering)
73
+ MODEL_ORDER: Tuple[str, ...] = (
74
+ "ImputeUBP",
75
+ "ImputeVAE",
76
+ "ImputeAutoencoder",
77
+ "ImputeNLPCA",
78
+ "ImputeMostFrequent",
79
+ "ImputeRefAllele",
80
+ )
81
+
82
+ # Strategies supported by SimMissingTransformer + SimConfig.
83
+ SIM_STRATEGY_CHOICES: Tuple[str, ...] = (
84
+ "random",
85
+ "random_weighted",
86
+ "random_weighted_inv",
87
+ "nonrandom",
88
+ "nonrandom_weighted",
89
+ )
90
+
91
+ P = ParamSpec("P")
92
+ R = TypeVar("R")
93
+
94
+
95
+ # ----------------------------- CLI Utilities ----------------------------- #
96
+ def _configure_logging(verbose: bool, log_file: Optional[str] = None) -> None:
97
+ """Configure root logger.
98
+
99
+ Args:
100
+ verbose (bool): If True, INFO; else ERROR.
101
+ log_file (Optional[str]): Optional file to tee logs to.
102
+ """
103
+ level = logging.INFO if verbose else logging.ERROR
104
+ handlers: List[logging.Handler] = [logging.StreamHandler(sys.stdout)]
105
+ if log_file:
106
+ handlers.append(logging.FileHandler(log_file, mode="w", encoding="utf-8"))
107
+ logging.basicConfig(
108
+ level=level,
109
+ format="%(asctime)s - %(levelname)s - %(message)s",
110
+ handlers=handlers,
111
+ )
112
+
113
+
114
+ def _parse_seed(seed_arg: str) -> Optional[int]:
115
+ """Parse --seed argument into an int or None."""
116
+ s = seed_arg.strip().lower()
117
+ if s == "random":
118
+ return None
119
+ if s == "deterministic":
120
+ return 42
121
+ try:
122
+ return int(seed_arg)
123
+ except ValueError as e:
124
+ raise argparse.ArgumentTypeError(
125
+ "Invalid --seed. Use 'random', 'deterministic', or an integer."
126
+ ) from e
127
+
128
+
129
+ def _parse_models(models: Iterable[str]) -> Tuple[str, ...]:
130
+ """Validate and canonicalize model names in a deterministic order.
131
+
132
+ - If no models are provided, returns all in MODEL_ORDER.
133
+ - If a subset is provided via --models, returns them in MODEL_ORDER order.
134
+ """
135
+ models = tuple(models) # in case it's a generator
136
+ valid = set(MODEL_ORDER)
137
+
138
+ # Validate first
139
+ unknown = [m for m in models if m not in valid]
140
+ if unknown:
141
+ raise argparse.ArgumentTypeError(
142
+ f"Unknown model(s): {unknown}. Valid options: {list(MODEL_ORDER)}"
143
+ )
144
+
145
+ # Default: all models in canonical order
146
+ if not models:
147
+ return MODEL_ORDER
148
+
149
+ # Subset: keep only those requested, but in canonical order
150
+ selected = tuple(m for m in MODEL_ORDER if m in models)
151
+ return selected
152
+
153
+
154
+ def _parse_overrides(pairs: list[str]) -> dict:
155
+ """Parse --set key=value into typed values via literal_eval."""
156
+ out: dict = {}
157
+ for kv in pairs or []:
158
+ if "=" not in kv:
159
+ raise argparse.ArgumentTypeError(f"--set expects key=value, got '{kv}'")
160
+ k, v = kv.split("=", 1)
161
+ v = v.strip()
162
+ try:
163
+ out[k] = ast.literal_eval(v)
164
+ except Exception:
165
+ out[k] = v # raw string fallback
166
+ return out
167
+
168
+
169
+ def _args_to_cli_overrides(args: argparse.Namespace) -> dict:
170
+ """Convert explicitly provided CLI flags into config dot-overrides."""
171
+ overrides: dict = {}
172
+
173
+ # IO / top-level controls
174
+ if hasattr(args, "prefix") and args.prefix is not None:
175
+ overrides["io.prefix"] = args.prefix
176
+ else:
177
+ # Note: we don't know input_path here; prefix default is handled later.
178
+ # This fallback is preserved to avoid changing semantics.
179
+ if hasattr(args, "vcf"):
180
+ overrides["io.prefix"] = str(Path(args.vcf).stem)
181
+
182
+ if hasattr(args, "verbose"):
183
+ overrides["io.verbose"] = bool(args.verbose)
184
+ if hasattr(args, "n_jobs"):
185
+ overrides["io.n_jobs"] = int(args.n_jobs)
186
+ if hasattr(args, "seed"):
187
+ overrides["io.seed"] = _parse_seed(args.seed)
188
+ if hasattr(args, "debug"):
189
+ overrides["io.debug"] = bool(args.debug)
190
+
191
+ # Train
192
+ if hasattr(args, "batch_size"):
193
+ overrides["train.batch_size"] = int(args.batch_size)
194
+ if hasattr(args, "device"):
195
+ dev = args.device
196
+ if dev == "cuda":
197
+ dev = "gpu"
198
+ overrides["train.device"] = dev
199
+
200
+ # Plot
201
+ if hasattr(args, "plot_format"):
202
+ overrides["plot.fmt"] = args.plot_format
203
+
204
+ # Simulation overrides (shared across config-driven models)
205
+ if hasattr(args, "sim_strategy"):
206
+ overrides["sim.sim_strategy"] = args.sim_strategy
207
+ if hasattr(args, "sim_prop"):
208
+ overrides["sim.sim_prop"] = float(args.sim_prop)
209
+ if hasattr(args, "simulate_missing"):
210
+ overrides["sim.simulate_missing"] = bool(args.simulate_missing)
211
+
212
+ # Tuning
213
+ if hasattr(args, "tune"):
214
+ overrides["tune.enabled"] = bool(args.tune)
215
+ if hasattr(args, "tune_n_trials"):
216
+ overrides["tune.n_trials"] = int(args.tune_n_trials)
217
+
218
+ return overrides
219
+
220
+
221
+ def _format_seconds(seconds: float) -> str:
222
+ total = int(round(seconds))
223
+ minutes, secs = divmod(total, 60)
224
+ hours, minutes = divmod(minutes, 60)
225
+ if hours:
226
+ return f"{hours:d}:{minutes:02d}:{secs:02d}"
227
+ return f"{minutes:d}:{secs:02d}"
228
+
229
+
230
+ def log_model_time(fn: Callable[P, R]) -> Callable[P, R]:
231
+ """Decorator to time run_model_safely; assumes model_name is first arg."""
232
+
233
+ @wraps(fn)
234
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
235
+ model_name = str(args[0]) if args else "<unknown model>"
236
+ start = time.perf_counter()
237
+ try:
238
+ result = fn(*args, **kwargs)
239
+ except Exception:
240
+ elapsed = time.perf_counter() - start
241
+ logging.error(
242
+ f"{model_name} failed after {elapsed:0.2f}s "
243
+ f"({_format_seconds(elapsed)}).",
244
+ exc_info=True,
245
+ )
246
+ raise
247
+ elapsed = time.perf_counter() - start
248
+ logging.info(
249
+ f"{model_name} finished in {elapsed:0.2f}s "
250
+ f"({_format_seconds(elapsed)})."
251
+ )
252
+ return result
253
+
254
+ return cast(Callable[P, R], wrapper)
255
+
256
+
257
+ # ------------------------------ Core Runner ------------------------------ #
258
+ def build_genotype_data(
259
+ input_path: str,
260
+ fmt: Literal["vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"],
261
+ popmap_path: str | None,
262
+ treefile: str | None,
263
+ qmatrix: str | None,
264
+ siterates: str | None,
265
+ force_popmap: bool,
266
+ verbose: bool,
267
+ include_pops: List[str] | None,
268
+ plot_format: Literal["pdf", "png", "jpg", "jpeg"],
269
+ ):
270
+ """Load genotype data from heterogeneous inputs."""
271
+ logging.info(f"Loading {fmt.upper()} and popmap data...")
272
+
273
+ kwargs = {
274
+ "filename": input_path,
275
+ "popmapfile": popmap_path,
276
+ "force_popmap": force_popmap,
277
+ "verbose": verbose,
278
+ "include_pops": include_pops if include_pops else None,
279
+ "prefix": f"snpio_{Path(input_path).stem}",
280
+ "plot_format": plot_format,
281
+ }
282
+
283
+ if fmt == "vcf":
284
+ gd = VCFReader(**kwargs)
285
+ elif fmt == "phylip":
286
+ gd = PhylipReader(**kwargs)
287
+ elif fmt == "genepop":
288
+ gd = GenePopReader(**kwargs)
289
+ else:
290
+ raise ValueError(f"Unsupported genotype data format: {fmt}")
291
+
292
+ tp = None
293
+ if treefile is not None:
294
+ logging.info("Parsing phylogenetic tree...")
295
+
296
+ tp = TreeParser(
297
+ gd, treefile=treefile, qmatrix=qmatrix, siterates=siterates, verbose=True
298
+ )
299
+
300
+ logging.info("Loaded genotype data.")
301
+ return gd, tp
302
+
303
+
304
+ @log_model_time
305
+ def run_model_safely(model_name: str, builder, *, warn_only: bool = True) -> None:
306
+ """Run model builder + fit/transform with error isolation."""
307
+ logging.info(f"▶ Running {model_name} ...")
308
+ try:
309
+ model = builder()
310
+ model.fit()
311
+ X_imputed = model.transform()
312
+ logging.info(f"✓ {model_name} completed.")
313
+ return X_imputed
314
+ except Exception as e:
315
+ if warn_only:
316
+ logging.warning(f"⚠ {model_name} failed: {e}", exc_info=True)
317
+ else:
318
+ raise
319
+
320
+
321
+ # -------------------------- Model Registry ------------------------------- #
322
+ # Add config-driven models here by listing the class and its config dataclass.
323
+ MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
324
+ "ImputeUBP": {"cls": ImputeUBP, "config_cls": UBPConfig},
325
+ "ImputeNLPCA": {"cls": ImputeNLPCA, "config_cls": NLPCAConfig},
326
+ "ImputeAutoencoder": {"cls": ImputeAutoencoder, "config_cls": AutoencoderConfig},
327
+ "ImputeVAE": {"cls": ImputeVAE, "config_cls": VAEConfig},
328
+ "ImputeMostFrequent": {"cls": ImputeMostFrequent, "config_cls": MostFrequentConfig},
329
+ "ImputeRefAllele": {"cls": ImputeRefAllele, "config_cls": RefAlleleConfig},
330
+ }
331
+
332
+
333
+ def _build_effective_config_for_model(
334
+ model_name: str, args: argparse.Namespace
335
+ ) -> Any | None:
336
+ """Build the effective config object for a specific model (if it has one).
337
+
338
+ Precedence (lowest → highest):
339
+ defaults < preset (--preset) < YAML (--config) < explicit CLI flags < --set
340
+
341
+ Returns:
342
+ Config dataclass instance or None (for models without config dataclasses).
343
+ """
344
+ reg = MODEL_REGISTRY[model_name]
345
+ cfg_cls = reg.get("config_cls")
346
+
347
+ if cfg_cls is None:
348
+ return None
349
+
350
+ # 0) Start from pure dataclass defaults.
351
+ cfg = cfg_cls()
352
+
353
+ # 1) If user explicitly provided a preset, overlay it.
354
+ if hasattr(args, "preset"):
355
+ preset_name = args.preset
356
+ cfg = cfg_cls.from_preset(preset_name)
357
+ logging.info(f"Initialized {model_name} from '{preset_name}' preset.")
358
+ else:
359
+ logging.info(f"Initialized {model_name} from dataclass defaults (no preset).")
360
+
361
+ # 2) YAML overlays preset/defaults (boss). Ignore any 'preset' in YAML.
362
+ yaml_path = getattr(args, "config", None)
363
+
364
+ if yaml_path:
365
+ cfg = load_yaml_to_dataclass(
366
+ yaml_path,
367
+ cfg_cls,
368
+ base=cfg,
369
+ yaml_preset_behavior="ignore", # 'preset' key in YAML ignored with warning
370
+ )
371
+ logging.info(
372
+ f"Loaded YAML config for {model_name} from {yaml_path} (ignored 'preset' in YAML if present)."
373
+ )
374
+
375
+ # 3) Explicit CLI flags overlay YAML.
376
+ cli_overrides = _args_to_cli_overrides(args)
377
+ if cli_overrides:
378
+ cfg = apply_dot_overrides(cfg, cli_overrides)
379
+
380
+ # 4) --set has highest precedence.
381
+ user_overrides = _parse_overrides(getattr(args, "set", []))
382
+
383
+ if user_overrides:
384
+ try:
385
+ cfg = apply_dot_overrides(cfg, user_overrides)
386
+ except Exception as e:
387
+ if model_name in {
388
+ "ImputeUBP",
389
+ "ImputeNLPCA",
390
+ "ImputeAutoencoder",
391
+ "ImputeVAE",
392
+ }:
393
+ logging.error(
394
+ f"Error applying --set overrides to {model_name} config: {e}"
395
+ )
396
+ raise
397
+ else:
398
+ pass # non-config-driven models ignore --set
399
+
400
+ return cfg
401
+
402
+
403
+ def _maybe_print_or_dump_configs(
404
+ cfgs_by_model: Dict[str, Any], args: argparse.Namespace
405
+ ) -> bool:
406
+ """Handle --print-config / --dump-config for ALL config-driven models selected.
407
+
408
+ Returns:
409
+ True if we printed/dumped and should exit; else False.
410
+ """
411
+ did_io = False
412
+ if getattr(args, "print_config", False):
413
+ for m, cfg in cfgs_by_model.items():
414
+ if cfg is None:
415
+ continue
416
+ print(f"# --- {m} effective config ---")
417
+ print(dataclass_to_yaml(cfg))
418
+ print()
419
+ did_io = True
420
+
421
+ if hasattr(args, "dump_config") and args.dump_config:
422
+ # If multiple models, add suffix per model (before extension if possible)
423
+ dump_base = args.dump_config
424
+ for m, cfg in cfgs_by_model.items():
425
+ if cfg is None:
426
+ continue
427
+ if "." in dump_base:
428
+ stem, ext = dump_base.rsplit(".", 1)
429
+ path = f"{stem}.{m}.{ext}"
430
+ else:
431
+ path = f"{dump_base}.{m}.yaml"
432
+ save_dataclass_yaml(cfg, path)
433
+ logging.info(f"Saved {m} config to {path}")
434
+ did_io = True
435
+
436
+ return did_io
437
+
438
+
439
+ def main(argv: Optional[List[str]] = None) -> int:
440
+ parser = argparse.ArgumentParser(
441
+ prog="pg-sui",
442
+ description="Run PG-SUI imputation models on an input file. Handle configuration via presets, YAML, and CLI flags. The default is to run all models.",
443
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
444
+ usage="%(prog)s [options]",
445
+ )
446
+
447
+ # ----------------------------- Required I/O ----------------------------- #
448
+ parser.add_argument(
449
+ "--input",
450
+ default=argparse.SUPPRESS,
451
+ help="Path to input file (VCF/PHYLIP/STRUCTURE/GENEPOP).",
452
+ )
453
+ parser.add_argument(
454
+ "--format",
455
+ choices=(
456
+ "infer",
457
+ "vcf",
458
+ "vcf.gz",
459
+ "phy",
460
+ "phylip",
461
+ "str",
462
+ "structure",
463
+ "genepop",
464
+ "gen",
465
+ ),
466
+ default=argparse.SUPPRESS,
467
+ help="Input format. If 'infer', deduced from file extension. The default is 'infer'.",
468
+ )
469
+ # Back-compat: --vcf retained; if both provided, --input wins.
470
+ parser.add_argument(
471
+ "--vcf",
472
+ default=argparse.SUPPRESS,
473
+ help="Path to input VCF file. Can be bgzipped or uncompressed.",
474
+ )
475
+ parser.add_argument(
476
+ "--popmap",
477
+ default=argparse.SUPPRESS,
478
+ help="Path to population map file. This is a two-column tab-delimited file with sample IDs and population IDs.",
479
+ )
480
+ parser.add_argument(
481
+ "--treefile",
482
+ default=argparse.SUPPRESS,
483
+ help="Path to phylogenetic tree file. Can be in Newick (recommended) or Nexus format.",
484
+ )
485
+ parser.add_argument(
486
+ "--qmatrix",
487
+ default=argparse.SUPPRESS,
488
+ help="Path to IQ-TREE output file (has .iqtree extension) that contains Rate Matrix Q. Used with --treefile and --siterates.",
489
+ )
490
+ parser.add_argument(
491
+ "--siterates",
492
+ default=argparse.SUPPRESS,
493
+ help="Path to SNP site rates file (has .rate extension). Used with --treefile and --qmatrix.",
494
+ )
495
+ parser.add_argument(
496
+ "--prefix",
497
+ default=argparse.SUPPRESS,
498
+ help="Output file prefix.",
499
+ )
500
+
501
+ # ---------------------- Generic Config Inputs -------------------------- #
502
+ parser.add_argument(
503
+ "--config",
504
+ default=argparse.SUPPRESS,
505
+ help="YAML config for config-driven models (NLPCA/UBP/Autoencoder/VAE).",
506
+ )
507
+ parser.add_argument(
508
+ "--preset",
509
+ choices=("fast", "balanced", "thorough"),
510
+ default=argparse.SUPPRESS, # <-- no default; optional
511
+ help="If provided, initialize config(s) from this preset; otherwise start from dataclass defaults.",
512
+ )
513
+ parser.add_argument(
514
+ "--set",
515
+ action="append",
516
+ default=argparse.SUPPRESS,
517
+ help="Dot-key overrides, e.g. --set model.latent_dim=4",
518
+ )
519
+ parser.add_argument(
520
+ "--print-config",
521
+ action="store_true",
522
+ help="Print effective config(s) and exit.",
523
+ )
524
+ parser.add_argument(
525
+ "--dump-config",
526
+ default=argparse.SUPPRESS,
527
+ help="Write effective config(s) YAML to this path (multi-model gets suffixed).",
528
+ )
529
+
530
+ # ------------------------------ Toggles -------------------------------- #
531
+ parser.add_argument(
532
+ "--tune",
533
+ action="store_true",
534
+ default=argparse.SUPPRESS,
535
+ help="Enable hyperparameter tuning (if supported).",
536
+ )
537
+ parser.add_argument(
538
+ "--tune-n-trials",
539
+ type=int,
540
+ default=argparse.SUPPRESS,
541
+ help="Optuna trials when --tune is set.",
542
+ )
543
+ parser.add_argument(
544
+ "--batch-size",
545
+ type=int,
546
+ default=argparse.SUPPRESS,
547
+ help="Batch size for NN-based models.",
548
+ )
549
+ parser.add_argument(
550
+ "--device",
551
+ choices=("cpu", "cuda", "mps"),
552
+ default=argparse.SUPPRESS,
553
+ help="Compute device for NN-based models.",
554
+ )
555
+ parser.add_argument(
556
+ "--n-jobs",
557
+ type=int,
558
+ default=argparse.SUPPRESS,
559
+ help="Parallel workers for various steps.",
560
+ )
561
+ parser.add_argument(
562
+ "--plot-format",
563
+ choices=("png", "pdf", "svg", "jpg", "jpeg"),
564
+ default=argparse.SUPPRESS,
565
+ help="Figure format for model plots.",
566
+ )
567
+ # ------------------------- Simulation Controls ------------------------ #
568
+ parser.add_argument(
569
+ "--sim-strategy",
570
+ choices=SIM_STRATEGY_CHOICES,
571
+ default=argparse.SUPPRESS,
572
+ help="Override the missing-data simulation strategy for all config-driven models.",
573
+ )
574
+ parser.add_argument(
575
+ "--sim-prop",
576
+ type=float,
577
+ default=argparse.SUPPRESS,
578
+ help="Override the proportion of observed entries to mask during simulation (0-1).",
579
+ )
580
+ parser.add_argument(
581
+ "--simulate-missing",
582
+ action="store_false",
583
+ default=argparse.SUPPRESS,
584
+ help="Disable missing-data simulation regardless of preset/config (when provided).",
585
+ )
586
+
587
+ # --------------------------- Seed & logging ---------------------------- #
588
+ parser.add_argument(
589
+ "--seed",
590
+ default=argparse.SUPPRESS,
591
+ help="Random seed: 'random', 'deterministic', or an integer.",
592
+ )
593
+ parser.add_argument("--verbose", action="store_true", help="Info-level logging.")
594
+ parser.add_argument("--debug", action="store_true", help="Debug-level logging.")
595
+ parser.add_argument(
596
+ "--log-file", default=argparse.SUPPRESS, help="Also write logs to a file."
597
+ )
598
+
599
+ # ---------------------------- Data filtering --------------------------- #
600
+ parser.add_argument(
601
+ "--include-pops",
602
+ nargs="+",
603
+ default=argparse.SUPPRESS,
604
+ help="Optional list of population IDs to include.",
605
+ )
606
+ parser.add_argument(
607
+ "--force-popmap",
608
+ action="store_true",
609
+ default=False,
610
+ help="Require popmap (error if absent).",
611
+ )
612
+
613
+ # ---------------------------- Model selection -------------------------- #
614
+ parser.add_argument(
615
+ "--models",
616
+ nargs="+",
617
+ default=argparse.SUPPRESS,
618
+ help=(
619
+ "Which models to run. Choices: ImputeUBP ImputeVAE ImputeAutoencoder ImputeNLPCA ImputeMostFrequent ImputeRefAllele. Default is all."
620
+ ),
621
+ )
622
+
623
+ # -------------------------- MultiQC integration ------------------------ #
624
+ parser.add_argument(
625
+ "--multiqc",
626
+ action="store_true",
627
+ help=(
628
+ "Build a MultiQC HTML report at the end of the run, combining SNPio and PG-SUI plots (requires SNPio's MultiQC module)."
629
+ ),
630
+ )
631
+ parser.add_argument(
632
+ "--multiqc-title",
633
+ default=argparse.SUPPRESS,
634
+ help="Optional title for the MultiQC report (default: 'PG-SUI MultiQC Report - <prefix>').",
635
+ )
636
+ parser.add_argument(
637
+ "--multiqc-output-dir",
638
+ default=argparse.SUPPRESS,
639
+ help="Optional output directory for the MultiQC report (default: '<prefix>_output/multiqc').",
640
+ )
641
+ parser.add_argument(
642
+ "--multiqc-overwrite",
643
+ action="store_true",
644
+ default=False,
645
+ help="Overwrite an existing MultiQC report if present.",
646
+ )
647
+
648
+ # ------------------------------ Safety/UX ------------------------------ #
649
+ parser.add_argument(
650
+ "--dry-run",
651
+ action="store_true",
652
+ help="Parse args and load data, but skip model training.",
653
+ )
654
+
655
+ args = parser.parse_args(argv)
656
+
657
+ # Logging (verbose default is False unless passed)
658
+ _configure_logging(
659
+ verbose=getattr(args, "verbose", False),
660
+ log_file=getattr(args, "log_file", None),
661
+ )
662
+
663
+ # Models selection (default to all if not explicitly provided)
664
+ try:
665
+ selected_models = _parse_models(getattr(args, "models", ()))
666
+ except argparse.ArgumentTypeError as e:
667
+ parser.error(str(e))
668
+ return 2
669
+
670
+ # Input resolution
671
+ input_path = getattr(args, "input", None)
672
+ if input_path is None and hasattr(args, "vcf"):
673
+ input_path = args.vcf
674
+ if not hasattr(args, "format"):
675
+ setattr(args, "format", "vcf")
676
+
677
+ if input_path is None:
678
+ parser.error("You must provide --input (or legacy --vcf).")
679
+ return 2
680
+
681
+ fmt: Literal["infer", "vcf", "vcf.gz", "phy", "phylip", "genepop", "gen"] = getattr(
682
+ args, "format", "infer"
683
+ )
684
+
685
+ if fmt == "infer":
686
+ if input_path.endswith((".vcf", ".vcf.gz")):
687
+ fmt_final = "vcf"
688
+ elif input_path.endswith((".phy", ".phylip")):
689
+ fmt_final = "phylip"
690
+ elif input_path.endswith((".genepop", ".gen")):
691
+ fmt_final = "genepop"
692
+ else:
693
+ parser.error(
694
+ "Could not infer input format from file extension. Please provide --format."
695
+ )
696
+ return 2
697
+ else:
698
+ fmt_final = fmt
699
+
700
+ popmap_path = getattr(args, "popmap", None)
701
+ include_pops = getattr(args, "include_pops", None)
702
+ verbose_flag = getattr(args, "verbose", False)
703
+ force_popmap = bool(getattr(args, "force_popmap", False))
704
+
705
+ # Canonical prefix for this run (used for outputs and MultiQC)
706
+ prefix: str = getattr(args, "prefix", str(Path(input_path).stem))
707
+
708
+ treefile = getattr(args, "treefile", None)
709
+ qmatrix = getattr(args, "qmatrix", None)
710
+ siterates = getattr(args, "siterates", None)
711
+
712
+ if any(x is not None for x in (treefile, qmatrix, siterates)):
713
+ if not all(x is not None for x in (treefile, qmatrix, siterates)):
714
+ parser.error(
715
+ "--treefile, --qmatrix, and --siterates must all be provided together or they should all be omitted."
716
+ )
717
+ return 2
718
+
719
+ # Load genotype data
720
+ gd, tp = build_genotype_data(
721
+ input_path=input_path,
722
+ fmt=fmt_final,
723
+ popmap_path=popmap_path,
724
+ treefile=treefile,
725
+ qmatrix=qmatrix,
726
+ siterates=siterates,
727
+ force_popmap=force_popmap,
728
+ verbose=verbose_flag,
729
+ include_pops=include_pops,
730
+ plot_format=getattr(args, "plot_format", "pdf"),
731
+ )
732
+
733
+ if getattr(args, "dry_run", False):
734
+ logging.info("Dry run complete. Exiting without training models.")
735
+ return 0
736
+
737
+ # ---------------- Build config(s) per selected model ------------------- #
738
+ cfgs_by_model: Dict[str, Any] = {
739
+ m: _build_effective_config_for_model(m, args) for m in selected_models
740
+ }
741
+
742
+ # Maybe print/dump configs and exit
743
+ if _maybe_print_or_dump_configs(cfgs_by_model, args):
744
+ return 0
745
+
746
+ # ------------------------- Model Builders ------------------------------ #
747
+ def build_impute_ubp():
748
+ cfg = cfgs_by_model.get("ImputeUBP")
749
+ if cfg is None:
750
+ cfg = (
751
+ UBPConfig.from_preset(args.preset)
752
+ if hasattr(args, "preset")
753
+ else UBPConfig()
754
+ )
755
+ return ImputeUBP(
756
+ genotype_data=gd,
757
+ tree_parser=tp,
758
+ config=cfg,
759
+ simulate_missing=cfg.sim.simulate_missing,
760
+ sim_strategy=cfg.sim.sim_strategy,
761
+ sim_prop=cfg.sim.sim_prop,
762
+ sim_kwargs=cfg.sim.sim_kwargs,
763
+ )
764
+
765
+ def build_impute_nlpca():
766
+ cfg = cfgs_by_model.get("ImputeNLPCA")
767
+ if cfg is None:
768
+ cfg = (
769
+ NLPCAConfig.from_preset(args.preset)
770
+ if hasattr(args, "preset")
771
+ else NLPCAConfig()
772
+ )
773
+ return ImputeNLPCA(
774
+ genotype_data=gd,
775
+ tree_parser=tp,
776
+ config=cfg,
777
+ simulate_missing=cfg.sim.simulate_missing,
778
+ sim_strategy=cfg.sim.sim_strategy,
779
+ sim_prop=cfg.sim.sim_prop,
780
+ sim_kwargs=cfg.sim.sim_kwargs,
781
+ )
782
+
783
+ def build_impute_vae():
784
+ cfg = cfgs_by_model.get("ImputeVAE")
785
+ if cfg is None:
786
+ cfg = (
787
+ VAEConfig.from_preset(args.preset)
788
+ if hasattr(args, "preset")
789
+ else VAEConfig()
790
+ )
791
+ return ImputeVAE(
792
+ genotype_data=gd,
793
+ tree_parser=tp,
794
+ config=cfg,
795
+ simulate_missing=cfg.sim.simulate_missing,
796
+ sim_strategy=cfg.sim.sim_strategy,
797
+ sim_prop=cfg.sim.sim_prop,
798
+ sim_kwargs=cfg.sim.sim_kwargs,
799
+ )
800
+
801
+ def build_impute_autoencoder():
802
+ cfg = cfgs_by_model.get("ImputeAutoencoder")
803
+ if cfg is None:
804
+ cfg = (
805
+ AutoencoderConfig.from_preset(args.preset)
806
+ if hasattr(args, "preset")
807
+ else AutoencoderConfig()
808
+ )
809
+ return ImputeAutoencoder(
810
+ genotype_data=gd,
811
+ tree_parser=tp,
812
+ config=cfg,
813
+ simulate_missing=cfg.sim.simulate_missing,
814
+ sim_strategy=cfg.sim.sim_strategy,
815
+ sim_prop=cfg.sim.sim_prop,
816
+ sim_kwargs=cfg.sim.sim_kwargs,
817
+ )
818
+
819
+ def build_impute_mostfreq():
820
+ cfg = cfgs_by_model.get("ImputeMostFrequent")
821
+ if cfg is None:
822
+ cfg = (
823
+ MostFrequentConfig.from_preset(args.preset)
824
+ if hasattr(args, "preset")
825
+ else MostFrequentConfig()
826
+ )
827
+ return ImputeMostFrequent(
828
+ gd,
829
+ tree_parser=tp,
830
+ config=cfg,
831
+ simulate_missing=cfg.sim.simulate_missing,
832
+ sim_strategy=cfg.sim.sim_strategy,
833
+ sim_prop=cfg.sim.sim_prop,
834
+ sim_kwargs=cfg.sim.sim_kwargs,
835
+ )
836
+
837
+ def build_impute_refallele():
838
+ cfg = cfgs_by_model.get("ImputeRefAllele")
839
+ if cfg is None:
840
+ cfg = (
841
+ RefAlleleConfig.from_preset(args.preset)
842
+ if hasattr(args, "preset")
843
+ else RefAlleleConfig()
844
+ )
845
+ return ImputeRefAllele(
846
+ gd,
847
+ tree_parser=tp,
848
+ config=cfg,
849
+ simulate_missing=cfg.sim.simulate_missing,
850
+ sim_strategy=cfg.sim.sim_strategy,
851
+ sim_prop=cfg.sim.sim_prop,
852
+ sim_kwargs=cfg.sim.sim_kwargs,
853
+ )
854
+
855
+ model_builders = {
856
+ "ImputeUBP": build_impute_ubp,
857
+ "ImputeVAE": build_impute_vae,
858
+ "ImputeAutoencoder": build_impute_autoencoder,
859
+ "ImputeNLPCA": build_impute_nlpca,
860
+ "ImputeMostFrequent": build_impute_mostfreq,
861
+ "ImputeRefAllele": build_impute_refallele,
862
+ }
863
+
864
+ logging.info(f"Selected models: {', '.join(selected_models)}")
865
+ for name in selected_models:
866
+ X_imputed = run_model_safely(name, model_builders[name], warn_only=False)
867
+ gd_imp = gd.copy()
868
+ gd_imp.snp_data = X_imputed
869
+
870
+ if name in {"ImputeUBP", "ImputeVAE", "ImputeAutoencoder", "ImputeNLPCA"}:
871
+ family = "Unsupervised"
872
+ elif name in {"ImputeMostFrequent", "ImputeRefAllele"}:
873
+ family = "Deterministic"
874
+ elif name in {"ImputeHistGradientBoosting", "ImputeRandomForest"}:
875
+ family = "Supervised"
876
+ else:
877
+ raise ValueError(f"Unknown model family for {name}")
878
+
879
+ pth = Path(f"{prefix}_output/{family}/imputed/{name}")
880
+ pth.mkdir(parents=True, exist_ok=True)
881
+
882
+ logging.info(f"Writing imputed VCF for {name} to {pth} ...")
883
+
884
+ if fmt_final == "vcf":
885
+ gd_imp.write_vcf(pth / f"{name.lower()}_imputed.vcf.gz")
886
+ elif fmt_final == "phylip":
887
+ gd_imp.write_phylip(pth / f"{name.lower()}_imputed.phy")
888
+ elif fmt_final == "genepop":
889
+ gd_imp.write_genepop(pth / f"{name.lower()}_imputed.gen")
890
+ else:
891
+ logging.warning(
892
+ f"Output format {fmt_final} not supported for imputed data export."
893
+ )
894
+
895
+ logging.info("All requested models processed.")
896
+
897
+ # -------------------------- MultiQC builder ---------------------------- #
898
+
899
+ mqc_output_dir = getattr(args, "multiqc_output_dir", f"{prefix}_output/multiqc")
900
+ mqc_title = getattr(args, "multiqc_title", f"PG-SUI MultiQC Report - {prefix}")
901
+ overwrite = bool(getattr(args, "multiqc_overwrite", False))
902
+
903
+ logging.info(
904
+ f"Building MultiQC report in '{mqc_output_dir}' (title={mqc_title}, overwrite={overwrite})..."
905
+ )
906
+
907
+ try:
908
+ SNPioMultiQC.build(
909
+ prefix=prefix,
910
+ output_dir=mqc_output_dir,
911
+ title=mqc_title,
912
+ overwrite=overwrite,
913
+ )
914
+ logging.info("MultiQC report successfully built.")
915
+ except Exception as exc2: # pragma: no cover
916
+ logging.error(f"Failed to build MultiQC report: {exc2}", exc_info=True)
917
+
918
+ return 0
919
+
920
+
921
+ if __name__ == "__main__":
922
+ raise SystemExit(main())