smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +32 -6
- smftools/cli/hmm_adata.py +232 -31
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +77 -73
- smftools/cli/preprocess_adata.py +178 -53
- smftools/cli/spatial_adata.py +149 -101
- smftools/cli_entry.py +12 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +38 -1
- smftools/config/experiment_config.py +53 -1
- smftools/constants.py +65 -0
- smftools/hmm/HMM.py +88 -0
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/converted_BAM_to_adata.py +584 -163
- smftools/informatics/h5ad_functions.py +115 -2
- smftools/informatics/modkit_extract_to_adata.py +1003 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +9 -0
- smftools/plotting/general_plotting.py +2411 -628
- smftools/plotting/hmm_plotting.py +85 -7
- smftools/preprocessing/__init__.py +1 -0
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +4 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +91 -8
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from smftools.logging_utils import get_logger
|
|
8
|
+
from smftools.optional_imports import require
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import anndata as ad
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def calculate_nmf(
|
|
17
|
+
adata: "ad.AnnData",
|
|
18
|
+
layer: str | None = "nan_half",
|
|
19
|
+
var_filters: Sequence[str] | None = None,
|
|
20
|
+
n_components: int = 2,
|
|
21
|
+
max_iter: int = 200,
|
|
22
|
+
random_state: int = 0,
|
|
23
|
+
overwrite: bool = True,
|
|
24
|
+
embedding_key: str = "X_nmf",
|
|
25
|
+
components_key: str = "H_nmf",
|
|
26
|
+
uns_key: str = "nmf",
|
|
27
|
+
) -> "ad.AnnData":
|
|
28
|
+
"""Compute a low-dimensional NMF embedding.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
adata: AnnData object to update.
|
|
32
|
+
layer: Layer name to use for NMF (``None`` uses ``adata.X``).
|
|
33
|
+
var_filters: Optional list of var masks to subset features.
|
|
34
|
+
n_components: Number of NMF components to compute.
|
|
35
|
+
max_iter: Maximum number of NMF iterations.
|
|
36
|
+
random_state: Random seed for the NMF initializer.
|
|
37
|
+
overwrite: Whether to recompute if the embedding already exists.
|
|
38
|
+
embedding_key: Key for the embedding in ``adata.obsm``.
|
|
39
|
+
components_key: Key for the components matrix in ``adata.varm``.
|
|
40
|
+
uns_key: Key for metadata stored in ``adata.uns``.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
anndata.AnnData: Updated AnnData object.
|
|
44
|
+
"""
|
|
45
|
+
from scipy.sparse import issparse
|
|
46
|
+
|
|
47
|
+
require("sklearn", extra="ml-base", purpose="NMF calculation")
|
|
48
|
+
from sklearn.decomposition import NMF
|
|
49
|
+
|
|
50
|
+
has_embedding = embedding_key in adata.obsm
|
|
51
|
+
has_components = components_key in adata.varm
|
|
52
|
+
if has_embedding and has_components and not overwrite:
|
|
53
|
+
logger.info("NMF embedding and components already present; skipping recomputation.")
|
|
54
|
+
return adata
|
|
55
|
+
if has_embedding and not has_components and not overwrite:
|
|
56
|
+
logger.info("NMF embedding present without components; recomputing to store components.")
|
|
57
|
+
|
|
58
|
+
subset_mask = None
|
|
59
|
+
if var_filters:
|
|
60
|
+
subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
|
|
61
|
+
adata_subset = adata[:, subset_mask].copy()
|
|
62
|
+
logger.info(
|
|
63
|
+
"Subsetting adata: retained %s features based on filters %s",
|
|
64
|
+
adata_subset.shape[1],
|
|
65
|
+
var_filters,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
adata_subset = adata.copy()
|
|
69
|
+
logger.info("No var filters provided. Using all features.")
|
|
70
|
+
|
|
71
|
+
data = adata_subset.layers[layer] if layer else adata_subset.X
|
|
72
|
+
if issparse(data):
|
|
73
|
+
data = data.copy()
|
|
74
|
+
if data.data.size and np.isnan(data.data).any():
|
|
75
|
+
logger.warning("NaNs detected in sparse data, filling with 0.5 before NMF.")
|
|
76
|
+
data.data = np.nan_to_num(data.data, nan=0.5)
|
|
77
|
+
if data.data.size and (data.data < 0).any():
|
|
78
|
+
logger.warning("Negative values detected in sparse data, clipping to 0 for NMF.")
|
|
79
|
+
data.data[data.data < 0] = 0
|
|
80
|
+
else:
|
|
81
|
+
if np.isnan(data).any():
|
|
82
|
+
logger.warning("NaNs detected, filling with 0.5 before NMF.")
|
|
83
|
+
data = np.nan_to_num(data, nan=0.5)
|
|
84
|
+
if (data < 0).any():
|
|
85
|
+
logger.warning("Negative values detected, clipping to 0 for NMF.")
|
|
86
|
+
data = np.clip(data, a_min=0, a_max=None)
|
|
87
|
+
|
|
88
|
+
model = NMF(
|
|
89
|
+
n_components=n_components,
|
|
90
|
+
init="nndsvda",
|
|
91
|
+
max_iter=max_iter,
|
|
92
|
+
random_state=random_state,
|
|
93
|
+
)
|
|
94
|
+
embedding = model.fit_transform(data)
|
|
95
|
+
components = model.components_.T
|
|
96
|
+
|
|
97
|
+
if subset_mask is not None:
|
|
98
|
+
components_matrix = np.zeros((adata.shape[1], components.shape[1]))
|
|
99
|
+
components_matrix[subset_mask, :] = components
|
|
100
|
+
else:
|
|
101
|
+
components_matrix = components
|
|
102
|
+
|
|
103
|
+
adata.obsm[embedding_key] = embedding
|
|
104
|
+
adata.varm[components_key] = components_matrix
|
|
105
|
+
adata.uns[uns_key] = {
|
|
106
|
+
"n_components": n_components,
|
|
107
|
+
"max_iter": max_iter,
|
|
108
|
+
"random_state": random_state,
|
|
109
|
+
"layer": layer,
|
|
110
|
+
"var_filters": list(var_filters) if var_filters else None,
|
|
111
|
+
"components_key": components_key,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
logger.info(
|
|
115
|
+
"Stored: adata.obsm['%s'] and adata.varm['%s']",
|
|
116
|
+
embedding_key,
|
|
117
|
+
components_key,
|
|
118
|
+
)
|
|
119
|
+
return adata
|
smftools/tools/calculate_umap.py
CHANGED
|
@@ -19,6 +19,7 @@ def calculate_umap(
|
|
|
19
19
|
knn_neighbors: int = 100,
|
|
20
20
|
overwrite: bool = True,
|
|
21
21
|
threads: int = 8,
|
|
22
|
+
random_state: int | None = 0,
|
|
22
23
|
) -> "ad.AnnData":
|
|
23
24
|
"""Compute PCA, neighbors, and UMAP embeddings.
|
|
24
25
|
|
|
@@ -37,9 +38,11 @@ def calculate_umap(
|
|
|
37
38
|
import os
|
|
38
39
|
|
|
39
40
|
import numpy as np
|
|
41
|
+
import scipy.linalg as spla
|
|
42
|
+
import scipy.sparse as sp
|
|
40
43
|
|
|
41
|
-
|
|
42
|
-
|
|
44
|
+
umap = require("umap", extra="umap", purpose="UMAP calculation")
|
|
45
|
+
pynndescent = require("pynndescent", extra="umap", purpose="KNN graph computation")
|
|
43
46
|
|
|
44
47
|
os.environ["OMP_NUM_THREADS"] = str(threads)
|
|
45
48
|
|
|
@@ -59,7 +62,7 @@ def calculate_umap(
|
|
|
59
62
|
# Step 2: NaN handling inside layer
|
|
60
63
|
if layer:
|
|
61
64
|
data = adata_subset.layers[layer]
|
|
62
|
-
if not issparse(data):
|
|
65
|
+
if not sp.issparse(data):
|
|
63
66
|
if np.isnan(data).any():
|
|
64
67
|
logger.warning("NaNs detected, filling with 0.5 before PCA + neighbors.")
|
|
65
68
|
data = np.nan_to_num(data, nan=0.5)
|
|
@@ -75,18 +78,98 @@ def calculate_umap(
|
|
|
75
78
|
if "X_umap" not in adata_subset.obsm or overwrite:
|
|
76
79
|
n_pcs = min(adata_subset.shape[1], n_pcs)
|
|
77
80
|
logger.info("Running PCA with n_pcs=%s", n_pcs)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
+
|
|
82
|
+
if layer:
|
|
83
|
+
matrix = adata_subset.layers[layer]
|
|
84
|
+
else:
|
|
85
|
+
matrix = adata_subset.X
|
|
86
|
+
|
|
87
|
+
if sp.issparse(matrix):
|
|
88
|
+
logger.warning("Converting sparse matrix to dense for PCA.")
|
|
89
|
+
matrix = matrix.toarray()
|
|
90
|
+
|
|
91
|
+
matrix = np.asarray(matrix, dtype=float)
|
|
92
|
+
mean = matrix.mean(axis=0)
|
|
93
|
+
centered = matrix - mean
|
|
94
|
+
|
|
95
|
+
if centered.shape[0] == 0 or centered.shape[1] == 0:
|
|
96
|
+
raise ValueError("PCA requires a non-empty matrix.")
|
|
97
|
+
|
|
98
|
+
if n_pcs <= 0:
|
|
99
|
+
raise ValueError("n_pcs must be positive.")
|
|
100
|
+
|
|
101
|
+
if centered.shape[1] <= n_pcs:
|
|
102
|
+
n_pcs = centered.shape[1]
|
|
103
|
+
|
|
104
|
+
if centered.shape[0] < n_pcs:
|
|
105
|
+
n_pcs = centered.shape[0]
|
|
106
|
+
|
|
107
|
+
u, s, vt = spla.svd(centered, full_matrices=False)
|
|
108
|
+
|
|
109
|
+
u = u[:, :n_pcs]
|
|
110
|
+
s = s[:n_pcs]
|
|
111
|
+
vt = vt[:n_pcs]
|
|
112
|
+
|
|
113
|
+
adata_subset.obsm["X_pca"] = u * s
|
|
114
|
+
adata_subset.varm["PCs"] = vt.T
|
|
115
|
+
|
|
116
|
+
logger.info("Running neighborhood graph with pynndescent (n_neighbors=%s)", knn_neighbors)
|
|
117
|
+
n_neighbors = min(knn_neighbors, max(1, adata_subset.n_obs - 1))
|
|
118
|
+
nn_index = pynndescent.NNDescent(
|
|
119
|
+
adata_subset.obsm["X_pca"],
|
|
120
|
+
n_neighbors=n_neighbors,
|
|
121
|
+
metric="euclidean",
|
|
122
|
+
random_state=random_state,
|
|
123
|
+
n_jobs=threads,
|
|
124
|
+
)
|
|
125
|
+
knn_indices, knn_dists = nn_index.neighbor_graph
|
|
126
|
+
|
|
127
|
+
rows = np.repeat(np.arange(adata_subset.n_obs), n_neighbors)
|
|
128
|
+
cols = knn_indices.reshape(-1)
|
|
129
|
+
distances = sp.coo_matrix(
|
|
130
|
+
(knn_dists.reshape(-1), (rows, cols)),
|
|
131
|
+
shape=(adata_subset.n_obs, adata_subset.n_obs),
|
|
132
|
+
).tocsr()
|
|
133
|
+
adata_subset.obsp["distances"] = distances
|
|
134
|
+
|
|
81
135
|
logger.info("Running UMAP")
|
|
82
|
-
|
|
136
|
+
umap_model = umap.UMAP(
|
|
137
|
+
n_neighbors=n_neighbors,
|
|
138
|
+
n_components=2,
|
|
139
|
+
metric="euclidean",
|
|
140
|
+
random_state=random_state,
|
|
141
|
+
)
|
|
142
|
+
adata_subset.obsm["X_umap"] = umap_model.fit_transform(adata_subset.obsm["X_pca"])
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
from umap.umap_ import fuzzy_simplicial_set
|
|
146
|
+
|
|
147
|
+
fuzzy_result = fuzzy_simplicial_set(
|
|
148
|
+
adata_subset.obsm["X_pca"],
|
|
149
|
+
n_neighbors=n_neighbors,
|
|
150
|
+
random_state=random_state,
|
|
151
|
+
metric="euclidean",
|
|
152
|
+
knn_indices=knn_indices,
|
|
153
|
+
knn_dists=knn_dists,
|
|
154
|
+
)
|
|
155
|
+
connectivities = fuzzy_result[0] if isinstance(fuzzy_result, tuple) else fuzzy_result
|
|
156
|
+
except TypeError:
|
|
157
|
+
connectivities = umap_model.graph_
|
|
158
|
+
|
|
159
|
+
adata_subset.obsp["connectivities"] = connectivities
|
|
83
160
|
|
|
84
161
|
# Step 4: Store results in original adata
|
|
85
162
|
adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
|
|
86
163
|
adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
|
|
87
164
|
adata.obsp["distances"] = adata_subset.obsp["distances"]
|
|
88
165
|
adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
|
|
89
|
-
adata.uns["neighbors"] =
|
|
166
|
+
adata.uns["neighbors"] = {
|
|
167
|
+
"params": {
|
|
168
|
+
"n_neighbors": knn_neighbors,
|
|
169
|
+
"method": "pynndescent",
|
|
170
|
+
"metric": "euclidean",
|
|
171
|
+
}
|
|
172
|
+
}
|
|
90
173
|
|
|
91
174
|
# Fix varm["PCs"] shape mismatch
|
|
92
175
|
pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import json
|
|
5
|
+
from typing import TYPE_CHECKING, Optional, Sequence, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from smftools.logging_utils import get_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import anndata as ad
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
|
|
20
|
+
Safe w.r.t. contiguity/layout.
|
|
21
|
+
"""
|
|
22
|
+
B = np.asarray(B, dtype=np.uint8)
|
|
23
|
+
packed_u8 = np.packbits(B, axis=1) # (n, ceil(w/8)) uint8
|
|
24
|
+
|
|
25
|
+
n, nb = packed_u8.shape
|
|
26
|
+
pad = (-nb) % 8
|
|
27
|
+
if pad:
|
|
28
|
+
packed_u8 = np.pad(packed_u8, ((0, 0), (0, pad)), mode="constant", constant_values=0)
|
|
29
|
+
|
|
30
|
+
packed_u8 = np.ascontiguousarray(packed_u8)
|
|
31
|
+
|
|
32
|
+
# group 8 bytes -> uint64
|
|
33
|
+
packed_u64 = packed_u8.reshape(n, -1, 8).view(np.uint64).reshape(n, -1)
|
|
34
|
+
return packed_u64
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _popcount_u64_matrix(A_u64: np.ndarray) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Popcount for an array of uint64, vectorized and portable across NumPy versions.
|
|
40
|
+
|
|
41
|
+
Returns an integer array with the SAME SHAPE as A_u64.
|
|
42
|
+
"""
|
|
43
|
+
A_u64 = np.ascontiguousarray(A_u64)
|
|
44
|
+
# View as bytes; IMPORTANT: reshape to add a trailing byte axis of length 8
|
|
45
|
+
b = A_u64.view(np.uint8).reshape(A_u64.shape + (8,))
|
|
46
|
+
# unpack bits within that byte axis -> (..., 64), then sum
|
|
47
|
+
return np.unpackbits(b, axis=-1).sum(axis=-1)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def rolling_window_nn_distance(
|
|
51
|
+
adata,
|
|
52
|
+
layer: Optional[str] = None,
|
|
53
|
+
window: int = 15,
|
|
54
|
+
step: int = 2,
|
|
55
|
+
min_overlap: int = 10,
|
|
56
|
+
return_fraction: bool = True,
|
|
57
|
+
block_rows: int = 256,
|
|
58
|
+
block_cols: int = 2048,
|
|
59
|
+
store_obsm: Optional[str] = "rolling_nn_dist",
|
|
60
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
61
|
+
"""
|
|
62
|
+
Rolling-window nearest-neighbor distance per read, overlap-aware.
|
|
63
|
+
|
|
64
|
+
Distance between reads i,j in a window:
|
|
65
|
+
- use only positions where BOTH are observed (non-NaN)
|
|
66
|
+
- require overlap >= min_overlap
|
|
67
|
+
- mismatch = count(x_i != x_j) over overlapped positions
|
|
68
|
+
- distance = mismatch/overlap (if return_fraction) else mismatch
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
out : (n_obs, n_windows) float
|
|
73
|
+
Nearest-neighbor distance per read per window (NaN if no valid neighbor).
|
|
74
|
+
starts : (n_windows,) int
|
|
75
|
+
Window start indices in var-space.
|
|
76
|
+
"""
|
|
77
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
|
78
|
+
X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
|
|
79
|
+
|
|
80
|
+
n, p = X.shape
|
|
81
|
+
if window > p:
|
|
82
|
+
raise ValueError(f"window={window} is larger than n_vars={p}")
|
|
83
|
+
if window <= 0:
|
|
84
|
+
raise ValueError("window must be > 0")
|
|
85
|
+
if step <= 0:
|
|
86
|
+
raise ValueError("step must be > 0")
|
|
87
|
+
if min_overlap <= 0:
|
|
88
|
+
raise ValueError("min_overlap must be > 0")
|
|
89
|
+
|
|
90
|
+
starts = np.arange(0, p - window + 1, step, dtype=int)
|
|
91
|
+
nW = len(starts)
|
|
92
|
+
out = np.full((n, nW), np.nan, dtype=float)
|
|
93
|
+
|
|
94
|
+
for wi, s in enumerate(starts):
|
|
95
|
+
wX = X[:, s : s + window] # (n, window)
|
|
96
|
+
|
|
97
|
+
# observed mask; values as 0/1 where observed, 0 elsewhere
|
|
98
|
+
M = ~np.isnan(wX)
|
|
99
|
+
V = np.where(M, wX, 0).astype(np.float32)
|
|
100
|
+
|
|
101
|
+
# ensure binary 0/1
|
|
102
|
+
V = (V > 0).astype(np.uint8)
|
|
103
|
+
|
|
104
|
+
M64 = _pack_bool_to_u64(M)
|
|
105
|
+
V64 = _pack_bool_to_u64(V.astype(bool))
|
|
106
|
+
|
|
107
|
+
best = np.full(n, np.inf, dtype=float)
|
|
108
|
+
|
|
109
|
+
for i0 in range(0, n, block_rows):
|
|
110
|
+
i1 = min(n, i0 + block_rows)
|
|
111
|
+
Mi = M64[i0:i1] # (bi, nb)
|
|
112
|
+
Vi = V64[i0:i1]
|
|
113
|
+
bi = i1 - i0
|
|
114
|
+
|
|
115
|
+
local_best = np.full(bi, np.inf, dtype=float)
|
|
116
|
+
|
|
117
|
+
for j0 in range(0, n, block_cols):
|
|
118
|
+
j1 = min(n, j0 + block_cols)
|
|
119
|
+
Mj = M64[j0:j1] # (bj, nb)
|
|
120
|
+
Vj = V64[j0:j1]
|
|
121
|
+
bj = j1 - j0
|
|
122
|
+
|
|
123
|
+
overlap_counts = np.zeros((bi, bj), dtype=np.uint16)
|
|
124
|
+
mismatch_counts = np.zeros((bi, bj), dtype=np.uint16)
|
|
125
|
+
|
|
126
|
+
for k in range(Mi.shape[1]):
|
|
127
|
+
ob = (Mi[:, k][:, None] & Mj[:, k][None, :]).astype(np.uint64)
|
|
128
|
+
overlap_counts += _popcount_u64_matrix(ob).astype(np.uint16)
|
|
129
|
+
|
|
130
|
+
mb = ((Vi[:, k][:, None] ^ Vj[:, k][None, :]) & ob).astype(np.uint64)
|
|
131
|
+
mismatch_counts += _popcount_u64_matrix(mb).astype(np.uint16)
|
|
132
|
+
|
|
133
|
+
ok = overlap_counts >= min_overlap
|
|
134
|
+
if not np.any(ok):
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
dist = np.full((bi, bj), np.inf, dtype=float)
|
|
138
|
+
if return_fraction:
|
|
139
|
+
dist[ok] = mismatch_counts[ok] / overlap_counts[ok]
|
|
140
|
+
else:
|
|
141
|
+
dist[ok] = mismatch_counts[ok].astype(float)
|
|
142
|
+
|
|
143
|
+
# exclude self comparisons (diagonal) when blocks overlap
|
|
144
|
+
if (i0 <= j1) and (j0 <= i1):
|
|
145
|
+
ii = np.arange(i0, i1)
|
|
146
|
+
jj = ii[(ii >= j0) & (ii < j1)]
|
|
147
|
+
if jj.size:
|
|
148
|
+
dist[(jj - i0), (jj - j0)] = np.inf
|
|
149
|
+
|
|
150
|
+
local_best = np.minimum(local_best, dist.min(axis=1))
|
|
151
|
+
|
|
152
|
+
best[i0:i1] = local_best
|
|
153
|
+
|
|
154
|
+
best[~np.isfinite(best)] = np.nan
|
|
155
|
+
out[:, wi] = best
|
|
156
|
+
|
|
157
|
+
if store_obsm is not None:
|
|
158
|
+
adata.obsm[store_obsm] = out
|
|
159
|
+
adata.uns[f"{store_obsm}_starts"] = starts
|
|
160
|
+
adata.uns[f"{store_obsm}_window"] = int(window)
|
|
161
|
+
adata.uns[f"{store_obsm}_step"] = int(step)
|
|
162
|
+
adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
|
|
163
|
+
adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
|
|
164
|
+
adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
|
|
165
|
+
|
|
166
|
+
return out, starts
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def assign_rolling_nn_results(
|
|
170
|
+
parent_adata: "ad.AnnData",
|
|
171
|
+
subset_adata: "ad.AnnData",
|
|
172
|
+
values: np.ndarray,
|
|
173
|
+
starts: np.ndarray,
|
|
174
|
+
obsm_key: str,
|
|
175
|
+
window: int,
|
|
176
|
+
step: int,
|
|
177
|
+
min_overlap: int,
|
|
178
|
+
return_fraction: bool,
|
|
179
|
+
layer: Optional[str],
|
|
180
|
+
) -> None:
|
|
181
|
+
"""
|
|
182
|
+
Assign rolling NN results computed on a subset back onto a parent AnnData.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
parent_adata : AnnData
|
|
187
|
+
Parent AnnData that should store the combined results.
|
|
188
|
+
subset_adata : AnnData
|
|
189
|
+
Subset AnnData used to compute `values`.
|
|
190
|
+
values : np.ndarray
|
|
191
|
+
Rolling NN output with shape (n_subset_obs, n_windows).
|
|
192
|
+
starts : np.ndarray
|
|
193
|
+
Window start indices corresponding to `values`.
|
|
194
|
+
obsm_key : str
|
|
195
|
+
Key to store results under in parent_adata.obsm.
|
|
196
|
+
window : int
|
|
197
|
+
Rolling window size (stored in parent_adata.uns).
|
|
198
|
+
step : int
|
|
199
|
+
Rolling window step size (stored in parent_adata.uns).
|
|
200
|
+
min_overlap : int
|
|
201
|
+
Minimum overlap (stored in parent_adata.uns).
|
|
202
|
+
return_fraction : bool
|
|
203
|
+
Whether distances are fractional (stored in parent_adata.uns).
|
|
204
|
+
layer : str | None
|
|
205
|
+
Layer used for calculations (stored in parent_adata.uns).
|
|
206
|
+
"""
|
|
207
|
+
n_obs = parent_adata.n_obs
|
|
208
|
+
n_windows = values.shape[1]
|
|
209
|
+
|
|
210
|
+
if obsm_key not in parent_adata.obsm:
|
|
211
|
+
parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
|
|
212
|
+
parent_adata.uns[f"{obsm_key}_starts"] = starts
|
|
213
|
+
parent_adata.uns[f"{obsm_key}_window"] = int(window)
|
|
214
|
+
parent_adata.uns[f"{obsm_key}_step"] = int(step)
|
|
215
|
+
parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
|
|
216
|
+
parent_adata.uns[f"{obsm_key}_return_fraction"] = bool(return_fraction)
|
|
217
|
+
parent_adata.uns[f"{obsm_key}_layer"] = layer if layer is not None else "X"
|
|
218
|
+
else:
|
|
219
|
+
existing = parent_adata.obsm[obsm_key]
|
|
220
|
+
if existing.shape[1] != n_windows:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Existing obsm[{obsm_key!r}] has {existing.shape[1]} windows; "
|
|
223
|
+
f"new values have {n_windows} windows."
|
|
224
|
+
)
|
|
225
|
+
existing_starts = parent_adata.uns.get(f"{obsm_key}_starts")
|
|
226
|
+
if existing_starts is not None and not np.array_equal(existing_starts, starts):
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"Existing obsm[{obsm_key!r}] has different window starts than new values."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
|
|
232
|
+
if (parent_indexer < 0).any():
|
|
233
|
+
raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
|
|
234
|
+
|
|
235
|
+
parent_adata.obsm[obsm_key][parent_indexer, :] = values
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Iterable, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
|
|
8
|
+
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import anndata as ad
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def build_sequence_one_hot_and_mask(
|
|
18
|
+
encoded_sequences: np.ndarray,
|
|
19
|
+
*,
|
|
20
|
+
bases: Sequence[str] = ("A", "C", "G", "T"),
|
|
21
|
+
dtype: np.dtype | type[np.floating] = np.float32,
|
|
22
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
23
|
+
"""Build one-hot encoded reads and a seen/unseen mask.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
encoded_sequences: Integer-encoded sequences shaped (n_reads, seq_len).
|
|
27
|
+
bases: Bases to one-hot encode.
|
|
28
|
+
dtype: Output dtype for the one-hot tensor.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tuple of (one_hot_tensor, mask) where:
|
|
32
|
+
- one_hot_tensor: (n_reads, seq_len, n_bases)
|
|
33
|
+
- mask: (n_reads, seq_len) boolean array indicating seen bases.
|
|
34
|
+
"""
|
|
35
|
+
encoded = np.asarray(encoded_sequences)
|
|
36
|
+
if encoded.ndim != 2:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"encoded_sequences must be 2D with shape (n_reads, seq_len); got {encoded.shape}."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
base_values = np.array(
|
|
42
|
+
[MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in bases],
|
|
43
|
+
dtype=encoded.dtype,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if np.issubdtype(encoded.dtype, np.floating):
|
|
47
|
+
encoded = encoded.copy()
|
|
48
|
+
encoded[np.isnan(encoded)] = -1
|
|
49
|
+
|
|
50
|
+
mask = np.isin(encoded, base_values)
|
|
51
|
+
one_hot = np.zeros((*encoded.shape, len(base_values)), dtype=dtype)
|
|
52
|
+
|
|
53
|
+
for idx, base_value in enumerate(base_values):
|
|
54
|
+
one_hot[..., idx] = encoded == base_value
|
|
55
|
+
|
|
56
|
+
return one_hot, mask
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def calculate_sequence_cp_decomposition(
|
|
60
|
+
adata: "ad.AnnData",
|
|
61
|
+
*,
|
|
62
|
+
layer: str,
|
|
63
|
+
rank: int = 5,
|
|
64
|
+
n_iter_max: int = 100,
|
|
65
|
+
random_state: int = 0,
|
|
66
|
+
overwrite: bool = True,
|
|
67
|
+
embedding_key: str = "X_cp_sequence",
|
|
68
|
+
components_key: str = "H_cp_sequence",
|
|
69
|
+
uns_key: str = "cp_sequence",
|
|
70
|
+
bases: Iterable[str] = ("A", "C", "G", "T"),
|
|
71
|
+
backend: str = "pytorch",
|
|
72
|
+
show_progress: bool = False,
|
|
73
|
+
init: str = "random",
|
|
74
|
+
) -> "ad.AnnData":
|
|
75
|
+
"""Compute CP decomposition on one-hot encoded sequence data with masking.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
adata: AnnData object to update.
|
|
79
|
+
layer: Layer name containing integer-encoded sequences.
|
|
80
|
+
rank: CP rank.
|
|
81
|
+
n_iter_max: Maximum number of iterations for the solver.
|
|
82
|
+
random_state: Random seed for initialization.
|
|
83
|
+
overwrite: Whether to recompute if the embedding already exists.
|
|
84
|
+
embedding_key: Key for embedding in ``adata.obsm``.
|
|
85
|
+
components_key: Key for position factors in ``adata.varm``.
|
|
86
|
+
uns_key: Key for metadata stored in ``adata.uns``.
|
|
87
|
+
bases: Bases to one-hot encode (in order).
|
|
88
|
+
backend: Tensorly backend to use (``numpy`` or ``pytorch``).
|
|
89
|
+
show_progress: Whether to display progress during factorization if supported.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Updated AnnData object containing the CP decomposition outputs.
|
|
93
|
+
"""
|
|
94
|
+
if embedding_key in adata.obsm and components_key in adata.varm and not overwrite:
|
|
95
|
+
logger.info("CP embedding and components already present; skipping recomputation.")
|
|
96
|
+
return adata
|
|
97
|
+
|
|
98
|
+
if backend not in {"numpy", "pytorch"}:
|
|
99
|
+
raise ValueError(f"Unsupported backend '{backend}'. Use 'numpy' or 'pytorch'.")
|
|
100
|
+
|
|
101
|
+
tensorly = require("tensorly", extra="ml-base", purpose="CP decomposition")
|
|
102
|
+
from tensorly.decomposition import parafac
|
|
103
|
+
|
|
104
|
+
tensorly.set_backend(backend)
|
|
105
|
+
|
|
106
|
+
if layer not in adata.layers:
|
|
107
|
+
raise KeyError(f"Layer '{layer}' not found in adata.layers.")
|
|
108
|
+
|
|
109
|
+
one_hot, mask = build_sequence_one_hot_and_mask(adata.layers[layer], bases=tuple(bases))
|
|
110
|
+
mask_tensor = np.repeat(mask[:, :, None], one_hot.shape[2], axis=2)
|
|
111
|
+
|
|
112
|
+
device = "numpy"
|
|
113
|
+
if backend == "pytorch":
|
|
114
|
+
torch = require("torch", extra="ml-base", purpose="CP decomposition backend")
|
|
115
|
+
if torch.cuda.is_available():
|
|
116
|
+
device = torch.device("cuda")
|
|
117
|
+
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
|
118
|
+
device = torch.device("mps")
|
|
119
|
+
else:
|
|
120
|
+
device = torch.device("cpu")
|
|
121
|
+
|
|
122
|
+
one_hot = torch.tensor(one_hot, dtype=torch.float32, device=device)
|
|
123
|
+
mask_tensor = torch.tensor(mask_tensor, dtype=torch.float32, device=device)
|
|
124
|
+
|
|
125
|
+
parafac_kwargs = {
|
|
126
|
+
"rank": rank,
|
|
127
|
+
"n_iter_max": n_iter_max,
|
|
128
|
+
"init": init,
|
|
129
|
+
"mask": mask_tensor,
|
|
130
|
+
"random_state": random_state,
|
|
131
|
+
}
|
|
132
|
+
import inspect
|
|
133
|
+
|
|
134
|
+
if "verbose" in inspect.signature(parafac).parameters:
|
|
135
|
+
parafac_kwargs["verbose"] = show_progress
|
|
136
|
+
|
|
137
|
+
cp = parafac(one_hot, **parafac_kwargs)
|
|
138
|
+
|
|
139
|
+
if backend == "pytorch":
|
|
140
|
+
weights = cp.weights.detach().cpu().numpy()
|
|
141
|
+
read_factors, position_factors, base_factors = [
|
|
142
|
+
factor.detach().cpu().numpy() for factor in cp.factors
|
|
143
|
+
]
|
|
144
|
+
else:
|
|
145
|
+
weights = np.asarray(cp.weights)
|
|
146
|
+
read_factors, position_factors, base_factors = [np.asarray(f) for f in cp.factors]
|
|
147
|
+
|
|
148
|
+
adata.obsm[embedding_key] = read_factors
|
|
149
|
+
adata.varm[components_key] = position_factors
|
|
150
|
+
adata.uns[uns_key] = {
|
|
151
|
+
"rank": rank,
|
|
152
|
+
"n_iter_max": n_iter_max,
|
|
153
|
+
"random_state": random_state,
|
|
154
|
+
"layer": layer,
|
|
155
|
+
"components_key": components_key,
|
|
156
|
+
"weights": weights,
|
|
157
|
+
"base_factors": base_factors,
|
|
158
|
+
"base_labels": list(bases),
|
|
159
|
+
"backend": backend,
|
|
160
|
+
"device": str(device),
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
logger.info(
|
|
164
|
+
"Stored: adata.obsm['%s'], adata.varm['%s'], adata.uns['%s']",
|
|
165
|
+
embedding_key,
|
|
166
|
+
components_key,
|
|
167
|
+
uns_key,
|
|
168
|
+
)
|
|
169
|
+
return adata
|