pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# Standard library imports
|
|
2
|
+
import copy
|
|
2
3
|
import json
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
# Third-party imports
|
|
7
8
|
import matplotlib.pyplot as plt
|
|
@@ -11,14 +12,14 @@ from matplotlib.figure import Figure
|
|
|
11
12
|
from plotly.graph_objs import Figure as PlotlyFigure
|
|
12
13
|
from sklearn.exceptions import NotFittedError
|
|
13
14
|
from sklearn.metrics import (
|
|
14
|
-
|
|
15
|
+
average_precision_score,
|
|
15
16
|
classification_report,
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
recall_score,
|
|
17
|
+
jaccard_score,
|
|
18
|
+
matthews_corrcoef,
|
|
19
19
|
)
|
|
20
20
|
from snpio import GenotypeEncoder
|
|
21
21
|
from snpio.utils.logging import LoggerManager
|
|
22
|
+
from snpio.utils.misc import validate_input_type
|
|
22
23
|
|
|
23
24
|
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
24
25
|
from pgsui.data_processing.containers import MostFrequentConfig
|
|
@@ -54,6 +55,7 @@ def ensure_mostfrequent_config(
|
|
|
54
55
|
if isinstance(config, str):
|
|
55
56
|
return load_yaml_to_dataclass(config, MostFrequentConfig)
|
|
56
57
|
if isinstance(config, dict):
|
|
58
|
+
config = copy.deepcopy(config) # copy
|
|
57
59
|
base = MostFrequentConfig()
|
|
58
60
|
# honor optional top-level 'preset'
|
|
59
61
|
preset = config.pop("preset", None)
|
|
@@ -77,9 +79,9 @@ def ensure_mostfrequent_config(
|
|
|
77
79
|
|
|
78
80
|
|
|
79
81
|
class ImputeMostFrequent:
|
|
80
|
-
"""Most-frequent (mode) imputer
|
|
82
|
+
"""Most-frequent (mode) deterministic imputer for 0/1/2 genotypes.
|
|
81
83
|
|
|
82
|
-
|
|
84
|
+
Computes the per-locus mode (globally or per population) from the training set and uses it to fill missing values. The evaluation protocol mirrors the DL imputers: train/test split with evaluation on either all observed test cells or a simulated-missing subset (depending on config), plus classification reports and plots. It handles both diploid and haploid data. Input genotypes are expected in 0/1/2 encoding with missing values represented by any negative integer. Output is returned as IUPAC strings via ``decode_012``.
|
|
83
85
|
"""
|
|
84
86
|
|
|
85
87
|
def __init__(
|
|
@@ -109,14 +111,14 @@ class ImputeMostFrequent:
|
|
|
109
111
|
tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
|
|
110
112
|
config (MostFrequentConfig | dict | str | None): Configuration as a dataclass,
|
|
111
113
|
nested dict, or YAML path. If None, defaults are used.
|
|
112
|
-
overrides (dict
|
|
114
|
+
overrides (Optional[dict]): Flat dot-key overrides applied last with highest precedence, e.g. {'algo.by_populations': True, 'split.test_size': 0.3}.
|
|
113
115
|
simulate_missing (bool): Whether to simulate missing data if enabled in config. Defaults to True.
|
|
114
|
-
sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
|
|
115
|
-
sim_prop (float): Proportion of data to simulate as missing if enabled in config.
|
|
116
|
+
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data if enabled in config.
|
|
117
|
+
sim_prop (float): Proportion of data to simulate as missing if enabled in config. Default is 0.2.
|
|
116
118
|
sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
|
|
117
119
|
|
|
118
120
|
Notes:
|
|
119
|
-
- This mirrors other config-driven models (AE/VAE
|
|
121
|
+
- This mirrors other config-driven models (AE/VAE).
|
|
120
122
|
- Evaluation split behavior uses cfg.split; plotting uses cfg.plot.
|
|
121
123
|
- I/O/logging seeds and verbosity use cfg.io.
|
|
122
124
|
"""
|
|
@@ -151,18 +153,19 @@ class ImputeMostFrequent:
|
|
|
151
153
|
self.rng = np.random.default_rng(cfg.io.seed)
|
|
152
154
|
self.encoder = GenotypeEncoder(self.genotype_data)
|
|
153
155
|
|
|
154
|
-
|
|
155
|
-
X012 = self.encoder.genotypes_012.astype(np.int16, copy=False)
|
|
156
|
+
self.missing_internal = -1
|
|
156
157
|
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
np.nan_to_num(X012, nan=-1.0, copy=False)
|
|
158
|
+
# include common missing value aliases
|
|
159
|
+
self.missing_aliases = {int(cfg.algo.missing), -9, -1}
|
|
160
160
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
161
|
+
X = np.asarray(self.encoder.genotypes_012)
|
|
162
|
+
Xf = X.astype(np.float32, copy=False)
|
|
163
|
+
Xf = np.where(np.isnan(Xf), -1.0, Xf)
|
|
164
|
+
Xf[Xf < 0] = -1.0
|
|
165
|
+
self.X012_ = Xf.astype(np.int8, copy=False)
|
|
166
|
+
self.num_features_ = self.X012_.shape[1]
|
|
164
167
|
|
|
165
|
-
# Simulated-missing controls (mirror VAE/AE
|
|
168
|
+
# Simulated-missing controls (mirror VAE/AE semantics where possible)
|
|
166
169
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
167
170
|
sim_cfg_kwargs = dict(getattr(sim_cfg, "sim_kwargs", {}) or {})
|
|
168
171
|
|
|
@@ -226,12 +229,11 @@ class ImputeMostFrequent:
|
|
|
226
229
|
self.test_idx_: Optional[np.ndarray] = None
|
|
227
230
|
self.X_train_df_: Optional[pd.DataFrame] = None
|
|
228
231
|
self.ground_truth012_: Optional[np.ndarray] = None
|
|
229
|
-
self.metrics_: Dict[str, int | float] = {}
|
|
230
232
|
self.X_imputed012_: Optional[np.ndarray] = None
|
|
231
233
|
|
|
232
234
|
# Ploidy heuristic for 0/1/2 scoring parity
|
|
233
|
-
|
|
234
|
-
self.is_haploid_ =
|
|
235
|
+
self.ploidy = self.cfg.io.ploidy
|
|
236
|
+
self.is_haploid_ = self.ploidy == 1
|
|
235
237
|
|
|
236
238
|
# Plotting (use config, not genotype_data fields)
|
|
237
239
|
self.plot_format = cfg.plot.fmt
|
|
@@ -243,6 +245,11 @@ class ImputeMostFrequent:
|
|
|
243
245
|
self.model_name = (
|
|
244
246
|
"ImputeMostFrequentPerPop" if self.by_populations else "ImputeMostFrequent"
|
|
245
247
|
)
|
|
248
|
+
|
|
249
|
+
# Output dirs
|
|
250
|
+
dirs = ["models", "plots", "metrics", "optimize", "parameters"]
|
|
251
|
+
self._create_model_directories(self.prefix, dirs)
|
|
252
|
+
|
|
246
253
|
self.plotter_ = Plotting(
|
|
247
254
|
self.model_name,
|
|
248
255
|
prefix=self.prefix,
|
|
@@ -258,10 +265,6 @@ class ImputeMostFrequent:
|
|
|
258
265
|
multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
259
266
|
)
|
|
260
267
|
|
|
261
|
-
# Output dirs
|
|
262
|
-
dirs = ["models", "plots", "metrics", "optimize", "parameters"]
|
|
263
|
-
self._create_model_directories(self.prefix, dirs)
|
|
264
|
-
|
|
265
268
|
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
266
269
|
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
267
270
|
self.logger.error(msg)
|
|
@@ -280,14 +283,21 @@ class ImputeMostFrequent:
|
|
|
280
283
|
|
|
281
284
|
# Work in DataFrame with NaN as missing for mode computation
|
|
282
285
|
df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
|
|
283
|
-
df_all =
|
|
284
|
-
df_all = df_all.replace(-9, np.nan) # Just in case
|
|
286
|
+
df_all[df_all < 0] = np.nan
|
|
285
287
|
|
|
286
288
|
# Modes from TRAIN rows only (per-locus)
|
|
287
289
|
df_train = df_all.iloc[self.train_idx_].copy()
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
290
|
+
|
|
291
|
+
modes = {}
|
|
292
|
+
for col in df_train.columns:
|
|
293
|
+
s = df_train[col].dropna()
|
|
294
|
+
if s.empty:
|
|
295
|
+
modes[col] = self.default
|
|
296
|
+
else:
|
|
297
|
+
vc = s.value_counts()
|
|
298
|
+
# deterministic tie-break: smallest genotype among ties
|
|
299
|
+
modes[col] = int(vc.index[vc.to_numpy() == vc.to_numpy().max()].min())
|
|
300
|
+
self.global_modes_ = modes
|
|
291
301
|
|
|
292
302
|
self.group_modes_.clear()
|
|
293
303
|
if self.by_populations:
|
|
@@ -304,13 +314,14 @@ class ImputeMostFrequent:
|
|
|
304
314
|
self.logger.error(msg)
|
|
305
315
|
raise ValueError(msg)
|
|
306
316
|
|
|
307
|
-
# ------------------------------
|
|
308
317
|
# Simulated-missing mask (global → test-only)
|
|
309
|
-
# ------------------------------
|
|
310
318
|
obs_mask = df_all.notna().to_numpy() # observed = not NaN
|
|
311
|
-
n_samples
|
|
319
|
+
n_samples = obs_mask.shape[0]
|
|
312
320
|
|
|
313
321
|
if self.simulate_missing:
|
|
322
|
+
X_for_sim = self.ground_truth012_.astype(np.float32, copy=True)
|
|
323
|
+
X_for_sim[X_for_sim < 0] = -9.0
|
|
324
|
+
|
|
314
325
|
# Use the same transformer as VAE
|
|
315
326
|
tr = SimMissingTransformer(
|
|
316
327
|
genotype_data=self.genotype_data,
|
|
@@ -322,11 +333,7 @@ class ImputeMostFrequent:
|
|
|
322
333
|
verbose=self.verbose,
|
|
323
334
|
**self.sim_kwargs,
|
|
324
335
|
)
|
|
325
|
-
# Fit on 0/1/2 with -1 for missing, like VAE
|
|
326
|
-
X_for_sim = self.ground_truth012_.astype(float, copy=True)
|
|
327
|
-
X_for_sim[X_for_sim < 0] = np.nan
|
|
328
336
|
tr.fit(X_for_sim)
|
|
329
|
-
|
|
330
337
|
sim_mask_global = tr.sim_missing_mask_.astype(bool)
|
|
331
338
|
|
|
332
339
|
# Don't simulate on already-missing cells
|
|
@@ -359,7 +366,7 @@ class ImputeMostFrequent:
|
|
|
359
366
|
self.X_train_df_ = df_sim
|
|
360
367
|
self.is_fit_ = True
|
|
361
368
|
|
|
362
|
-
# Save parameters
|
|
369
|
+
# Save parameters
|
|
363
370
|
best_params = self.cfg.to_dict()
|
|
364
371
|
params_fp = self.parameters_dir / "best_parameters.json"
|
|
365
372
|
with open(params_fp, "w") as f:
|
|
@@ -389,11 +396,14 @@ class ImputeMostFrequent:
|
|
|
389
396
|
msg = "Model is not fitted. Call fit() before transform()."
|
|
390
397
|
self.logger.error(msg)
|
|
391
398
|
raise NotFittedError(msg)
|
|
392
|
-
|
|
399
|
+
|
|
400
|
+
assert (
|
|
401
|
+
self.X_train_df_ is not None
|
|
402
|
+
), f"[{self.model_name}] X_train_df_ is not set after fit()."
|
|
393
403
|
|
|
394
404
|
# 1) Impute the evaluation-masked copy (to compute metrics)
|
|
395
405
|
imputed_eval_df = self._impute_df(self.X_train_df_)
|
|
396
|
-
X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.
|
|
406
|
+
X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int8)
|
|
397
407
|
self.X_imputed012_ = X_imputed_eval
|
|
398
408
|
|
|
399
409
|
# Evaluate like DL models (0/1/2, then 10-class from decoded strings)
|
|
@@ -401,22 +411,31 @@ class ImputeMostFrequent:
|
|
|
401
411
|
|
|
402
412
|
# 2) Impute the FULL dataset (only true missings)
|
|
403
413
|
df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
|
|
404
|
-
df_missingonly
|
|
414
|
+
df_missingonly[df_missingonly < 0] = np.nan
|
|
415
|
+
|
|
405
416
|
imputed_full_df = self._impute_df(df_missingonly)
|
|
406
|
-
X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.
|
|
417
|
+
X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int8)
|
|
418
|
+
|
|
419
|
+
neg = int(np.count_nonzero(X_imputed_full_012 < 0))
|
|
420
|
+
if neg:
|
|
421
|
+
msg = f"{neg} negative entries remain after REF imputation. Unique: {np.unique(X_imputed_full_012[X_imputed_full_012 < 0])[:10]}"
|
|
422
|
+
self.logger.error(msg)
|
|
423
|
+
raise RuntimeError(msg)
|
|
407
424
|
|
|
408
425
|
# Plot distributions (parity with DL transform())
|
|
409
426
|
if self.ground_truth012_ is None:
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
)
|
|
427
|
+
msg = "ground_truth012_ is not set; cannot plot distributions."
|
|
428
|
+
self.logger.error(msg)
|
|
429
|
+
raise NotFittedError(msg)
|
|
430
|
+
|
|
431
|
+
imp_decoded = self.decode_012(X_imputed_full_012)
|
|
413
432
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
433
|
+
if self.show_plots:
|
|
434
|
+
gt_decoded = self.decode_012(self.ground_truth012_)
|
|
435
|
+
self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
|
|
436
|
+
self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
|
|
418
437
|
|
|
419
|
-
# Return IUPAC strings
|
|
438
|
+
# Return IUPAC strings
|
|
420
439
|
return imp_decoded
|
|
421
440
|
|
|
422
441
|
def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
|
|
@@ -452,7 +471,7 @@ class ImputeMostFrequent:
|
|
|
452
471
|
df = df_in.fillna(modes)
|
|
453
472
|
else:
|
|
454
473
|
df = df_in.copy()
|
|
455
|
-
return df.astype(np.
|
|
474
|
+
return df.astype(np.int8)
|
|
456
475
|
|
|
457
476
|
def _impute_by_population_mode(self, df_in: pd.DataFrame) -> pd.DataFrame:
|
|
458
477
|
"""Impute missing cells in df_in using population-specific modes.
|
|
@@ -466,7 +485,7 @@ class ImputeMostFrequent:
|
|
|
466
485
|
pd.DataFrame: DataFrame with missing values imputed.
|
|
467
486
|
"""
|
|
468
487
|
if not df_in.isnull().values.any():
|
|
469
|
-
return df_in.astype(np.
|
|
488
|
+
return df_in.astype(np.int8)
|
|
470
489
|
|
|
471
490
|
df = df_in.copy()
|
|
472
491
|
pops = pd.Series(self.pops, index=df.index)
|
|
@@ -489,7 +508,7 @@ class ImputeMostFrequent:
|
|
|
489
508
|
mask = np.isnan(values)
|
|
490
509
|
values[mask] = replacements[mask]
|
|
491
510
|
|
|
492
|
-
return pd.DataFrame(values, columns=df.columns, index=df.index).astype(np.
|
|
511
|
+
return pd.DataFrame(values, columns=df.columns, index=df.index).astype(np.int8)
|
|
493
512
|
|
|
494
513
|
def _series_mode(self, s: pd.Series) -> int:
|
|
495
514
|
"""Compute the mode of a pandas Series, ignoring NaNs.
|
|
@@ -505,11 +524,13 @@ class ImputeMostFrequent:
|
|
|
505
524
|
s_valid = s.dropna().astype(int)
|
|
506
525
|
if s_valid.empty:
|
|
507
526
|
return self.default
|
|
527
|
+
|
|
508
528
|
# Mode among {0,1,2}; if ties, pandas picks the smallest (okay)
|
|
509
529
|
mode_val = int(s_valid.mode().iloc[0])
|
|
510
530
|
if mode_val not in (0, 1, 2):
|
|
511
531
|
# Safety: clamp to valid zygosity in case of odd inputs
|
|
512
532
|
mode_val = self.default if self.default in (0, 1, 2) else 0
|
|
533
|
+
|
|
513
534
|
return mode_val
|
|
514
535
|
|
|
515
536
|
def _evaluate_and_report(self) -> None:
|
|
@@ -540,8 +561,8 @@ class ImputeMostFrequent:
|
|
|
540
561
|
X_pred_eval = self.ground_truth012_.copy()
|
|
541
562
|
X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
|
|
542
563
|
|
|
543
|
-
y_true_dec = self.
|
|
544
|
-
y_pred_dec = self.
|
|
564
|
+
y_true_dec = self.decode_012(self.ground_truth012_)
|
|
565
|
+
y_pred_dec = self.decode_012(X_pred_eval)
|
|
545
566
|
|
|
546
567
|
encodings_dict = {
|
|
547
568
|
"A": 0,
|
|
@@ -565,43 +586,35 @@ class ImputeMostFrequent:
|
|
|
565
586
|
|
|
566
587
|
y_true_10 = y_true_int[self.sim_mask_]
|
|
567
588
|
y_pred_10 = y_pred_int[self.sim_mask_]
|
|
589
|
+
|
|
590
|
+
m = (y_true_10 >= 0) & (y_pred_10 >= 0)
|
|
591
|
+
y_true_10, y_pred_10 = y_true_10[m], y_pred_10[m]
|
|
592
|
+
if y_true_10.size == 0:
|
|
593
|
+
self.logger.warning(
|
|
594
|
+
"No valid IUPAC test cells; skipping 10-class evaluation."
|
|
595
|
+
)
|
|
596
|
+
return
|
|
597
|
+
|
|
568
598
|
self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
|
|
569
599
|
|
|
570
600
|
def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
|
571
601
|
"""0/1/2 zygosity report & confusion matrix.
|
|
572
602
|
|
|
573
|
-
This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is
|
|
603
|
+
This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is haploid (only 0 and 2 present), it folds ALT (2) into the binary ALT/PRESENT class (1) for evaluation. The method computes metrics, logs the report, and creates visualizations of the results.
|
|
574
604
|
|
|
575
605
|
Args:
|
|
576
606
|
y_true (np.ndarray): True genotypes (0/1/2) for masked
|
|
577
607
|
y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
|
|
578
|
-
|
|
579
|
-
Raises:
|
|
580
|
-
NotFittedError: If fit() and transform() have not been called.
|
|
581
608
|
"""
|
|
582
|
-
labels = [0, 1, 2]
|
|
609
|
+
labels: list[int] = [0, 1, 2]
|
|
610
|
+
report_names: list[str] = ["REF", "HET", "ALT"]
|
|
611
|
+
|
|
583
612
|
# Haploid parity: fold ALT (2) into ALT/Present (1)
|
|
584
613
|
if self.is_haploid_:
|
|
585
|
-
y_true
|
|
586
|
-
y_pred
|
|
587
|
-
labels = [0, 1]
|
|
588
|
-
|
|
589
|
-
metrics = {
|
|
590
|
-
"n_masked_test": int(y_true.size),
|
|
591
|
-
"accuracy": accuracy_score(y_true, y_pred),
|
|
592
|
-
"f1": f1_score(
|
|
593
|
-
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
594
|
-
),
|
|
595
|
-
"precision": precision_score(
|
|
596
|
-
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
597
|
-
),
|
|
598
|
-
"recall": recall_score(
|
|
599
|
-
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
600
|
-
),
|
|
601
|
-
}
|
|
602
|
-
self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
|
|
603
|
-
|
|
604
|
-
report_names = ["REF", "HET"] if self.is_haploid_ else ["REF", "HET", "ALT"]
|
|
614
|
+
y_true = np.where(y_true == 2, 1, y_true)
|
|
615
|
+
y_pred = np.where(y_pred == 2, 1, y_pred)
|
|
616
|
+
labels: list[int] = [0, 1]
|
|
617
|
+
report_names: list[str] = ["REF", "ALT"]
|
|
605
618
|
|
|
606
619
|
report: dict | str = classification_report(
|
|
607
620
|
y_true,
|
|
@@ -617,50 +630,46 @@ class ImputeMostFrequent:
|
|
|
617
630
|
self.logger.error(msg)
|
|
618
631
|
raise TypeError(msg)
|
|
619
632
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
report_subset[k] = tmp
|
|
629
|
-
|
|
630
|
-
if report_subset:
|
|
631
|
-
pm = PrettyMetrics(
|
|
632
|
-
report_subset,
|
|
633
|
-
precision=3,
|
|
634
|
-
title=f"{self.model_name} Zygosity Report",
|
|
633
|
+
if self.show_plots:
|
|
634
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
635
|
+
|
|
636
|
+
plots = viz.plot_all(
|
|
637
|
+
report,
|
|
638
|
+
title_prefix=f"{self.model_name} Zygosity Report",
|
|
639
|
+
show=self.show_plots,
|
|
640
|
+
heatmap_classes_only=True,
|
|
635
641
|
)
|
|
636
|
-
pm.render()
|
|
637
642
|
|
|
638
|
-
|
|
643
|
+
for name, fig in plots.items():
|
|
644
|
+
fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
|
|
645
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
646
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
647
|
+
plt.close(fig)
|
|
648
|
+
elif isinstance(fig, PlotlyFigure):
|
|
649
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
639
650
|
|
|
640
|
-
|
|
641
|
-
report,
|
|
642
|
-
title_prefix=f"{self.model_name} Zygosity Report",
|
|
643
|
-
show=getattr(self, "show_plots", False),
|
|
644
|
-
heatmap_classes_only=True,
|
|
645
|
-
)
|
|
651
|
+
viz._reset_mpl_style()
|
|
646
652
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
plt.close(fig)
|
|
652
|
-
elif isinstance(fig, PlotlyFigure):
|
|
653
|
-
fig.write_html(file=fout.with_suffix(".html"))
|
|
653
|
+
# Confusion matrix
|
|
654
|
+
self.plotter_.plot_confusion_matrix(
|
|
655
|
+
y_true, y_pred, label_names=report_names, prefix="zygosity"
|
|
656
|
+
)
|
|
654
657
|
|
|
655
|
-
|
|
658
|
+
# ------ Additional metrics ------
|
|
659
|
+
report_full = self._additional_metrics(
|
|
660
|
+
y_true, y_pred, labels, report_names, report
|
|
661
|
+
)
|
|
656
662
|
|
|
657
|
-
|
|
658
|
-
|
|
663
|
+
if self.verbose or self.debug:
|
|
664
|
+
pm = PrettyMetrics(
|
|
665
|
+
report_full,
|
|
666
|
+
precision=2,
|
|
667
|
+
title=f"{self.model_name} Zygosity Report",
|
|
668
|
+
)
|
|
669
|
+
pm.render()
|
|
659
670
|
|
|
660
|
-
#
|
|
661
|
-
self.
|
|
662
|
-
y_true, y_pred, label_names=report_names, prefix="zygosity"
|
|
663
|
-
)
|
|
671
|
+
# Save JSON
|
|
672
|
+
self._save_report(report_full, suffix="zygosity")
|
|
664
673
|
|
|
665
674
|
def _evaluate_iupac10_and_plot(
|
|
666
675
|
self, y_true: np.ndarray, y_pred: np.ndarray
|
|
@@ -672,32 +681,18 @@ class ImputeMostFrequent:
|
|
|
672
681
|
Args:
|
|
673
682
|
y_true (np.ndarray): True genotypes (0-9) for masked
|
|
674
683
|
y_pred (np.ndarray): Predicted genotypes (0-9) for masked
|
|
675
|
-
|
|
676
|
-
Raises:
|
|
677
|
-
NotFittedError: If fit() and transform() have not been called.
|
|
678
684
|
"""
|
|
679
685
|
labels_idx = list(range(10))
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
"f1": f1_score(
|
|
685
|
-
y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
|
|
686
|
-
),
|
|
687
|
-
"precision": precision_score(
|
|
688
|
-
y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
|
|
689
|
-
),
|
|
690
|
-
"recall": recall_score(
|
|
691
|
-
y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
|
|
692
|
-
),
|
|
693
|
-
}
|
|
694
|
-
self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
|
|
686
|
+
report_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
|
|
687
|
+
|
|
688
|
+
# Create an identity matrix and use the targets array as indices
|
|
689
|
+
y_score = np.eye(len(report_names))[y_pred]
|
|
695
690
|
|
|
696
691
|
report: dict | str = classification_report(
|
|
697
692
|
y_true,
|
|
698
693
|
y_pred,
|
|
699
694
|
labels=labels_idx,
|
|
700
|
-
target_names=
|
|
695
|
+
target_names=report_names,
|
|
701
696
|
zero_division=0,
|
|
702
697
|
output_dict=True,
|
|
703
698
|
)
|
|
@@ -707,54 +702,50 @@ class ImputeMostFrequent:
|
|
|
707
702
|
self.logger.error(msg)
|
|
708
703
|
raise TypeError(msg)
|
|
709
704
|
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
report_subset[k] = tmp
|
|
719
|
-
|
|
720
|
-
if report_subset:
|
|
721
|
-
pm = PrettyMetrics(
|
|
722
|
-
report_subset,
|
|
723
|
-
precision=3,
|
|
724
|
-
title=f"{self.model_name} IUPAC 10-Class Report",
|
|
705
|
+
if self.show_plots:
|
|
706
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
707
|
+
|
|
708
|
+
plots = viz.plot_all(
|
|
709
|
+
report,
|
|
710
|
+
title_prefix=f"{self.model_name} IUPAC Report",
|
|
711
|
+
show=self.show_plots,
|
|
712
|
+
heatmap_classes_only=True,
|
|
725
713
|
)
|
|
726
|
-
pm.render()
|
|
727
714
|
|
|
728
|
-
|
|
715
|
+
# Reset the style from Optuna's plotting.
|
|
716
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
729
717
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
718
|
+
for name, fig in plots.items():
|
|
719
|
+
fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
|
|
720
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
721
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
722
|
+
plt.close(fig)
|
|
723
|
+
elif isinstance(fig, PlotlyFigure):
|
|
724
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
736
725
|
|
|
737
|
-
|
|
738
|
-
|
|
726
|
+
# Reset the style
|
|
727
|
+
viz._reset_mpl_style()
|
|
739
728
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
plt.close(fig)
|
|
745
|
-
elif isinstance(fig, PlotlyFigure):
|
|
746
|
-
fig.write_html(file=fout.with_suffix(".html"))
|
|
729
|
+
# Confusion matrix
|
|
730
|
+
self.plotter_.plot_confusion_matrix(
|
|
731
|
+
y_true, y_pred, label_names=report_names, prefix="iupac"
|
|
732
|
+
)
|
|
747
733
|
|
|
748
|
-
#
|
|
749
|
-
|
|
734
|
+
# ------ Additional metrics ------
|
|
735
|
+
report_full = self._additional_metrics(
|
|
736
|
+
y_true, y_pred, labels_idx, report_names, report
|
|
737
|
+
)
|
|
750
738
|
|
|
751
|
-
|
|
752
|
-
|
|
739
|
+
if self.verbose or self.debug:
|
|
740
|
+
pm = PrettyMetrics(
|
|
741
|
+
report_full,
|
|
742
|
+
precision=2,
|
|
743
|
+
title=f"{self.model_name} IUPAC 10-Class Report",
|
|
744
|
+
)
|
|
745
|
+
pm.render()
|
|
753
746
|
|
|
754
|
-
#
|
|
755
|
-
self.
|
|
756
|
-
y_true, y_pred, label_names=labels_names, prefix="iupac"
|
|
757
|
-
)
|
|
747
|
+
# Save JSON
|
|
748
|
+
self._save_report(report_full, suffix="iupac")
|
|
758
749
|
|
|
759
750
|
def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
760
751
|
"""Create train/test split indices.
|
|
@@ -780,14 +771,14 @@ class ImputeMostFrequent:
|
|
|
780
771
|
buckets = []
|
|
781
772
|
for pop in np.unique(self.pops):
|
|
782
773
|
rows = np.where(self.pops == pop)[0]
|
|
783
|
-
k = int(round(self.test_size * rows.size))
|
|
774
|
+
k = max(1, int(round(self.test_size * rows.size)))
|
|
784
775
|
if k > 0:
|
|
785
776
|
buckets.append(self.rng.choice(rows, size=k, replace=False))
|
|
786
777
|
test_idx = (
|
|
787
778
|
np.sort(np.concatenate(buckets)) if buckets else np.array([], dtype=int)
|
|
788
779
|
)
|
|
789
780
|
else:
|
|
790
|
-
k = int(round(self.test_size * n))
|
|
781
|
+
k = max(1, int(round(self.test_size * n)))
|
|
791
782
|
test_idx = (
|
|
792
783
|
self.rng.choice(n, size=k, replace=False)
|
|
793
784
|
if k > 0
|
|
@@ -797,13 +788,13 @@ class ImputeMostFrequent:
|
|
|
797
788
|
train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
|
|
798
789
|
return train_idx, test_idx
|
|
799
790
|
|
|
800
|
-
def _save_report(self, report_dict: Dict[str,
|
|
791
|
+
def _save_report(self, report_dict: Dict[str, Any], suffix: str) -> None:
|
|
801
792
|
"""Save classification report dictionary as a JSON file.
|
|
802
793
|
|
|
803
794
|
This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
|
|
804
795
|
|
|
805
796
|
Args:
|
|
806
|
-
report_dict (Dict[str,
|
|
797
|
+
report_dict (Dict[str, Any]): The classification report dictionary to save.
|
|
807
798
|
suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
|
|
808
799
|
|
|
809
800
|
Raises:
|
|
@@ -842,3 +833,305 @@ class ImputeMostFrequent:
|
|
|
842
833
|
msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
|
|
843
834
|
self.logger.error(msg)
|
|
844
835
|
raise Exception(msg)
|
|
836
|
+
|
|
837
|
+
def decode_012(
|
|
838
|
+
self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
|
|
839
|
+
) -> np.ndarray:
|
|
840
|
+
"""Decode 012-encodings to IUPAC chars with metadata repair.
|
|
841
|
+
|
|
842
|
+
This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
|
|
843
|
+
1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
|
|
844
|
+
2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
|
|
845
|
+
|
|
846
|
+
Args:
|
|
847
|
+
X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
|
|
848
|
+
is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
|
|
849
|
+
|
|
850
|
+
Returns:
|
|
851
|
+
np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
|
|
852
|
+
|
|
853
|
+
Notes:
|
|
854
|
+
- The method normalizes input values to handle various formats, including strings, lists, and arrays.
|
|
855
|
+
- It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
|
|
856
|
+
- Missing or invalid codes are represented as 'N' if they can't be resolved.
|
|
857
|
+
- The method includes repair logic to infer missing metadata from the source SNP data when necessary.
|
|
858
|
+
|
|
859
|
+
Raises:
|
|
860
|
+
ValueError: If input is not a DataFrame.
|
|
861
|
+
"""
|
|
862
|
+
df = validate_input_type(X, return_type="df")
|
|
863
|
+
|
|
864
|
+
if not isinstance(df, pd.DataFrame):
|
|
865
|
+
msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
|
|
866
|
+
self.logger.error(msg)
|
|
867
|
+
raise ValueError(msg)
|
|
868
|
+
|
|
869
|
+
# IUPAC Definitions
|
|
870
|
+
iupac_to_bases: dict[str, set[str]] = {
|
|
871
|
+
"A": {"A"},
|
|
872
|
+
"C": {"C"},
|
|
873
|
+
"G": {"G"},
|
|
874
|
+
"T": {"T"},
|
|
875
|
+
"R": {"A", "G"},
|
|
876
|
+
"Y": {"C", "T"},
|
|
877
|
+
"S": {"G", "C"},
|
|
878
|
+
"W": {"A", "T"},
|
|
879
|
+
"K": {"G", "T"},
|
|
880
|
+
"M": {"A", "C"},
|
|
881
|
+
"B": {"C", "G", "T"},
|
|
882
|
+
"D": {"A", "G", "T"},
|
|
883
|
+
"H": {"A", "C", "T"},
|
|
884
|
+
"V": {"A", "C", "G"},
|
|
885
|
+
"N": set(),
|
|
886
|
+
}
|
|
887
|
+
bases_to_iupac = {
|
|
888
|
+
frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
|
|
889
|
+
}
|
|
890
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
|
|
891
|
+
|
|
892
|
+
def _normalize_iupac(value: object) -> str | None:
|
|
893
|
+
"""Normalize an input into a single IUPAC code token or None."""
|
|
894
|
+
if value is None:
|
|
895
|
+
return None
|
|
896
|
+
|
|
897
|
+
# Bytes -> str (make type narrowing explicit)
|
|
898
|
+
if isinstance(value, (bytes, np.bytes_)):
|
|
899
|
+
value = bytes(value).decode("utf-8", errors="ignore")
|
|
900
|
+
|
|
901
|
+
# Handle list/tuple/array/Series: take first valid
|
|
902
|
+
if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
|
|
903
|
+
# Convert Series to numpy array for consistent behavior
|
|
904
|
+
if isinstance(value, pd.Series):
|
|
905
|
+
arr = value.to_numpy()
|
|
906
|
+
else:
|
|
907
|
+
arr = value
|
|
908
|
+
|
|
909
|
+
# Scalar numpy array fast path
|
|
910
|
+
if isinstance(arr, np.ndarray) and arr.ndim == 0:
|
|
911
|
+
return _normalize_iupac(arr.item())
|
|
912
|
+
|
|
913
|
+
# Empty sequence/array
|
|
914
|
+
if len(arr) == 0:
|
|
915
|
+
return None
|
|
916
|
+
|
|
917
|
+
# First valid element wins
|
|
918
|
+
for item in arr:
|
|
919
|
+
code = _normalize_iupac(item)
|
|
920
|
+
if code is not None:
|
|
921
|
+
return code
|
|
922
|
+
return None
|
|
923
|
+
|
|
924
|
+
s = str(value).upper().strip()
|
|
925
|
+
if not s or s in missing_codes:
|
|
926
|
+
return None
|
|
927
|
+
|
|
928
|
+
if "," in s:
|
|
929
|
+
for tok in (t.strip() for t in s.split(",")):
|
|
930
|
+
if tok and tok not in missing_codes and tok in iupac_to_bases:
|
|
931
|
+
return tok
|
|
932
|
+
return None
|
|
933
|
+
|
|
934
|
+
return s if s in iupac_to_bases else None
|
|
935
|
+
|
|
936
|
+
codes_df = df.apply(pd.to_numeric, errors="coerce")
|
|
937
|
+
codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
|
|
938
|
+
n_rows, n_cols = codes.shape
|
|
939
|
+
|
|
940
|
+
if is_nuc:
|
|
941
|
+
iupac_list = np.array(
|
|
942
|
+
["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
|
|
943
|
+
)
|
|
944
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
945
|
+
mask = (codes >= 0) & (codes <= 9)
|
|
946
|
+
out[mask] = iupac_list[codes[mask]]
|
|
947
|
+
return out
|
|
948
|
+
|
|
949
|
+
# Metadata fetch
|
|
950
|
+
ref_alleles = getattr(self.genotype_data, "ref", [])
|
|
951
|
+
alt_alleles = getattr(self.genotype_data, "alt", [])
|
|
952
|
+
|
|
953
|
+
if len(ref_alleles) != n_cols:
|
|
954
|
+
ref_alleles = getattr(self, "_ref", [None] * n_cols)
|
|
955
|
+
if len(alt_alleles) != n_cols:
|
|
956
|
+
alt_alleles = getattr(self, "_alt", [None] * n_cols)
|
|
957
|
+
|
|
958
|
+
# Ensure list length matches
|
|
959
|
+
if len(ref_alleles) != n_cols:
|
|
960
|
+
ref_alleles = [None] * n_cols
|
|
961
|
+
if len(alt_alleles) != n_cols:
|
|
962
|
+
alt_alleles = [None] * n_cols
|
|
963
|
+
|
|
964
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
965
|
+
source_snp_data = None
|
|
966
|
+
|
|
967
|
+
for j in range(n_cols):
|
|
968
|
+
ref = _normalize_iupac(ref_alleles[j])
|
|
969
|
+
alt = _normalize_iupac(alt_alleles[j])
|
|
970
|
+
|
|
971
|
+
# --- REPAIR LOGIC ---
|
|
972
|
+
# If metadata is missing, scan the source column.
|
|
973
|
+
if ref is None or alt is None:
|
|
974
|
+
if source_snp_data is None and self.genotype_data.snp_data is not None:
|
|
975
|
+
try:
|
|
976
|
+
source_snp_data = np.asarray(self.genotype_data.snp_data)
|
|
977
|
+
except Exception:
|
|
978
|
+
pass # if lazy loading fails
|
|
979
|
+
|
|
980
|
+
if source_snp_data is not None:
|
|
981
|
+
try:
|
|
982
|
+
col_data = source_snp_data[:, j]
|
|
983
|
+
uniques = set()
|
|
984
|
+
# Optimization: check up to 200 non-empty values
|
|
985
|
+
count = 0
|
|
986
|
+
for val in col_data:
|
|
987
|
+
norm = _normalize_iupac(val)
|
|
988
|
+
if norm:
|
|
989
|
+
uniques.add(norm)
|
|
990
|
+
count += 1
|
|
991
|
+
if len(uniques) >= 2 or count > 200:
|
|
992
|
+
break
|
|
993
|
+
|
|
994
|
+
sorted_u = sorted(list(uniques))
|
|
995
|
+
if len(sorted_u) >= 1 and ref is None:
|
|
996
|
+
ref = sorted_u[0]
|
|
997
|
+
if len(sorted_u) >= 2 and alt is None:
|
|
998
|
+
alt = sorted_u[1]
|
|
999
|
+
except Exception:
|
|
1000
|
+
pass
|
|
1001
|
+
|
|
1002
|
+
# --- DEFAULTS FOR MISSING ---
|
|
1003
|
+
# If still missing, we cannot decode.
|
|
1004
|
+
if ref is None and alt is None:
|
|
1005
|
+
ref = "N"
|
|
1006
|
+
alt = "N"
|
|
1007
|
+
elif ref is None:
|
|
1008
|
+
ref = alt
|
|
1009
|
+
elif alt is None:
|
|
1010
|
+
alt = ref # Monomorphic site: ALT becomes REF
|
|
1011
|
+
|
|
1012
|
+
# --- COMPUTE HET CODE ---
|
|
1013
|
+
if ref == alt:
|
|
1014
|
+
het_code = ref
|
|
1015
|
+
else:
|
|
1016
|
+
ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
|
|
1017
|
+
alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
|
|
1018
|
+
union_set = frozenset(ref_set | alt_set)
|
|
1019
|
+
het_code = bases_to_iupac.get(union_set, "N")
|
|
1020
|
+
|
|
1021
|
+
# --- ASSIGNMENT WITH SAFETY FALLBACKS ---
|
|
1022
|
+
col_codes = codes[:, j]
|
|
1023
|
+
|
|
1024
|
+
# Case 0: REF
|
|
1025
|
+
if ref != "N":
|
|
1026
|
+
out[col_codes == 0, j] = ref
|
|
1027
|
+
|
|
1028
|
+
# Case 1: HET
|
|
1029
|
+
if het_code != "N":
|
|
1030
|
+
out[col_codes == 1, j] = het_code
|
|
1031
|
+
else:
|
|
1032
|
+
# If HET code is invalid (e.g. ref='A', alt='N'),
|
|
1033
|
+
# fallback to REF
|
|
1034
|
+
# Fix for an issue where a HET prediction at a monomorphic site
|
|
1035
|
+
# produced 'N'
|
|
1036
|
+
if ref != "N":
|
|
1037
|
+
out[col_codes == 1, j] = ref
|
|
1038
|
+
|
|
1039
|
+
# Case 2: ALT
|
|
1040
|
+
if alt != "N":
|
|
1041
|
+
out[col_codes == 2, j] = alt
|
|
1042
|
+
else:
|
|
1043
|
+
# If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
|
|
1044
|
+
# Fix for an issue where an ALT prediction on a monomorphic site
|
|
1045
|
+
# produced 'N'
|
|
1046
|
+
if ref != "N":
|
|
1047
|
+
out[col_codes == 2, j] = ref
|
|
1048
|
+
|
|
1049
|
+
return out
|
|
1050
|
+
|
|
1051
|
+
def _additional_metrics(
|
|
1052
|
+
self,
|
|
1053
|
+
y_true: np.ndarray,
|
|
1054
|
+
y_pred: np.ndarray,
|
|
1055
|
+
labels: list[int],
|
|
1056
|
+
report_names: list[str],
|
|
1057
|
+
report: dict[str, dict[str, float] | float],
|
|
1058
|
+
) -> dict[str, dict[str, float] | float]:
|
|
1059
|
+
"""Compute additional metrics and augment the report dictionary.
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
y_true (np.ndarray): True genotypes.
|
|
1063
|
+
y_pred (np.ndarray): Predicted genotypes.
|
|
1064
|
+
labels (list[int]): List of label indices.
|
|
1065
|
+
report_names (list[str]): List of report names corresponding to labels.
|
|
1066
|
+
report (dict[str, dict[str, float] | float]): Classification report dictionary to augment.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
|
|
1070
|
+
"""
|
|
1071
|
+
# Create an identity matrix and use the targets array as indices
|
|
1072
|
+
y_score = np.eye(len(report_names))[y_pred]
|
|
1073
|
+
|
|
1074
|
+
# Per-class metrics
|
|
1075
|
+
ap_pc = average_precision_score(y_true, y_score, average=None)
|
|
1076
|
+
jaccard_pc = jaccard_score(
|
|
1077
|
+
y_true, y_pred, average=None, labels=labels, zero_division=0
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
# Macro/weighted metrics
|
|
1081
|
+
ap_macro = average_precision_score(y_true, y_score, average="macro")
|
|
1082
|
+
ap_weighted = average_precision_score(y_true, y_score, average="weighted")
|
|
1083
|
+
jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
|
|
1084
|
+
jaccard_weighted = jaccard_score(
|
|
1085
|
+
y_true, y_pred, average="weighted", zero_division=0
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
# Matthews correlation coefficient (MCC)
|
|
1089
|
+
mcc = matthews_corrcoef(y_true, y_pred)
|
|
1090
|
+
|
|
1091
|
+
if not isinstance(ap_pc, np.ndarray):
|
|
1092
|
+
msg = "average_precision_score or f1_score did not return np.ndarray as expected."
|
|
1093
|
+
self.logger.error(msg)
|
|
1094
|
+
raise TypeError(msg)
|
|
1095
|
+
|
|
1096
|
+
if not isinstance(jaccard_pc, np.ndarray):
|
|
1097
|
+
msg = "jaccard_score did not return np.ndarray as expected."
|
|
1098
|
+
self.logger.error(msg)
|
|
1099
|
+
raise TypeError(msg)
|
|
1100
|
+
|
|
1101
|
+
# Add per-class metrics
|
|
1102
|
+
report_full = {}
|
|
1103
|
+
dd_subset = {
|
|
1104
|
+
k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
|
|
1105
|
+
}
|
|
1106
|
+
for i, class_name in enumerate(report_names):
|
|
1107
|
+
class_report: dict[str, float] = {}
|
|
1108
|
+
if class_name in dd_subset:
|
|
1109
|
+
class_report = dd_subset[class_name]
|
|
1110
|
+
|
|
1111
|
+
if isinstance(class_report, float) or not class_report:
|
|
1112
|
+
continue
|
|
1113
|
+
|
|
1114
|
+
report_full[class_name] = dict(class_report)
|
|
1115
|
+
report_full[class_name]["average-precision"] = float(ap_pc[i])
|
|
1116
|
+
report_full[class_name]["jaccard"] = float(jaccard_pc[i])
|
|
1117
|
+
|
|
1118
|
+
macro_avg = report.get("macro avg")
|
|
1119
|
+
if isinstance(macro_avg, dict):
|
|
1120
|
+
report_full["macro avg"] = dict(macro_avg)
|
|
1121
|
+
report_full["macro avg"]["average-precision"] = float(ap_macro)
|
|
1122
|
+
report_full["macro avg"]["jaccard"] = float(jaccard_macro)
|
|
1123
|
+
|
|
1124
|
+
weighted_avg = report.get("weighted avg")
|
|
1125
|
+
if isinstance(weighted_avg, dict):
|
|
1126
|
+
report_full["weighted avg"] = dict(weighted_avg)
|
|
1127
|
+
report_full["weighted avg"]["average-precision"] = float(ap_weighted)
|
|
1128
|
+
report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
|
|
1129
|
+
|
|
1130
|
+
# Add scalar summary metrics
|
|
1131
|
+
report_full["mcc"] = float(mcc)
|
|
1132
|
+
accuracy_val = report.get("accuracy")
|
|
1133
|
+
|
|
1134
|
+
if isinstance(accuracy_val, (int, float)):
|
|
1135
|
+
report_full["accuracy"] = float(accuracy_val)
|
|
1136
|
+
|
|
1137
|
+
return report_full
|