smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +49 -7
- smftools/cli/hmm_adata.py +250 -32
- smftools/cli/latent_adata.py +773 -0
- smftools/cli/load_adata.py +78 -74
- smftools/cli/preprocess_adata.py +122 -58
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +74 -112
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +52 -4
- smftools/config/conversion.yaml +1 -1
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +85 -12
- smftools/config/experiment_config.py +146 -1
- smftools/constants.py +69 -0
- smftools/hmm/HMM.py +88 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +636 -175
- smftools/informatics/h5ad_functions.py +198 -2
- smftools/informatics/modkit_extract_to_adata.py +1007 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +26 -3
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +62 -1583
- smftools/plotting/hmm_plotting.py +1670 -8
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +4 -0
- smftools/preprocessing/append_base_context.py +18 -18
- smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +159 -99
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +10 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +130 -0
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +79 -80
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +872 -0
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +217 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -60,6 +60,20 @@ stages:
|
|
|
60
60
|
notes: "Mapping quality score."
|
|
61
61
|
requires: []
|
|
62
62
|
optional_inputs: []
|
|
63
|
+
reference_start:
|
|
64
|
+
dtype: "float"
|
|
65
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
66
|
+
modified_by: []
|
|
67
|
+
notes: "0-based reference start position for the alignment."
|
|
68
|
+
requires: []
|
|
69
|
+
optional_inputs: []
|
|
70
|
+
reference_end:
|
|
71
|
+
dtype: "float"
|
|
72
|
+
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
73
|
+
modified_by: []
|
|
74
|
+
notes: "0-based reference end position (exclusive) for the alignment."
|
|
75
|
+
requires: []
|
|
76
|
+
optional_inputs: []
|
|
63
77
|
read_length_to_reference_length_ratio:
|
|
64
78
|
dtype: "float"
|
|
65
79
|
created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
|
|
@@ -179,7 +193,7 @@ stages:
|
|
|
179
193
|
obs:
|
|
180
194
|
leiden:
|
|
181
195
|
dtype: "category"
|
|
182
|
-
created_by: "smftools.tools.
|
|
196
|
+
created_by: "smftools.tools.calculate_leiden"
|
|
183
197
|
modified_by: []
|
|
184
198
|
notes: "Leiden cluster assignments."
|
|
185
199
|
requires: [["obsm.X_umap"]]
|
smftools/tools/__init__.py
CHANGED
|
@@ -3,6 +3,11 @@ from __future__ import annotations
|
|
|
3
3
|
from importlib import import_module
|
|
4
4
|
|
|
5
5
|
_LAZY_ATTRS = {
|
|
6
|
+
"calculate_leiden": "smftools.tools.calculate_leiden",
|
|
7
|
+
"calculate_nmf": "smftools.tools.calculate_nmf",
|
|
8
|
+
"calculate_sequence_cp_decomposition": "smftools.tools.tensor_factorization",
|
|
9
|
+
"calculate_pca": "smftools.tools.calculate_pca",
|
|
10
|
+
"calculate_knn": "smftools.tools.calculate_knn",
|
|
6
11
|
"calculate_umap": "smftools.tools.calculate_umap",
|
|
7
12
|
"cluster_adata_on_methylation": "smftools.tools.cluster_adata_on_methylation",
|
|
8
13
|
"combine_layers": "smftools.tools.general_tools",
|
|
@@ -11,6 +16,11 @@ _LAZY_ATTRS = {
|
|
|
11
16
|
"calculate_relative_risk_on_activity": "smftools.tools.position_stats",
|
|
12
17
|
"compute_positionwise_statistics": "smftools.tools.position_stats",
|
|
13
18
|
"calculate_row_entropy": "smftools.tools.read_stats",
|
|
19
|
+
"align_sequences_with_mismatches": "smftools.tools.sequence_alignment",
|
|
20
|
+
"rolling_window_nn_distance": "smftools.tools.rolling_nn_distance",
|
|
21
|
+
"annotate_zero_hamming_segments": "smftools.tools.rolling_nn_distance",
|
|
22
|
+
"assign_per_read_segments_layer": "smftools.tools.rolling_nn_distance",
|
|
23
|
+
"select_top_segments_per_read": "smftools.tools.rolling_nn_distance",
|
|
14
24
|
"subset_adata": "smftools.tools.subset_adata",
|
|
15
25
|
}
|
|
16
26
|
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from smftools.logging_utils import get_logger
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import anndata as ad
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def calculate_knn(
|
|
15
|
+
adata: "ad.AnnData",
|
|
16
|
+
obsm: str = "X_pca",
|
|
17
|
+
knn_neighbors: int = 100,
|
|
18
|
+
overwrite: bool = True,
|
|
19
|
+
threads: int = 8,
|
|
20
|
+
random_state: int | None = 0,
|
|
21
|
+
symmetrize: bool = True,
|
|
22
|
+
) -> "ad.AnnData":
|
|
23
|
+
"""Compute a KNN distance graph on an embedding in `adata.obsm[obsm]`.
|
|
24
|
+
|
|
25
|
+
Stores:
|
|
26
|
+
- adata.obsp[f"knn_distances_{obsm}"] : CSR sparse matrix of distances
|
|
27
|
+
- adata.uns[f"knn_distances_{obsm}"]["params"] : metadata
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
adata: AnnData object to update.
|
|
31
|
+
obsm: Key in `adata.obsm` to use as the embedding.
|
|
32
|
+
knn_neighbors: Target number of neighbors (will be clipped to n_obs-1).
|
|
33
|
+
overwrite: If False and graph exists, do nothing.
|
|
34
|
+
threads: Parallel jobs for pynndescent.
|
|
35
|
+
random_state: Seed for pynndescent.
|
|
36
|
+
symmetrize: If True, make distance graph symmetric via min(A, A.T).
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Updated AnnData.
|
|
40
|
+
"""
|
|
41
|
+
import numpy as np
|
|
42
|
+
import scipy.sparse as sp
|
|
43
|
+
|
|
44
|
+
if obsm not in adata.obsm:
|
|
45
|
+
raise KeyError(f"`{obsm}` not found in adata.obsm. Available: {list(adata.obsm.keys())}")
|
|
46
|
+
|
|
47
|
+
out_key = f"knn_distances_{obsm}"
|
|
48
|
+
if not overwrite and out_key in adata.obsp:
|
|
49
|
+
logger.info("KNN graph %r already exists and overwrite=False; skipping.", out_key)
|
|
50
|
+
return adata
|
|
51
|
+
|
|
52
|
+
data = adata.obsm[obsm]
|
|
53
|
+
|
|
54
|
+
if sp.issparse(data):
|
|
55
|
+
# Convert to float32 for pynndescent/numba friendliness if needed
|
|
56
|
+
data = data.astype(np.float32)
|
|
57
|
+
logger.info(
|
|
58
|
+
"Sparse embedding detected (%s). Proceeding without NaN check.", type(data).__name__
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
data = np.asarray(data)
|
|
62
|
+
if np.isnan(data).any():
|
|
63
|
+
logger.warning("NaNs detected in %s; filling NaNs with 0.5 before KNN.", obsm)
|
|
64
|
+
data = np.nan_to_num(data, nan=0.5)
|
|
65
|
+
data = data.astype(np.float32, copy=False)
|
|
66
|
+
|
|
67
|
+
pynndescent = require("pynndescent", extra="umap", purpose="KNN graph computation")
|
|
68
|
+
|
|
69
|
+
n_obs = data.shape[0]
|
|
70
|
+
if n_obs < 2:
|
|
71
|
+
raise ValueError(f"Need at least 2 observations for KNN; got n_obs={n_obs}")
|
|
72
|
+
|
|
73
|
+
n_neighbors = min(int(knn_neighbors), n_obs - 1)
|
|
74
|
+
if n_neighbors < 1:
|
|
75
|
+
raise ValueError(f"Computed n_neighbors={n_neighbors}; check knn_neighbors and n_obs.")
|
|
76
|
+
|
|
77
|
+
logger.info(
|
|
78
|
+
"Running pynndescent KNN (obsm=%s, n_neighbors=%d, metric=euclidean, n_jobs=%d)",
|
|
79
|
+
obsm,
|
|
80
|
+
n_neighbors,
|
|
81
|
+
threads,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
nn_index = pynndescent.NNDescent(
|
|
85
|
+
data,
|
|
86
|
+
n_neighbors=n_neighbors,
|
|
87
|
+
metric="euclidean",
|
|
88
|
+
random_state=random_state,
|
|
89
|
+
n_jobs=threads,
|
|
90
|
+
)
|
|
91
|
+
knn_indices, knn_dists = nn_index.neighbor_graph # shapes: (n_obs, n_neighbors)
|
|
92
|
+
|
|
93
|
+
rows = np.repeat(np.arange(n_obs, dtype=np.int64), n_neighbors)
|
|
94
|
+
cols = knn_indices.reshape(-1).astype(np.int64, copy=False)
|
|
95
|
+
vals = knn_dists.reshape(-1).astype(np.float32, copy=False)
|
|
96
|
+
|
|
97
|
+
distances = sp.coo_matrix((vals, (rows, cols)), shape=(n_obs, n_obs)).tocsr()
|
|
98
|
+
|
|
99
|
+
# Optional: ensure diagonal is 0 and (optionally) symmetrize
|
|
100
|
+
distances.setdiag(0.0)
|
|
101
|
+
distances.eliminate_zeros()
|
|
102
|
+
|
|
103
|
+
if symmetrize:
|
|
104
|
+
# Keep the smaller directed distance for each undirected edge
|
|
105
|
+
distances = distances.minimum(distances.T)
|
|
106
|
+
|
|
107
|
+
adata.obsp[out_key] = distances
|
|
108
|
+
adata.uns[out_key] = {
|
|
109
|
+
"params": {
|
|
110
|
+
"obsm": obsm,
|
|
111
|
+
"n_neighbors_requested": int(knn_neighbors),
|
|
112
|
+
"n_neighbors_used": int(n_neighbors),
|
|
113
|
+
"method": "pynndescent",
|
|
114
|
+
"metric": "euclidean",
|
|
115
|
+
"random_state": random_state,
|
|
116
|
+
"n_jobs": int(threads),
|
|
117
|
+
"symmetrize": bool(symmetrize),
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return adata
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import anndata as ad
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def calculate_leiden(
|
|
18
|
+
adata: "ad.AnnData",
|
|
19
|
+
*,
|
|
20
|
+
resolution: float = 0.1,
|
|
21
|
+
key_added: str = "leiden",
|
|
22
|
+
connectivities_key: str = "connectivities",
|
|
23
|
+
) -> "ad.AnnData":
|
|
24
|
+
"""Compute Leiden clusters from a connectivity graph.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
adata: AnnData object with ``obsp[connectivities_key]`` set.
|
|
28
|
+
resolution: Resolution parameter for Leiden clustering.
|
|
29
|
+
key_added: Column name to store cluster assignments in ``adata.obs``.
|
|
30
|
+
connectivities_key: Key in ``adata.obsp`` containing a sparse adjacency matrix.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Updated AnnData object with Leiden labels in ``adata.obs``.
|
|
34
|
+
"""
|
|
35
|
+
if connectivities_key not in adata.obsp:
|
|
36
|
+
raise KeyError(f"Missing connectivities '{connectivities_key}' in adata.obsp.")
|
|
37
|
+
|
|
38
|
+
igraph = require("igraph", extra="cluster", purpose="Leiden clustering")
|
|
39
|
+
leidenalg = require("leidenalg", extra="cluster", purpose="Leiden clustering")
|
|
40
|
+
|
|
41
|
+
connectivities = adata.obsp[connectivities_key]
|
|
42
|
+
coo = connectivities.tocoo()
|
|
43
|
+
edges = list(zip(coo.row.tolist(), coo.col.tolist()))
|
|
44
|
+
graph = igraph.Graph(n=connectivities.shape[0], edges=edges, directed=False)
|
|
45
|
+
graph.es["weight"] = coo.data.tolist()
|
|
46
|
+
|
|
47
|
+
partition = leidenalg.find_partition(
|
|
48
|
+
graph,
|
|
49
|
+
leidenalg.RBConfigurationVertexPartition,
|
|
50
|
+
weights=graph.es["weight"],
|
|
51
|
+
resolution_parameter=resolution,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
labels = np.array(partition.membership, dtype=str)
|
|
55
|
+
adata.obs[key_added] = pd.Categorical(labels)
|
|
56
|
+
logger.info("Stored Leiden clusters in adata.obs['%s'].", key_added)
|
|
57
|
+
return adata
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from smftools.logging_utils import get_logger
|
|
8
|
+
from smftools.optional_imports import require
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import anndata as ad
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def calculate_nmf(
|
|
18
|
+
adata: "ad.AnnData",
|
|
19
|
+
layer: str | None = "nan_half",
|
|
20
|
+
var_mask: "np.ndarray | Sequence[bool] | None" = None,
|
|
21
|
+
n_components: int = 2,
|
|
22
|
+
max_iter: int = 200,
|
|
23
|
+
random_state: int = 0,
|
|
24
|
+
overwrite: bool = True,
|
|
25
|
+
embedding_key: str = "X_nmf",
|
|
26
|
+
components_key: str = "H_nmf",
|
|
27
|
+
uns_key: str = "nmf",
|
|
28
|
+
suffix: str | None = None,
|
|
29
|
+
) -> "ad.AnnData":
|
|
30
|
+
"""Compute a low-dimensional NMF embedding.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
adata: AnnData object to update.
|
|
34
|
+
layer: Layer name to use for NMF (``None`` uses ``adata.X``).
|
|
35
|
+
var_mask: Optional boolean mask to subset features.
|
|
36
|
+
n_components: Number of NMF components to compute.
|
|
37
|
+
max_iter: Maximum number of NMF iterations.
|
|
38
|
+
random_state: Random seed for the NMF initializer.
|
|
39
|
+
overwrite: Whether to recompute if the embedding already exists.
|
|
40
|
+
embedding_key: Key for the embedding in ``adata.obsm``.
|
|
41
|
+
components_key: Key for the components matrix in ``adata.varm``.
|
|
42
|
+
uns_key: Key for metadata stored in ``adata.uns``.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
anndata.AnnData: Updated AnnData object.
|
|
46
|
+
"""
|
|
47
|
+
from scipy.sparse import issparse
|
|
48
|
+
|
|
49
|
+
require("sklearn", extra="ml-base", purpose="NMF calculation")
|
|
50
|
+
from sklearn.decomposition import NMF
|
|
51
|
+
|
|
52
|
+
if suffix:
|
|
53
|
+
embedding_key = f"{embedding_key}_{suffix}"
|
|
54
|
+
components_key = f"{components_key}_{suffix}"
|
|
55
|
+
uns_key = f"{uns_key}_{suffix}"
|
|
56
|
+
|
|
57
|
+
has_embedding = embedding_key in adata.obsm
|
|
58
|
+
has_components = components_key in adata.varm
|
|
59
|
+
if has_embedding and has_components and not overwrite:
|
|
60
|
+
logger.info("NMF embedding and components already present; skipping recomputation.")
|
|
61
|
+
return adata
|
|
62
|
+
if has_embedding and not has_components and not overwrite:
|
|
63
|
+
logger.info("NMF embedding present without components; recomputing to store components.")
|
|
64
|
+
|
|
65
|
+
subset_mask = None
|
|
66
|
+
if var_mask is not None:
|
|
67
|
+
subset_mask = np.asarray(var_mask, dtype=bool)
|
|
68
|
+
if subset_mask.ndim != 1 or subset_mask.shape[0] != adata.n_vars:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"var_mask must be a 1D boolean array with length matching adata.n_vars."
|
|
71
|
+
)
|
|
72
|
+
adata_subset = adata[:, subset_mask].copy()
|
|
73
|
+
logger.info(
|
|
74
|
+
"Subsetting adata: retained %s features based on filters %s",
|
|
75
|
+
adata_subset.shape[1],
|
|
76
|
+
"var_mask",
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
adata_subset = adata.copy()
|
|
80
|
+
logger.info("No var_mask provided. Using all features.")
|
|
81
|
+
|
|
82
|
+
data = adata_subset.layers[layer] if layer else adata_subset.X
|
|
83
|
+
if issparse(data):
|
|
84
|
+
data = data.copy()
|
|
85
|
+
if data.data.size and np.isnan(data.data).any():
|
|
86
|
+
logger.warning("NaNs detected in sparse data, filling with 0.5 before NMF.")
|
|
87
|
+
data.data = np.nan_to_num(data.data, nan=0.5)
|
|
88
|
+
if data.data.size and (data.data < 0).any():
|
|
89
|
+
logger.warning("Negative values detected in sparse data, clipping to 0 for NMF.")
|
|
90
|
+
data.data[data.data < 0] = 0
|
|
91
|
+
else:
|
|
92
|
+
if np.isnan(data).any():
|
|
93
|
+
logger.warning("NaNs detected, filling with 0.5 before NMF.")
|
|
94
|
+
data = np.nan_to_num(data, nan=0.5)
|
|
95
|
+
if (data < 0).any():
|
|
96
|
+
logger.warning("Negative values detected, clipping to 0 for NMF.")
|
|
97
|
+
data = np.clip(data, a_min=0, a_max=None)
|
|
98
|
+
|
|
99
|
+
model = NMF(
|
|
100
|
+
n_components=n_components,
|
|
101
|
+
init="nndsvda",
|
|
102
|
+
max_iter=max_iter,
|
|
103
|
+
random_state=random_state,
|
|
104
|
+
)
|
|
105
|
+
embedding = model.fit_transform(data)
|
|
106
|
+
components = model.components_.T
|
|
107
|
+
|
|
108
|
+
if subset_mask is not None:
|
|
109
|
+
components_matrix = np.zeros((adata.shape[1], components.shape[1]))
|
|
110
|
+
components_matrix[subset_mask, :] = components
|
|
111
|
+
else:
|
|
112
|
+
components_matrix = components
|
|
113
|
+
|
|
114
|
+
adata.obsm[embedding_key] = embedding
|
|
115
|
+
adata.varm[components_key] = components_matrix
|
|
116
|
+
adata.uns[uns_key] = {
|
|
117
|
+
"n_components": n_components,
|
|
118
|
+
"max_iter": max_iter,
|
|
119
|
+
"random_state": random_state,
|
|
120
|
+
"layer": layer,
|
|
121
|
+
"var_mask_provided": var_mask is not None,
|
|
122
|
+
"components_key": components_key,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
logger.info(
|
|
126
|
+
"Stored: adata.obsm['%s'] and adata.varm['%s']",
|
|
127
|
+
embedding_key,
|
|
128
|
+
components_key,
|
|
129
|
+
)
|
|
130
|
+
return adata
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Sequence
|
|
4
|
+
|
|
5
|
+
from smftools.logging_utils import get_logger
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import anndata as ad
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def calculate_pca(
|
|
16
|
+
adata: "ad.AnnData",
|
|
17
|
+
layer: str | None = "nan_half",
|
|
18
|
+
var_mask: "np.ndarray | Sequence[bool] | None" = None,
|
|
19
|
+
n_pcs: int = 15,
|
|
20
|
+
overwrite: bool = True,
|
|
21
|
+
output_suffix: str | None = None,
|
|
22
|
+
fill_nan: float | None = 0.5,
|
|
23
|
+
) -> "ad.AnnData":
|
|
24
|
+
"""Compute PCA and store scores in `.obsm` and loadings in `.varm`."""
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import scipy.sparse as sp
|
|
28
|
+
|
|
29
|
+
obsm_output = f"X_pca_{output_suffix}" if output_suffix else "X_pca"
|
|
30
|
+
varm_output = f"PCs_{output_suffix}" if output_suffix else "PCs"
|
|
31
|
+
|
|
32
|
+
if not overwrite and obsm_output in adata.obsm and varm_output in adata.varm:
|
|
33
|
+
logger.info(
|
|
34
|
+
"PCA outputs already exist and overwrite=False; skipping (%s, %s).",
|
|
35
|
+
obsm_output,
|
|
36
|
+
varm_output,
|
|
37
|
+
)
|
|
38
|
+
return adata
|
|
39
|
+
|
|
40
|
+
# --- Build feature subset mask (over vars) ---
|
|
41
|
+
if var_mask is not None:
|
|
42
|
+
subset_mask = np.asarray(var_mask, dtype=bool)
|
|
43
|
+
if subset_mask.ndim != 1 or subset_mask.shape[0] != adata.n_vars:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"var_mask must be a 1D boolean array with length matching adata.n_vars."
|
|
46
|
+
)
|
|
47
|
+
n_vars_used = int(subset_mask.sum())
|
|
48
|
+
if n_vars_used == 0:
|
|
49
|
+
raise ValueError("var_mask retained 0 features.")
|
|
50
|
+
logger.info(
|
|
51
|
+
"Subsetting vars: retained %d / %d features from var_mask",
|
|
52
|
+
n_vars_used,
|
|
53
|
+
adata.n_vars,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
subset_mask = slice(None)
|
|
57
|
+
n_vars_used = adata.n_vars
|
|
58
|
+
logger.info("No var_mask provided; using all %d features.", adata.n_vars)
|
|
59
|
+
|
|
60
|
+
# --- Pull matrix view ---
|
|
61
|
+
if layer is None:
|
|
62
|
+
matrix = adata.X
|
|
63
|
+
layer_used = None
|
|
64
|
+
else:
|
|
65
|
+
if layer not in adata.layers:
|
|
66
|
+
raise KeyError(
|
|
67
|
+
f"Layer {layer!r} not found in adata.layers. Available: {list(adata.layers.keys())}"
|
|
68
|
+
)
|
|
69
|
+
matrix = adata.layers[layer]
|
|
70
|
+
layer_used = layer
|
|
71
|
+
|
|
72
|
+
matrix = matrix[:, subset_mask] # slice view (sparse OK)
|
|
73
|
+
|
|
74
|
+
n_obs = matrix.shape[0]
|
|
75
|
+
if n_obs < 2:
|
|
76
|
+
raise ValueError(f"PCA requires at least 2 observations; got n_obs={n_obs}")
|
|
77
|
+
if n_vars_used < 1:
|
|
78
|
+
raise ValueError("PCA requires at least 1 feature.")
|
|
79
|
+
|
|
80
|
+
n_pcs_requested = int(n_pcs)
|
|
81
|
+
n_pcs_used = min(n_pcs_requested, n_obs, n_vars_used)
|
|
82
|
+
if n_pcs_used < 1:
|
|
83
|
+
raise ValueError(f"n_pcs_used became {n_pcs_used}; check inputs.")
|
|
84
|
+
|
|
85
|
+
# --- NaN handling (dense only; sparse usually won’t store NaNs) ---
|
|
86
|
+
if not sp.issparse(matrix):
|
|
87
|
+
X = np.asarray(matrix, dtype=np.float32)
|
|
88
|
+
if fill_nan is not None and np.isnan(X).any():
|
|
89
|
+
logger.warning("NaNs detected; filling NaNs with %s before PCA.", fill_nan)
|
|
90
|
+
X = np.nan_to_num(X, nan=float(fill_nan))
|
|
91
|
+
else:
|
|
92
|
+
X = matrix # keep sparse
|
|
93
|
+
|
|
94
|
+
# --- PCA ---
|
|
95
|
+
# Prefer sklearn's randomized PCA for speed on big matrices.
|
|
96
|
+
used_sklearn = False
|
|
97
|
+
try:
|
|
98
|
+
sklearn = require("sklearn", extra="ml", purpose="PCA computation")
|
|
99
|
+
from sklearn.decomposition import PCA, TruncatedSVD
|
|
100
|
+
|
|
101
|
+
if sp.issparse(X):
|
|
102
|
+
# TruncatedSVD works on sparse without centering; good approximation.
|
|
103
|
+
# If you *need* centered PCA on sparse, you'd need different machinery.
|
|
104
|
+
logger.info("Running TruncatedSVD (sparse) with n_components=%d", n_pcs_used)
|
|
105
|
+
model = TruncatedSVD(n_components=n_pcs_used, random_state=0)
|
|
106
|
+
scores = model.fit_transform(X) # (n_obs, n_pcs)
|
|
107
|
+
loadings = model.components_.T # (n_vars_used, n_pcs)
|
|
108
|
+
mean = None
|
|
109
|
+
explained_variance_ratio = getattr(model, "explained_variance_ratio_", None)
|
|
110
|
+
else:
|
|
111
|
+
logger.info(
|
|
112
|
+
"Running sklearn PCA with n_components=%d (svd_solver=randomized)", n_pcs_used
|
|
113
|
+
)
|
|
114
|
+
model = PCA(n_components=n_pcs_used, svd_solver="randomized", random_state=0)
|
|
115
|
+
scores = model.fit_transform(X) # (n_obs, n_pcs)
|
|
116
|
+
loadings = model.components_.T # (n_vars_used, n_pcs)
|
|
117
|
+
mean = model.mean_
|
|
118
|
+
explained_variance_ratio = model.explained_variance_ratio_
|
|
119
|
+
|
|
120
|
+
used_sklearn = True
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
# Fallback to your manual SVD (dense only)
|
|
124
|
+
if sp.issparse(X):
|
|
125
|
+
raise RuntimeError(
|
|
126
|
+
"Sparse input PCA fallback is not implemented without sklearn. "
|
|
127
|
+
"Install scikit-learn (extra 'ml') or densify upstream."
|
|
128
|
+
) from e
|
|
129
|
+
|
|
130
|
+
import scipy.linalg as spla
|
|
131
|
+
|
|
132
|
+
logger.warning(
|
|
133
|
+
"sklearn PCA unavailable; falling back to full SVD (can be slow). Reason: %s", e
|
|
134
|
+
)
|
|
135
|
+
Xd = np.asarray(X, dtype=np.float64)
|
|
136
|
+
mean = Xd.mean(axis=0)
|
|
137
|
+
centered = Xd - mean
|
|
138
|
+
u, s, vt = spla.svd(centered, full_matrices=False)
|
|
139
|
+
u = u[:, :n_pcs_used]
|
|
140
|
+
s = s[:n_pcs_used]
|
|
141
|
+
vt = vt[:n_pcs_used]
|
|
142
|
+
scores = u * s
|
|
143
|
+
loadings = vt.T
|
|
144
|
+
explained_variance_ratio = None
|
|
145
|
+
|
|
146
|
+
# --- Store scores (obsm) ---
|
|
147
|
+
adata.obsm[obsm_output] = scores
|
|
148
|
+
|
|
149
|
+
# --- Store loadings (varm) with original var dimension ---
|
|
150
|
+
pc_matrix = np.zeros((adata.n_vars, n_pcs_used), dtype=np.float32)
|
|
151
|
+
if isinstance(subset_mask, slice):
|
|
152
|
+
pc_matrix[:, :] = loadings
|
|
153
|
+
else:
|
|
154
|
+
pc_matrix[subset_mask, :] = loadings.astype(np.float32, copy=False)
|
|
155
|
+
|
|
156
|
+
adata.varm[varm_output] = pc_matrix
|
|
157
|
+
|
|
158
|
+
# --- Metadata ---
|
|
159
|
+
adata.uns[obsm_output] = {
|
|
160
|
+
"params": {
|
|
161
|
+
"layer": layer_used,
|
|
162
|
+
"var_mask_provided": var_mask is not None,
|
|
163
|
+
"n_pcs_requested": n_pcs_requested,
|
|
164
|
+
"n_pcs_used": int(n_pcs_used),
|
|
165
|
+
"used_sklearn": used_sklearn,
|
|
166
|
+
"fill_nan": fill_nan,
|
|
167
|
+
"note_sparse": bool(sp.issparse(matrix)),
|
|
168
|
+
},
|
|
169
|
+
"explained_variance_ratio": explained_variance_ratio,
|
|
170
|
+
"mean": mean.tolist() if (mean is not None and isinstance(mean, np.ndarray)) else None,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
logger.info(
|
|
174
|
+
"Stored PCA: adata.obsm[%s] (%s) and adata.varm[%s] (%s)",
|
|
175
|
+
obsm_output,
|
|
176
|
+
scores.shape,
|
|
177
|
+
varm_output,
|
|
178
|
+
pc_matrix.shape,
|
|
179
|
+
)
|
|
180
|
+
return adata
|