smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +39 -7
- smftools/_settings.py +2 -0
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +34 -6
- smftools/cli/hmm_adata.py +239 -33
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +167 -131
- smftools/cli/preprocess_adata.py +180 -53
- smftools/cli/spatial_adata.py +152 -100
- smftools/cli_entry.py +38 -1
- smftools/config/__init__.py +2 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +42 -2
- smftools/config/experiment_config.py +59 -1
- smftools/constants.py +65 -0
- smftools/datasets/__init__.py +2 -0
- smftools/hmm/HMM.py +97 -3
- smftools/hmm/__init__.py +24 -13
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +2 -0
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +5 -2
- smftools/hmm/display_hmm.py +4 -1
- smftools/hmm/hmm_readwrite.py +7 -2
- smftools/hmm/nucleosome_hmm_refinement.py +2 -0
- smftools/informatics/__init__.py +59 -34
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +2 -0
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1093 -176
- smftools/informatics/basecalling.py +2 -0
- smftools/informatics/bed_functions.py +271 -61
- smftools/informatics/binarize_converted_base_identities.py +3 -0
- smftools/informatics/complement_base_list.py +2 -0
- smftools/informatics/converted_BAM_to_adata.py +641 -176
- smftools/informatics/fasta_functions.py +94 -10
- smftools/informatics/h5ad_functions.py +123 -4
- smftools/informatics/modkit_extract_to_adata.py +1019 -431
- smftools/informatics/modkit_functions.py +2 -0
- smftools/informatics/ohe.py +2 -0
- smftools/informatics/pod5_functions.py +3 -2
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/machine_learning/__init__.py +22 -6
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +18 -4
- smftools/machine_learning/data/preprocessing.py +2 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +2 -0
- smftools/machine_learning/evaluation/evaluators.py +14 -9
- smftools/machine_learning/inference/__init__.py +2 -0
- smftools/machine_learning/inference/inference_utils.py +2 -0
- smftools/machine_learning/inference/lightning_inference.py +6 -1
- smftools/machine_learning/inference/sklearn_inference.py +2 -0
- smftools/machine_learning/inference/sliding_window_inference.py +2 -0
- smftools/machine_learning/models/__init__.py +2 -0
- smftools/machine_learning/models/base.py +7 -2
- smftools/machine_learning/models/cnn.py +7 -2
- smftools/machine_learning/models/lightning_base.py +16 -11
- smftools/machine_learning/models/mlp.py +5 -1
- smftools/machine_learning/models/positional.py +7 -2
- smftools/machine_learning/models/rnn.py +5 -1
- smftools/machine_learning/models/sklearn_models.py +14 -9
- smftools/machine_learning/models/transformer.py +7 -2
- smftools/machine_learning/models/wrappers.py +6 -2
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +13 -3
- smftools/machine_learning/training/train_sklearn_model.py +2 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +5 -1
- smftools/machine_learning/utils/grl.py +5 -1
- smftools/metadata.py +1 -1
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +41 -31
- smftools/plotting/autocorrelation_plotting.py +9 -5
- smftools/plotting/classifiers.py +16 -4
- smftools/plotting/general_plotting.py +2415 -629
- smftools/plotting/hmm_plotting.py +97 -9
- smftools/plotting/position_stats.py +15 -7
- smftools/plotting/qc_plotting.py +6 -1
- smftools/preprocessing/__init__.py +36 -37
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/archived/calculate_complexity.py +2 -0
- smftools/preprocessing/archived/mark_duplicates.py +2 -0
- smftools/preprocessing/archived/preprocessing.py +2 -0
- smftools/preprocessing/archived/remove_duplicates.py +2 -0
- smftools/preprocessing/binary_layers_to_ohe.py +2 -1
- smftools/preprocessing/calculate_complexity_II.py +4 -1
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_pairwise_differences.py +2 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
- smftools/preprocessing/calculate_position_Youden.py +9 -2
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
- smftools/preprocessing/flag_duplicate_reads.py +42 -54
- smftools/preprocessing/make_dirs.py +2 -1
- smftools/preprocessing/min_non_diagonal.py +2 -0
- smftools/preprocessing/recipes.py +2 -0
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +30 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +2 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +2 -0
- smftools/tools/archived/subset_adata_v2.py +2 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +93 -8
- smftools/tools/cluster_adata_on_methylation.py +7 -1
- smftools/tools/position_stats.py +17 -27
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
- smftools-0.3.1.dist-info/RECORD +189 -0
- smftools-0.2.5.dist-info/RECORD +0 -181
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
smftools/tools/calculate_umap.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING, Sequence
|
|
4
4
|
|
|
5
5
|
from smftools.logging_utils import get_logger
|
|
6
|
+
from smftools.optional_imports import require
|
|
6
7
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
9
|
import anndata as ad
|
|
@@ -18,6 +19,7 @@ def calculate_umap(
|
|
|
18
19
|
knn_neighbors: int = 100,
|
|
19
20
|
overwrite: bool = True,
|
|
20
21
|
threads: int = 8,
|
|
22
|
+
random_state: int | None = 0,
|
|
21
23
|
) -> "ad.AnnData":
|
|
22
24
|
"""Compute PCA, neighbors, and UMAP embeddings.
|
|
23
25
|
|
|
@@ -36,8 +38,11 @@ def calculate_umap(
|
|
|
36
38
|
import os
|
|
37
39
|
|
|
38
40
|
import numpy as np
|
|
39
|
-
import
|
|
40
|
-
|
|
41
|
+
import scipy.linalg as spla
|
|
42
|
+
import scipy.sparse as sp
|
|
43
|
+
|
|
44
|
+
umap = require("umap", extra="umap", purpose="UMAP calculation")
|
|
45
|
+
pynndescent = require("pynndescent", extra="umap", purpose="KNN graph computation")
|
|
41
46
|
|
|
42
47
|
os.environ["OMP_NUM_THREADS"] = str(threads)
|
|
43
48
|
|
|
@@ -57,7 +62,7 @@ def calculate_umap(
|
|
|
57
62
|
# Step 2: NaN handling inside layer
|
|
58
63
|
if layer:
|
|
59
64
|
data = adata_subset.layers[layer]
|
|
60
|
-
if not issparse(data):
|
|
65
|
+
if not sp.issparse(data):
|
|
61
66
|
if np.isnan(data).any():
|
|
62
67
|
logger.warning("NaNs detected, filling with 0.5 before PCA + neighbors.")
|
|
63
68
|
data = np.nan_to_num(data, nan=0.5)
|
|
@@ -73,18 +78,98 @@ def calculate_umap(
|
|
|
73
78
|
if "X_umap" not in adata_subset.obsm or overwrite:
|
|
74
79
|
n_pcs = min(adata_subset.shape[1], n_pcs)
|
|
75
80
|
logger.info("Running PCA with n_pcs=%s", n_pcs)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
81
|
+
|
|
82
|
+
if layer:
|
|
83
|
+
matrix = adata_subset.layers[layer]
|
|
84
|
+
else:
|
|
85
|
+
matrix = adata_subset.X
|
|
86
|
+
|
|
87
|
+
if sp.issparse(matrix):
|
|
88
|
+
logger.warning("Converting sparse matrix to dense for PCA.")
|
|
89
|
+
matrix = matrix.toarray()
|
|
90
|
+
|
|
91
|
+
matrix = np.asarray(matrix, dtype=float)
|
|
92
|
+
mean = matrix.mean(axis=0)
|
|
93
|
+
centered = matrix - mean
|
|
94
|
+
|
|
95
|
+
if centered.shape[0] == 0 or centered.shape[1] == 0:
|
|
96
|
+
raise ValueError("PCA requires a non-empty matrix.")
|
|
97
|
+
|
|
98
|
+
if n_pcs <= 0:
|
|
99
|
+
raise ValueError("n_pcs must be positive.")
|
|
100
|
+
|
|
101
|
+
if centered.shape[1] <= n_pcs:
|
|
102
|
+
n_pcs = centered.shape[1]
|
|
103
|
+
|
|
104
|
+
if centered.shape[0] < n_pcs:
|
|
105
|
+
n_pcs = centered.shape[0]
|
|
106
|
+
|
|
107
|
+
u, s, vt = spla.svd(centered, full_matrices=False)
|
|
108
|
+
|
|
109
|
+
u = u[:, :n_pcs]
|
|
110
|
+
s = s[:n_pcs]
|
|
111
|
+
vt = vt[:n_pcs]
|
|
112
|
+
|
|
113
|
+
adata_subset.obsm["X_pca"] = u * s
|
|
114
|
+
adata_subset.varm["PCs"] = vt.T
|
|
115
|
+
|
|
116
|
+
logger.info("Running neighborhood graph with pynndescent (n_neighbors=%s)", knn_neighbors)
|
|
117
|
+
n_neighbors = min(knn_neighbors, max(1, adata_subset.n_obs - 1))
|
|
118
|
+
nn_index = pynndescent.NNDescent(
|
|
119
|
+
adata_subset.obsm["X_pca"],
|
|
120
|
+
n_neighbors=n_neighbors,
|
|
121
|
+
metric="euclidean",
|
|
122
|
+
random_state=random_state,
|
|
123
|
+
n_jobs=threads,
|
|
124
|
+
)
|
|
125
|
+
knn_indices, knn_dists = nn_index.neighbor_graph
|
|
126
|
+
|
|
127
|
+
rows = np.repeat(np.arange(adata_subset.n_obs), n_neighbors)
|
|
128
|
+
cols = knn_indices.reshape(-1)
|
|
129
|
+
distances = sp.coo_matrix(
|
|
130
|
+
(knn_dists.reshape(-1), (rows, cols)),
|
|
131
|
+
shape=(adata_subset.n_obs, adata_subset.n_obs),
|
|
132
|
+
).tocsr()
|
|
133
|
+
adata_subset.obsp["distances"] = distances
|
|
134
|
+
|
|
79
135
|
logger.info("Running UMAP")
|
|
80
|
-
|
|
136
|
+
umap_model = umap.UMAP(
|
|
137
|
+
n_neighbors=n_neighbors,
|
|
138
|
+
n_components=2,
|
|
139
|
+
metric="euclidean",
|
|
140
|
+
random_state=random_state,
|
|
141
|
+
)
|
|
142
|
+
adata_subset.obsm["X_umap"] = umap_model.fit_transform(adata_subset.obsm["X_pca"])
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
from umap.umap_ import fuzzy_simplicial_set
|
|
146
|
+
|
|
147
|
+
fuzzy_result = fuzzy_simplicial_set(
|
|
148
|
+
adata_subset.obsm["X_pca"],
|
|
149
|
+
n_neighbors=n_neighbors,
|
|
150
|
+
random_state=random_state,
|
|
151
|
+
metric="euclidean",
|
|
152
|
+
knn_indices=knn_indices,
|
|
153
|
+
knn_dists=knn_dists,
|
|
154
|
+
)
|
|
155
|
+
connectivities = fuzzy_result[0] if isinstance(fuzzy_result, tuple) else fuzzy_result
|
|
156
|
+
except TypeError:
|
|
157
|
+
connectivities = umap_model.graph_
|
|
158
|
+
|
|
159
|
+
adata_subset.obsp["connectivities"] = connectivities
|
|
81
160
|
|
|
82
161
|
# Step 4: Store results in original adata
|
|
83
162
|
adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
|
|
84
163
|
adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
|
|
85
164
|
adata.obsp["distances"] = adata_subset.obsp["distances"]
|
|
86
165
|
adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
|
|
87
|
-
adata.uns["neighbors"] =
|
|
166
|
+
adata.uns["neighbors"] = {
|
|
167
|
+
"params": {
|
|
168
|
+
"n_neighbors": knn_neighbors,
|
|
169
|
+
"method": "pynndescent",
|
|
170
|
+
"metric": "euclidean",
|
|
171
|
+
}
|
|
172
|
+
}
|
|
88
173
|
|
|
89
174
|
# Fix varm["PCs"] shape mismatch
|
|
90
175
|
pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
|
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
from typing import TYPE_CHECKING, Sequence
|
|
5
5
|
|
|
6
6
|
from smftools.logging_utils import get_logger
|
|
7
|
+
from smftools.optional_imports import require
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
9
10
|
import anndata as ad
|
|
@@ -109,7 +110,12 @@ def cluster_adata_on_methylation(
|
|
|
109
110
|
)
|
|
110
111
|
elif method == "kmeans":
|
|
111
112
|
try:
|
|
112
|
-
|
|
113
|
+
sklearn_cluster = require(
|
|
114
|
+
"sklearn.cluster",
|
|
115
|
+
extra="ml-base",
|
|
116
|
+
purpose="k-means clustering",
|
|
117
|
+
)
|
|
118
|
+
KMeans = sklearn_cluster.KMeans
|
|
113
119
|
|
|
114
120
|
kmeans = KMeans(n_clusters=n_clusters)
|
|
115
121
|
kmeans.fit(site_subset.layers[layer])
|
smftools/tools/position_stats.py
CHANGED
|
@@ -1,41 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import warnings
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from itertools import cycle
|
|
4
7
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
|
5
8
|
|
|
6
|
-
if TYPE_CHECKING:
|
|
7
|
-
import anndata as ad
|
|
8
|
-
|
|
9
|
-
import matplotlib.pyplot as plt
|
|
10
9
|
import numpy as np
|
|
11
10
|
import pandas as pd
|
|
11
|
+
from scipy.stats import chi2_contingency
|
|
12
|
+
from tqdm import tqdm
|
|
12
13
|
|
|
13
|
-
|
|
14
|
-
try:
|
|
15
|
-
from joblib import Parallel, delayed
|
|
16
|
-
|
|
17
|
-
JOBLIB_AVAILABLE = True
|
|
18
|
-
except Exception:
|
|
19
|
-
JOBLIB_AVAILABLE = False
|
|
14
|
+
from smftools.optional_imports import require
|
|
20
15
|
|
|
21
|
-
|
|
22
|
-
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import anndata as ad
|
|
23
18
|
|
|
24
|
-
|
|
25
|
-
except Exception:
|
|
26
|
-
SCIPY_STATS_AVAILABLE = False
|
|
19
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
27
20
|
|
|
28
21
|
# -----------------------------
|
|
29
22
|
# Compute positionwise statistic (multi-method + simple site_types)
|
|
30
23
|
# -----------------------------
|
|
31
|
-
import os
|
|
32
|
-
from contextlib import contextmanager
|
|
33
|
-
from itertools import cycle
|
|
34
|
-
|
|
35
|
-
import joblib
|
|
36
|
-
from joblib import Parallel, cpu_count, delayed
|
|
37
|
-
from scipy.stats import chi2_contingency
|
|
38
|
-
from tqdm import tqdm
|
|
39
24
|
|
|
40
25
|
|
|
41
26
|
# ------------------------- Utilities -------------------------
|
|
@@ -197,6 +182,8 @@ def calculate_relative_risk_on_activity(
|
|
|
197
182
|
@contextmanager
|
|
198
183
|
def tqdm_joblib(tqdm_object: tqdm):
|
|
199
184
|
"""Context manager to patch joblib to update a tqdm progress bar."""
|
|
185
|
+
joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
|
|
186
|
+
|
|
200
187
|
old = joblib.parallel.BatchCompletionCallBack
|
|
201
188
|
|
|
202
189
|
class TqdmBatchCompletionCallback(old): # type: ignore
|
|
@@ -315,6 +302,8 @@ def compute_positionwise_statistics(
|
|
|
315
302
|
max_threads: Maximum number of threads.
|
|
316
303
|
reverse_indices_on_store: Whether to reverse indices on output storage.
|
|
317
304
|
"""
|
|
305
|
+
joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
|
|
306
|
+
|
|
318
307
|
if isinstance(methods, str):
|
|
319
308
|
methods = [methods]
|
|
320
309
|
methods = [m.lower() for m in methods]
|
|
@@ -325,7 +314,7 @@ def compute_positionwise_statistics(
|
|
|
325
314
|
|
|
326
315
|
# workers
|
|
327
316
|
if max_threads is None or max_threads <= 0:
|
|
328
|
-
n_jobs = max(1, cpu_count() or 1)
|
|
317
|
+
n_jobs = max(1, joblib.cpu_count() or 1)
|
|
329
318
|
else:
|
|
330
319
|
n_jobs = max(1, int(max_threads))
|
|
331
320
|
|
|
@@ -439,13 +428,14 @@ def compute_positionwise_statistics(
|
|
|
439
428
|
worker = _relative_risk_row_job
|
|
440
429
|
out = np.full((n_pos, n_pos), np.nan, dtype=float)
|
|
441
430
|
tasks = (
|
|
442
|
-
delayed(worker)(i, X_bin, min_count_for_pairwise)
|
|
431
|
+
joblib.delayed(worker)(i, X_bin, min_count_for_pairwise)
|
|
432
|
+
for i in range(n_pos)
|
|
443
433
|
)
|
|
444
434
|
pbar_rows = tqdm(
|
|
445
435
|
total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False
|
|
446
436
|
)
|
|
447
437
|
with tqdm_joblib(pbar_rows):
|
|
448
|
-
results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
438
|
+
results = joblib.Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
449
439
|
pbar_rows.close()
|
|
450
440
|
for i, row in results:
|
|
451
441
|
out[int(i), :] = row
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import json
|
|
5
|
+
from typing import TYPE_CHECKING, Optional, Sequence, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from smftools.logging_utils import get_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import anndata as ad
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
|
|
20
|
+
Safe w.r.t. contiguity/layout.
|
|
21
|
+
"""
|
|
22
|
+
B = np.asarray(B, dtype=np.uint8)
|
|
23
|
+
packed_u8 = np.packbits(B, axis=1) # (n, ceil(w/8)) uint8
|
|
24
|
+
|
|
25
|
+
n, nb = packed_u8.shape
|
|
26
|
+
pad = (-nb) % 8
|
|
27
|
+
if pad:
|
|
28
|
+
packed_u8 = np.pad(packed_u8, ((0, 0), (0, pad)), mode="constant", constant_values=0)
|
|
29
|
+
|
|
30
|
+
packed_u8 = np.ascontiguousarray(packed_u8)
|
|
31
|
+
|
|
32
|
+
# group 8 bytes -> uint64
|
|
33
|
+
packed_u64 = packed_u8.reshape(n, -1, 8).view(np.uint64).reshape(n, -1)
|
|
34
|
+
return packed_u64
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _popcount_u64_matrix(A_u64: np.ndarray) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Popcount for an array of uint64, vectorized and portable across NumPy versions.
|
|
40
|
+
|
|
41
|
+
Returns an integer array with the SAME SHAPE as A_u64.
|
|
42
|
+
"""
|
|
43
|
+
A_u64 = np.ascontiguousarray(A_u64)
|
|
44
|
+
# View as bytes; IMPORTANT: reshape to add a trailing byte axis of length 8
|
|
45
|
+
b = A_u64.view(np.uint8).reshape(A_u64.shape + (8,))
|
|
46
|
+
# unpack bits within that byte axis -> (..., 64), then sum
|
|
47
|
+
return np.unpackbits(b, axis=-1).sum(axis=-1)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def rolling_window_nn_distance(
|
|
51
|
+
adata,
|
|
52
|
+
layer: Optional[str] = None,
|
|
53
|
+
window: int = 15,
|
|
54
|
+
step: int = 2,
|
|
55
|
+
min_overlap: int = 10,
|
|
56
|
+
return_fraction: bool = True,
|
|
57
|
+
block_rows: int = 256,
|
|
58
|
+
block_cols: int = 2048,
|
|
59
|
+
store_obsm: Optional[str] = "rolling_nn_dist",
|
|
60
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
61
|
+
"""
|
|
62
|
+
Rolling-window nearest-neighbor distance per read, overlap-aware.
|
|
63
|
+
|
|
64
|
+
Distance between reads i,j in a window:
|
|
65
|
+
- use only positions where BOTH are observed (non-NaN)
|
|
66
|
+
- require overlap >= min_overlap
|
|
67
|
+
- mismatch = count(x_i != x_j) over overlapped positions
|
|
68
|
+
- distance = mismatch/overlap (if return_fraction) else mismatch
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
out : (n_obs, n_windows) float
|
|
73
|
+
Nearest-neighbor distance per read per window (NaN if no valid neighbor).
|
|
74
|
+
starts : (n_windows,) int
|
|
75
|
+
Window start indices in var-space.
|
|
76
|
+
"""
|
|
77
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
|
78
|
+
X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
|
|
79
|
+
|
|
80
|
+
n, p = X.shape
|
|
81
|
+
if window > p:
|
|
82
|
+
raise ValueError(f"window={window} is larger than n_vars={p}")
|
|
83
|
+
if window <= 0:
|
|
84
|
+
raise ValueError("window must be > 0")
|
|
85
|
+
if step <= 0:
|
|
86
|
+
raise ValueError("step must be > 0")
|
|
87
|
+
if min_overlap <= 0:
|
|
88
|
+
raise ValueError("min_overlap must be > 0")
|
|
89
|
+
|
|
90
|
+
starts = np.arange(0, p - window + 1, step, dtype=int)
|
|
91
|
+
nW = len(starts)
|
|
92
|
+
out = np.full((n, nW), np.nan, dtype=float)
|
|
93
|
+
|
|
94
|
+
for wi, s in enumerate(starts):
|
|
95
|
+
wX = X[:, s : s + window] # (n, window)
|
|
96
|
+
|
|
97
|
+
# observed mask; values as 0/1 where observed, 0 elsewhere
|
|
98
|
+
M = ~np.isnan(wX)
|
|
99
|
+
V = np.where(M, wX, 0).astype(np.float32)
|
|
100
|
+
|
|
101
|
+
# ensure binary 0/1
|
|
102
|
+
V = (V > 0).astype(np.uint8)
|
|
103
|
+
|
|
104
|
+
M64 = _pack_bool_to_u64(M)
|
|
105
|
+
V64 = _pack_bool_to_u64(V.astype(bool))
|
|
106
|
+
|
|
107
|
+
best = np.full(n, np.inf, dtype=float)
|
|
108
|
+
|
|
109
|
+
for i0 in range(0, n, block_rows):
|
|
110
|
+
i1 = min(n, i0 + block_rows)
|
|
111
|
+
Mi = M64[i0:i1] # (bi, nb)
|
|
112
|
+
Vi = V64[i0:i1]
|
|
113
|
+
bi = i1 - i0
|
|
114
|
+
|
|
115
|
+
local_best = np.full(bi, np.inf, dtype=float)
|
|
116
|
+
|
|
117
|
+
for j0 in range(0, n, block_cols):
|
|
118
|
+
j1 = min(n, j0 + block_cols)
|
|
119
|
+
Mj = M64[j0:j1] # (bj, nb)
|
|
120
|
+
Vj = V64[j0:j1]
|
|
121
|
+
bj = j1 - j0
|
|
122
|
+
|
|
123
|
+
overlap_counts = np.zeros((bi, bj), dtype=np.uint16)
|
|
124
|
+
mismatch_counts = np.zeros((bi, bj), dtype=np.uint16)
|
|
125
|
+
|
|
126
|
+
for k in range(Mi.shape[1]):
|
|
127
|
+
ob = (Mi[:, k][:, None] & Mj[:, k][None, :]).astype(np.uint64)
|
|
128
|
+
overlap_counts += _popcount_u64_matrix(ob).astype(np.uint16)
|
|
129
|
+
|
|
130
|
+
mb = ((Vi[:, k][:, None] ^ Vj[:, k][None, :]) & ob).astype(np.uint64)
|
|
131
|
+
mismatch_counts += _popcount_u64_matrix(mb).astype(np.uint16)
|
|
132
|
+
|
|
133
|
+
ok = overlap_counts >= min_overlap
|
|
134
|
+
if not np.any(ok):
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
dist = np.full((bi, bj), np.inf, dtype=float)
|
|
138
|
+
if return_fraction:
|
|
139
|
+
dist[ok] = mismatch_counts[ok] / overlap_counts[ok]
|
|
140
|
+
else:
|
|
141
|
+
dist[ok] = mismatch_counts[ok].astype(float)
|
|
142
|
+
|
|
143
|
+
# exclude self comparisons (diagonal) when blocks overlap
|
|
144
|
+
if (i0 <= j1) and (j0 <= i1):
|
|
145
|
+
ii = np.arange(i0, i1)
|
|
146
|
+
jj = ii[(ii >= j0) & (ii < j1)]
|
|
147
|
+
if jj.size:
|
|
148
|
+
dist[(jj - i0), (jj - j0)] = np.inf
|
|
149
|
+
|
|
150
|
+
local_best = np.minimum(local_best, dist.min(axis=1))
|
|
151
|
+
|
|
152
|
+
best[i0:i1] = local_best
|
|
153
|
+
|
|
154
|
+
best[~np.isfinite(best)] = np.nan
|
|
155
|
+
out[:, wi] = best
|
|
156
|
+
|
|
157
|
+
if store_obsm is not None:
|
|
158
|
+
adata.obsm[store_obsm] = out
|
|
159
|
+
adata.uns[f"{store_obsm}_starts"] = starts
|
|
160
|
+
adata.uns[f"{store_obsm}_window"] = int(window)
|
|
161
|
+
adata.uns[f"{store_obsm}_step"] = int(step)
|
|
162
|
+
adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
|
|
163
|
+
adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
|
|
164
|
+
adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
|
|
165
|
+
|
|
166
|
+
return out, starts
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def assign_rolling_nn_results(
|
|
170
|
+
parent_adata: "ad.AnnData",
|
|
171
|
+
subset_adata: "ad.AnnData",
|
|
172
|
+
values: np.ndarray,
|
|
173
|
+
starts: np.ndarray,
|
|
174
|
+
obsm_key: str,
|
|
175
|
+
window: int,
|
|
176
|
+
step: int,
|
|
177
|
+
min_overlap: int,
|
|
178
|
+
return_fraction: bool,
|
|
179
|
+
layer: Optional[str],
|
|
180
|
+
) -> None:
|
|
181
|
+
"""
|
|
182
|
+
Assign rolling NN results computed on a subset back onto a parent AnnData.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
parent_adata : AnnData
|
|
187
|
+
Parent AnnData that should store the combined results.
|
|
188
|
+
subset_adata : AnnData
|
|
189
|
+
Subset AnnData used to compute `values`.
|
|
190
|
+
values : np.ndarray
|
|
191
|
+
Rolling NN output with shape (n_subset_obs, n_windows).
|
|
192
|
+
starts : np.ndarray
|
|
193
|
+
Window start indices corresponding to `values`.
|
|
194
|
+
obsm_key : str
|
|
195
|
+
Key to store results under in parent_adata.obsm.
|
|
196
|
+
window : int
|
|
197
|
+
Rolling window size (stored in parent_adata.uns).
|
|
198
|
+
step : int
|
|
199
|
+
Rolling window step size (stored in parent_adata.uns).
|
|
200
|
+
min_overlap : int
|
|
201
|
+
Minimum overlap (stored in parent_adata.uns).
|
|
202
|
+
return_fraction : bool
|
|
203
|
+
Whether distances are fractional (stored in parent_adata.uns).
|
|
204
|
+
layer : str | None
|
|
205
|
+
Layer used for calculations (stored in parent_adata.uns).
|
|
206
|
+
"""
|
|
207
|
+
n_obs = parent_adata.n_obs
|
|
208
|
+
n_windows = values.shape[1]
|
|
209
|
+
|
|
210
|
+
if obsm_key not in parent_adata.obsm:
|
|
211
|
+
parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
|
|
212
|
+
parent_adata.uns[f"{obsm_key}_starts"] = starts
|
|
213
|
+
parent_adata.uns[f"{obsm_key}_window"] = int(window)
|
|
214
|
+
parent_adata.uns[f"{obsm_key}_step"] = int(step)
|
|
215
|
+
parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
|
|
216
|
+
parent_adata.uns[f"{obsm_key}_return_fraction"] = bool(return_fraction)
|
|
217
|
+
parent_adata.uns[f"{obsm_key}_layer"] = layer if layer is not None else "X"
|
|
218
|
+
else:
|
|
219
|
+
existing = parent_adata.obsm[obsm_key]
|
|
220
|
+
if existing.shape[1] != n_windows:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Existing obsm[{obsm_key!r}] has {existing.shape[1]} windows; "
|
|
223
|
+
f"new values have {n_windows} windows."
|
|
224
|
+
)
|
|
225
|
+
existing_starts = parent_adata.uns.get(f"{obsm_key}_starts")
|
|
226
|
+
if existing_starts is not None and not np.array_equal(existing_starts, starts):
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"Existing obsm[{obsm_key!r}] has different window starts than new values."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
|
|
232
|
+
if (parent_indexer < 0).any():
|
|
233
|
+
raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
|
|
234
|
+
|
|
235
|
+
parent_adata.obsm[obsm_key][parent_indexer, :] = values
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Iterable, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
|
|
8
|
+
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import anndata as ad
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def build_sequence_one_hot_and_mask(
|
|
18
|
+
encoded_sequences: np.ndarray,
|
|
19
|
+
*,
|
|
20
|
+
bases: Sequence[str] = ("A", "C", "G", "T"),
|
|
21
|
+
dtype: np.dtype | type[np.floating] = np.float32,
|
|
22
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
23
|
+
"""Build one-hot encoded reads and a seen/unseen mask.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
encoded_sequences: Integer-encoded sequences shaped (n_reads, seq_len).
|
|
27
|
+
bases: Bases to one-hot encode.
|
|
28
|
+
dtype: Output dtype for the one-hot tensor.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tuple of (one_hot_tensor, mask) where:
|
|
32
|
+
- one_hot_tensor: (n_reads, seq_len, n_bases)
|
|
33
|
+
- mask: (n_reads, seq_len) boolean array indicating seen bases.
|
|
34
|
+
"""
|
|
35
|
+
encoded = np.asarray(encoded_sequences)
|
|
36
|
+
if encoded.ndim != 2:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"encoded_sequences must be 2D with shape (n_reads, seq_len); got {encoded.shape}."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
base_values = np.array(
|
|
42
|
+
[MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in bases],
|
|
43
|
+
dtype=encoded.dtype,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if np.issubdtype(encoded.dtype, np.floating):
|
|
47
|
+
encoded = encoded.copy()
|
|
48
|
+
encoded[np.isnan(encoded)] = -1
|
|
49
|
+
|
|
50
|
+
mask = np.isin(encoded, base_values)
|
|
51
|
+
one_hot = np.zeros((*encoded.shape, len(base_values)), dtype=dtype)
|
|
52
|
+
|
|
53
|
+
for idx, base_value in enumerate(base_values):
|
|
54
|
+
one_hot[..., idx] = encoded == base_value
|
|
55
|
+
|
|
56
|
+
return one_hot, mask
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def calculate_sequence_cp_decomposition(
|
|
60
|
+
adata: "ad.AnnData",
|
|
61
|
+
*,
|
|
62
|
+
layer: str,
|
|
63
|
+
rank: int = 5,
|
|
64
|
+
n_iter_max: int = 100,
|
|
65
|
+
random_state: int = 0,
|
|
66
|
+
overwrite: bool = True,
|
|
67
|
+
embedding_key: str = "X_cp_sequence",
|
|
68
|
+
components_key: str = "H_cp_sequence",
|
|
69
|
+
uns_key: str = "cp_sequence",
|
|
70
|
+
bases: Iterable[str] = ("A", "C", "G", "T"),
|
|
71
|
+
backend: str = "pytorch",
|
|
72
|
+
show_progress: bool = False,
|
|
73
|
+
init: str = "random",
|
|
74
|
+
) -> "ad.AnnData":
|
|
75
|
+
"""Compute CP decomposition on one-hot encoded sequence data with masking.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
adata: AnnData object to update.
|
|
79
|
+
layer: Layer name containing integer-encoded sequences.
|
|
80
|
+
rank: CP rank.
|
|
81
|
+
n_iter_max: Maximum number of iterations for the solver.
|
|
82
|
+
random_state: Random seed for initialization.
|
|
83
|
+
overwrite: Whether to recompute if the embedding already exists.
|
|
84
|
+
embedding_key: Key for embedding in ``adata.obsm``.
|
|
85
|
+
components_key: Key for position factors in ``adata.varm``.
|
|
86
|
+
uns_key: Key for metadata stored in ``adata.uns``.
|
|
87
|
+
bases: Bases to one-hot encode (in order).
|
|
88
|
+
backend: Tensorly backend to use (``numpy`` or ``pytorch``).
|
|
89
|
+
show_progress: Whether to display progress during factorization if supported.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Updated AnnData object containing the CP decomposition outputs.
|
|
93
|
+
"""
|
|
94
|
+
if embedding_key in adata.obsm and components_key in adata.varm and not overwrite:
|
|
95
|
+
logger.info("CP embedding and components already present; skipping recomputation.")
|
|
96
|
+
return adata
|
|
97
|
+
|
|
98
|
+
if backend not in {"numpy", "pytorch"}:
|
|
99
|
+
raise ValueError(f"Unsupported backend '{backend}'. Use 'numpy' or 'pytorch'.")
|
|
100
|
+
|
|
101
|
+
tensorly = require("tensorly", extra="ml-base", purpose="CP decomposition")
|
|
102
|
+
from tensorly.decomposition import parafac
|
|
103
|
+
|
|
104
|
+
tensorly.set_backend(backend)
|
|
105
|
+
|
|
106
|
+
if layer not in adata.layers:
|
|
107
|
+
raise KeyError(f"Layer '{layer}' not found in adata.layers.")
|
|
108
|
+
|
|
109
|
+
one_hot, mask = build_sequence_one_hot_and_mask(adata.layers[layer], bases=tuple(bases))
|
|
110
|
+
mask_tensor = np.repeat(mask[:, :, None], one_hot.shape[2], axis=2)
|
|
111
|
+
|
|
112
|
+
device = "numpy"
|
|
113
|
+
if backend == "pytorch":
|
|
114
|
+
torch = require("torch", extra="ml-base", purpose="CP decomposition backend")
|
|
115
|
+
if torch.cuda.is_available():
|
|
116
|
+
device = torch.device("cuda")
|
|
117
|
+
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
|
118
|
+
device = torch.device("mps")
|
|
119
|
+
else:
|
|
120
|
+
device = torch.device("cpu")
|
|
121
|
+
|
|
122
|
+
one_hot = torch.tensor(one_hot, dtype=torch.float32, device=device)
|
|
123
|
+
mask_tensor = torch.tensor(mask_tensor, dtype=torch.float32, device=device)
|
|
124
|
+
|
|
125
|
+
parafac_kwargs = {
|
|
126
|
+
"rank": rank,
|
|
127
|
+
"n_iter_max": n_iter_max,
|
|
128
|
+
"init": init,
|
|
129
|
+
"mask": mask_tensor,
|
|
130
|
+
"random_state": random_state,
|
|
131
|
+
}
|
|
132
|
+
import inspect
|
|
133
|
+
|
|
134
|
+
if "verbose" in inspect.signature(parafac).parameters:
|
|
135
|
+
parafac_kwargs["verbose"] = show_progress
|
|
136
|
+
|
|
137
|
+
cp = parafac(one_hot, **parafac_kwargs)
|
|
138
|
+
|
|
139
|
+
if backend == "pytorch":
|
|
140
|
+
weights = cp.weights.detach().cpu().numpy()
|
|
141
|
+
read_factors, position_factors, base_factors = [
|
|
142
|
+
factor.detach().cpu().numpy() for factor in cp.factors
|
|
143
|
+
]
|
|
144
|
+
else:
|
|
145
|
+
weights = np.asarray(cp.weights)
|
|
146
|
+
read_factors, position_factors, base_factors = [np.asarray(f) for f in cp.factors]
|
|
147
|
+
|
|
148
|
+
adata.obsm[embedding_key] = read_factors
|
|
149
|
+
adata.varm[components_key] = position_factors
|
|
150
|
+
adata.uns[uns_key] = {
|
|
151
|
+
"rank": rank,
|
|
152
|
+
"n_iter_max": n_iter_max,
|
|
153
|
+
"random_state": random_state,
|
|
154
|
+
"layer": layer,
|
|
155
|
+
"components_key": components_key,
|
|
156
|
+
"weights": weights,
|
|
157
|
+
"base_factors": base_factors,
|
|
158
|
+
"base_labels": list(bases),
|
|
159
|
+
"backend": backend,
|
|
160
|
+
"device": str(device),
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
logger.info(
|
|
164
|
+
"Stored: adata.obsm['%s'], adata.varm['%s'], adata.uns['%s']",
|
|
165
|
+
embedding_key,
|
|
166
|
+
components_key,
|
|
167
|
+
uns_key,
|
|
168
|
+
)
|
|
169
|
+
return adata
|