pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  # Standard library
2
+ import copy
2
3
  import json
3
4
  from pathlib import Path
4
5
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
@@ -11,16 +12,16 @@ from matplotlib.figure import Figure
11
12
  from plotly.graph_objs import Figure as PlotlyFigure
12
13
  from sklearn.exceptions import NotFittedError
13
14
  from sklearn.metrics import (
14
- accuracy_score,
15
+ average_precision_score,
15
16
  classification_report,
16
- f1_score,
17
- precision_score,
18
- recall_score,
17
+ jaccard_score,
18
+ matthews_corrcoef,
19
19
  )
20
20
 
21
21
  # Project
22
22
  from snpio import GenotypeEncoder
23
23
  from snpio.utils.logging import LoggerManager
24
+ from snpio.utils.misc import validate_input_type
24
25
 
25
26
  from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
26
27
  from pgsui.data_processing.containers import RefAlleleConfig
@@ -58,6 +59,7 @@ def ensure_refallele_config(
58
59
  if isinstance(config, str):
59
60
  return load_yaml_to_dataclass(config, RefAlleleConfig)
60
61
  if isinstance(config, dict):
62
+ config = copy.deepcopy(config) # copy
61
63
  base = RefAlleleConfig()
62
64
  # honor optional top-level 'preset'
63
65
  preset = config.pop("preset", None)
@@ -82,9 +84,9 @@ def ensure_refallele_config(
82
84
 
83
85
 
84
86
  class ImputeRefAllele:
85
- """Deterministic imputer that replaces all missing 0/1/2 genotype values with the REF genotype (0).
87
+ """Deterministic imputer that fills missing genotypes with REF (0).
86
88
 
87
- The imputer works on 0/1/2 with -1 as missing. Evaluation splits samples into TRAIN/TEST once. Masks ALL originally observed cells on TEST rows for eval. Produces: 0/1/2 (zygosity) classification report + confusion matrix 10-class IUPAC classification report (via decode_012) + confusion matrix. Plots genotype distribution before/after imputation.
89
+ Operates on 0/1/2 encodings with missing values represented by any negative integer. Evaluation splits samples into TRAIN/TEST once, then evaluates on either all observed test cells or a simulated-missing subset (depending on config). Produces 0/1/2 (zygosity) and 10-class IUPAC reports plus confusion matrices, and plots genotype distributions before/after imputation. Output is returned as IUPAC strings via ``decode_012``.
88
90
  """
89
91
 
90
92
  def __init__(
@@ -107,16 +109,16 @@ class ImputeRefAllele:
107
109
  ) -> None:
108
110
  """Initialize the Ref-Allele imputer from a unified config.
109
111
 
110
- This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and various parameters based on the configuration. The imputer is prepared to handle population-specific modes if specified in the configuration.
112
+ This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and simulated-missing controls.
111
113
 
112
114
  Args:
113
115
  genotype_data (GenotypeData): Backing genotype data.
114
- tree_parser (Optional[TreeParser]): Optional SNPio phylogenetic tree parser for population-specific modes.
116
+ tree_parser (Optional[TreeParser]): Optional SNPio tree parser for nonrandom simulated-missing modes.
115
117
  config (RefAlleleConfig | dict | str | None): Configuration as a dataclass, nested dict, or YAML path. If None, defaults are used.
116
- overrides (dict | None): Flat dot-key overrides applied last with highest precedence, e.g. {'split.test_size': 0.25, 'algo.missing': -1}.
118
+ overrides (Optional[dict]): Flat dot-key overrides applied last with highest precedence, e.g. {'split.test_size': 0.25, 'algo.missing': -1}.
117
119
  simulate_missing (bool): Whether to simulate missing data during evaluation. Default is True.
118
- sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
119
- sim_prop (float): Proportion of data to simulate as missing if enabled in config.
120
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data if enabled in config.
121
+ sim_prop (float): Proportion of data to simulate as missing if enabled in config. Default is 0.2.
120
122
  sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
121
123
  """
122
124
  # Normalize config then apply highest-precedence overrides
@@ -153,7 +155,7 @@ class ImputeRefAllele:
153
155
  self.plots_dir: Path
154
156
  self.metrics_dir: Path
155
157
  self.parameters_dir: Path
156
- self.model_dir: Path
158
+ self.models_dir: Path
157
159
  self.optimize_dir: Path
158
160
 
159
161
  # Logger
@@ -174,7 +176,7 @@ class ImputeRefAllele:
174
176
  self.encoder = GenotypeEncoder(self.genotype_data)
175
177
 
176
178
  # Work in 0/1/2 with -1 for missing
177
- X012 = self.encoder.genotypes_012.astype(np.int16, copy=True)
179
+ X012 = self.encoder.genotypes_012.astype(np.int8, copy=True)
178
180
  X012[X012 < 0] = -1
179
181
  self.X012_ = X012
180
182
  self.num_features_ = X012.shape[1]
@@ -199,8 +201,8 @@ class ImputeRefAllele:
199
201
  self.metrics_: Dict[str, int | float] = {}
200
202
 
201
203
  # Ploidy heuristic for 0/1/2 scoring parity
202
- uniq = np.unique(self.X012_[self.X012_ != -1])
203
- self.is_haploid_ = np.array_equal(np.sort(uniq), np.array([0, 2]))
204
+ self.ploidy = self.cfg.io.ploidy
205
+ self.is_haploid_ = self.ploidy == 1
204
206
 
205
207
  # Plotting (use config)
206
208
  self.plot_format = cfg.plot.fmt
@@ -243,8 +245,7 @@ class ImputeRefAllele:
243
245
 
244
246
  # Use NaN for missing inside a DataFrame to leverage fillna
245
247
  df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
246
- df_all = df_all.replace(self.missing, np.nan)
247
- df_all = df_all.replace(-9, np.nan) # Just in case
248
+ df_all[df_all < 0] = np.nan
248
249
 
249
250
  # Observed mask in the ORIGINAL data (before any simulated-missing)
250
251
  obs_mask = df_all.notna().to_numpy() # shape (n_samples, n_loci)
@@ -256,6 +257,9 @@ class ImputeRefAllele:
256
257
 
257
258
  # Decide how to build the sim mask: legacy vs simulated-missing
258
259
  if getattr(self, "simulate_missing", False):
260
+ X_for_sim = self.ground_truth012_.astype(np.float32, copy=True)
261
+ X_for_sim[X_for_sim < 0] = -9.0
262
+
259
263
  # Simulate missing on the full matrix; we only use the mask.
260
264
  tr = SimMissingTransformer(
261
265
  genotype_data=self.genotype_data,
@@ -267,7 +271,7 @@ class ImputeRefAllele:
267
271
  verbose=self.verbose,
268
272
  **(self.sim_kwargs or {}),
269
273
  )
270
- tr.fit(self.ground_truth012_.copy())
274
+ tr.fit(X_for_sim)
271
275
  sim_mask_global = tr.sim_missing_mask_.astype(bool)
272
276
 
273
277
  # Only consider cells that were originally observed
@@ -317,12 +321,17 @@ class ImputeRefAllele:
317
321
  NotFittedError: If the model has not been fitted yet.
318
322
  """
319
323
  if not self.is_fit_:
320
- raise NotFittedError("Model is not fitted. Call fit() before transform().")
321
- assert self.X_train_df_ is not None
324
+ msg = "ImputeRefAllele instance is not fitted yet. Call 'fit()' before 'transform()'."
325
+ self.logger.error(msg)
326
+ raise NotFittedError(msg)
327
+
328
+ assert (
329
+ self.X_train_df_ is not None
330
+ ), f"[{self.model_name}] X_train_df_ is not set after fit()."
322
331
 
323
332
  # 1) Impute the evaluation-masked copy (compute metrics)
324
333
  imputed_eval_df = self._impute_ref(df_in=self.X_train_df_)
325
- X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int16)
334
+ X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int8)
326
335
  self.X_imputed012_ = X_imputed_eval
327
336
 
328
337
  # Evaluate parity with DL models
@@ -330,23 +339,24 @@ class ImputeRefAllele:
330
339
 
331
340
  # 2) Impute the FULL dataset (only true missings)
332
341
  df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
333
- df_missingonly = df_missingonly.replace(self.missing, np.nan)
334
- df_missingonly = df_missingonly.replace(-9, np.nan) # Just in case
342
+ df_missingonly[df_missingonly < 0] = np.nan
335
343
 
336
344
  imputed_full_df = self._impute_ref(df_in=df_missingonly)
337
- X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int16)
345
+ X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int8)
338
346
 
339
347
  # Plot distributions (like DL .transform())
340
348
 
341
349
  if self.ground_truth012_ is None:
342
- msg = "ground_truth012_ is None; cannot plot distributions."
343
- self.logger.error(msg)
350
+ msg = "ground_truth012_ is NoneType; cannot plot distributions."
351
+ self.logger.error(msg, exc_info=True)
352
+ raise NotFittedError(msg)
353
+
354
+ imp_decoded = self.decode_012(X_imputed_full_012)
344
355
 
345
- raise NotFittedError("ground_truth012_ is None; cannot plot distributions.")
346
- gt_decoded = self.encoder.decode_012(self.ground_truth012_)
347
- imp_decoded = self.encoder.decode_012(X_imputed_full_012)
348
- self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
349
- self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
356
+ if self.show_plots:
357
+ gt_decoded = self.decode_012(self.ground_truth012_)
358
+ self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
359
+ self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
350
360
 
351
361
  # Return IUPAC strings
352
362
  return imp_decoded
@@ -365,7 +375,7 @@ class ImputeRefAllele:
365
375
  df = df_in.copy()
366
376
  # Fill all NaNs with 0 (homozygous REF) column-wise; constant so vectorized is fine
367
377
  df = df.fillna(0)
368
- return df.astype(np.int16)
378
+ return df.astype(np.int8)
369
379
 
370
380
  def _evaluate_and_report(self) -> None:
371
381
  """Evaluate imputed vs. ground truth on masked test cells; produce reports and plots.
@@ -394,8 +404,8 @@ class ImputeRefAllele:
394
404
  X_pred_eval = self.ground_truth012_.copy()
395
405
  X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
396
406
 
397
- y_true_dec = self.encoder.decode_012(self.ground_truth012_)
398
- y_pred_dec = self.encoder.decode_012(X_pred_eval)
407
+ y_true_dec = self.decode_012(self.ground_truth012_)
408
+ y_pred_dec = self.decode_012(X_pred_eval)
399
409
 
400
410
  encodings_dict = {
401
411
  "A": 0,
@@ -418,43 +428,37 @@ class ImputeRefAllele:
418
428
  )
419
429
  y_true_10 = y_true_int[self.sim_mask_]
420
430
  y_pred_10 = y_pred_int[self.sim_mask_]
431
+
432
+ m = (y_true_10 >= 0) & (y_pred_10 >= 0)
433
+ y_true_10, y_pred_10 = y_true_10[m], y_pred_10[m]
434
+ if y_true_10.size == 0:
435
+ self.logger.warning(
436
+ "No valid IUPAC test cells; skipping 10-class evaluation."
437
+ )
438
+ return
439
+
421
440
  self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
422
441
 
423
442
  def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
424
443
  """0/1/2 zygosity report & confusion matrix.
425
444
 
426
- This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is determined to be haploid (only 0 and 2 present), it folds the ALT genotype (2) into HET (1) for evaluation purposes. The method computes various performance metrics, logs the classification report, and creates visualizations of the results.
445
+ This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is haploid (only 0 and 2 present), it folds ALT (2) into the binary ALT/PRESENT class (1) for evaluation. The method computes metrics, logs the report, and creates visualizations of the results.
427
446
 
428
447
  Args:
429
448
  y_true (np.ndarray): True genotypes (0/1/2) for masked
430
- y_pred (np.ndarray): Predicted genotypes (0/1/2) for
449
+ y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
431
450
  """
432
- labels = [0, 1, 2]
433
- report_names = ["REF", "HET", "ALT"]
451
+ labels: list[int] = [0, 1, 2]
452
+ report_names: list[str] = ["REF", "HET", "ALT"]
434
453
 
435
- # Haploid parity: fold 2 -> 1
454
+ # Haploid parity: fold ALT (2) into ALT/Present (1)
436
455
  if self.is_haploid_:
437
- y_true[y_true == 2] = 1
438
- y_pred[y_pred == 2] = 1
439
- labels = [0, 1]
440
- report_names = ["REF", "ALT"]
441
-
442
- metrics = {
443
- "n_masked_test": int(y_true.size),
444
- "accuracy": accuracy_score(y_true, y_pred),
445
- "f1": f1_score(
446
- y_true, y_pred, average="weighted", labels=labels, zero_division=0
447
- ),
448
- "precision": precision_score(
449
- y_true, y_pred, average="weighted", labels=labels, zero_division=0
450
- ),
451
- "recall": recall_score(
452
- y_true, y_pred, average="weighted", labels=labels, zero_division=0
453
- ),
454
- }
455
- self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
456
+ y_true = np.where(y_true == 2, 1, y_true)
457
+ y_pred = np.where(y_pred == 2, 1, y_pred)
458
+ labels: list[int] = [0, 1]
459
+ report_names: list[str] = ["REF", "ALT"]
456
460
 
457
- report: str | dict = classification_report(
461
+ report: dict | str = classification_report(
458
462
  y_true,
459
463
  y_pred,
460
464
  labels=labels,
@@ -468,91 +472,69 @@ class ImputeRefAllele:
468
472
  self.logger.error(msg)
469
473
  raise TypeError(msg)
470
474
 
471
- report_subset = {}
472
- for k, v in report.items():
473
- tmp = {}
474
- if isinstance(v, dict) and "support" in v:
475
- for k2, v2 in v.items():
476
- if k2 != "support":
477
- tmp[k2] = v2
478
- if tmp:
479
- report_subset[k] = tmp
480
-
481
- if report_subset:
482
- pm = PrettyMetrics(
483
- report_subset,
484
- precision=3,
485
- title=f"{self.model_name} Zygosity Report",
486
- )
487
- pm.render()
488
-
489
- viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
475
+ if self.show_plots:
476
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
490
477
 
491
- if not isinstance(report, dict):
492
- msg = "classification_report did not return a dict as expected."
493
- self.logger.error(msg)
494
- raise TypeError(msg)
478
+ plots = viz.plot_all(
479
+ report,
480
+ title_prefix=f"{self.model_name} Zygosity Report",
481
+ show=self.show_plots,
482
+ heatmap_classes_only=True,
483
+ )
495
484
 
496
- plots = viz.plot_all(
497
- report,
498
- title_prefix=f"{self.model_name} Zygosity Report",
499
- show=getattr(self, "show_plots", False),
500
- heatmap_classes_only=True,
501
- )
485
+ for name, fig in plots.items():
486
+ fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
487
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
488
+ fig.savefig(fout, dpi=300, facecolor="#111122")
489
+ plt.close(fig)
490
+ elif isinstance(fig, PlotlyFigure):
491
+ fig.write_html(file=fout.with_suffix(".html"))
502
492
 
503
- # Reset the style from Optuna's plotting.
504
- plt.rcParams.update(self.plotter_.param_dict)
493
+ viz._reset_mpl_style()
505
494
 
506
- for name, fig in plots.items():
507
- fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
508
- if hasattr(fig, "savefig") and isinstance(fig, Figure):
509
- fig.savefig(fout, dpi=300, facecolor="#111122")
510
- plt.close(fig)
511
- elif isinstance(fig, PlotlyFigure):
512
- fig.write_html(file=fout.with_suffix(".html"))
495
+ # Confusion matrix
496
+ self.plotter_.plot_confusion_matrix(
497
+ y_true, y_pred, label_names=report_names, prefix="zygosity"
498
+ )
513
499
 
514
- viz._reset_mpl_style()
500
+ # ------ Additional metrics ------
501
+ report_full = self._additional_metrics(
502
+ y_true, y_pred, labels, report_names, report
503
+ )
515
504
 
516
- self._save_report(report, suffix="zygosity")
505
+ if self.verbose or self.debug:
506
+ pm = PrettyMetrics(
507
+ report_full,
508
+ precision=2,
509
+ title=f"{self.model_name} Zygosity Report",
510
+ )
511
+ pm.render()
517
512
 
518
- # Confusion matrix
519
- self.plotter_.plot_confusion_matrix(
520
- y_true, y_pred, label_names=report_names, prefix="zygosity"
521
- )
513
+ # Save JSON
514
+ self._save_report(report_full, suffix="zygosity")
522
515
 
523
516
  def _evaluate_iupac10_and_plot(
524
517
  self, y_true: np.ndarray, y_pred: np.ndarray
525
518
  ) -> None:
526
519
  """10-class IUPAC report & confusion matrix.
527
520
 
528
- This method generates a classification report and confusion matrix for genotypes encoded using the 10 IUPAC codes (0-9). The IUPAC codes represent various nucleotide combinations, including ambiguous bases.
521
+ This method generates a classification report and confusion matrix for genotypes encoded as 10-class IUPAC codes (0-9). It computes various performance metrics, logs the classification report, and creates visualizations of the results.
529
522
 
530
523
  Args:
531
- y_true (np.ndarray): True genotypes (0-9) for masked test cells.
532
- y_pred (np.ndarray): Predicted genotypes (0-9) for masked test cells.
524
+ y_true (np.ndarray): True genotypes (0-9) for masked
525
+ y_pred (np.ndarray): Predicted genotypes (0-9) for masked
533
526
  """
534
527
  labels_idx = list(range(10))
535
- labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
536
-
537
- metrics = {
538
- "accuracy": accuracy_score(y_true, y_pred),
539
- "f1": f1_score(
540
- y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
541
- ),
542
- "precision": precision_score(
543
- y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
544
- ),
545
- "recall": recall_score(
546
- y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
547
- ),
548
- }
549
- self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
528
+ report_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
550
529
 
551
- report = classification_report(
530
+ # Create an identity matrix and use the targets array as indices
531
+ y_score = np.eye(len(report_names))[y_pred]
532
+
533
+ report: dict | str = classification_report(
552
534
  y_true,
553
535
  y_pred,
554
536
  labels=labels_idx,
555
- target_names=labels_names,
537
+ target_names=report_names,
556
538
  zero_division=0,
557
539
  output_dict=True,
558
540
  )
@@ -562,30 +544,50 @@ class ImputeRefAllele:
562
544
  self.logger.error(msg)
563
545
  raise TypeError(msg)
564
546
 
565
- report_subset = {}
566
- for k, v in report.items():
567
- tmp = {}
568
- if isinstance(v, dict) and "support" in v:
569
- for k2, v2 in v.items():
570
- if k2 != "support":
571
- tmp[k2] = v2
572
- if tmp:
573
- report_subset[k] = tmp
574
-
575
- if report_subset:
547
+ if self.show_plots:
548
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
549
+
550
+ plots = viz.plot_all(
551
+ report,
552
+ title_prefix=f"{self.model_name} IUPAC Report",
553
+ show=self.show_plots,
554
+ heatmap_classes_only=True,
555
+ )
556
+
557
+ # Reset the style from Optuna's plotting.
558
+ plt.rcParams.update(self.plotter_.param_dict)
559
+
560
+ for name, fig in plots.items():
561
+ fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
562
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
563
+ fig.savefig(fout, dpi=300, facecolor="#111122")
564
+ plt.close(fig)
565
+ elif isinstance(fig, PlotlyFigure):
566
+ fig.write_html(file=fout.with_suffix(".html"))
567
+
568
+ # Reset the style
569
+ viz._reset_mpl_style()
570
+
571
+ # Confusion matrix
572
+ self.plotter_.plot_confusion_matrix(
573
+ y_true, y_pred, label_names=report_names, prefix="iupac"
574
+ )
575
+
576
+ # ------ Additional metrics ------
577
+ report_full = self._additional_metrics(
578
+ y_true, y_pred, labels_idx, report_names, report
579
+ )
580
+
581
+ if self.verbose or self.debug:
576
582
  pm = PrettyMetrics(
577
- report_subset,
578
- precision=3,
583
+ report_full,
584
+ precision=2,
579
585
  title=f"{self.model_name} IUPAC 10-Class Report",
580
586
  )
581
587
  pm.render()
582
588
 
583
- self._save_report(report, suffix="iupac")
584
-
585
- # Confusion matrix
586
- self.plotter_.plot_confusion_matrix(
587
- y_true, y_pred, label_names=labels_names, prefix="iupac"
588
- )
589
+ # Save JSON
590
+ self._save_report(report_full, suffix="iupac")
589
591
 
590
592
  def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
591
593
  """Create train/test split indices.
@@ -623,25 +625,28 @@ class ImputeRefAllele:
623
625
  train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
624
626
  return train_idx, test_idx
625
627
 
626
- def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
628
+ def _save_report(self, report_dict: Dict[str, Any], suffix: str) -> None:
627
629
  """Save classification report dictionary as a JSON file.
628
630
 
629
- This method saves the provided classification report dictionary to a JSON file in the metrics directory. The filename includes a suffix to distinguish between different types of reports (e.g., 'zygosity' or 'iupac').
631
+ This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
630
632
 
631
633
  Args:
632
- report_dict (Dict[str, float]): The classification report dictionary to save.
634
+ report_dict (Dict[str, Any]): The classification report dictionary to save.
633
635
  suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
634
636
 
635
637
  Raises:
636
638
  NotFittedError: If fit() and transform() have not been called.
637
639
  """
638
640
  if not self.is_fit_ or self.X_imputed012_ is None:
639
- raise NotFittedError("No report to save. Ensure fit() and transform() ran.")
641
+ msg = "No report to save. Ensure fit() and transform() have been called."
642
+ raise NotFittedError(msg)
640
643
 
641
644
  out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
642
645
  with open(out_fp, "w") as f:
643
646
  json.dump(report_dict, f, indent=4)
644
- self.logger.info(f"{self.model_name} {suffix} report saved to {out_fp}.")
647
+
648
+ msg = f"{self.model_name} {suffix} report saved to {out_fp}."
649
+ self.logger.info(msg)
645
650
 
646
651
  def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
647
652
  """Creates the directory structure for storing model outputs.
@@ -667,3 +672,305 @@ class ImputeRefAllele:
667
672
  msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
668
673
  self.logger.error(msg)
669
674
  raise Exception(msg)
675
+
676
+ def decode_012(
677
+ self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
678
+ ) -> np.ndarray:
679
+ """Decode 012-encodings to IUPAC chars with metadata repair.
680
+
681
+ This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
682
+ 1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
683
+ 2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
684
+
685
+ Args:
686
+ X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
687
+ is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
688
+
689
+ Returns:
690
+ np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
691
+
692
+ Notes:
693
+ - The method normalizes input values to handle various formats, including strings, lists, and arrays.
694
+ - It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
695
+ - Missing or invalid codes are represented as 'N' if they can't be resolved.
696
+ - The method includes repair logic to infer missing metadata from the source SNP data when necessary.
697
+
698
+ Raises:
699
+ ValueError: If input is not a DataFrame.
700
+ """
701
+ df = validate_input_type(X, return_type="df")
702
+
703
+ if not isinstance(df, pd.DataFrame):
704
+ msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
705
+ self.logger.error(msg)
706
+ raise ValueError(msg)
707
+
708
+ # IUPAC Definitions
709
+ iupac_to_bases: dict[str, set[str]] = {
710
+ "A": {"A"},
711
+ "C": {"C"},
712
+ "G": {"G"},
713
+ "T": {"T"},
714
+ "R": {"A", "G"},
715
+ "Y": {"C", "T"},
716
+ "S": {"G", "C"},
717
+ "W": {"A", "T"},
718
+ "K": {"G", "T"},
719
+ "M": {"A", "C"},
720
+ "B": {"C", "G", "T"},
721
+ "D": {"A", "G", "T"},
722
+ "H": {"A", "C", "T"},
723
+ "V": {"A", "C", "G"},
724
+ "N": set(),
725
+ }
726
+ bases_to_iupac = {
727
+ frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
728
+ }
729
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
730
+
731
+ def _normalize_iupac(value: object) -> str | None:
732
+ """Normalize an input into a single IUPAC code token or None."""
733
+ if value is None:
734
+ return None
735
+
736
+ # Bytes -> str (make type narrowing explicit)
737
+ if isinstance(value, (bytes, np.bytes_)):
738
+ value = bytes(value).decode("utf-8", errors="ignore")
739
+
740
+ # Handle list/tuple/array/Series: take first valid
741
+ if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
742
+ # Convert Series to numpy array for consistent behavior
743
+ if isinstance(value, pd.Series):
744
+ arr = value.to_numpy()
745
+ else:
746
+ arr = value
747
+
748
+ # Scalar numpy array fast path
749
+ if isinstance(arr, np.ndarray) and arr.ndim == 0:
750
+ return _normalize_iupac(arr.item())
751
+
752
+ # Empty sequence/array
753
+ if len(arr) == 0:
754
+ return None
755
+
756
+ # First valid element wins
757
+ for item in arr:
758
+ code = _normalize_iupac(item)
759
+ if code is not None:
760
+ return code
761
+ return None
762
+
763
+ s = str(value).upper().strip()
764
+ if not s or s in missing_codes:
765
+ return None
766
+
767
+ if "," in s:
768
+ for tok in (t.strip() for t in s.split(",")):
769
+ if tok and tok not in missing_codes and tok in iupac_to_bases:
770
+ return tok
771
+ return None
772
+
773
+ return s if s in iupac_to_bases else None
774
+
775
+ codes_df = df.apply(pd.to_numeric, errors="coerce")
776
+ codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
777
+ n_rows, n_cols = codes.shape
778
+
779
+ if is_nuc:
780
+ iupac_list = np.array(
781
+ ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
782
+ )
783
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
784
+ mask = (codes >= 0) & (codes <= 9)
785
+ out[mask] = iupac_list[codes[mask]]
786
+ return out
787
+
788
+ # Metadata fetch
789
+ ref_alleles = getattr(self.genotype_data, "ref", [])
790
+ alt_alleles = getattr(self.genotype_data, "alt", [])
791
+
792
+ if len(ref_alleles) != n_cols:
793
+ ref_alleles = getattr(self, "_ref", [None] * n_cols)
794
+ if len(alt_alleles) != n_cols:
795
+ alt_alleles = getattr(self, "_alt", [None] * n_cols)
796
+
797
+ # Ensure list length matches
798
+ if len(ref_alleles) != n_cols:
799
+ ref_alleles = [None] * n_cols
800
+ if len(alt_alleles) != n_cols:
801
+ alt_alleles = [None] * n_cols
802
+
803
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
804
+ source_snp_data = None
805
+
806
+ for j in range(n_cols):
807
+ ref = _normalize_iupac(ref_alleles[j])
808
+ alt = _normalize_iupac(alt_alleles[j])
809
+
810
+ # --- REPAIR LOGIC ---
811
+ # If metadata is missing, scan the source column.
812
+ if ref is None or alt is None:
813
+ if source_snp_data is None and self.genotype_data.snp_data is not None:
814
+ try:
815
+ source_snp_data = np.asarray(self.genotype_data.snp_data)
816
+ except Exception:
817
+ pass # if lazy loading fails
818
+
819
+ if source_snp_data is not None:
820
+ try:
821
+ col_data = source_snp_data[:, j]
822
+ uniques = set()
823
+ # Optimization: check up to 200 non-empty values
824
+ count = 0
825
+ for val in col_data:
826
+ norm = _normalize_iupac(val)
827
+ if norm:
828
+ uniques.add(norm)
829
+ count += 1
830
+ if len(uniques) >= 2 or count > 200:
831
+ break
832
+
833
+ sorted_u = sorted(list(uniques))
834
+ if len(sorted_u) >= 1 and ref is None:
835
+ ref = sorted_u[0]
836
+ if len(sorted_u) >= 2 and alt is None:
837
+ alt = sorted_u[1]
838
+ except Exception:
839
+ pass
840
+
841
+ # --- DEFAULTS FOR MISSING ---
842
+ # If still missing, we cannot decode.
843
+ if ref is None and alt is None:
844
+ ref = "N"
845
+ alt = "N"
846
+ elif ref is None:
847
+ ref = alt
848
+ elif alt is None:
849
+ alt = ref # Monomorphic site: ALT becomes REF
850
+
851
+ # --- COMPUTE HET CODE ---
852
+ if ref == alt:
853
+ het_code = ref
854
+ else:
855
+ ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
856
+ alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
857
+ union_set = frozenset(ref_set | alt_set)
858
+ het_code = bases_to_iupac.get(union_set, "N")
859
+
860
+ # --- ASSIGNMENT WITH SAFETY FALLBACKS ---
861
+ col_codes = codes[:, j]
862
+
863
+ # Case 0: REF
864
+ if ref != "N":
865
+ out[col_codes == 0, j] = ref
866
+
867
+ # Case 1: HET
868
+ if het_code != "N":
869
+ out[col_codes == 1, j] = het_code
870
+ else:
871
+ # If HET code is invalid (e.g. ref='A', alt='N'),
872
+ # fallback to REF
873
+ # Fix for an issue where a HET prediction at a monomorphic site
874
+ # produced 'N'
875
+ if ref != "N":
876
+ out[col_codes == 1, j] = ref
877
+
878
+ # Case 2: ALT
879
+ if alt != "N":
880
+ out[col_codes == 2, j] = alt
881
+ else:
882
+ # If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
883
+ # Fix for an issue where an ALT prediction on a monomorphic site
884
+ # produced 'N'
885
+ if ref != "N":
886
+ out[col_codes == 2, j] = ref
887
+
888
+ return out
889
+
890
+ def _additional_metrics(
891
+ self,
892
+ y_true: np.ndarray,
893
+ y_pred: np.ndarray,
894
+ labels: list[int],
895
+ report_names: list[str],
896
+ report: dict[str, dict[str, float] | float],
897
+ ) -> dict[str, dict[str, float] | float]:
898
+ """Compute additional metrics and augment the report dictionary.
899
+
900
+ Args:
901
+ y_true (np.ndarray): True genotypes.
902
+ y_pred (np.ndarray): Predicted genotypes.
903
+ labels (list[int]): List of label indices.
904
+ report_names (list[str]): List of report names corresponding to labels.
905
+ report (dict[str, dict[str, float] | float]): Classification report dictionary to augment.
906
+
907
+ Returns:
908
+ dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
909
+ """
910
+ # Create an identity matrix and use the targets array as indices
911
+ y_score = np.eye(len(report_names))[y_pred]
912
+
913
+ # Per-class metrics
914
+ ap_pc = average_precision_score(y_true, y_score, average=None)
915
+ jaccard_pc = jaccard_score(
916
+ y_true, y_pred, average=None, labels=labels, zero_division=0
917
+ )
918
+
919
+ # Macro/weighted metrics
920
+ ap_macro = average_precision_score(y_true, y_score, average="macro")
921
+ ap_weighted = average_precision_score(y_true, y_score, average="weighted")
922
+ jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
923
+ jaccard_weighted = jaccard_score(
924
+ y_true, y_pred, average="weighted", zero_division=0
925
+ )
926
+
927
+ # Matthews correlation coefficient (MCC)
928
+ mcc = matthews_corrcoef(y_true, y_pred)
929
+
930
+ if not isinstance(ap_pc, np.ndarray):
931
+ msg = "average_precision_score or f1_score did not return np.ndarray as expected."
932
+ self.logger.error(msg)
933
+ raise TypeError(msg)
934
+
935
+ if not isinstance(jaccard_pc, np.ndarray):
936
+ msg = "jaccard_score did not return np.ndarray as expected."
937
+ self.logger.error(msg)
938
+ raise TypeError(msg)
939
+
940
+ # Add per-class metrics
941
+ report_full = {}
942
+ dd_subset = {
943
+ k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
944
+ }
945
+ for i, class_name in enumerate(report_names):
946
+ class_report: dict[str, float] = {}
947
+ if class_name in dd_subset:
948
+ class_report = dd_subset[class_name]
949
+
950
+ if isinstance(class_report, float) or not class_report:
951
+ continue
952
+
953
+ report_full[class_name] = dict(class_report)
954
+ report_full[class_name]["average-precision"] = float(ap_pc[i])
955
+ report_full[class_name]["jaccard"] = float(jaccard_pc[i])
956
+
957
+ macro_avg = report.get("macro avg")
958
+ if isinstance(macro_avg, dict):
959
+ report_full["macro avg"] = dict(macro_avg)
960
+ report_full["macro avg"]["average-precision"] = float(ap_macro)
961
+ report_full["macro avg"]["jaccard"] = float(jaccard_macro)
962
+
963
+ weighted_avg = report.get("weighted avg")
964
+ if isinstance(weighted_avg, dict):
965
+ report_full["weighted avg"] = dict(weighted_avg)
966
+ report_full["weighted avg"]["average-precision"] = float(ap_weighted)
967
+ report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
968
+
969
+ # Add scalar summary metrics
970
+ report_full["mcc"] = float(mcc)
971
+ accuracy_val = report.get("accuracy")
972
+
973
+ if isinstance(accuracy_val, (int, float)):
974
+ report_full["accuracy"] = float(accuracy_val)
975
+
976
+ return report_full