pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  # Standard library imports
2
+ import copy
2
3
  import json
3
4
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
5
6
 
6
7
  # Third-party imports
7
8
  import matplotlib.pyplot as plt
@@ -11,14 +12,14 @@ from matplotlib.figure import Figure
11
12
  from plotly.graph_objs import Figure as PlotlyFigure
12
13
  from sklearn.exceptions import NotFittedError
13
14
  from sklearn.metrics import (
14
- accuracy_score,
15
+ average_precision_score,
15
16
  classification_report,
16
- f1_score,
17
- precision_score,
18
- recall_score,
17
+ jaccard_score,
18
+ matthews_corrcoef,
19
19
  )
20
20
  from snpio import GenotypeEncoder
21
21
  from snpio.utils.logging import LoggerManager
22
+ from snpio.utils.misc import validate_input_type
22
23
 
23
24
  from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
24
25
  from pgsui.data_processing.containers import MostFrequentConfig
@@ -54,6 +55,7 @@ def ensure_mostfrequent_config(
54
55
  if isinstance(config, str):
55
56
  return load_yaml_to_dataclass(config, MostFrequentConfig)
56
57
  if isinstance(config, dict):
58
+ config = copy.deepcopy(config) # copy
57
59
  base = MostFrequentConfig()
58
60
  # honor optional top-level 'preset'
59
61
  preset = config.pop("preset", None)
@@ -77,9 +79,9 @@ def ensure_mostfrequent_config(
77
79
 
78
80
 
79
81
  class ImputeMostFrequent:
80
- """Most-frequent (mode) imputer that mirrors DL evaluation on 0/1/2.
82
+ """Most-frequent (mode) deterministic imputer for 0/1/2 genotypes.
81
83
 
82
- This imputer computes the most frequent genotype (mode) for each locus based on the training set and uses it to fill in missing values. It supports both global modes and population-specific modes if population data is provided. The imputer follows an evaluation protocol similar to deep learning models, including splitting the data into training and testing sets, masking observed cells in the test set for evaluation, and producing detailed classification reports and plots. It handles both diploid and haploid data, with special considerations for haploid scenarios. The imputer is designed to work seamlessly with genotype data encoded in 0/1/2 format, where -1 indicates missing values.
84
+ Computes the per-locus mode (globally or per population) from the training set and uses it to fill missing values. The evaluation protocol mirrors the DL imputers: train/test split with evaluation on either all observed test cells or a simulated-missing subset (depending on config), plus classification reports and plots. It handles both diploid and haploid data. Input genotypes are expected in 0/1/2 encoding with missing values represented by any negative integer. Output is returned as IUPAC strings via ``decode_012``.
83
85
  """
84
86
 
85
87
  def __init__(
@@ -109,14 +111,14 @@ class ImputeMostFrequent:
109
111
  tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
110
112
  config (MostFrequentConfig | dict | str | None): Configuration as a dataclass,
111
113
  nested dict, or YAML path. If None, defaults are used.
112
- overrides (dict | None): Flat dot-key overrides applied last with highest precedence, e.g. {'algo.by_populations': True, 'split.test_size': 0.3}.
114
+ overrides (Optional[dict]): Flat dot-key overrides applied last with highest precedence, e.g. {'algo.by_populations': True, 'split.test_size': 0.3}.
113
115
  simulate_missing (bool): Whether to simulate missing data if enabled in config. Defaults to True.
114
- sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
115
- sim_prop (float): Proportion of data to simulate as missing if enabled in config.
116
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data if enabled in config.
117
+ sim_prop (float): Proportion of data to simulate as missing if enabled in config. Default is 0.2.
116
118
  sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
117
119
 
118
120
  Notes:
119
- - This mirrors other config-driven models (AE/VAE/NLPCA/UBP).
121
+ - This mirrors other config-driven models (AE/VAE).
120
122
  - Evaluation split behavior uses cfg.split; plotting uses cfg.plot.
121
123
  - I/O/logging seeds and verbosity use cfg.io.
122
124
  """
@@ -151,18 +153,19 @@ class ImputeMostFrequent:
151
153
  self.rng = np.random.default_rng(cfg.io.seed)
152
154
  self.encoder = GenotypeEncoder(self.genotype_data)
153
155
 
154
- # Work in 0/1/2 with -1 for missing (parity with DL modules)
155
- X012 = self.encoder.genotypes_012.astype(np.int16, copy=False)
156
+ self.missing_internal = -1
156
157
 
157
- # 2. In-place replacement of NaNs
158
- # NOTE: X012 will be consumed to make ground_truth_
159
- np.nan_to_num(X012, nan=-1.0, copy=False)
158
+ # include common missing value aliases
159
+ self.missing_aliases = {int(cfg.algo.missing), -9, -1}
160
160
 
161
- X012[X012 < 0] = -1
162
- self.X012_ = X012
163
- self.num_features_ = X012.shape[1]
161
+ X = np.asarray(self.encoder.genotypes_012)
162
+ Xf = X.astype(np.float32, copy=False)
163
+ Xf = np.where(np.isnan(Xf), -1.0, Xf)
164
+ Xf[Xf < 0] = -1.0
165
+ self.X012_ = Xf.astype(np.int8, copy=False)
166
+ self.num_features_ = self.X012_.shape[1]
164
167
 
165
- # Simulated-missing controls (mirror VAE/AE/NLPCA semantics where possible)
168
+ # Simulated-missing controls (mirror VAE/AE semantics where possible)
166
169
  sim_cfg = getattr(self.cfg, "sim", None)
167
170
  sim_cfg_kwargs = dict(getattr(sim_cfg, "sim_kwargs", {}) or {})
168
171
 
@@ -226,12 +229,11 @@ class ImputeMostFrequent:
226
229
  self.test_idx_: Optional[np.ndarray] = None
227
230
  self.X_train_df_: Optional[pd.DataFrame] = None
228
231
  self.ground_truth012_: Optional[np.ndarray] = None
229
- self.metrics_: Dict[str, int | float] = {}
230
232
  self.X_imputed012_: Optional[np.ndarray] = None
231
233
 
232
234
  # Ploidy heuristic for 0/1/2 scoring parity
233
- uniq = np.unique(self.X012_[self.X012_ != -1])
234
- self.is_haploid_ = np.array_equal(np.sort(uniq), np.array([0, 2]))
235
+ self.ploidy = self.cfg.io.ploidy
236
+ self.is_haploid_ = self.ploidy == 1
235
237
 
236
238
  # Plotting (use config, not genotype_data fields)
237
239
  self.plot_format = cfg.plot.fmt
@@ -243,6 +245,11 @@ class ImputeMostFrequent:
243
245
  self.model_name = (
244
246
  "ImputeMostFrequentPerPop" if self.by_populations else "ImputeMostFrequent"
245
247
  )
248
+
249
+ # Output dirs
250
+ dirs = ["models", "plots", "metrics", "optimize", "parameters"]
251
+ self._create_model_directories(self.prefix, dirs)
252
+
246
253
  self.plotter_ = Plotting(
247
254
  self.model_name,
248
255
  prefix=self.prefix,
@@ -258,10 +265,6 @@ class ImputeMostFrequent:
258
265
  multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
259
266
  )
260
267
 
261
- # Output dirs
262
- dirs = ["models", "plots", "metrics", "optimize", "parameters"]
263
- self._create_model_directories(self.prefix, dirs)
264
-
265
268
  if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
266
269
  msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
267
270
  self.logger.error(msg)
@@ -280,14 +283,21 @@ class ImputeMostFrequent:
280
283
 
281
284
  # Work in DataFrame with NaN as missing for mode computation
282
285
  df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
283
- df_all = df_all.replace(self.missing, np.nan)
284
- df_all = df_all.replace(-9, np.nan) # Just in case
286
+ df_all[df_all < 0] = np.nan
285
287
 
286
288
  # Modes from TRAIN rows only (per-locus)
287
289
  df_train = df_all.iloc[self.train_idx_].copy()
288
- self.global_modes_ = {
289
- col: self._series_mode(df_train[col]) for col in df_train.columns
290
- }
290
+
291
+ modes = {}
292
+ for col in df_train.columns:
293
+ s = df_train[col].dropna()
294
+ if s.empty:
295
+ modes[col] = self.default
296
+ else:
297
+ vc = s.value_counts()
298
+ # deterministic tie-break: smallest genotype among ties
299
+ modes[col] = int(vc.index[vc.to_numpy() == vc.to_numpy().max()].min())
300
+ self.global_modes_ = modes
291
301
 
292
302
  self.group_modes_.clear()
293
303
  if self.by_populations:
@@ -304,13 +314,14 @@ class ImputeMostFrequent:
304
314
  self.logger.error(msg)
305
315
  raise ValueError(msg)
306
316
 
307
- # ------------------------------
308
317
  # Simulated-missing mask (global → test-only)
309
- # ------------------------------
310
318
  obs_mask = df_all.notna().to_numpy() # observed = not NaN
311
- n_samples, n_loci = obs_mask.shape
319
+ n_samples = obs_mask.shape[0]
312
320
 
313
321
  if self.simulate_missing:
322
+ X_for_sim = self.ground_truth012_.astype(np.float32, copy=True)
323
+ X_for_sim[X_for_sim < 0] = -9.0
324
+
314
325
  # Use the same transformer as VAE
315
326
  tr = SimMissingTransformer(
316
327
  genotype_data=self.genotype_data,
@@ -322,11 +333,7 @@ class ImputeMostFrequent:
322
333
  verbose=self.verbose,
323
334
  **self.sim_kwargs,
324
335
  )
325
- # Fit on 0/1/2 with -1 for missing, like VAE
326
- X_for_sim = self.ground_truth012_.astype(float, copy=True)
327
- X_for_sim[X_for_sim < 0] = np.nan
328
336
  tr.fit(X_for_sim)
329
-
330
337
  sim_mask_global = tr.sim_missing_mask_.astype(bool)
331
338
 
332
339
  # Don't simulate on already-missing cells
@@ -359,7 +366,7 @@ class ImputeMostFrequent:
359
366
  self.X_train_df_ = df_sim
360
367
  self.is_fit_ = True
361
368
 
362
- # Save parameters (unchanged)
369
+ # Save parameters
363
370
  best_params = self.cfg.to_dict()
364
371
  params_fp = self.parameters_dir / "best_parameters.json"
365
372
  with open(params_fp, "w") as f:
@@ -389,11 +396,14 @@ class ImputeMostFrequent:
389
396
  msg = "Model is not fitted. Call fit() before transform()."
390
397
  self.logger.error(msg)
391
398
  raise NotFittedError(msg)
392
- assert self.X_train_df_ is not None
399
+
400
+ assert (
401
+ self.X_train_df_ is not None
402
+ ), f"[{self.model_name}] X_train_df_ is not set after fit()."
393
403
 
394
404
  # 1) Impute the evaluation-masked copy (to compute metrics)
395
405
  imputed_eval_df = self._impute_df(self.X_train_df_)
396
- X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int16)
406
+ X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int8)
397
407
  self.X_imputed012_ = X_imputed_eval
398
408
 
399
409
  # Evaluate like DL models (0/1/2, then 10-class from decoded strings)
@@ -401,22 +411,31 @@ class ImputeMostFrequent:
401
411
 
402
412
  # 2) Impute the FULL dataset (only true missings)
403
413
  df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
404
- df_missingonly.replace(self.missing, np.nan, inplace=True)
414
+ df_missingonly[df_missingonly < 0] = np.nan
415
+
405
416
  imputed_full_df = self._impute_df(df_missingonly)
406
- X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int16)
417
+ X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int8)
418
+
419
+ neg = int(np.count_nonzero(X_imputed_full_012 < 0))
420
+ if neg:
421
+ msg = f"{neg} negative entries remain after REF imputation. Unique: {np.unique(X_imputed_full_012[X_imputed_full_012 < 0])[:10]}"
422
+ self.logger.error(msg)
423
+ raise RuntimeError(msg)
407
424
 
408
425
  # Plot distributions (parity with DL transform())
409
426
  if self.ground_truth012_ is None:
410
- raise NotFittedError(
411
- "ground_truth012_ is not set; cannot plot distributions."
412
- )
427
+ msg = "ground_truth012_ is not set; cannot plot distributions."
428
+ self.logger.error(msg)
429
+ raise NotFittedError(msg)
430
+
431
+ imp_decoded = self.decode_012(X_imputed_full_012)
413
432
 
414
- gt_decoded = self.encoder.decode_012(self.ground_truth012_)
415
- imp_decoded = self.encoder.decode_012(X_imputed_full_012)
416
- self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
417
- self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
433
+ if self.show_plots:
434
+ gt_decoded = self.decode_012(self.ground_truth012_)
435
+ self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
436
+ self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
418
437
 
419
- # Return IUPAC strings (same as DL .transform())
438
+ # Return IUPAC strings
420
439
  return imp_decoded
421
440
 
422
441
  def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
@@ -452,7 +471,7 @@ class ImputeMostFrequent:
452
471
  df = df_in.fillna(modes)
453
472
  else:
454
473
  df = df_in.copy()
455
- return df.astype(np.int16)
474
+ return df.astype(np.int8)
456
475
 
457
476
  def _impute_by_population_mode(self, df_in: pd.DataFrame) -> pd.DataFrame:
458
477
  """Impute missing cells in df_in using population-specific modes.
@@ -466,7 +485,7 @@ class ImputeMostFrequent:
466
485
  pd.DataFrame: DataFrame with missing values imputed.
467
486
  """
468
487
  if not df_in.isnull().values.any():
469
- return df_in.astype(np.int16)
488
+ return df_in.astype(np.int8)
470
489
 
471
490
  df = df_in.copy()
472
491
  pops = pd.Series(self.pops, index=df.index)
@@ -489,7 +508,7 @@ class ImputeMostFrequent:
489
508
  mask = np.isnan(values)
490
509
  values[mask] = replacements[mask]
491
510
 
492
- return pd.DataFrame(values, columns=df.columns, index=df.index).astype(np.int16)
511
+ return pd.DataFrame(values, columns=df.columns, index=df.index).astype(np.int8)
493
512
 
494
513
  def _series_mode(self, s: pd.Series) -> int:
495
514
  """Compute the mode of a pandas Series, ignoring NaNs.
@@ -505,11 +524,13 @@ class ImputeMostFrequent:
505
524
  s_valid = s.dropna().astype(int)
506
525
  if s_valid.empty:
507
526
  return self.default
527
+
508
528
  # Mode among {0,1,2}; if ties, pandas picks the smallest (okay)
509
529
  mode_val = int(s_valid.mode().iloc[0])
510
530
  if mode_val not in (0, 1, 2):
511
531
  # Safety: clamp to valid zygosity in case of odd inputs
512
532
  mode_val = self.default if self.default in (0, 1, 2) else 0
533
+
513
534
  return mode_val
514
535
 
515
536
  def _evaluate_and_report(self) -> None:
@@ -540,8 +561,8 @@ class ImputeMostFrequent:
540
561
  X_pred_eval = self.ground_truth012_.copy()
541
562
  X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
542
563
 
543
- y_true_dec = self.encoder.decode_012(self.ground_truth012_)
544
- y_pred_dec = self.encoder.decode_012(X_pred_eval)
564
+ y_true_dec = self.decode_012(self.ground_truth012_)
565
+ y_pred_dec = self.decode_012(X_pred_eval)
545
566
 
546
567
  encodings_dict = {
547
568
  "A": 0,
@@ -565,43 +586,35 @@ class ImputeMostFrequent:
565
586
 
566
587
  y_true_10 = y_true_int[self.sim_mask_]
567
588
  y_pred_10 = y_pred_int[self.sim_mask_]
589
+
590
+ m = (y_true_10 >= 0) & (y_pred_10 >= 0)
591
+ y_true_10, y_pred_10 = y_true_10[m], y_pred_10[m]
592
+ if y_true_10.size == 0:
593
+ self.logger.warning(
594
+ "No valid IUPAC test cells; skipping 10-class evaluation."
595
+ )
596
+ return
597
+
568
598
  self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
569
599
 
570
600
  def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
571
601
  """0/1/2 zygosity report & confusion matrix.
572
602
 
573
- This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is determined to be haploid (only 0 and 2 present), it folds the ALT genotype (2) into HET (1) for evaluation purposes. The method computes various performance metrics, logs the classification report, and creates visualizations of the results.
603
+ This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is haploid (only 0 and 2 present), it folds ALT (2) into the binary ALT/PRESENT class (1) for evaluation. The method computes metrics, logs the report, and creates visualizations of the results.
574
604
 
575
605
  Args:
576
606
  y_true (np.ndarray): True genotypes (0/1/2) for masked
577
607
  y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
578
-
579
- Raises:
580
- NotFittedError: If fit() and transform() have not been called.
581
608
  """
582
- labels = [0, 1, 2]
609
+ labels: list[int] = [0, 1, 2]
610
+ report_names: list[str] = ["REF", "HET", "ALT"]
611
+
583
612
  # Haploid parity: fold ALT (2) into ALT/Present (1)
584
613
  if self.is_haploid_:
585
- y_true[y_true == 2] = 1
586
- y_pred[y_pred == 2] = 1
587
- labels = [0, 1]
588
-
589
- metrics = {
590
- "n_masked_test": int(y_true.size),
591
- "accuracy": accuracy_score(y_true, y_pred),
592
- "f1": f1_score(
593
- y_true, y_pred, average="macro", labels=labels, zero_division=0
594
- ),
595
- "precision": precision_score(
596
- y_true, y_pred, average="macro", labels=labels, zero_division=0
597
- ),
598
- "recall": recall_score(
599
- y_true, y_pred, average="macro", labels=labels, zero_division=0
600
- ),
601
- }
602
- self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
603
-
604
- report_names = ["REF", "HET"] if self.is_haploid_ else ["REF", "HET", "ALT"]
614
+ y_true = np.where(y_true == 2, 1, y_true)
615
+ y_pred = np.where(y_pred == 2, 1, y_pred)
616
+ labels: list[int] = [0, 1]
617
+ report_names: list[str] = ["REF", "ALT"]
605
618
 
606
619
  report: dict | str = classification_report(
607
620
  y_true,
@@ -617,50 +630,46 @@ class ImputeMostFrequent:
617
630
  self.logger.error(msg)
618
631
  raise TypeError(msg)
619
632
 
620
- report_subset = {}
621
- for k, v in report.items():
622
- tmp = {}
623
- if isinstance(v, dict) and "support" in v:
624
- for k2, v2 in v.items():
625
- if k2 != "support":
626
- tmp[k2] = v2
627
- if tmp:
628
- report_subset[k] = tmp
629
-
630
- if report_subset:
631
- pm = PrettyMetrics(
632
- report_subset,
633
- precision=3,
634
- title=f"{self.model_name} Zygosity Report",
633
+ if self.show_plots:
634
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
635
+
636
+ plots = viz.plot_all(
637
+ report,
638
+ title_prefix=f"{self.model_name} Zygosity Report",
639
+ show=self.show_plots,
640
+ heatmap_classes_only=True,
635
641
  )
636
- pm.render()
637
642
 
638
- viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
643
+ for name, fig in plots.items():
644
+ fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
645
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
646
+ fig.savefig(fout, dpi=300, facecolor="#111122")
647
+ plt.close(fig)
648
+ elif isinstance(fig, PlotlyFigure):
649
+ fig.write_html(file=fout.with_suffix(".html"))
639
650
 
640
- plots = viz.plot_all(
641
- report,
642
- title_prefix=f"{self.model_name} Zygosity Report",
643
- show=getattr(self, "show_plots", False),
644
- heatmap_classes_only=True,
645
- )
651
+ viz._reset_mpl_style()
646
652
 
647
- for name, fig in plots.items():
648
- fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
649
- if hasattr(fig, "savefig") and isinstance(fig, Figure):
650
- fig.savefig(fout, dpi=300, facecolor="#111122")
651
- plt.close(fig)
652
- elif isinstance(fig, PlotlyFigure):
653
- fig.write_html(file=fout.with_suffix(".html"))
653
+ # Confusion matrix
654
+ self.plotter_.plot_confusion_matrix(
655
+ y_true, y_pred, label_names=report_names, prefix="zygosity"
656
+ )
654
657
 
655
- viz._reset_mpl_style()
658
+ # ------ Additional metrics ------
659
+ report_full = self._additional_metrics(
660
+ y_true, y_pred, labels, report_names, report
661
+ )
656
662
 
657
- # Save JSON
658
- self._save_report(report, suffix="zygosity")
663
+ if self.verbose or self.debug:
664
+ pm = PrettyMetrics(
665
+ report_full,
666
+ precision=2,
667
+ title=f"{self.model_name} Zygosity Report",
668
+ )
669
+ pm.render()
659
670
 
660
- # Confusion matrix
661
- self.plotter_.plot_confusion_matrix(
662
- y_true, y_pred, label_names=report_names, prefix="zygosity"
663
- )
671
+ # Save JSON
672
+ self._save_report(report_full, suffix="zygosity")
664
673
 
665
674
  def _evaluate_iupac10_and_plot(
666
675
  self, y_true: np.ndarray, y_pred: np.ndarray
@@ -672,32 +681,18 @@ class ImputeMostFrequent:
672
681
  Args:
673
682
  y_true (np.ndarray): True genotypes (0-9) for masked
674
683
  y_pred (np.ndarray): Predicted genotypes (0-9) for masked
675
-
676
- Raises:
677
- NotFittedError: If fit() and transform() have not been called.
678
684
  """
679
685
  labels_idx = list(range(10))
680
- labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
681
-
682
- metrics = {
683
- "accuracy": accuracy_score(y_true, y_pred),
684
- "f1": f1_score(
685
- y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
686
- ),
687
- "precision": precision_score(
688
- y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
689
- ),
690
- "recall": recall_score(
691
- y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
692
- ),
693
- }
694
- self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
686
+ report_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
687
+
688
+ # Create an identity matrix and use the targets array as indices
689
+ y_score = np.eye(len(report_names))[y_pred]
695
690
 
696
691
  report: dict | str = classification_report(
697
692
  y_true,
698
693
  y_pred,
699
694
  labels=labels_idx,
700
- target_names=labels_names,
695
+ target_names=report_names,
701
696
  zero_division=0,
702
697
  output_dict=True,
703
698
  )
@@ -707,54 +702,50 @@ class ImputeMostFrequent:
707
702
  self.logger.error(msg)
708
703
  raise TypeError(msg)
709
704
 
710
- report_subset = {}
711
- for k, v in report.items():
712
- tmp = {}
713
- if isinstance(v, dict) and "support" in v:
714
- for k2, v2 in v.items():
715
- if k2 != "support":
716
- tmp[k2] = v2
717
- if tmp:
718
- report_subset[k] = tmp
719
-
720
- if report_subset:
721
- pm = PrettyMetrics(
722
- report_subset,
723
- precision=3,
724
- title=f"{self.model_name} IUPAC 10-Class Report",
705
+ if self.show_plots:
706
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
707
+
708
+ plots = viz.plot_all(
709
+ report,
710
+ title_prefix=f"{self.model_name} IUPAC Report",
711
+ show=self.show_plots,
712
+ heatmap_classes_only=True,
725
713
  )
726
- pm.render()
727
714
 
728
- viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
715
+ # Reset the style from Optuna's plotting.
716
+ plt.rcParams.update(self.plotter_.param_dict)
729
717
 
730
- plots = viz.plot_all(
731
- report,
732
- title_prefix=f"{self.model_name} IUPAC Report",
733
- show=getattr(self, "show_plots", False),
734
- heatmap_classes_only=True,
735
- )
718
+ for name, fig in plots.items():
719
+ fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
720
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
721
+ fig.savefig(fout, dpi=300, facecolor="#111122")
722
+ plt.close(fig)
723
+ elif isinstance(fig, PlotlyFigure):
724
+ fig.write_html(file=fout.with_suffix(".html"))
736
725
 
737
- # Reset the style from Optuna's plotting.
738
- plt.rcParams.update(self.plotter_.param_dict)
726
+ # Reset the style
727
+ viz._reset_mpl_style()
739
728
 
740
- for name, fig in plots.items():
741
- fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
742
- if hasattr(fig, "savefig") and isinstance(fig, Figure):
743
- fig.savefig(fout, dpi=300, facecolor="#111122")
744
- plt.close(fig)
745
- elif isinstance(fig, PlotlyFigure):
746
- fig.write_html(file=fout.with_suffix(".html"))
729
+ # Confusion matrix
730
+ self.plotter_.plot_confusion_matrix(
731
+ y_true, y_pred, label_names=report_names, prefix="iupac"
732
+ )
747
733
 
748
- # Reset the style
749
- viz._reset_mpl_style()
734
+ # ------ Additional metrics ------
735
+ report_full = self._additional_metrics(
736
+ y_true, y_pred, labels_idx, report_names, report
737
+ )
750
738
 
751
- # Save JSON
752
- self._save_report(report, suffix="iupac")
739
+ if self.verbose or self.debug:
740
+ pm = PrettyMetrics(
741
+ report_full,
742
+ precision=2,
743
+ title=f"{self.model_name} IUPAC 10-Class Report",
744
+ )
745
+ pm.render()
753
746
 
754
- # Confusion matrix
755
- self.plotter_.plot_confusion_matrix(
756
- y_true, y_pred, label_names=labels_names, prefix="iupac"
757
- )
747
+ # Save JSON
748
+ self._save_report(report_full, suffix="iupac")
758
749
 
759
750
  def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
760
751
  """Create train/test split indices.
@@ -780,14 +771,14 @@ class ImputeMostFrequent:
780
771
  buckets = []
781
772
  for pop in np.unique(self.pops):
782
773
  rows = np.where(self.pops == pop)[0]
783
- k = int(round(self.test_size * rows.size))
774
+ k = max(1, int(round(self.test_size * rows.size)))
784
775
  if k > 0:
785
776
  buckets.append(self.rng.choice(rows, size=k, replace=False))
786
777
  test_idx = (
787
778
  np.sort(np.concatenate(buckets)) if buckets else np.array([], dtype=int)
788
779
  )
789
780
  else:
790
- k = int(round(self.test_size * n))
781
+ k = max(1, int(round(self.test_size * n)))
791
782
  test_idx = (
792
783
  self.rng.choice(n, size=k, replace=False)
793
784
  if k > 0
@@ -797,13 +788,13 @@ class ImputeMostFrequent:
797
788
  train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
798
789
  return train_idx, test_idx
799
790
 
800
- def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
791
+ def _save_report(self, report_dict: Dict[str, Any], suffix: str) -> None:
801
792
  """Save classification report dictionary as a JSON file.
802
793
 
803
794
  This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
804
795
 
805
796
  Args:
806
- report_dict (Dict[str, float]): The classification report dictionary to save.
797
+ report_dict (Dict[str, Any]): The classification report dictionary to save.
807
798
  suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
808
799
 
809
800
  Raises:
@@ -842,3 +833,305 @@ class ImputeMostFrequent:
842
833
  msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
843
834
  self.logger.error(msg)
844
835
  raise Exception(msg)
836
+
837
+ def decode_012(
838
+ self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
839
+ ) -> np.ndarray:
840
+ """Decode 012-encodings to IUPAC chars with metadata repair.
841
+
842
+ This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
843
+ 1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
844
+ 2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
845
+
846
+ Args:
847
+ X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
848
+ is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
849
+
850
+ Returns:
851
+ np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
852
+
853
+ Notes:
854
+ - The method normalizes input values to handle various formats, including strings, lists, and arrays.
855
+ - It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
856
+ - Missing or invalid codes are represented as 'N' if they can't be resolved.
857
+ - The method includes repair logic to infer missing metadata from the source SNP data when necessary.
858
+
859
+ Raises:
860
+ ValueError: If input is not a DataFrame.
861
+ """
862
+ df = validate_input_type(X, return_type="df")
863
+
864
+ if not isinstance(df, pd.DataFrame):
865
+ msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
866
+ self.logger.error(msg)
867
+ raise ValueError(msg)
868
+
869
+ # IUPAC Definitions
870
+ iupac_to_bases: dict[str, set[str]] = {
871
+ "A": {"A"},
872
+ "C": {"C"},
873
+ "G": {"G"},
874
+ "T": {"T"},
875
+ "R": {"A", "G"},
876
+ "Y": {"C", "T"},
877
+ "S": {"G", "C"},
878
+ "W": {"A", "T"},
879
+ "K": {"G", "T"},
880
+ "M": {"A", "C"},
881
+ "B": {"C", "G", "T"},
882
+ "D": {"A", "G", "T"},
883
+ "H": {"A", "C", "T"},
884
+ "V": {"A", "C", "G"},
885
+ "N": set(),
886
+ }
887
+ bases_to_iupac = {
888
+ frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
889
+ }
890
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
891
+
892
+ def _normalize_iupac(value: object) -> str | None:
893
+ """Normalize an input into a single IUPAC code token or None."""
894
+ if value is None:
895
+ return None
896
+
897
+ # Bytes -> str (make type narrowing explicit)
898
+ if isinstance(value, (bytes, np.bytes_)):
899
+ value = bytes(value).decode("utf-8", errors="ignore")
900
+
901
+ # Handle list/tuple/array/Series: take first valid
902
+ if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
903
+ # Convert Series to numpy array for consistent behavior
904
+ if isinstance(value, pd.Series):
905
+ arr = value.to_numpy()
906
+ else:
907
+ arr = value
908
+
909
+ # Scalar numpy array fast path
910
+ if isinstance(arr, np.ndarray) and arr.ndim == 0:
911
+ return _normalize_iupac(arr.item())
912
+
913
+ # Empty sequence/array
914
+ if len(arr) == 0:
915
+ return None
916
+
917
+ # First valid element wins
918
+ for item in arr:
919
+ code = _normalize_iupac(item)
920
+ if code is not None:
921
+ return code
922
+ return None
923
+
924
+ s = str(value).upper().strip()
925
+ if not s or s in missing_codes:
926
+ return None
927
+
928
+ if "," in s:
929
+ for tok in (t.strip() for t in s.split(",")):
930
+ if tok and tok not in missing_codes and tok in iupac_to_bases:
931
+ return tok
932
+ return None
933
+
934
+ return s if s in iupac_to_bases else None
935
+
936
+ codes_df = df.apply(pd.to_numeric, errors="coerce")
937
+ codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
938
+ n_rows, n_cols = codes.shape
939
+
940
+ if is_nuc:
941
+ iupac_list = np.array(
942
+ ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
943
+ )
944
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
945
+ mask = (codes >= 0) & (codes <= 9)
946
+ out[mask] = iupac_list[codes[mask]]
947
+ return out
948
+
949
+ # Metadata fetch
950
+ ref_alleles = getattr(self.genotype_data, "ref", [])
951
+ alt_alleles = getattr(self.genotype_data, "alt", [])
952
+
953
+ if len(ref_alleles) != n_cols:
954
+ ref_alleles = getattr(self, "_ref", [None] * n_cols)
955
+ if len(alt_alleles) != n_cols:
956
+ alt_alleles = getattr(self, "_alt", [None] * n_cols)
957
+
958
+ # Ensure list length matches
959
+ if len(ref_alleles) != n_cols:
960
+ ref_alleles = [None] * n_cols
961
+ if len(alt_alleles) != n_cols:
962
+ alt_alleles = [None] * n_cols
963
+
964
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
965
+ source_snp_data = None
966
+
967
+ for j in range(n_cols):
968
+ ref = _normalize_iupac(ref_alleles[j])
969
+ alt = _normalize_iupac(alt_alleles[j])
970
+
971
+ # --- REPAIR LOGIC ---
972
+ # If metadata is missing, scan the source column.
973
+ if ref is None or alt is None:
974
+ if source_snp_data is None and self.genotype_data.snp_data is not None:
975
+ try:
976
+ source_snp_data = np.asarray(self.genotype_data.snp_data)
977
+ except Exception:
978
+ pass # if lazy loading fails
979
+
980
+ if source_snp_data is not None:
981
+ try:
982
+ col_data = source_snp_data[:, j]
983
+ uniques = set()
984
+ # Optimization: check up to 200 non-empty values
985
+ count = 0
986
+ for val in col_data:
987
+ norm = _normalize_iupac(val)
988
+ if norm:
989
+ uniques.add(norm)
990
+ count += 1
991
+ if len(uniques) >= 2 or count > 200:
992
+ break
993
+
994
+ sorted_u = sorted(list(uniques))
995
+ if len(sorted_u) >= 1 and ref is None:
996
+ ref = sorted_u[0]
997
+ if len(sorted_u) >= 2 and alt is None:
998
+ alt = sorted_u[1]
999
+ except Exception:
1000
+ pass
1001
+
1002
+ # --- DEFAULTS FOR MISSING ---
1003
+ # If still missing, we cannot decode.
1004
+ if ref is None and alt is None:
1005
+ ref = "N"
1006
+ alt = "N"
1007
+ elif ref is None:
1008
+ ref = alt
1009
+ elif alt is None:
1010
+ alt = ref # Monomorphic site: ALT becomes REF
1011
+
1012
+ # --- COMPUTE HET CODE ---
1013
+ if ref == alt:
1014
+ het_code = ref
1015
+ else:
1016
+ ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
1017
+ alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
1018
+ union_set = frozenset(ref_set | alt_set)
1019
+ het_code = bases_to_iupac.get(union_set, "N")
1020
+
1021
+ # --- ASSIGNMENT WITH SAFETY FALLBACKS ---
1022
+ col_codes = codes[:, j]
1023
+
1024
+ # Case 0: REF
1025
+ if ref != "N":
1026
+ out[col_codes == 0, j] = ref
1027
+
1028
+ # Case 1: HET
1029
+ if het_code != "N":
1030
+ out[col_codes == 1, j] = het_code
1031
+ else:
1032
+ # If HET code is invalid (e.g. ref='A', alt='N'),
1033
+ # fallback to REF
1034
+ # Fix for an issue where a HET prediction at a monomorphic site
1035
+ # produced 'N'
1036
+ if ref != "N":
1037
+ out[col_codes == 1, j] = ref
1038
+
1039
+ # Case 2: ALT
1040
+ if alt != "N":
1041
+ out[col_codes == 2, j] = alt
1042
+ else:
1043
+ # If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
1044
+ # Fix for an issue where an ALT prediction on a monomorphic site
1045
+ # produced 'N'
1046
+ if ref != "N":
1047
+ out[col_codes == 2, j] = ref
1048
+
1049
+ return out
1050
+
1051
+ def _additional_metrics(
1052
+ self,
1053
+ y_true: np.ndarray,
1054
+ y_pred: np.ndarray,
1055
+ labels: list[int],
1056
+ report_names: list[str],
1057
+ report: dict[str, dict[str, float] | float],
1058
+ ) -> dict[str, dict[str, float] | float]:
1059
+ """Compute additional metrics and augment the report dictionary.
1060
+
1061
+ Args:
1062
+ y_true (np.ndarray): True genotypes.
1063
+ y_pred (np.ndarray): Predicted genotypes.
1064
+ labels (list[int]): List of label indices.
1065
+ report_names (list[str]): List of report names corresponding to labels.
1066
+ report (dict[str, dict[str, float] | float]): Classification report dictionary to augment.
1067
+
1068
+ Returns:
1069
+ dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
1070
+ """
1071
+ # Create an identity matrix and use the targets array as indices
1072
+ y_score = np.eye(len(report_names))[y_pred]
1073
+
1074
+ # Per-class metrics
1075
+ ap_pc = average_precision_score(y_true, y_score, average=None)
1076
+ jaccard_pc = jaccard_score(
1077
+ y_true, y_pred, average=None, labels=labels, zero_division=0
1078
+ )
1079
+
1080
+ # Macro/weighted metrics
1081
+ ap_macro = average_precision_score(y_true, y_score, average="macro")
1082
+ ap_weighted = average_precision_score(y_true, y_score, average="weighted")
1083
+ jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
1084
+ jaccard_weighted = jaccard_score(
1085
+ y_true, y_pred, average="weighted", zero_division=0
1086
+ )
1087
+
1088
+ # Matthews correlation coefficient (MCC)
1089
+ mcc = matthews_corrcoef(y_true, y_pred)
1090
+
1091
+ if not isinstance(ap_pc, np.ndarray):
1092
+ msg = "average_precision_score or f1_score did not return np.ndarray as expected."
1093
+ self.logger.error(msg)
1094
+ raise TypeError(msg)
1095
+
1096
+ if not isinstance(jaccard_pc, np.ndarray):
1097
+ msg = "jaccard_score did not return np.ndarray as expected."
1098
+ self.logger.error(msg)
1099
+ raise TypeError(msg)
1100
+
1101
+ # Add per-class metrics
1102
+ report_full = {}
1103
+ dd_subset = {
1104
+ k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
1105
+ }
1106
+ for i, class_name in enumerate(report_names):
1107
+ class_report: dict[str, float] = {}
1108
+ if class_name in dd_subset:
1109
+ class_report = dd_subset[class_name]
1110
+
1111
+ if isinstance(class_report, float) or not class_report:
1112
+ continue
1113
+
1114
+ report_full[class_name] = dict(class_report)
1115
+ report_full[class_name]["average-precision"] = float(ap_pc[i])
1116
+ report_full[class_name]["jaccard"] = float(jaccard_pc[i])
1117
+
1118
+ macro_avg = report.get("macro avg")
1119
+ if isinstance(macro_avg, dict):
1120
+ report_full["macro avg"] = dict(macro_avg)
1121
+ report_full["macro avg"]["average-precision"] = float(ap_macro)
1122
+ report_full["macro avg"]["jaccard"] = float(jaccard_macro)
1123
+
1124
+ weighted_avg = report.get("weighted avg")
1125
+ if isinstance(weighted_avg, dict):
1126
+ report_full["weighted avg"] = dict(weighted_avg)
1127
+ report_full["weighted avg"]["average-precision"] = float(ap_weighted)
1128
+ report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
1129
+
1130
+ # Add scalar summary metrics
1131
+ report_full["mcc"] = float(mcc)
1132
+ accuracy_val = report.get("accuracy")
1133
+
1134
+ if isinstance(accuracy_val, (int, float)):
1135
+ report_full["accuracy"] = float(accuracy_val)
1136
+
1137
+ return report_full