pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Standard library
|
|
2
|
+
import copy
|
|
2
3
|
import json
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
|
|
@@ -11,16 +12,16 @@ from matplotlib.figure import Figure
|
|
|
11
12
|
from plotly.graph_objs import Figure as PlotlyFigure
|
|
12
13
|
from sklearn.exceptions import NotFittedError
|
|
13
14
|
from sklearn.metrics import (
|
|
14
|
-
|
|
15
|
+
average_precision_score,
|
|
15
16
|
classification_report,
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
recall_score,
|
|
17
|
+
jaccard_score,
|
|
18
|
+
matthews_corrcoef,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
# Project
|
|
22
22
|
from snpio import GenotypeEncoder
|
|
23
23
|
from snpio.utils.logging import LoggerManager
|
|
24
|
+
from snpio.utils.misc import validate_input_type
|
|
24
25
|
|
|
25
26
|
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
26
27
|
from pgsui.data_processing.containers import RefAlleleConfig
|
|
@@ -58,6 +59,7 @@ def ensure_refallele_config(
|
|
|
58
59
|
if isinstance(config, str):
|
|
59
60
|
return load_yaml_to_dataclass(config, RefAlleleConfig)
|
|
60
61
|
if isinstance(config, dict):
|
|
62
|
+
config = copy.deepcopy(config) # copy
|
|
61
63
|
base = RefAlleleConfig()
|
|
62
64
|
# honor optional top-level 'preset'
|
|
63
65
|
preset = config.pop("preset", None)
|
|
@@ -82,9 +84,9 @@ def ensure_refallele_config(
|
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
class ImputeRefAllele:
|
|
85
|
-
"""Deterministic imputer that
|
|
87
|
+
"""Deterministic imputer that fills missing genotypes with REF (0).
|
|
86
88
|
|
|
87
|
-
|
|
89
|
+
Operates on 0/1/2 encodings with missing values represented by any negative integer. Evaluation splits samples into TRAIN/TEST once, then evaluates on either all observed test cells or a simulated-missing subset (depending on config). Produces 0/1/2 (zygosity) and 10-class IUPAC reports plus confusion matrices, and plots genotype distributions before/after imputation. Output is returned as IUPAC strings via ``decode_012``.
|
|
88
90
|
"""
|
|
89
91
|
|
|
90
92
|
def __init__(
|
|
@@ -107,16 +109,16 @@ class ImputeRefAllele:
|
|
|
107
109
|
) -> None:
|
|
108
110
|
"""Initialize the Ref-Allele imputer from a unified config.
|
|
109
111
|
|
|
110
|
-
This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and
|
|
112
|
+
This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and simulated-missing controls.
|
|
111
113
|
|
|
112
114
|
Args:
|
|
113
115
|
genotype_data (GenotypeData): Backing genotype data.
|
|
114
|
-
tree_parser (Optional[TreeParser]): Optional SNPio
|
|
116
|
+
tree_parser (Optional[TreeParser]): Optional SNPio tree parser for nonrandom simulated-missing modes.
|
|
115
117
|
config (RefAlleleConfig | dict | str | None): Configuration as a dataclass, nested dict, or YAML path. If None, defaults are used.
|
|
116
|
-
overrides (dict
|
|
118
|
+
overrides (Optional[dict]): Flat dot-key overrides applied last with highest precedence, e.g. {'split.test_size': 0.25, 'algo.missing': -1}.
|
|
117
119
|
simulate_missing (bool): Whether to simulate missing data during evaluation. Default is True.
|
|
118
|
-
sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
|
|
119
|
-
sim_prop (float): Proportion of data to simulate as missing if enabled in config.
|
|
120
|
+
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data if enabled in config.
|
|
121
|
+
sim_prop (float): Proportion of data to simulate as missing if enabled in config. Default is 0.2.
|
|
120
122
|
sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
|
|
121
123
|
"""
|
|
122
124
|
# Normalize config then apply highest-precedence overrides
|
|
@@ -153,7 +155,7 @@ class ImputeRefAllele:
|
|
|
153
155
|
self.plots_dir: Path
|
|
154
156
|
self.metrics_dir: Path
|
|
155
157
|
self.parameters_dir: Path
|
|
156
|
-
self.
|
|
158
|
+
self.models_dir: Path
|
|
157
159
|
self.optimize_dir: Path
|
|
158
160
|
|
|
159
161
|
# Logger
|
|
@@ -174,7 +176,7 @@ class ImputeRefAllele:
|
|
|
174
176
|
self.encoder = GenotypeEncoder(self.genotype_data)
|
|
175
177
|
|
|
176
178
|
# Work in 0/1/2 with -1 for missing
|
|
177
|
-
X012 = self.encoder.genotypes_012.astype(np.
|
|
179
|
+
X012 = self.encoder.genotypes_012.astype(np.int8, copy=True)
|
|
178
180
|
X012[X012 < 0] = -1
|
|
179
181
|
self.X012_ = X012
|
|
180
182
|
self.num_features_ = X012.shape[1]
|
|
@@ -199,8 +201,8 @@ class ImputeRefAllele:
|
|
|
199
201
|
self.metrics_: Dict[str, int | float] = {}
|
|
200
202
|
|
|
201
203
|
# Ploidy heuristic for 0/1/2 scoring parity
|
|
202
|
-
|
|
203
|
-
self.is_haploid_ =
|
|
204
|
+
self.ploidy = self.cfg.io.ploidy
|
|
205
|
+
self.is_haploid_ = self.ploidy == 1
|
|
204
206
|
|
|
205
207
|
# Plotting (use config)
|
|
206
208
|
self.plot_format = cfg.plot.fmt
|
|
@@ -243,8 +245,7 @@ class ImputeRefAllele:
|
|
|
243
245
|
|
|
244
246
|
# Use NaN for missing inside a DataFrame to leverage fillna
|
|
245
247
|
df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
|
|
246
|
-
df_all =
|
|
247
|
-
df_all = df_all.replace(-9, np.nan) # Just in case
|
|
248
|
+
df_all[df_all < 0] = np.nan
|
|
248
249
|
|
|
249
250
|
# Observed mask in the ORIGINAL data (before any simulated-missing)
|
|
250
251
|
obs_mask = df_all.notna().to_numpy() # shape (n_samples, n_loci)
|
|
@@ -256,6 +257,9 @@ class ImputeRefAllele:
|
|
|
256
257
|
|
|
257
258
|
# Decide how to build the sim mask: legacy vs simulated-missing
|
|
258
259
|
if getattr(self, "simulate_missing", False):
|
|
260
|
+
X_for_sim = self.ground_truth012_.astype(np.float32, copy=True)
|
|
261
|
+
X_for_sim[X_for_sim < 0] = -9.0
|
|
262
|
+
|
|
259
263
|
# Simulate missing on the full matrix; we only use the mask.
|
|
260
264
|
tr = SimMissingTransformer(
|
|
261
265
|
genotype_data=self.genotype_data,
|
|
@@ -267,7 +271,7 @@ class ImputeRefAllele:
|
|
|
267
271
|
verbose=self.verbose,
|
|
268
272
|
**(self.sim_kwargs or {}),
|
|
269
273
|
)
|
|
270
|
-
tr.fit(
|
|
274
|
+
tr.fit(X_for_sim)
|
|
271
275
|
sim_mask_global = tr.sim_missing_mask_.astype(bool)
|
|
272
276
|
|
|
273
277
|
# Only consider cells that were originally observed
|
|
@@ -317,12 +321,17 @@ class ImputeRefAllele:
|
|
|
317
321
|
NotFittedError: If the model has not been fitted yet.
|
|
318
322
|
"""
|
|
319
323
|
if not self.is_fit_:
|
|
320
|
-
|
|
321
|
-
|
|
324
|
+
msg = "ImputeRefAllele instance is not fitted yet. Call 'fit()' before 'transform()'."
|
|
325
|
+
self.logger.error(msg)
|
|
326
|
+
raise NotFittedError(msg)
|
|
327
|
+
|
|
328
|
+
assert (
|
|
329
|
+
self.X_train_df_ is not None
|
|
330
|
+
), f"[{self.model_name}] X_train_df_ is not set after fit()."
|
|
322
331
|
|
|
323
332
|
# 1) Impute the evaluation-masked copy (compute metrics)
|
|
324
333
|
imputed_eval_df = self._impute_ref(df_in=self.X_train_df_)
|
|
325
|
-
X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.
|
|
334
|
+
X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int8)
|
|
326
335
|
self.X_imputed012_ = X_imputed_eval
|
|
327
336
|
|
|
328
337
|
# Evaluate parity with DL models
|
|
@@ -330,23 +339,24 @@ class ImputeRefAllele:
|
|
|
330
339
|
|
|
331
340
|
# 2) Impute the FULL dataset (only true missings)
|
|
332
341
|
df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
|
|
333
|
-
df_missingonly =
|
|
334
|
-
df_missingonly = df_missingonly.replace(-9, np.nan) # Just in case
|
|
342
|
+
df_missingonly[df_missingonly < 0] = np.nan
|
|
335
343
|
|
|
336
344
|
imputed_full_df = self._impute_ref(df_in=df_missingonly)
|
|
337
|
-
X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.
|
|
345
|
+
X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int8)
|
|
338
346
|
|
|
339
347
|
# Plot distributions (like DL .transform())
|
|
340
348
|
|
|
341
349
|
if self.ground_truth012_ is None:
|
|
342
|
-
msg = "ground_truth012_ is
|
|
343
|
-
self.logger.error(msg)
|
|
350
|
+
msg = "ground_truth012_ is NoneType; cannot plot distributions."
|
|
351
|
+
self.logger.error(msg, exc_info=True)
|
|
352
|
+
raise NotFittedError(msg)
|
|
353
|
+
|
|
354
|
+
imp_decoded = self.decode_012(X_imputed_full_012)
|
|
344
355
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
|
|
356
|
+
if self.show_plots:
|
|
357
|
+
gt_decoded = self.decode_012(self.ground_truth012_)
|
|
358
|
+
self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
|
|
359
|
+
self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
|
|
350
360
|
|
|
351
361
|
# Return IUPAC strings
|
|
352
362
|
return imp_decoded
|
|
@@ -365,7 +375,7 @@ class ImputeRefAllele:
|
|
|
365
375
|
df = df_in.copy()
|
|
366
376
|
# Fill all NaNs with 0 (homozygous REF) column-wise; constant so vectorized is fine
|
|
367
377
|
df = df.fillna(0)
|
|
368
|
-
return df.astype(np.
|
|
378
|
+
return df.astype(np.int8)
|
|
369
379
|
|
|
370
380
|
def _evaluate_and_report(self) -> None:
|
|
371
381
|
"""Evaluate imputed vs. ground truth on masked test cells; produce reports and plots.
|
|
@@ -394,8 +404,8 @@ class ImputeRefAllele:
|
|
|
394
404
|
X_pred_eval = self.ground_truth012_.copy()
|
|
395
405
|
X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
|
|
396
406
|
|
|
397
|
-
y_true_dec = self.
|
|
398
|
-
y_pred_dec = self.
|
|
407
|
+
y_true_dec = self.decode_012(self.ground_truth012_)
|
|
408
|
+
y_pred_dec = self.decode_012(X_pred_eval)
|
|
399
409
|
|
|
400
410
|
encodings_dict = {
|
|
401
411
|
"A": 0,
|
|
@@ -418,43 +428,37 @@ class ImputeRefAllele:
|
|
|
418
428
|
)
|
|
419
429
|
y_true_10 = y_true_int[self.sim_mask_]
|
|
420
430
|
y_pred_10 = y_pred_int[self.sim_mask_]
|
|
431
|
+
|
|
432
|
+
m = (y_true_10 >= 0) & (y_pred_10 >= 0)
|
|
433
|
+
y_true_10, y_pred_10 = y_true_10[m], y_pred_10[m]
|
|
434
|
+
if y_true_10.size == 0:
|
|
435
|
+
self.logger.warning(
|
|
436
|
+
"No valid IUPAC test cells; skipping 10-class evaluation."
|
|
437
|
+
)
|
|
438
|
+
return
|
|
439
|
+
|
|
421
440
|
self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
|
|
422
441
|
|
|
423
442
|
def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
|
424
443
|
"""0/1/2 zygosity report & confusion matrix.
|
|
425
444
|
|
|
426
|
-
This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is
|
|
445
|
+
This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is haploid (only 0 and 2 present), it folds ALT (2) into the binary ALT/PRESENT class (1) for evaluation. The method computes metrics, logs the report, and creates visualizations of the results.
|
|
427
446
|
|
|
428
447
|
Args:
|
|
429
448
|
y_true (np.ndarray): True genotypes (0/1/2) for masked
|
|
430
|
-
y_pred (np.ndarray): Predicted genotypes (0/1/2) for
|
|
449
|
+
y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
|
|
431
450
|
"""
|
|
432
|
-
labels = [0, 1, 2]
|
|
433
|
-
report_names = ["REF", "HET", "ALT"]
|
|
451
|
+
labels: list[int] = [0, 1, 2]
|
|
452
|
+
report_names: list[str] = ["REF", "HET", "ALT"]
|
|
434
453
|
|
|
435
|
-
# Haploid parity: fold 2
|
|
454
|
+
# Haploid parity: fold ALT (2) into ALT/Present (1)
|
|
436
455
|
if self.is_haploid_:
|
|
437
|
-
y_true
|
|
438
|
-
y_pred
|
|
439
|
-
labels = [0, 1]
|
|
440
|
-
report_names = ["REF", "ALT"]
|
|
441
|
-
|
|
442
|
-
metrics = {
|
|
443
|
-
"n_masked_test": int(y_true.size),
|
|
444
|
-
"accuracy": accuracy_score(y_true, y_pred),
|
|
445
|
-
"f1": f1_score(
|
|
446
|
-
y_true, y_pred, average="weighted", labels=labels, zero_division=0
|
|
447
|
-
),
|
|
448
|
-
"precision": precision_score(
|
|
449
|
-
y_true, y_pred, average="weighted", labels=labels, zero_division=0
|
|
450
|
-
),
|
|
451
|
-
"recall": recall_score(
|
|
452
|
-
y_true, y_pred, average="weighted", labels=labels, zero_division=0
|
|
453
|
-
),
|
|
454
|
-
}
|
|
455
|
-
self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
|
|
456
|
+
y_true = np.where(y_true == 2, 1, y_true)
|
|
457
|
+
y_pred = np.where(y_pred == 2, 1, y_pred)
|
|
458
|
+
labels: list[int] = [0, 1]
|
|
459
|
+
report_names: list[str] = ["REF", "ALT"]
|
|
456
460
|
|
|
457
|
-
report:
|
|
461
|
+
report: dict | str = classification_report(
|
|
458
462
|
y_true,
|
|
459
463
|
y_pred,
|
|
460
464
|
labels=labels,
|
|
@@ -468,91 +472,69 @@ class ImputeRefAllele:
|
|
|
468
472
|
self.logger.error(msg)
|
|
469
473
|
raise TypeError(msg)
|
|
470
474
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
tmp = {}
|
|
474
|
-
if isinstance(v, dict) and "support" in v:
|
|
475
|
-
for k2, v2 in v.items():
|
|
476
|
-
if k2 != "support":
|
|
477
|
-
tmp[k2] = v2
|
|
478
|
-
if tmp:
|
|
479
|
-
report_subset[k] = tmp
|
|
480
|
-
|
|
481
|
-
if report_subset:
|
|
482
|
-
pm = PrettyMetrics(
|
|
483
|
-
report_subset,
|
|
484
|
-
precision=3,
|
|
485
|
-
title=f"{self.model_name} Zygosity Report",
|
|
486
|
-
)
|
|
487
|
-
pm.render()
|
|
488
|
-
|
|
489
|
-
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
475
|
+
if self.show_plots:
|
|
476
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
490
477
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
478
|
+
plots = viz.plot_all(
|
|
479
|
+
report,
|
|
480
|
+
title_prefix=f"{self.model_name} Zygosity Report",
|
|
481
|
+
show=self.show_plots,
|
|
482
|
+
heatmap_classes_only=True,
|
|
483
|
+
)
|
|
495
484
|
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
485
|
+
for name, fig in plots.items():
|
|
486
|
+
fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
|
|
487
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
488
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
489
|
+
plt.close(fig)
|
|
490
|
+
elif isinstance(fig, PlotlyFigure):
|
|
491
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
502
492
|
|
|
503
|
-
|
|
504
|
-
plt.rcParams.update(self.plotter_.param_dict)
|
|
493
|
+
viz._reset_mpl_style()
|
|
505
494
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
plt.close(fig)
|
|
511
|
-
elif isinstance(fig, PlotlyFigure):
|
|
512
|
-
fig.write_html(file=fout.with_suffix(".html"))
|
|
495
|
+
# Confusion matrix
|
|
496
|
+
self.plotter_.plot_confusion_matrix(
|
|
497
|
+
y_true, y_pred, label_names=report_names, prefix="zygosity"
|
|
498
|
+
)
|
|
513
499
|
|
|
514
|
-
|
|
500
|
+
# ------ Additional metrics ------
|
|
501
|
+
report_full = self._additional_metrics(
|
|
502
|
+
y_true, y_pred, labels, report_names, report
|
|
503
|
+
)
|
|
515
504
|
|
|
516
|
-
self.
|
|
505
|
+
if self.verbose or self.debug:
|
|
506
|
+
pm = PrettyMetrics(
|
|
507
|
+
report_full,
|
|
508
|
+
precision=2,
|
|
509
|
+
title=f"{self.model_name} Zygosity Report",
|
|
510
|
+
)
|
|
511
|
+
pm.render()
|
|
517
512
|
|
|
518
|
-
#
|
|
519
|
-
self.
|
|
520
|
-
y_true, y_pred, label_names=report_names, prefix="zygosity"
|
|
521
|
-
)
|
|
513
|
+
# Save JSON
|
|
514
|
+
self._save_report(report_full, suffix="zygosity")
|
|
522
515
|
|
|
523
516
|
def _evaluate_iupac10_and_plot(
|
|
524
517
|
self, y_true: np.ndarray, y_pred: np.ndarray
|
|
525
518
|
) -> None:
|
|
526
519
|
"""10-class IUPAC report & confusion matrix.
|
|
527
520
|
|
|
528
|
-
This method generates a classification report and confusion matrix for genotypes encoded
|
|
521
|
+
This method generates a classification report and confusion matrix for genotypes encoded as 10-class IUPAC codes (0-9). It computes various performance metrics, logs the classification report, and creates visualizations of the results.
|
|
529
522
|
|
|
530
523
|
Args:
|
|
531
|
-
y_true (np.ndarray): True genotypes (0-9) for masked
|
|
532
|
-
y_pred (np.ndarray): Predicted genotypes (0-9) for masked
|
|
524
|
+
y_true (np.ndarray): True genotypes (0-9) for masked
|
|
525
|
+
y_pred (np.ndarray): Predicted genotypes (0-9) for masked
|
|
533
526
|
"""
|
|
534
527
|
labels_idx = list(range(10))
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
metrics = {
|
|
538
|
-
"accuracy": accuracy_score(y_true, y_pred),
|
|
539
|
-
"f1": f1_score(
|
|
540
|
-
y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
|
|
541
|
-
),
|
|
542
|
-
"precision": precision_score(
|
|
543
|
-
y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
|
|
544
|
-
),
|
|
545
|
-
"recall": recall_score(
|
|
546
|
-
y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
|
|
547
|
-
),
|
|
548
|
-
}
|
|
549
|
-
self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
|
|
528
|
+
report_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
|
|
550
529
|
|
|
551
|
-
|
|
530
|
+
# Create an identity matrix and use the targets array as indices
|
|
531
|
+
y_score = np.eye(len(report_names))[y_pred]
|
|
532
|
+
|
|
533
|
+
report: dict | str = classification_report(
|
|
552
534
|
y_true,
|
|
553
535
|
y_pred,
|
|
554
536
|
labels=labels_idx,
|
|
555
|
-
target_names=
|
|
537
|
+
target_names=report_names,
|
|
556
538
|
zero_division=0,
|
|
557
539
|
output_dict=True,
|
|
558
540
|
)
|
|
@@ -562,30 +544,50 @@ class ImputeRefAllele:
|
|
|
562
544
|
self.logger.error(msg)
|
|
563
545
|
raise TypeError(msg)
|
|
564
546
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
547
|
+
if self.show_plots:
|
|
548
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
549
|
+
|
|
550
|
+
plots = viz.plot_all(
|
|
551
|
+
report,
|
|
552
|
+
title_prefix=f"{self.model_name} IUPAC Report",
|
|
553
|
+
show=self.show_plots,
|
|
554
|
+
heatmap_classes_only=True,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# Reset the style from Optuna's plotting.
|
|
558
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
559
|
+
|
|
560
|
+
for name, fig in plots.items():
|
|
561
|
+
fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
|
|
562
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
563
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
564
|
+
plt.close(fig)
|
|
565
|
+
elif isinstance(fig, PlotlyFigure):
|
|
566
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
567
|
+
|
|
568
|
+
# Reset the style
|
|
569
|
+
viz._reset_mpl_style()
|
|
570
|
+
|
|
571
|
+
# Confusion matrix
|
|
572
|
+
self.plotter_.plot_confusion_matrix(
|
|
573
|
+
y_true, y_pred, label_names=report_names, prefix="iupac"
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
# ------ Additional metrics ------
|
|
577
|
+
report_full = self._additional_metrics(
|
|
578
|
+
y_true, y_pred, labels_idx, report_names, report
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
if self.verbose or self.debug:
|
|
576
582
|
pm = PrettyMetrics(
|
|
577
|
-
|
|
578
|
-
precision=
|
|
583
|
+
report_full,
|
|
584
|
+
precision=2,
|
|
579
585
|
title=f"{self.model_name} IUPAC 10-Class Report",
|
|
580
586
|
)
|
|
581
587
|
pm.render()
|
|
582
588
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
# Confusion matrix
|
|
586
|
-
self.plotter_.plot_confusion_matrix(
|
|
587
|
-
y_true, y_pred, label_names=labels_names, prefix="iupac"
|
|
588
|
-
)
|
|
589
|
+
# Save JSON
|
|
590
|
+
self._save_report(report_full, suffix="iupac")
|
|
589
591
|
|
|
590
592
|
def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
591
593
|
"""Create train/test split indices.
|
|
@@ -623,25 +625,28 @@ class ImputeRefAllele:
|
|
|
623
625
|
train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
|
|
624
626
|
return train_idx, test_idx
|
|
625
627
|
|
|
626
|
-
def _save_report(self, report_dict: Dict[str,
|
|
628
|
+
def _save_report(self, report_dict: Dict[str, Any], suffix: str) -> None:
|
|
627
629
|
"""Save classification report dictionary as a JSON file.
|
|
628
630
|
|
|
629
|
-
This method saves the provided classification report dictionary to a JSON file in the metrics directory
|
|
631
|
+
This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
|
|
630
632
|
|
|
631
633
|
Args:
|
|
632
|
-
report_dict (Dict[str,
|
|
634
|
+
report_dict (Dict[str, Any]): The classification report dictionary to save.
|
|
633
635
|
suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
|
|
634
636
|
|
|
635
637
|
Raises:
|
|
636
638
|
NotFittedError: If fit() and transform() have not been called.
|
|
637
639
|
"""
|
|
638
640
|
if not self.is_fit_ or self.X_imputed012_ is None:
|
|
639
|
-
|
|
641
|
+
msg = "No report to save. Ensure fit() and transform() have been called."
|
|
642
|
+
raise NotFittedError(msg)
|
|
640
643
|
|
|
641
644
|
out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
|
|
642
645
|
with open(out_fp, "w") as f:
|
|
643
646
|
json.dump(report_dict, f, indent=4)
|
|
644
|
-
|
|
647
|
+
|
|
648
|
+
msg = f"{self.model_name} {suffix} report saved to {out_fp}."
|
|
649
|
+
self.logger.info(msg)
|
|
645
650
|
|
|
646
651
|
def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
|
|
647
652
|
"""Creates the directory structure for storing model outputs.
|
|
@@ -667,3 +672,305 @@ class ImputeRefAllele:
|
|
|
667
672
|
msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
|
|
668
673
|
self.logger.error(msg)
|
|
669
674
|
raise Exception(msg)
|
|
675
|
+
|
|
676
|
+
def decode_012(
|
|
677
|
+
self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
|
|
678
|
+
) -> np.ndarray:
|
|
679
|
+
"""Decode 012-encodings to IUPAC chars with metadata repair.
|
|
680
|
+
|
|
681
|
+
This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
|
|
682
|
+
1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
|
|
683
|
+
2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
|
|
687
|
+
is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
|
|
691
|
+
|
|
692
|
+
Notes:
|
|
693
|
+
- The method normalizes input values to handle various formats, including strings, lists, and arrays.
|
|
694
|
+
- It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
|
|
695
|
+
- Missing or invalid codes are represented as 'N' if they can't be resolved.
|
|
696
|
+
- The method includes repair logic to infer missing metadata from the source SNP data when necessary.
|
|
697
|
+
|
|
698
|
+
Raises:
|
|
699
|
+
ValueError: If input is not a DataFrame.
|
|
700
|
+
"""
|
|
701
|
+
df = validate_input_type(X, return_type="df")
|
|
702
|
+
|
|
703
|
+
if not isinstance(df, pd.DataFrame):
|
|
704
|
+
msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
|
|
705
|
+
self.logger.error(msg)
|
|
706
|
+
raise ValueError(msg)
|
|
707
|
+
|
|
708
|
+
# IUPAC Definitions
|
|
709
|
+
iupac_to_bases: dict[str, set[str]] = {
|
|
710
|
+
"A": {"A"},
|
|
711
|
+
"C": {"C"},
|
|
712
|
+
"G": {"G"},
|
|
713
|
+
"T": {"T"},
|
|
714
|
+
"R": {"A", "G"},
|
|
715
|
+
"Y": {"C", "T"},
|
|
716
|
+
"S": {"G", "C"},
|
|
717
|
+
"W": {"A", "T"},
|
|
718
|
+
"K": {"G", "T"},
|
|
719
|
+
"M": {"A", "C"},
|
|
720
|
+
"B": {"C", "G", "T"},
|
|
721
|
+
"D": {"A", "G", "T"},
|
|
722
|
+
"H": {"A", "C", "T"},
|
|
723
|
+
"V": {"A", "C", "G"},
|
|
724
|
+
"N": set(),
|
|
725
|
+
}
|
|
726
|
+
bases_to_iupac = {
|
|
727
|
+
frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
|
|
728
|
+
}
|
|
729
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
|
|
730
|
+
|
|
731
|
+
def _normalize_iupac(value: object) -> str | None:
|
|
732
|
+
"""Normalize an input into a single IUPAC code token or None."""
|
|
733
|
+
if value is None:
|
|
734
|
+
return None
|
|
735
|
+
|
|
736
|
+
# Bytes -> str (make type narrowing explicit)
|
|
737
|
+
if isinstance(value, (bytes, np.bytes_)):
|
|
738
|
+
value = bytes(value).decode("utf-8", errors="ignore")
|
|
739
|
+
|
|
740
|
+
# Handle list/tuple/array/Series: take first valid
|
|
741
|
+
if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
|
|
742
|
+
# Convert Series to numpy array for consistent behavior
|
|
743
|
+
if isinstance(value, pd.Series):
|
|
744
|
+
arr = value.to_numpy()
|
|
745
|
+
else:
|
|
746
|
+
arr = value
|
|
747
|
+
|
|
748
|
+
# Scalar numpy array fast path
|
|
749
|
+
if isinstance(arr, np.ndarray) and arr.ndim == 0:
|
|
750
|
+
return _normalize_iupac(arr.item())
|
|
751
|
+
|
|
752
|
+
# Empty sequence/array
|
|
753
|
+
if len(arr) == 0:
|
|
754
|
+
return None
|
|
755
|
+
|
|
756
|
+
# First valid element wins
|
|
757
|
+
for item in arr:
|
|
758
|
+
code = _normalize_iupac(item)
|
|
759
|
+
if code is not None:
|
|
760
|
+
return code
|
|
761
|
+
return None
|
|
762
|
+
|
|
763
|
+
s = str(value).upper().strip()
|
|
764
|
+
if not s or s in missing_codes:
|
|
765
|
+
return None
|
|
766
|
+
|
|
767
|
+
if "," in s:
|
|
768
|
+
for tok in (t.strip() for t in s.split(",")):
|
|
769
|
+
if tok and tok not in missing_codes and tok in iupac_to_bases:
|
|
770
|
+
return tok
|
|
771
|
+
return None
|
|
772
|
+
|
|
773
|
+
return s if s in iupac_to_bases else None
|
|
774
|
+
|
|
775
|
+
codes_df = df.apply(pd.to_numeric, errors="coerce")
|
|
776
|
+
codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
|
|
777
|
+
n_rows, n_cols = codes.shape
|
|
778
|
+
|
|
779
|
+
if is_nuc:
|
|
780
|
+
iupac_list = np.array(
|
|
781
|
+
["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
|
|
782
|
+
)
|
|
783
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
784
|
+
mask = (codes >= 0) & (codes <= 9)
|
|
785
|
+
out[mask] = iupac_list[codes[mask]]
|
|
786
|
+
return out
|
|
787
|
+
|
|
788
|
+
# Metadata fetch
|
|
789
|
+
ref_alleles = getattr(self.genotype_data, "ref", [])
|
|
790
|
+
alt_alleles = getattr(self.genotype_data, "alt", [])
|
|
791
|
+
|
|
792
|
+
if len(ref_alleles) != n_cols:
|
|
793
|
+
ref_alleles = getattr(self, "_ref", [None] * n_cols)
|
|
794
|
+
if len(alt_alleles) != n_cols:
|
|
795
|
+
alt_alleles = getattr(self, "_alt", [None] * n_cols)
|
|
796
|
+
|
|
797
|
+
# Ensure list length matches
|
|
798
|
+
if len(ref_alleles) != n_cols:
|
|
799
|
+
ref_alleles = [None] * n_cols
|
|
800
|
+
if len(alt_alleles) != n_cols:
|
|
801
|
+
alt_alleles = [None] * n_cols
|
|
802
|
+
|
|
803
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
804
|
+
source_snp_data = None
|
|
805
|
+
|
|
806
|
+
for j in range(n_cols):
|
|
807
|
+
ref = _normalize_iupac(ref_alleles[j])
|
|
808
|
+
alt = _normalize_iupac(alt_alleles[j])
|
|
809
|
+
|
|
810
|
+
# --- REPAIR LOGIC ---
|
|
811
|
+
# If metadata is missing, scan the source column.
|
|
812
|
+
if ref is None or alt is None:
|
|
813
|
+
if source_snp_data is None and self.genotype_data.snp_data is not None:
|
|
814
|
+
try:
|
|
815
|
+
source_snp_data = np.asarray(self.genotype_data.snp_data)
|
|
816
|
+
except Exception:
|
|
817
|
+
pass # if lazy loading fails
|
|
818
|
+
|
|
819
|
+
if source_snp_data is not None:
|
|
820
|
+
try:
|
|
821
|
+
col_data = source_snp_data[:, j]
|
|
822
|
+
uniques = set()
|
|
823
|
+
# Optimization: check up to 200 non-empty values
|
|
824
|
+
count = 0
|
|
825
|
+
for val in col_data:
|
|
826
|
+
norm = _normalize_iupac(val)
|
|
827
|
+
if norm:
|
|
828
|
+
uniques.add(norm)
|
|
829
|
+
count += 1
|
|
830
|
+
if len(uniques) >= 2 or count > 200:
|
|
831
|
+
break
|
|
832
|
+
|
|
833
|
+
sorted_u = sorted(list(uniques))
|
|
834
|
+
if len(sorted_u) >= 1 and ref is None:
|
|
835
|
+
ref = sorted_u[0]
|
|
836
|
+
if len(sorted_u) >= 2 and alt is None:
|
|
837
|
+
alt = sorted_u[1]
|
|
838
|
+
except Exception:
|
|
839
|
+
pass
|
|
840
|
+
|
|
841
|
+
# --- DEFAULTS FOR MISSING ---
|
|
842
|
+
# If still missing, we cannot decode.
|
|
843
|
+
if ref is None and alt is None:
|
|
844
|
+
ref = "N"
|
|
845
|
+
alt = "N"
|
|
846
|
+
elif ref is None:
|
|
847
|
+
ref = alt
|
|
848
|
+
elif alt is None:
|
|
849
|
+
alt = ref # Monomorphic site: ALT becomes REF
|
|
850
|
+
|
|
851
|
+
# --- COMPUTE HET CODE ---
|
|
852
|
+
if ref == alt:
|
|
853
|
+
het_code = ref
|
|
854
|
+
else:
|
|
855
|
+
ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
|
|
856
|
+
alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
|
|
857
|
+
union_set = frozenset(ref_set | alt_set)
|
|
858
|
+
het_code = bases_to_iupac.get(union_set, "N")
|
|
859
|
+
|
|
860
|
+
# --- ASSIGNMENT WITH SAFETY FALLBACKS ---
|
|
861
|
+
col_codes = codes[:, j]
|
|
862
|
+
|
|
863
|
+
# Case 0: REF
|
|
864
|
+
if ref != "N":
|
|
865
|
+
out[col_codes == 0, j] = ref
|
|
866
|
+
|
|
867
|
+
# Case 1: HET
|
|
868
|
+
if het_code != "N":
|
|
869
|
+
out[col_codes == 1, j] = het_code
|
|
870
|
+
else:
|
|
871
|
+
# If HET code is invalid (e.g. ref='A', alt='N'),
|
|
872
|
+
# fallback to REF
|
|
873
|
+
# Fix for an issue where a HET prediction at a monomorphic site
|
|
874
|
+
# produced 'N'
|
|
875
|
+
if ref != "N":
|
|
876
|
+
out[col_codes == 1, j] = ref
|
|
877
|
+
|
|
878
|
+
# Case 2: ALT
|
|
879
|
+
if alt != "N":
|
|
880
|
+
out[col_codes == 2, j] = alt
|
|
881
|
+
else:
|
|
882
|
+
# If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
|
|
883
|
+
# Fix for an issue where an ALT prediction on a monomorphic site
|
|
884
|
+
# produced 'N'
|
|
885
|
+
if ref != "N":
|
|
886
|
+
out[col_codes == 2, j] = ref
|
|
887
|
+
|
|
888
|
+
return out
|
|
889
|
+
|
|
890
|
+
def _additional_metrics(
|
|
891
|
+
self,
|
|
892
|
+
y_true: np.ndarray,
|
|
893
|
+
y_pred: np.ndarray,
|
|
894
|
+
labels: list[int],
|
|
895
|
+
report_names: list[str],
|
|
896
|
+
report: dict[str, dict[str, float] | float],
|
|
897
|
+
) -> dict[str, dict[str, float] | float]:
|
|
898
|
+
"""Compute additional metrics and augment the report dictionary.
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
y_true (np.ndarray): True genotypes.
|
|
902
|
+
y_pred (np.ndarray): Predicted genotypes.
|
|
903
|
+
labels (list[int]): List of label indices.
|
|
904
|
+
report_names (list[str]): List of report names corresponding to labels.
|
|
905
|
+
report (dict[str, dict[str, float] | float]): Classification report dictionary to augment.
|
|
906
|
+
|
|
907
|
+
Returns:
|
|
908
|
+
dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
|
|
909
|
+
"""
|
|
910
|
+
# Create an identity matrix and use the targets array as indices
|
|
911
|
+
y_score = np.eye(len(report_names))[y_pred]
|
|
912
|
+
|
|
913
|
+
# Per-class metrics
|
|
914
|
+
ap_pc = average_precision_score(y_true, y_score, average=None)
|
|
915
|
+
jaccard_pc = jaccard_score(
|
|
916
|
+
y_true, y_pred, average=None, labels=labels, zero_division=0
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
# Macro/weighted metrics
|
|
920
|
+
ap_macro = average_precision_score(y_true, y_score, average="macro")
|
|
921
|
+
ap_weighted = average_precision_score(y_true, y_score, average="weighted")
|
|
922
|
+
jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
|
|
923
|
+
jaccard_weighted = jaccard_score(
|
|
924
|
+
y_true, y_pred, average="weighted", zero_division=0
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Matthews correlation coefficient (MCC)
|
|
928
|
+
mcc = matthews_corrcoef(y_true, y_pred)
|
|
929
|
+
|
|
930
|
+
if not isinstance(ap_pc, np.ndarray):
|
|
931
|
+
msg = "average_precision_score or f1_score did not return np.ndarray as expected."
|
|
932
|
+
self.logger.error(msg)
|
|
933
|
+
raise TypeError(msg)
|
|
934
|
+
|
|
935
|
+
if not isinstance(jaccard_pc, np.ndarray):
|
|
936
|
+
msg = "jaccard_score did not return np.ndarray as expected."
|
|
937
|
+
self.logger.error(msg)
|
|
938
|
+
raise TypeError(msg)
|
|
939
|
+
|
|
940
|
+
# Add per-class metrics
|
|
941
|
+
report_full = {}
|
|
942
|
+
dd_subset = {
|
|
943
|
+
k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
|
|
944
|
+
}
|
|
945
|
+
for i, class_name in enumerate(report_names):
|
|
946
|
+
class_report: dict[str, float] = {}
|
|
947
|
+
if class_name in dd_subset:
|
|
948
|
+
class_report = dd_subset[class_name]
|
|
949
|
+
|
|
950
|
+
if isinstance(class_report, float) or not class_report:
|
|
951
|
+
continue
|
|
952
|
+
|
|
953
|
+
report_full[class_name] = dict(class_report)
|
|
954
|
+
report_full[class_name]["average-precision"] = float(ap_pc[i])
|
|
955
|
+
report_full[class_name]["jaccard"] = float(jaccard_pc[i])
|
|
956
|
+
|
|
957
|
+
macro_avg = report.get("macro avg")
|
|
958
|
+
if isinstance(macro_avg, dict):
|
|
959
|
+
report_full["macro avg"] = dict(macro_avg)
|
|
960
|
+
report_full["macro avg"]["average-precision"] = float(ap_macro)
|
|
961
|
+
report_full["macro avg"]["jaccard"] = float(jaccard_macro)
|
|
962
|
+
|
|
963
|
+
weighted_avg = report.get("weighted avg")
|
|
964
|
+
if isinstance(weighted_avg, dict):
|
|
965
|
+
report_full["weighted avg"] = dict(weighted_avg)
|
|
966
|
+
report_full["weighted avg"]["average-precision"] = float(ap_weighted)
|
|
967
|
+
report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
|
|
968
|
+
|
|
969
|
+
# Add scalar summary metrics
|
|
970
|
+
report_full["mcc"] = float(mcc)
|
|
971
|
+
accuracy_val = report.get("accuracy")
|
|
972
|
+
|
|
973
|
+
if isinstance(accuracy_val, (int, float)):
|
|
974
|
+
report_full["accuracy"] = float(accuracy_val)
|
|
975
|
+
|
|
976
|
+
return report_full
|