pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,669 @@
|
|
|
1
|
+
# Standard library
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
# Third-party
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from matplotlib.figure import Figure
|
|
11
|
+
from plotly.graph_objs import Figure as PlotlyFigure
|
|
12
|
+
from sklearn.exceptions import NotFittedError
|
|
13
|
+
from sklearn.metrics import (
|
|
14
|
+
accuracy_score,
|
|
15
|
+
classification_report,
|
|
16
|
+
f1_score,
|
|
17
|
+
precision_score,
|
|
18
|
+
recall_score,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Project
|
|
22
|
+
from snpio import GenotypeEncoder
|
|
23
|
+
from snpio.utils.logging import LoggerManager
|
|
24
|
+
|
|
25
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
26
|
+
from pgsui.data_processing.containers import RefAlleleConfig
|
|
27
|
+
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
28
|
+
from pgsui.utils.classification_viz import ClassificationReportVisualizer
|
|
29
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
30
|
+
from pgsui.utils.plotting import Plotting
|
|
31
|
+
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from snpio import TreeParser
|
|
35
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def ensure_refallele_config(
|
|
39
|
+
config: Union[RefAlleleConfig, dict, str, None],
|
|
40
|
+
) -> RefAlleleConfig:
|
|
41
|
+
"""Return a concrete RefAlleleConfig (dataclass, dict, YAML path, or None).
|
|
42
|
+
|
|
43
|
+
This function normalizes the input configuration for the RefAllele imputer. It accepts a RefAlleleConfig instance, a dictionary of parameters, a path to a YAML file, or None. If None is provided, it returns a default RefAlleleConfig instance. If a dictionary is provided, it flattens any nested structures and applies the parameters to a base configuration, honoring any top-level 'preset' key. If a string path is provided, it loads the configuration from the specified YAML file.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
config (Union[RefAlleleConfig, dict, str, None]): Configuration input which can be a RefAlleleConfig instance, a dictionary of parameters, a path to a YAML file, or None.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
RefAlleleConfig: A concrete RefAlleleConfig instance.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
TypeError: If the input type is not supported.
|
|
53
|
+
"""
|
|
54
|
+
if config is None:
|
|
55
|
+
return RefAlleleConfig()
|
|
56
|
+
if isinstance(config, RefAlleleConfig):
|
|
57
|
+
return config
|
|
58
|
+
if isinstance(config, str):
|
|
59
|
+
return load_yaml_to_dataclass(config, RefAlleleConfig)
|
|
60
|
+
if isinstance(config, dict):
|
|
61
|
+
base = RefAlleleConfig()
|
|
62
|
+
# honor optional top-level 'preset'
|
|
63
|
+
preset = config.pop("preset", None)
|
|
64
|
+
if preset:
|
|
65
|
+
base = RefAlleleConfig.from_preset(preset)
|
|
66
|
+
|
|
67
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
68
|
+
for k, v in d.items():
|
|
69
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
70
|
+
if isinstance(v, dict):
|
|
71
|
+
_flatten(kk, v, out)
|
|
72
|
+
else:
|
|
73
|
+
out[kk] = v
|
|
74
|
+
return out
|
|
75
|
+
|
|
76
|
+
flat = _flatten("", config, {})
|
|
77
|
+
return apply_dot_overrides(base, flat)
|
|
78
|
+
|
|
79
|
+
raise TypeError(
|
|
80
|
+
f"config must be RefAlleleConfig, dict, YAML path, or None, but got: {type(config)}."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ImputeRefAllele:
|
|
85
|
+
"""Deterministic imputer that replaces all missing 0/1/2 genotype values with the REF genotype (0).
|
|
86
|
+
|
|
87
|
+
The imputer works on 0/1/2 with -1 as missing. Evaluation splits samples into TRAIN/TEST once. Masks ALL originally observed cells on TEST rows for eval. Produces: 0/1/2 (zygosity) classification report + confusion matrix 10-class IUPAC classification report (via decode_012) + confusion matrix. Plots genotype distribution before/after imputation.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
genotype_data: "GenotypeData",
|
|
93
|
+
*,
|
|
94
|
+
tree_parser: Optional["TreeParser"] = None,
|
|
95
|
+
config: Optional[Union[RefAlleleConfig, dict, str]] = None,
|
|
96
|
+
overrides: Optional[dict] = None,
|
|
97
|
+
simulate_missing: bool = True,
|
|
98
|
+
sim_strategy: Literal[
|
|
99
|
+
"random",
|
|
100
|
+
"random_weighted",
|
|
101
|
+
"random_weighted_inv",
|
|
102
|
+
"nonrandom",
|
|
103
|
+
"nonrandom_weighted",
|
|
104
|
+
] = "random",
|
|
105
|
+
sim_prop: float = 0.2,
|
|
106
|
+
sim_kwargs: Optional[dict] = None,
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Initialize the Ref-Allele imputer from a unified config.
|
|
109
|
+
|
|
110
|
+
This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and various parameters based on the configuration. The imputer is prepared to handle population-specific modes if specified in the configuration.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
genotype_data (GenotypeData): Backing genotype data.
|
|
114
|
+
tree_parser (Optional[TreeParser]): Optional SNPio phylogenetic tree parser for population-specific modes.
|
|
115
|
+
config (RefAlleleConfig | dict | str | None): Configuration as a dataclass, nested dict, or YAML path. If None, defaults are used.
|
|
116
|
+
overrides (dict | None): Flat dot-key overrides applied last with highest precedence, e.g. {'split.test_size': 0.25, 'algo.missing': -1}.
|
|
117
|
+
simulate_missing (bool): Whether to simulate missing data during evaluation. Default is True.
|
|
118
|
+
sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
|
|
119
|
+
sim_prop (float): Proportion of data to simulate as missing if enabled in config.
|
|
120
|
+
sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
|
|
121
|
+
"""
|
|
122
|
+
# Normalize config then apply highest-precedence overrides
|
|
123
|
+
cfg = ensure_refallele_config(config)
|
|
124
|
+
if overrides:
|
|
125
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
126
|
+
self.cfg = cfg
|
|
127
|
+
|
|
128
|
+
# Basic fields
|
|
129
|
+
self.genotype_data = genotype_data
|
|
130
|
+
self.tree_parser = tree_parser
|
|
131
|
+
self.prefix = cfg.io.prefix
|
|
132
|
+
self.verbose = cfg.io.verbose
|
|
133
|
+
self.debug = cfg.io.debug
|
|
134
|
+
|
|
135
|
+
# Simulation knobs (shared with other deterministic imputers)
|
|
136
|
+
if cfg.sim is None:
|
|
137
|
+
self.simulate_missing = simulate_missing
|
|
138
|
+
self.sim_strategy = sim_strategy
|
|
139
|
+
self.sim_prop = sim_prop
|
|
140
|
+
self.sim_kwargs = sim_kwargs or {}
|
|
141
|
+
else:
|
|
142
|
+
sim_cfg = cfg.sim
|
|
143
|
+
self.simulate_missing = getattr(
|
|
144
|
+
sim_cfg, "simulate_missing", simulate_missing
|
|
145
|
+
)
|
|
146
|
+
self.sim_strategy = getattr(sim_cfg, "sim_strategy", sim_strategy)
|
|
147
|
+
self.sim_prop = float(getattr(sim_cfg, "sim_prop", sim_prop))
|
|
148
|
+
self.sim_kwargs: Dict[str, Any] = dict(
|
|
149
|
+
getattr(sim_cfg, "sim_kwargs", sim_kwargs) or {}
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Output dirs
|
|
153
|
+
self.plots_dir: Path
|
|
154
|
+
self.metrics_dir: Path
|
|
155
|
+
self.parameters_dir: Path
|
|
156
|
+
self.model_dir: Path
|
|
157
|
+
self.optimize_dir: Path
|
|
158
|
+
|
|
159
|
+
# Logger
|
|
160
|
+
logman = LoggerManager(
|
|
161
|
+
__name__, prefix=self.prefix, verbose=self.verbose, debug=self.debug
|
|
162
|
+
)
|
|
163
|
+
self.logger = configure_logger(
|
|
164
|
+
logman.get_logger(), verbose=self.verbose, debug=self.debug
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
168
|
+
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
169
|
+
self.logger.error(msg)
|
|
170
|
+
raise ValueError(msg)
|
|
171
|
+
|
|
172
|
+
# RNG / encoder
|
|
173
|
+
self.rng = np.random.default_rng(cfg.io.seed)
|
|
174
|
+
self.encoder = GenotypeEncoder(self.genotype_data)
|
|
175
|
+
|
|
176
|
+
# Work in 0/1/2 with -1 for missing
|
|
177
|
+
X012 = self.encoder.genotypes_012.astype(np.int16, copy=True)
|
|
178
|
+
X012[X012 < 0] = -1
|
|
179
|
+
self.X012_ = X012
|
|
180
|
+
self.num_features_ = X012.shape[1]
|
|
181
|
+
|
|
182
|
+
# Split & algo knobs
|
|
183
|
+
self.test_size = float(cfg.split.test_size)
|
|
184
|
+
self.test_indices = (
|
|
185
|
+
None
|
|
186
|
+
if cfg.split.test_indices is None
|
|
187
|
+
else np.asarray(cfg.split.test_indices, dtype=int)
|
|
188
|
+
)
|
|
189
|
+
self.missing = int(cfg.algo.missing)
|
|
190
|
+
|
|
191
|
+
# State
|
|
192
|
+
self.is_fit_: bool = False
|
|
193
|
+
self.sim_mask_: np.ndarray | None = None
|
|
194
|
+
self.train_idx_: np.ndarray | None = None
|
|
195
|
+
self.test_idx_: np.ndarray | None = None
|
|
196
|
+
self.X_train_df_: pd.DataFrame | None = None
|
|
197
|
+
self.ground_truth012_: np.ndarray | None = None
|
|
198
|
+
self.X_imputed012_: np.ndarray | None = None
|
|
199
|
+
self.metrics_: Dict[str, int | float] = {}
|
|
200
|
+
|
|
201
|
+
# Ploidy heuristic for 0/1/2 scoring parity
|
|
202
|
+
uniq = np.unique(self.X012_[self.X012_ != -1])
|
|
203
|
+
self.is_haploid_ = np.array_equal(np.sort(uniq), np.array([0, 2]))
|
|
204
|
+
|
|
205
|
+
# Plotting (use config)
|
|
206
|
+
self.plot_format = cfg.plot.fmt
|
|
207
|
+
self.plot_fontsize = cfg.plot.fontsize
|
|
208
|
+
self.plot_despine = cfg.plot.despine
|
|
209
|
+
self.plot_dpi = cfg.plot.dpi
|
|
210
|
+
self.show_plots = cfg.plot.show
|
|
211
|
+
|
|
212
|
+
self.model_name = "ImputeRefAllele"
|
|
213
|
+
self.plotter_ = Plotting(
|
|
214
|
+
self.model_name,
|
|
215
|
+
prefix=self.prefix,
|
|
216
|
+
plot_format=self.plot_format,
|
|
217
|
+
plot_fontsize=self.plot_fontsize,
|
|
218
|
+
plot_dpi=self.plot_dpi,
|
|
219
|
+
title_fontsize=self.plot_fontsize,
|
|
220
|
+
despine=self.plot_despine,
|
|
221
|
+
show_plots=self.show_plots,
|
|
222
|
+
verbose=self.verbose,
|
|
223
|
+
debug=self.debug,
|
|
224
|
+
multiqc=True,
|
|
225
|
+
multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Output dirs
|
|
229
|
+
dirs = ["models", "plots", "metrics", "optimize", "parameters"]
|
|
230
|
+
self._create_model_directories(self.prefix, dirs)
|
|
231
|
+
|
|
232
|
+
def fit(self) -> "ImputeRefAllele":
|
|
233
|
+
"""Create TRAIN/TEST split and build eval mask, with optional sim-missing.
|
|
234
|
+
|
|
235
|
+
This method prepares the imputer by splitting the data into training and testing sets and constructing an evaluation mask. If `cfg.sim.simulate_missing` is False (default), it masks all originally observed genotype entries on TEST rows. If `cfg.sim.simulate_missing` is True, it uses SimMissingTransformer to select a subset of observed cells as simulated-missing, then restricts that mask to TEST rows only. Evaluation is then performed only on these simulated-missing cells, mirroring the deep learning models.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
ImputeRefAllele: The fitted imputer instance.
|
|
239
|
+
"""
|
|
240
|
+
# Train/test split indices
|
|
241
|
+
self.train_idx_, self.test_idx_ = self._make_train_test_split()
|
|
242
|
+
self.ground_truth012_ = self.X012_.copy()
|
|
243
|
+
|
|
244
|
+
# Use NaN for missing inside a DataFrame to leverage fillna
|
|
245
|
+
df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
|
|
246
|
+
df_all = df_all.replace(self.missing, np.nan)
|
|
247
|
+
df_all = df_all.replace(-9, np.nan) # Just in case
|
|
248
|
+
|
|
249
|
+
# Observed mask in the ORIGINAL data (before any simulated-missing)
|
|
250
|
+
obs_mask = df_all.notna().to_numpy() # shape (n_samples, n_loci)
|
|
251
|
+
|
|
252
|
+
# TEST row selector
|
|
253
|
+
test_rows_mask = np.zeros(obs_mask.shape[0], dtype=bool)
|
|
254
|
+
if self.test_idx_ is not None and self.test_idx_.size > 0:
|
|
255
|
+
test_rows_mask[self.test_idx_] = True
|
|
256
|
+
|
|
257
|
+
# Decide how to build the sim mask: legacy vs simulated-missing
|
|
258
|
+
if getattr(self, "simulate_missing", False):
|
|
259
|
+
# Simulate missing on the full matrix; we only use the mask.
|
|
260
|
+
tr = SimMissingTransformer(
|
|
261
|
+
genotype_data=self.genotype_data,
|
|
262
|
+
tree_parser=self.tree_parser,
|
|
263
|
+
prop_missing=self.sim_prop,
|
|
264
|
+
strategy=self.sim_strategy,
|
|
265
|
+
missing_val=-9,
|
|
266
|
+
mask_missing=True,
|
|
267
|
+
verbose=self.verbose,
|
|
268
|
+
**(self.sim_kwargs or {}),
|
|
269
|
+
)
|
|
270
|
+
tr.fit(self.ground_truth012_.copy())
|
|
271
|
+
sim_mask_global = tr.sim_missing_mask_.astype(bool)
|
|
272
|
+
|
|
273
|
+
# Only consider cells that were originally observed
|
|
274
|
+
sim_mask_global = sim_mask_global & obs_mask
|
|
275
|
+
|
|
276
|
+
# Restrict evaluation to TEST rows only
|
|
277
|
+
sim_mask = sim_mask_global & test_rows_mask[:, None]
|
|
278
|
+
mode_desc = "simulated missing on TEST rows"
|
|
279
|
+
else:
|
|
280
|
+
# Legacy behavior: mask ALL originally observed TEST cells
|
|
281
|
+
sim_mask = obs_mask & test_rows_mask[:, None]
|
|
282
|
+
mode_desc = "all originally observed cells on TEST rows"
|
|
283
|
+
|
|
284
|
+
# Apply eval mask: set these cells to NaN in the eval DataFrame
|
|
285
|
+
df_sim = df_all.copy()
|
|
286
|
+
df_sim.values[sim_mask] = np.nan
|
|
287
|
+
|
|
288
|
+
# Store state
|
|
289
|
+
self.sim_mask_ = sim_mask
|
|
290
|
+
self.X_train_df_ = df_sim
|
|
291
|
+
self.is_fit_ = True
|
|
292
|
+
|
|
293
|
+
n_masked = int(sim_mask.sum())
|
|
294
|
+
self.logger.info(
|
|
295
|
+
f"Fit complete. Train rows: {self.train_idx_.size}, "
|
|
296
|
+
f"Test rows: {self.test_idx_.size}. "
|
|
297
|
+
f"Masked {n_masked} cells for evaluation ({mode_desc})."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Persist config for reproducibility
|
|
301
|
+
params_fp = self.parameters_dir / "best_parameters.json"
|
|
302
|
+
best_params = self.cfg.to_dict()
|
|
303
|
+
with open(params_fp, "w") as f:
|
|
304
|
+
json.dump(best_params, f, indent=4)
|
|
305
|
+
|
|
306
|
+
return self
|
|
307
|
+
|
|
308
|
+
def transform(self) -> np.ndarray:
|
|
309
|
+
"""Impute missing values with REF genotype (0) and evaluate on masked test cells.
|
|
310
|
+
|
|
311
|
+
This method performs the imputation by replacing all missing genotype values with the REF genotype (0). It evaluates the imputation performance on the masked test cells, producing classification reports and plots that mirror those generated by deep learning models. The final output is the fully imputed genotype matrix in IUPAC string format.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
np.ndarray: The fully imputed genotype matrix in IUPAC string format.
|
|
315
|
+
|
|
316
|
+
Raises:
|
|
317
|
+
NotFittedError: If the model has not been fitted yet.
|
|
318
|
+
"""
|
|
319
|
+
if not self.is_fit_:
|
|
320
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
321
|
+
assert self.X_train_df_ is not None
|
|
322
|
+
|
|
323
|
+
# 1) Impute the evaluation-masked copy (compute metrics)
|
|
324
|
+
imputed_eval_df = self._impute_ref(df_in=self.X_train_df_)
|
|
325
|
+
X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int16)
|
|
326
|
+
self.X_imputed012_ = X_imputed_eval
|
|
327
|
+
|
|
328
|
+
# Evaluate parity with DL models
|
|
329
|
+
self._evaluate_and_report()
|
|
330
|
+
|
|
331
|
+
# 2) Impute the FULL dataset (only true missings)
|
|
332
|
+
df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
|
|
333
|
+
df_missingonly = df_missingonly.replace(self.missing, np.nan)
|
|
334
|
+
df_missingonly = df_missingonly.replace(-9, np.nan) # Just in case
|
|
335
|
+
|
|
336
|
+
imputed_full_df = self._impute_ref(df_in=df_missingonly)
|
|
337
|
+
X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int16)
|
|
338
|
+
|
|
339
|
+
# Plot distributions (like DL .transform())
|
|
340
|
+
|
|
341
|
+
if self.ground_truth012_ is None:
|
|
342
|
+
msg = "ground_truth012_ is None; cannot plot distributions."
|
|
343
|
+
self.logger.error(msg)
|
|
344
|
+
|
|
345
|
+
raise NotFittedError("ground_truth012_ is None; cannot plot distributions.")
|
|
346
|
+
gt_decoded = self.encoder.decode_012(self.ground_truth012_)
|
|
347
|
+
imp_decoded = self.encoder.decode_012(X_imputed_full_012)
|
|
348
|
+
self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
|
|
349
|
+
self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
|
|
350
|
+
|
|
351
|
+
# Return IUPAC strings
|
|
352
|
+
return imp_decoded
|
|
353
|
+
|
|
354
|
+
def _impute_ref(self, df_in: pd.DataFrame) -> pd.DataFrame:
|
|
355
|
+
"""Replace every NaN with the REF genotype code (0) across all loci.
|
|
356
|
+
|
|
357
|
+
This is the deterministic REF-allele imputation in 0/1/2 encoding. The method fills all NaN values in the input DataFrame with 0, representing the REF genotype. The operation is performed column-wise, and since the fill value is constant, it is efficient to apply it in a vectorized manner.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
df_in (pd.DataFrame): Input DataFrame with NaNs representing missing genotypes.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
pd.DataFrame: DataFrame with NaNs replaced by 0 (REF genotype).
|
|
364
|
+
"""
|
|
365
|
+
df = df_in.copy()
|
|
366
|
+
# Fill all NaNs with 0 (homozygous REF) column-wise; constant so vectorized is fine
|
|
367
|
+
df = df.fillna(0)
|
|
368
|
+
return df.astype(np.int16)
|
|
369
|
+
|
|
370
|
+
def _evaluate_and_report(self) -> None:
|
|
371
|
+
"""Evaluate imputed vs. ground truth on masked test cells; produce reports and plots.
|
|
372
|
+
|
|
373
|
+
Requires that fit() and transform() have been called. This method evaluates the imputed genotypes against the ground truth for the masked test cells, generating classification reports and confusion matrices for both 0/1/2 zygosity and 10-class IUPAC codes. It logs the results and saves the reports and plots to the designated output directories.
|
|
374
|
+
|
|
375
|
+
Raises:
|
|
376
|
+
NotFittedError: If fit() and transform() have not been called.
|
|
377
|
+
"""
|
|
378
|
+
assert (
|
|
379
|
+
self.sim_mask_ is not None
|
|
380
|
+
and self.ground_truth012_ is not None
|
|
381
|
+
and self.X_imputed012_ is not None
|
|
382
|
+
)
|
|
383
|
+
y_true_012 = self.ground_truth012_[self.sim_mask_]
|
|
384
|
+
y_pred_012 = self.X_imputed012_[self.sim_mask_]
|
|
385
|
+
|
|
386
|
+
if y_true_012.size == 0:
|
|
387
|
+
self.logger.info("No masked test cells; skipping evaluation.")
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
# 0/1/2 report (REF/HET/ALT), with haploid folding 2->1 if needed
|
|
391
|
+
self._evaluate_012_and_plot(y_true_012.copy(), y_pred_012.copy())
|
|
392
|
+
|
|
393
|
+
# 10-class IUPAC report from decoded strings (parity with DL)
|
|
394
|
+
X_pred_eval = self.ground_truth012_.copy()
|
|
395
|
+
X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
|
|
396
|
+
|
|
397
|
+
y_true_dec = self.encoder.decode_012(self.ground_truth012_)
|
|
398
|
+
y_pred_dec = self.encoder.decode_012(X_pred_eval)
|
|
399
|
+
|
|
400
|
+
encodings_dict = {
|
|
401
|
+
"A": 0,
|
|
402
|
+
"C": 1,
|
|
403
|
+
"G": 2,
|
|
404
|
+
"T": 3,
|
|
405
|
+
"W": 4,
|
|
406
|
+
"R": 5,
|
|
407
|
+
"M": 6,
|
|
408
|
+
"K": 7,
|
|
409
|
+
"Y": 8,
|
|
410
|
+
"S": 9,
|
|
411
|
+
"N": -1,
|
|
412
|
+
}
|
|
413
|
+
y_true_int = self.encoder.convert_int_iupac(
|
|
414
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
415
|
+
)
|
|
416
|
+
y_pred_int = self.encoder.convert_int_iupac(
|
|
417
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
418
|
+
)
|
|
419
|
+
y_true_10 = y_true_int[self.sim_mask_]
|
|
420
|
+
y_pred_10 = y_pred_int[self.sim_mask_]
|
|
421
|
+
self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
|
|
422
|
+
|
|
423
|
+
def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
|
424
|
+
"""0/1/2 zygosity report & confusion matrix.
|
|
425
|
+
|
|
426
|
+
This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is determined to be haploid (only 0 and 2 present), it folds the ALT genotype (2) into HET (1) for evaluation purposes. The method computes various performance metrics, logs the classification report, and creates visualizations of the results.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
y_true (np.ndarray): True genotypes (0/1/2) for masked
|
|
430
|
+
y_pred (np.ndarray): Predicted genotypes (0/1/2) for
|
|
431
|
+
"""
|
|
432
|
+
labels = [0, 1, 2]
|
|
433
|
+
report_names = ["REF", "HET", "ALT"]
|
|
434
|
+
|
|
435
|
+
# Haploid parity: fold 2 -> 1
|
|
436
|
+
if self.is_haploid_:
|
|
437
|
+
y_true[y_true == 2] = 1
|
|
438
|
+
y_pred[y_pred == 2] = 1
|
|
439
|
+
labels = [0, 1]
|
|
440
|
+
report_names = ["REF", "ALT"]
|
|
441
|
+
|
|
442
|
+
metrics = {
|
|
443
|
+
"n_masked_test": int(y_true.size),
|
|
444
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
|
445
|
+
"f1": f1_score(
|
|
446
|
+
y_true, y_pred, average="weighted", labels=labels, zero_division=0
|
|
447
|
+
),
|
|
448
|
+
"precision": precision_score(
|
|
449
|
+
y_true, y_pred, average="weighted", labels=labels, zero_division=0
|
|
450
|
+
),
|
|
451
|
+
"recall": recall_score(
|
|
452
|
+
y_true, y_pred, average="weighted", labels=labels, zero_division=0
|
|
453
|
+
),
|
|
454
|
+
}
|
|
455
|
+
self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
|
|
456
|
+
|
|
457
|
+
report: str | dict = classification_report(
|
|
458
|
+
y_true,
|
|
459
|
+
y_pred,
|
|
460
|
+
labels=labels,
|
|
461
|
+
target_names=report_names,
|
|
462
|
+
zero_division=0,
|
|
463
|
+
output_dict=True,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
if not isinstance(report, dict):
|
|
467
|
+
msg = "classification_report did not return a dict as expected."
|
|
468
|
+
self.logger.error(msg)
|
|
469
|
+
raise TypeError(msg)
|
|
470
|
+
|
|
471
|
+
report_subset = {}
|
|
472
|
+
for k, v in report.items():
|
|
473
|
+
tmp = {}
|
|
474
|
+
if isinstance(v, dict) and "support" in v:
|
|
475
|
+
for k2, v2 in v.items():
|
|
476
|
+
if k2 != "support":
|
|
477
|
+
tmp[k2] = v2
|
|
478
|
+
if tmp:
|
|
479
|
+
report_subset[k] = tmp
|
|
480
|
+
|
|
481
|
+
if report_subset:
|
|
482
|
+
pm = PrettyMetrics(
|
|
483
|
+
report_subset,
|
|
484
|
+
precision=3,
|
|
485
|
+
title=f"{self.model_name} Zygosity Report",
|
|
486
|
+
)
|
|
487
|
+
pm.render()
|
|
488
|
+
|
|
489
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
490
|
+
|
|
491
|
+
if not isinstance(report, dict):
|
|
492
|
+
msg = "classification_report did not return a dict as expected."
|
|
493
|
+
self.logger.error(msg)
|
|
494
|
+
raise TypeError(msg)
|
|
495
|
+
|
|
496
|
+
plots = viz.plot_all(
|
|
497
|
+
report,
|
|
498
|
+
title_prefix=f"{self.model_name} Zygosity Report",
|
|
499
|
+
show=getattr(self, "show_plots", False),
|
|
500
|
+
heatmap_classes_only=True,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# Reset the style from Optuna's plotting.
|
|
504
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
505
|
+
|
|
506
|
+
for name, fig in plots.items():
|
|
507
|
+
fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
|
|
508
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
509
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
510
|
+
plt.close(fig)
|
|
511
|
+
elif isinstance(fig, PlotlyFigure):
|
|
512
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
513
|
+
|
|
514
|
+
viz._reset_mpl_style()
|
|
515
|
+
|
|
516
|
+
self._save_report(report, suffix="zygosity")
|
|
517
|
+
|
|
518
|
+
# Confusion matrix
|
|
519
|
+
self.plotter_.plot_confusion_matrix(
|
|
520
|
+
y_true, y_pred, label_names=report_names, prefix="zygosity"
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
def _evaluate_iupac10_and_plot(
|
|
524
|
+
self, y_true: np.ndarray, y_pred: np.ndarray
|
|
525
|
+
) -> None:
|
|
526
|
+
"""10-class IUPAC report & confusion matrix.
|
|
527
|
+
|
|
528
|
+
This method generates a classification report and confusion matrix for genotypes encoded using the 10 IUPAC codes (0-9). The IUPAC codes represent various nucleotide combinations, including ambiguous bases.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
y_true (np.ndarray): True genotypes (0-9) for masked test cells.
|
|
532
|
+
y_pred (np.ndarray): Predicted genotypes (0-9) for masked test cells.
|
|
533
|
+
"""
|
|
534
|
+
labels_idx = list(range(10))
|
|
535
|
+
labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
|
|
536
|
+
|
|
537
|
+
metrics = {
|
|
538
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
|
539
|
+
"f1": f1_score(
|
|
540
|
+
y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
|
|
541
|
+
),
|
|
542
|
+
"precision": precision_score(
|
|
543
|
+
y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
|
|
544
|
+
),
|
|
545
|
+
"recall": recall_score(
|
|
546
|
+
y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
|
|
547
|
+
),
|
|
548
|
+
}
|
|
549
|
+
self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
|
|
550
|
+
|
|
551
|
+
report = classification_report(
|
|
552
|
+
y_true,
|
|
553
|
+
y_pred,
|
|
554
|
+
labels=labels_idx,
|
|
555
|
+
target_names=labels_names,
|
|
556
|
+
zero_division=0,
|
|
557
|
+
output_dict=True,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
if not isinstance(report, dict):
|
|
561
|
+
msg = "classification_report did not return a dict as expected."
|
|
562
|
+
self.logger.error(msg)
|
|
563
|
+
raise TypeError(msg)
|
|
564
|
+
|
|
565
|
+
report_subset = {}
|
|
566
|
+
for k, v in report.items():
|
|
567
|
+
tmp = {}
|
|
568
|
+
if isinstance(v, dict) and "support" in v:
|
|
569
|
+
for k2, v2 in v.items():
|
|
570
|
+
if k2 != "support":
|
|
571
|
+
tmp[k2] = v2
|
|
572
|
+
if tmp:
|
|
573
|
+
report_subset[k] = tmp
|
|
574
|
+
|
|
575
|
+
if report_subset:
|
|
576
|
+
pm = PrettyMetrics(
|
|
577
|
+
report_subset,
|
|
578
|
+
precision=3,
|
|
579
|
+
title=f"{self.model_name} IUPAC 10-Class Report",
|
|
580
|
+
)
|
|
581
|
+
pm.render()
|
|
582
|
+
|
|
583
|
+
self._save_report(report, suffix="iupac")
|
|
584
|
+
|
|
585
|
+
# Confusion matrix
|
|
586
|
+
self.plotter_.plot_confusion_matrix(
|
|
587
|
+
y_true, y_pred, label_names=labels_names, prefix="iupac"
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
591
|
+
"""Create train/test split indices.
|
|
592
|
+
|
|
593
|
+
This method generates training and testing indices for the dataset. If specific test indices are provided, it uses those; otherwise, it randomly selects a proportion of samples as the test set based on the specified test size. The method ensures that the selected test indices are within valid bounds and that there is no overlap between training and testing sets.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
Tuple[np.ndarray, np.ndarray]: Arrays of train and test indices.
|
|
597
|
+
|
|
598
|
+
Raises:
|
|
599
|
+
IndexError: If provided test_indices are out of bounds.
|
|
600
|
+
"""
|
|
601
|
+
n = self.X012_.shape[0]
|
|
602
|
+
all_idx = np.arange(n, dtype=int)
|
|
603
|
+
|
|
604
|
+
if self.test_indices is not None:
|
|
605
|
+
test_idx = np.unique(self.test_indices)
|
|
606
|
+
|
|
607
|
+
if np.any((test_idx < 0) | (test_idx >= n)):
|
|
608
|
+
msg = "Some test_indices are out of bounds."
|
|
609
|
+
self.logger.error(msg)
|
|
610
|
+
raise IndexError(msg)
|
|
611
|
+
|
|
612
|
+
train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
|
|
613
|
+
return train_idx, test_idx
|
|
614
|
+
|
|
615
|
+
k = int(round(self.test_size * n))
|
|
616
|
+
|
|
617
|
+
test_idx = (
|
|
618
|
+
self.rng.choice(n, size=k, replace=False)
|
|
619
|
+
if k > 0
|
|
620
|
+
else np.array([], dtype=int)
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
|
|
624
|
+
return train_idx, test_idx
|
|
625
|
+
|
|
626
|
+
def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
|
|
627
|
+
"""Save classification report dictionary as a JSON file.
|
|
628
|
+
|
|
629
|
+
This method saves the provided classification report dictionary to a JSON file in the metrics directory. The filename includes a suffix to distinguish between different types of reports (e.g., 'zygosity' or 'iupac').
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
report_dict (Dict[str, float]): The classification report dictionary to save.
|
|
633
|
+
suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
|
|
634
|
+
|
|
635
|
+
Raises:
|
|
636
|
+
NotFittedError: If fit() and transform() have not been called.
|
|
637
|
+
"""
|
|
638
|
+
if not self.is_fit_ or self.X_imputed012_ is None:
|
|
639
|
+
raise NotFittedError("No report to save. Ensure fit() and transform() ran.")
|
|
640
|
+
|
|
641
|
+
out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
|
|
642
|
+
with open(out_fp, "w") as f:
|
|
643
|
+
json.dump(report_dict, f, indent=4)
|
|
644
|
+
self.logger.info(f"{self.model_name} {suffix} report saved to {out_fp}.")
|
|
645
|
+
|
|
646
|
+
def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
|
|
647
|
+
"""Creates the directory structure for storing model outputs.
|
|
648
|
+
|
|
649
|
+
This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
prefix (str): The prefix for the main output directory.
|
|
653
|
+
outdirs (List[str]): A list of subdirectory names to create within the main directory.
|
|
654
|
+
|
|
655
|
+
Raises:
|
|
656
|
+
Exception: If any of the directories cannot be created.
|
|
657
|
+
"""
|
|
658
|
+
formatted_output_dir = Path(f"{prefix}_output")
|
|
659
|
+
base_dir = formatted_output_dir / "Deterministic"
|
|
660
|
+
|
|
661
|
+
for d in outdirs:
|
|
662
|
+
subdir = base_dir / d / self.model_name
|
|
663
|
+
setattr(self, f"{d}_dir", subdir)
|
|
664
|
+
try:
|
|
665
|
+
getattr(self, f"{d}_dir").mkdir(parents=True, exist_ok=True)
|
|
666
|
+
except Exception as e:
|
|
667
|
+
msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
|
|
668
|
+
self.logger.error(msg)
|
|
669
|
+
raise Exception(msg)
|
|
File without changes
|