smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/helpers.py +32 -6
  3. smftools/cli/hmm_adata.py +232 -31
  4. smftools/cli/latent_adata.py +318 -0
  5. smftools/cli/load_adata.py +77 -73
  6. smftools/cli/preprocess_adata.py +178 -53
  7. smftools/cli/spatial_adata.py +149 -101
  8. smftools/cli_entry.py +12 -0
  9. smftools/config/conversion.yaml +11 -1
  10. smftools/config/default.yaml +38 -1
  11. smftools/config/experiment_config.py +53 -1
  12. smftools/constants.py +65 -0
  13. smftools/hmm/HMM.py +88 -0
  14. smftools/informatics/__init__.py +6 -0
  15. smftools/informatics/bam_functions.py +358 -8
  16. smftools/informatics/converted_BAM_to_adata.py +584 -163
  17. smftools/informatics/h5ad_functions.py +115 -2
  18. smftools/informatics/modkit_extract_to_adata.py +1003 -425
  19. smftools/informatics/sequence_encoding.py +72 -0
  20. smftools/logging_utils.py +21 -2
  21. smftools/metadata.py +1 -1
  22. smftools/plotting/__init__.py +9 -0
  23. smftools/plotting/general_plotting.py +2411 -628
  24. smftools/plotting/hmm_plotting.py +85 -7
  25. smftools/preprocessing/__init__.py +1 -0
  26. smftools/preprocessing/append_base_context.py +17 -17
  27. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  28. smftools/preprocessing/calculate_consensus.py +1 -1
  29. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  30. smftools/readwrite.py +53 -17
  31. smftools/schema/anndata_schema_v1.yaml +15 -1
  32. smftools/tools/__init__.py +4 -0
  33. smftools/tools/calculate_leiden.py +57 -0
  34. smftools/tools/calculate_nmf.py +119 -0
  35. smftools/tools/calculate_umap.py +91 -8
  36. smftools/tools/rolling_nn_distance.py +235 -0
  37. smftools/tools/tensor_factorization.py +169 -0
  38. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
  39. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
  40. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  41. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  42. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from smftools.logging_utils import get_logger
8
+ from smftools.optional_imports import require
9
+
10
+ if TYPE_CHECKING:
11
+ import anndata as ad
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def calculate_nmf(
17
+ adata: "ad.AnnData",
18
+ layer: str | None = "nan_half",
19
+ var_filters: Sequence[str] | None = None,
20
+ n_components: int = 2,
21
+ max_iter: int = 200,
22
+ random_state: int = 0,
23
+ overwrite: bool = True,
24
+ embedding_key: str = "X_nmf",
25
+ components_key: str = "H_nmf",
26
+ uns_key: str = "nmf",
27
+ ) -> "ad.AnnData":
28
+ """Compute a low-dimensional NMF embedding.
29
+
30
+ Args:
31
+ adata: AnnData object to update.
32
+ layer: Layer name to use for NMF (``None`` uses ``adata.X``).
33
+ var_filters: Optional list of var masks to subset features.
34
+ n_components: Number of NMF components to compute.
35
+ max_iter: Maximum number of NMF iterations.
36
+ random_state: Random seed for the NMF initializer.
37
+ overwrite: Whether to recompute if the embedding already exists.
38
+ embedding_key: Key for the embedding in ``adata.obsm``.
39
+ components_key: Key for the components matrix in ``adata.varm``.
40
+ uns_key: Key for metadata stored in ``adata.uns``.
41
+
42
+ Returns:
43
+ anndata.AnnData: Updated AnnData object.
44
+ """
45
+ from scipy.sparse import issparse
46
+
47
+ require("sklearn", extra="ml-base", purpose="NMF calculation")
48
+ from sklearn.decomposition import NMF
49
+
50
+ has_embedding = embedding_key in adata.obsm
51
+ has_components = components_key in adata.varm
52
+ if has_embedding and has_components and not overwrite:
53
+ logger.info("NMF embedding and components already present; skipping recomputation.")
54
+ return adata
55
+ if has_embedding and not has_components and not overwrite:
56
+ logger.info("NMF embedding present without components; recomputing to store components.")
57
+
58
+ subset_mask = None
59
+ if var_filters:
60
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
61
+ adata_subset = adata[:, subset_mask].copy()
62
+ logger.info(
63
+ "Subsetting adata: retained %s features based on filters %s",
64
+ adata_subset.shape[1],
65
+ var_filters,
66
+ )
67
+ else:
68
+ adata_subset = adata.copy()
69
+ logger.info("No var filters provided. Using all features.")
70
+
71
+ data = adata_subset.layers[layer] if layer else adata_subset.X
72
+ if issparse(data):
73
+ data = data.copy()
74
+ if data.data.size and np.isnan(data.data).any():
75
+ logger.warning("NaNs detected in sparse data, filling with 0.5 before NMF.")
76
+ data.data = np.nan_to_num(data.data, nan=0.5)
77
+ if data.data.size and (data.data < 0).any():
78
+ logger.warning("Negative values detected in sparse data, clipping to 0 for NMF.")
79
+ data.data[data.data < 0] = 0
80
+ else:
81
+ if np.isnan(data).any():
82
+ logger.warning("NaNs detected, filling with 0.5 before NMF.")
83
+ data = np.nan_to_num(data, nan=0.5)
84
+ if (data < 0).any():
85
+ logger.warning("Negative values detected, clipping to 0 for NMF.")
86
+ data = np.clip(data, a_min=0, a_max=None)
87
+
88
+ model = NMF(
89
+ n_components=n_components,
90
+ init="nndsvda",
91
+ max_iter=max_iter,
92
+ random_state=random_state,
93
+ )
94
+ embedding = model.fit_transform(data)
95
+ components = model.components_.T
96
+
97
+ if subset_mask is not None:
98
+ components_matrix = np.zeros((adata.shape[1], components.shape[1]))
99
+ components_matrix[subset_mask, :] = components
100
+ else:
101
+ components_matrix = components
102
+
103
+ adata.obsm[embedding_key] = embedding
104
+ adata.varm[components_key] = components_matrix
105
+ adata.uns[uns_key] = {
106
+ "n_components": n_components,
107
+ "max_iter": max_iter,
108
+ "random_state": random_state,
109
+ "layer": layer,
110
+ "var_filters": list(var_filters) if var_filters else None,
111
+ "components_key": components_key,
112
+ }
113
+
114
+ logger.info(
115
+ "Stored: adata.obsm['%s'] and adata.varm['%s']",
116
+ embedding_key,
117
+ components_key,
118
+ )
119
+ return adata
@@ -19,6 +19,7 @@ def calculate_umap(
19
19
  knn_neighbors: int = 100,
20
20
  overwrite: bool = True,
21
21
  threads: int = 8,
22
+ random_state: int | None = 0,
22
23
  ) -> "ad.AnnData":
23
24
  """Compute PCA, neighbors, and UMAP embeddings.
24
25
 
@@ -37,9 +38,11 @@ def calculate_umap(
37
38
  import os
38
39
 
39
40
  import numpy as np
41
+ import scipy.linalg as spla
42
+ import scipy.sparse as sp
40
43
 
41
- sc = require("scanpy", extra="scanpy", purpose="UMAP calculation")
42
- from scipy.sparse import issparse
44
+ umap = require("umap", extra="umap", purpose="UMAP calculation")
45
+ pynndescent = require("pynndescent", extra="umap", purpose="KNN graph computation")
43
46
 
44
47
  os.environ["OMP_NUM_THREADS"] = str(threads)
45
48
 
@@ -59,7 +62,7 @@ def calculate_umap(
59
62
  # Step 2: NaN handling inside layer
60
63
  if layer:
61
64
  data = adata_subset.layers[layer]
62
- if not issparse(data):
65
+ if not sp.issparse(data):
63
66
  if np.isnan(data).any():
64
67
  logger.warning("NaNs detected, filling with 0.5 before PCA + neighbors.")
65
68
  data = np.nan_to_num(data, nan=0.5)
@@ -75,18 +78,98 @@ def calculate_umap(
75
78
  if "X_umap" not in adata_subset.obsm or overwrite:
76
79
  n_pcs = min(adata_subset.shape[1], n_pcs)
77
80
  logger.info("Running PCA with n_pcs=%s", n_pcs)
78
- sc.pp.pca(adata_subset, layer=layer)
79
- logger.info("Running neighborhood graph")
80
- sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
81
+
82
+ if layer:
83
+ matrix = adata_subset.layers[layer]
84
+ else:
85
+ matrix = adata_subset.X
86
+
87
+ if sp.issparse(matrix):
88
+ logger.warning("Converting sparse matrix to dense for PCA.")
89
+ matrix = matrix.toarray()
90
+
91
+ matrix = np.asarray(matrix, dtype=float)
92
+ mean = matrix.mean(axis=0)
93
+ centered = matrix - mean
94
+
95
+ if centered.shape[0] == 0 or centered.shape[1] == 0:
96
+ raise ValueError("PCA requires a non-empty matrix.")
97
+
98
+ if n_pcs <= 0:
99
+ raise ValueError("n_pcs must be positive.")
100
+
101
+ if centered.shape[1] <= n_pcs:
102
+ n_pcs = centered.shape[1]
103
+
104
+ if centered.shape[0] < n_pcs:
105
+ n_pcs = centered.shape[0]
106
+
107
+ u, s, vt = spla.svd(centered, full_matrices=False)
108
+
109
+ u = u[:, :n_pcs]
110
+ s = s[:n_pcs]
111
+ vt = vt[:n_pcs]
112
+
113
+ adata_subset.obsm["X_pca"] = u * s
114
+ adata_subset.varm["PCs"] = vt.T
115
+
116
+ logger.info("Running neighborhood graph with pynndescent (n_neighbors=%s)", knn_neighbors)
117
+ n_neighbors = min(knn_neighbors, max(1, adata_subset.n_obs - 1))
118
+ nn_index = pynndescent.NNDescent(
119
+ adata_subset.obsm["X_pca"],
120
+ n_neighbors=n_neighbors,
121
+ metric="euclidean",
122
+ random_state=random_state,
123
+ n_jobs=threads,
124
+ )
125
+ knn_indices, knn_dists = nn_index.neighbor_graph
126
+
127
+ rows = np.repeat(np.arange(adata_subset.n_obs), n_neighbors)
128
+ cols = knn_indices.reshape(-1)
129
+ distances = sp.coo_matrix(
130
+ (knn_dists.reshape(-1), (rows, cols)),
131
+ shape=(adata_subset.n_obs, adata_subset.n_obs),
132
+ ).tocsr()
133
+ adata_subset.obsp["distances"] = distances
134
+
81
135
  logger.info("Running UMAP")
82
- sc.tl.umap(adata_subset)
136
+ umap_model = umap.UMAP(
137
+ n_neighbors=n_neighbors,
138
+ n_components=2,
139
+ metric="euclidean",
140
+ random_state=random_state,
141
+ )
142
+ adata_subset.obsm["X_umap"] = umap_model.fit_transform(adata_subset.obsm["X_pca"])
143
+
144
+ try:
145
+ from umap.umap_ import fuzzy_simplicial_set
146
+
147
+ fuzzy_result = fuzzy_simplicial_set(
148
+ adata_subset.obsm["X_pca"],
149
+ n_neighbors=n_neighbors,
150
+ random_state=random_state,
151
+ metric="euclidean",
152
+ knn_indices=knn_indices,
153
+ knn_dists=knn_dists,
154
+ )
155
+ connectivities = fuzzy_result[0] if isinstance(fuzzy_result, tuple) else fuzzy_result
156
+ except TypeError:
157
+ connectivities = umap_model.graph_
158
+
159
+ adata_subset.obsp["connectivities"] = connectivities
83
160
 
84
161
  # Step 4: Store results in original adata
85
162
  adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
86
163
  adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
87
164
  adata.obsp["distances"] = adata_subset.obsp["distances"]
88
165
  adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
89
- adata.uns["neighbors"] = adata_subset.uns["neighbors"]
166
+ adata.uns["neighbors"] = {
167
+ "params": {
168
+ "n_neighbors": knn_neighbors,
169
+ "method": "pynndescent",
170
+ "metric": "euclidean",
171
+ }
172
+ }
90
173
 
91
174
  # Fix varm["PCs"] shape mismatch
92
175
  pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
@@ -0,0 +1,235 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import json
5
+ from typing import TYPE_CHECKING, Optional, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from smftools.logging_utils import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
18
+ """
19
+ Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
20
+ Safe w.r.t. contiguity/layout.
21
+ """
22
+ B = np.asarray(B, dtype=np.uint8)
23
+ packed_u8 = np.packbits(B, axis=1) # (n, ceil(w/8)) uint8
24
+
25
+ n, nb = packed_u8.shape
26
+ pad = (-nb) % 8
27
+ if pad:
28
+ packed_u8 = np.pad(packed_u8, ((0, 0), (0, pad)), mode="constant", constant_values=0)
29
+
30
+ packed_u8 = np.ascontiguousarray(packed_u8)
31
+
32
+ # group 8 bytes -> uint64
33
+ packed_u64 = packed_u8.reshape(n, -1, 8).view(np.uint64).reshape(n, -1)
34
+ return packed_u64
35
+
36
+
37
+ def _popcount_u64_matrix(A_u64: np.ndarray) -> np.ndarray:
38
+ """
39
+ Popcount for an array of uint64, vectorized and portable across NumPy versions.
40
+
41
+ Returns an integer array with the SAME SHAPE as A_u64.
42
+ """
43
+ A_u64 = np.ascontiguousarray(A_u64)
44
+ # View as bytes; IMPORTANT: reshape to add a trailing byte axis of length 8
45
+ b = A_u64.view(np.uint8).reshape(A_u64.shape + (8,))
46
+ # unpack bits within that byte axis -> (..., 64), then sum
47
+ return np.unpackbits(b, axis=-1).sum(axis=-1)
48
+
49
+
50
+ def rolling_window_nn_distance(
51
+ adata,
52
+ layer: Optional[str] = None,
53
+ window: int = 15,
54
+ step: int = 2,
55
+ min_overlap: int = 10,
56
+ return_fraction: bool = True,
57
+ block_rows: int = 256,
58
+ block_cols: int = 2048,
59
+ store_obsm: Optional[str] = "rolling_nn_dist",
60
+ ) -> Tuple[np.ndarray, np.ndarray]:
61
+ """
62
+ Rolling-window nearest-neighbor distance per read, overlap-aware.
63
+
64
+ Distance between reads i,j in a window:
65
+ - use only positions where BOTH are observed (non-NaN)
66
+ - require overlap >= min_overlap
67
+ - mismatch = count(x_i != x_j) over overlapped positions
68
+ - distance = mismatch/overlap (if return_fraction) else mismatch
69
+
70
+ Returns
71
+ -------
72
+ out : (n_obs, n_windows) float
73
+ Nearest-neighbor distance per read per window (NaN if no valid neighbor).
74
+ starts : (n_windows,) int
75
+ Window start indices in var-space.
76
+ """
77
+ X = adata.layers[layer] if layer is not None else adata.X
78
+ X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
79
+
80
+ n, p = X.shape
81
+ if window > p:
82
+ raise ValueError(f"window={window} is larger than n_vars={p}")
83
+ if window <= 0:
84
+ raise ValueError("window must be > 0")
85
+ if step <= 0:
86
+ raise ValueError("step must be > 0")
87
+ if min_overlap <= 0:
88
+ raise ValueError("min_overlap must be > 0")
89
+
90
+ starts = np.arange(0, p - window + 1, step, dtype=int)
91
+ nW = len(starts)
92
+ out = np.full((n, nW), np.nan, dtype=float)
93
+
94
+ for wi, s in enumerate(starts):
95
+ wX = X[:, s : s + window] # (n, window)
96
+
97
+ # observed mask; values as 0/1 where observed, 0 elsewhere
98
+ M = ~np.isnan(wX)
99
+ V = np.where(M, wX, 0).astype(np.float32)
100
+
101
+ # ensure binary 0/1
102
+ V = (V > 0).astype(np.uint8)
103
+
104
+ M64 = _pack_bool_to_u64(M)
105
+ V64 = _pack_bool_to_u64(V.astype(bool))
106
+
107
+ best = np.full(n, np.inf, dtype=float)
108
+
109
+ for i0 in range(0, n, block_rows):
110
+ i1 = min(n, i0 + block_rows)
111
+ Mi = M64[i0:i1] # (bi, nb)
112
+ Vi = V64[i0:i1]
113
+ bi = i1 - i0
114
+
115
+ local_best = np.full(bi, np.inf, dtype=float)
116
+
117
+ for j0 in range(0, n, block_cols):
118
+ j1 = min(n, j0 + block_cols)
119
+ Mj = M64[j0:j1] # (bj, nb)
120
+ Vj = V64[j0:j1]
121
+ bj = j1 - j0
122
+
123
+ overlap_counts = np.zeros((bi, bj), dtype=np.uint16)
124
+ mismatch_counts = np.zeros((bi, bj), dtype=np.uint16)
125
+
126
+ for k in range(Mi.shape[1]):
127
+ ob = (Mi[:, k][:, None] & Mj[:, k][None, :]).astype(np.uint64)
128
+ overlap_counts += _popcount_u64_matrix(ob).astype(np.uint16)
129
+
130
+ mb = ((Vi[:, k][:, None] ^ Vj[:, k][None, :]) & ob).astype(np.uint64)
131
+ mismatch_counts += _popcount_u64_matrix(mb).astype(np.uint16)
132
+
133
+ ok = overlap_counts >= min_overlap
134
+ if not np.any(ok):
135
+ continue
136
+
137
+ dist = np.full((bi, bj), np.inf, dtype=float)
138
+ if return_fraction:
139
+ dist[ok] = mismatch_counts[ok] / overlap_counts[ok]
140
+ else:
141
+ dist[ok] = mismatch_counts[ok].astype(float)
142
+
143
+ # exclude self comparisons (diagonal) when blocks overlap
144
+ if (i0 <= j1) and (j0 <= i1):
145
+ ii = np.arange(i0, i1)
146
+ jj = ii[(ii >= j0) & (ii < j1)]
147
+ if jj.size:
148
+ dist[(jj - i0), (jj - j0)] = np.inf
149
+
150
+ local_best = np.minimum(local_best, dist.min(axis=1))
151
+
152
+ best[i0:i1] = local_best
153
+
154
+ best[~np.isfinite(best)] = np.nan
155
+ out[:, wi] = best
156
+
157
+ if store_obsm is not None:
158
+ adata.obsm[store_obsm] = out
159
+ adata.uns[f"{store_obsm}_starts"] = starts
160
+ adata.uns[f"{store_obsm}_window"] = int(window)
161
+ adata.uns[f"{store_obsm}_step"] = int(step)
162
+ adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
163
+ adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
164
+ adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
165
+
166
+ return out, starts
167
+
168
+
169
+ def assign_rolling_nn_results(
170
+ parent_adata: "ad.AnnData",
171
+ subset_adata: "ad.AnnData",
172
+ values: np.ndarray,
173
+ starts: np.ndarray,
174
+ obsm_key: str,
175
+ window: int,
176
+ step: int,
177
+ min_overlap: int,
178
+ return_fraction: bool,
179
+ layer: Optional[str],
180
+ ) -> None:
181
+ """
182
+ Assign rolling NN results computed on a subset back onto a parent AnnData.
183
+
184
+ Parameters
185
+ ----------
186
+ parent_adata : AnnData
187
+ Parent AnnData that should store the combined results.
188
+ subset_adata : AnnData
189
+ Subset AnnData used to compute `values`.
190
+ values : np.ndarray
191
+ Rolling NN output with shape (n_subset_obs, n_windows).
192
+ starts : np.ndarray
193
+ Window start indices corresponding to `values`.
194
+ obsm_key : str
195
+ Key to store results under in parent_adata.obsm.
196
+ window : int
197
+ Rolling window size (stored in parent_adata.uns).
198
+ step : int
199
+ Rolling window step size (stored in parent_adata.uns).
200
+ min_overlap : int
201
+ Minimum overlap (stored in parent_adata.uns).
202
+ return_fraction : bool
203
+ Whether distances are fractional (stored in parent_adata.uns).
204
+ layer : str | None
205
+ Layer used for calculations (stored in parent_adata.uns).
206
+ """
207
+ n_obs = parent_adata.n_obs
208
+ n_windows = values.shape[1]
209
+
210
+ if obsm_key not in parent_adata.obsm:
211
+ parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
212
+ parent_adata.uns[f"{obsm_key}_starts"] = starts
213
+ parent_adata.uns[f"{obsm_key}_window"] = int(window)
214
+ parent_adata.uns[f"{obsm_key}_step"] = int(step)
215
+ parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
216
+ parent_adata.uns[f"{obsm_key}_return_fraction"] = bool(return_fraction)
217
+ parent_adata.uns[f"{obsm_key}_layer"] = layer if layer is not None else "X"
218
+ else:
219
+ existing = parent_adata.obsm[obsm_key]
220
+ if existing.shape[1] != n_windows:
221
+ raise ValueError(
222
+ f"Existing obsm[{obsm_key!r}] has {existing.shape[1]} windows; "
223
+ f"new values have {n_windows} windows."
224
+ )
225
+ existing_starts = parent_adata.uns.get(f"{obsm_key}_starts")
226
+ if existing_starts is not None and not np.array_equal(existing_starts, starts):
227
+ raise ValueError(
228
+ f"Existing obsm[{obsm_key!r}] has different window starts than new values."
229
+ )
230
+
231
+ parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
232
+ if (parent_indexer < 0).any():
233
+ raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
234
+
235
+ parent_adata.obsm[obsm_key][parent_indexer, :] = values
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Iterable, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def build_sequence_one_hot_and_mask(
18
+ encoded_sequences: np.ndarray,
19
+ *,
20
+ bases: Sequence[str] = ("A", "C", "G", "T"),
21
+ dtype: np.dtype | type[np.floating] = np.float32,
22
+ ) -> tuple[np.ndarray, np.ndarray]:
23
+ """Build one-hot encoded reads and a seen/unseen mask.
24
+
25
+ Args:
26
+ encoded_sequences: Integer-encoded sequences shaped (n_reads, seq_len).
27
+ bases: Bases to one-hot encode.
28
+ dtype: Output dtype for the one-hot tensor.
29
+
30
+ Returns:
31
+ Tuple of (one_hot_tensor, mask) where:
32
+ - one_hot_tensor: (n_reads, seq_len, n_bases)
33
+ - mask: (n_reads, seq_len) boolean array indicating seen bases.
34
+ """
35
+ encoded = np.asarray(encoded_sequences)
36
+ if encoded.ndim != 2:
37
+ raise ValueError(
38
+ f"encoded_sequences must be 2D with shape (n_reads, seq_len); got {encoded.shape}."
39
+ )
40
+
41
+ base_values = np.array(
42
+ [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in bases],
43
+ dtype=encoded.dtype,
44
+ )
45
+
46
+ if np.issubdtype(encoded.dtype, np.floating):
47
+ encoded = encoded.copy()
48
+ encoded[np.isnan(encoded)] = -1
49
+
50
+ mask = np.isin(encoded, base_values)
51
+ one_hot = np.zeros((*encoded.shape, len(base_values)), dtype=dtype)
52
+
53
+ for idx, base_value in enumerate(base_values):
54
+ one_hot[..., idx] = encoded == base_value
55
+
56
+ return one_hot, mask
57
+
58
+
59
+ def calculate_sequence_cp_decomposition(
60
+ adata: "ad.AnnData",
61
+ *,
62
+ layer: str,
63
+ rank: int = 5,
64
+ n_iter_max: int = 100,
65
+ random_state: int = 0,
66
+ overwrite: bool = True,
67
+ embedding_key: str = "X_cp_sequence",
68
+ components_key: str = "H_cp_sequence",
69
+ uns_key: str = "cp_sequence",
70
+ bases: Iterable[str] = ("A", "C", "G", "T"),
71
+ backend: str = "pytorch",
72
+ show_progress: bool = False,
73
+ init: str = "random",
74
+ ) -> "ad.AnnData":
75
+ """Compute CP decomposition on one-hot encoded sequence data with masking.
76
+
77
+ Args:
78
+ adata: AnnData object to update.
79
+ layer: Layer name containing integer-encoded sequences.
80
+ rank: CP rank.
81
+ n_iter_max: Maximum number of iterations for the solver.
82
+ random_state: Random seed for initialization.
83
+ overwrite: Whether to recompute if the embedding already exists.
84
+ embedding_key: Key for embedding in ``adata.obsm``.
85
+ components_key: Key for position factors in ``adata.varm``.
86
+ uns_key: Key for metadata stored in ``adata.uns``.
87
+ bases: Bases to one-hot encode (in order).
88
+ backend: Tensorly backend to use (``numpy`` or ``pytorch``).
89
+ show_progress: Whether to display progress during factorization if supported.
90
+
91
+ Returns:
92
+ Updated AnnData object containing the CP decomposition outputs.
93
+ """
94
+ if embedding_key in adata.obsm and components_key in adata.varm and not overwrite:
95
+ logger.info("CP embedding and components already present; skipping recomputation.")
96
+ return adata
97
+
98
+ if backend not in {"numpy", "pytorch"}:
99
+ raise ValueError(f"Unsupported backend '{backend}'. Use 'numpy' or 'pytorch'.")
100
+
101
+ tensorly = require("tensorly", extra="ml-base", purpose="CP decomposition")
102
+ from tensorly.decomposition import parafac
103
+
104
+ tensorly.set_backend(backend)
105
+
106
+ if layer not in adata.layers:
107
+ raise KeyError(f"Layer '{layer}' not found in adata.layers.")
108
+
109
+ one_hot, mask = build_sequence_one_hot_and_mask(adata.layers[layer], bases=tuple(bases))
110
+ mask_tensor = np.repeat(mask[:, :, None], one_hot.shape[2], axis=2)
111
+
112
+ device = "numpy"
113
+ if backend == "pytorch":
114
+ torch = require("torch", extra="ml-base", purpose="CP decomposition backend")
115
+ if torch.cuda.is_available():
116
+ device = torch.device("cuda")
117
+ elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
118
+ device = torch.device("mps")
119
+ else:
120
+ device = torch.device("cpu")
121
+
122
+ one_hot = torch.tensor(one_hot, dtype=torch.float32, device=device)
123
+ mask_tensor = torch.tensor(mask_tensor, dtype=torch.float32, device=device)
124
+
125
+ parafac_kwargs = {
126
+ "rank": rank,
127
+ "n_iter_max": n_iter_max,
128
+ "init": init,
129
+ "mask": mask_tensor,
130
+ "random_state": random_state,
131
+ }
132
+ import inspect
133
+
134
+ if "verbose" in inspect.signature(parafac).parameters:
135
+ parafac_kwargs["verbose"] = show_progress
136
+
137
+ cp = parafac(one_hot, **parafac_kwargs)
138
+
139
+ if backend == "pytorch":
140
+ weights = cp.weights.detach().cpu().numpy()
141
+ read_factors, position_factors, base_factors = [
142
+ factor.detach().cpu().numpy() for factor in cp.factors
143
+ ]
144
+ else:
145
+ weights = np.asarray(cp.weights)
146
+ read_factors, position_factors, base_factors = [np.asarray(f) for f in cp.factors]
147
+
148
+ adata.obsm[embedding_key] = read_factors
149
+ adata.varm[components_key] = position_factors
150
+ adata.uns[uns_key] = {
151
+ "rank": rank,
152
+ "n_iter_max": n_iter_max,
153
+ "random_state": random_state,
154
+ "layer": layer,
155
+ "components_key": components_key,
156
+ "weights": weights,
157
+ "base_factors": base_factors,
158
+ "base_labels": list(bases),
159
+ "backend": backend,
160
+ "device": str(device),
161
+ }
162
+
163
+ logger.info(
164
+ "Stored: adata.obsm['%s'], adata.varm['%s'], adata.uns['%s']",
165
+ embedding_key,
166
+ components_key,
167
+ uns_key,
168
+ )
169
+ return adata