smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +34 -6
  7. smftools/cli/hmm_adata.py +239 -33
  8. smftools/cli/latent_adata.py +318 -0
  9. smftools/cli/load_adata.py +167 -131
  10. smftools/cli/preprocess_adata.py +180 -53
  11. smftools/cli/spatial_adata.py +152 -100
  12. smftools/cli_entry.py +38 -1
  13. smftools/config/__init__.py +2 -0
  14. smftools/config/conversion.yaml +11 -1
  15. smftools/config/default.yaml +42 -2
  16. smftools/config/experiment_config.py +59 -1
  17. smftools/constants.py +65 -0
  18. smftools/datasets/__init__.py +2 -0
  19. smftools/hmm/HMM.py +97 -3
  20. smftools/hmm/__init__.py +24 -13
  21. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  22. smftools/hmm/archived/calculate_distances.py +2 -0
  23. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  24. smftools/hmm/archived/train_hmm.py +2 -0
  25. smftools/hmm/call_hmm_peaks.py +5 -2
  26. smftools/hmm/display_hmm.py +4 -1
  27. smftools/hmm/hmm_readwrite.py +7 -2
  28. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  29. smftools/informatics/__init__.py +59 -34
  30. smftools/informatics/archived/bam_conversion.py +2 -0
  31. smftools/informatics/archived/bam_direct.py +2 -0
  32. smftools/informatics/archived/basecall_pod5s.py +2 -0
  33. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  34. smftools/informatics/archived/conversion_smf.py +2 -0
  35. smftools/informatics/archived/deaminase_smf.py +1 -0
  36. smftools/informatics/archived/direct_smf.py +2 -0
  37. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  38. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  39. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  40. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  41. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  42. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  43. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  44. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  45. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  48. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  49. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  50. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  52. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  53. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  54. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  55. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  56. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  57. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  58. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  59. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  60. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  61. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  62. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  63. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  64. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  65. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  66. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  67. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  68. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  69. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  70. smftools/informatics/archived/subsample_pod5.py +2 -0
  71. smftools/informatics/bam_functions.py +1093 -176
  72. smftools/informatics/basecalling.py +2 -0
  73. smftools/informatics/bed_functions.py +271 -61
  74. smftools/informatics/binarize_converted_base_identities.py +3 -0
  75. smftools/informatics/complement_base_list.py +2 -0
  76. smftools/informatics/converted_BAM_to_adata.py +641 -176
  77. smftools/informatics/fasta_functions.py +94 -10
  78. smftools/informatics/h5ad_functions.py +123 -4
  79. smftools/informatics/modkit_extract_to_adata.py +1019 -431
  80. smftools/informatics/modkit_functions.py +2 -0
  81. smftools/informatics/ohe.py +2 -0
  82. smftools/informatics/pod5_functions.py +3 -2
  83. smftools/informatics/sequence_encoding.py +72 -0
  84. smftools/logging_utils.py +21 -2
  85. smftools/machine_learning/__init__.py +22 -6
  86. smftools/machine_learning/data/__init__.py +2 -0
  87. smftools/machine_learning/data/anndata_data_module.py +18 -4
  88. smftools/machine_learning/data/preprocessing.py +2 -0
  89. smftools/machine_learning/evaluation/__init__.py +2 -0
  90. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  91. smftools/machine_learning/evaluation/evaluators.py +14 -9
  92. smftools/machine_learning/inference/__init__.py +2 -0
  93. smftools/machine_learning/inference/inference_utils.py +2 -0
  94. smftools/machine_learning/inference/lightning_inference.py +6 -1
  95. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  96. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  97. smftools/machine_learning/models/__init__.py +2 -0
  98. smftools/machine_learning/models/base.py +7 -2
  99. smftools/machine_learning/models/cnn.py +7 -2
  100. smftools/machine_learning/models/lightning_base.py +16 -11
  101. smftools/machine_learning/models/mlp.py +5 -1
  102. smftools/machine_learning/models/positional.py +7 -2
  103. smftools/machine_learning/models/rnn.py +5 -1
  104. smftools/machine_learning/models/sklearn_models.py +14 -9
  105. smftools/machine_learning/models/transformer.py +7 -2
  106. smftools/machine_learning/models/wrappers.py +6 -2
  107. smftools/machine_learning/training/__init__.py +2 -0
  108. smftools/machine_learning/training/train_lightning_model.py +13 -3
  109. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  110. smftools/machine_learning/utils/__init__.py +2 -0
  111. smftools/machine_learning/utils/device.py +5 -1
  112. smftools/machine_learning/utils/grl.py +5 -1
  113. smftools/metadata.py +1 -1
  114. smftools/optional_imports.py +31 -0
  115. smftools/plotting/__init__.py +41 -31
  116. smftools/plotting/autocorrelation_plotting.py +9 -5
  117. smftools/plotting/classifiers.py +16 -4
  118. smftools/plotting/general_plotting.py +2415 -629
  119. smftools/plotting/hmm_plotting.py +97 -9
  120. smftools/plotting/position_stats.py +15 -7
  121. smftools/plotting/qc_plotting.py +6 -1
  122. smftools/preprocessing/__init__.py +36 -37
  123. smftools/preprocessing/append_base_context.py +17 -17
  124. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  125. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  126. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  127. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  128. smftools/preprocessing/archived/preprocessing.py +2 -0
  129. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  130. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  131. smftools/preprocessing/calculate_complexity_II.py +4 -1
  132. smftools/preprocessing/calculate_consensus.py +1 -1
  133. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  134. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  135. smftools/preprocessing/calculate_position_Youden.py +9 -2
  136. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  137. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  138. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  139. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  140. smftools/preprocessing/make_dirs.py +2 -1
  141. smftools/preprocessing/min_non_diagonal.py +2 -0
  142. smftools/preprocessing/recipes.py +2 -0
  143. smftools/readwrite.py +53 -17
  144. smftools/schema/anndata_schema_v1.yaml +15 -1
  145. smftools/tools/__init__.py +30 -18
  146. smftools/tools/archived/apply_hmm.py +2 -0
  147. smftools/tools/archived/classifiers.py +2 -0
  148. smftools/tools/archived/classify_methylated_features.py +2 -0
  149. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  150. smftools/tools/archived/subset_adata_v1.py +2 -0
  151. smftools/tools/archived/subset_adata_v2.py +2 -0
  152. smftools/tools/calculate_leiden.py +57 -0
  153. smftools/tools/calculate_nmf.py +119 -0
  154. smftools/tools/calculate_umap.py +93 -8
  155. smftools/tools/cluster_adata_on_methylation.py +7 -1
  156. smftools/tools/position_stats.py +17 -27
  157. smftools/tools/rolling_nn_distance.py +235 -0
  158. smftools/tools/tensor_factorization.py +169 -0
  159. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
  160. smftools-0.3.1.dist-info/RECORD +189 -0
  161. smftools-0.2.5.dist-info/RECORD +0 -181
  162. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  163. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  164. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Sequence
4
4
 
5
5
  from smftools.logging_utils import get_logger
6
+ from smftools.optional_imports import require
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  import anndata as ad
@@ -18,6 +19,7 @@ def calculate_umap(
18
19
  knn_neighbors: int = 100,
19
20
  overwrite: bool = True,
20
21
  threads: int = 8,
22
+ random_state: int | None = 0,
21
23
  ) -> "ad.AnnData":
22
24
  """Compute PCA, neighbors, and UMAP embeddings.
23
25
 
@@ -36,8 +38,11 @@ def calculate_umap(
36
38
  import os
37
39
 
38
40
  import numpy as np
39
- import scanpy as sc
40
- from scipy.sparse import issparse
41
+ import scipy.linalg as spla
42
+ import scipy.sparse as sp
43
+
44
+ umap = require("umap", extra="umap", purpose="UMAP calculation")
45
+ pynndescent = require("pynndescent", extra="umap", purpose="KNN graph computation")
41
46
 
42
47
  os.environ["OMP_NUM_THREADS"] = str(threads)
43
48
 
@@ -57,7 +62,7 @@ def calculate_umap(
57
62
  # Step 2: NaN handling inside layer
58
63
  if layer:
59
64
  data = adata_subset.layers[layer]
60
- if not issparse(data):
65
+ if not sp.issparse(data):
61
66
  if np.isnan(data).any():
62
67
  logger.warning("NaNs detected, filling with 0.5 before PCA + neighbors.")
63
68
  data = np.nan_to_num(data, nan=0.5)
@@ -73,18 +78,98 @@ def calculate_umap(
73
78
  if "X_umap" not in adata_subset.obsm or overwrite:
74
79
  n_pcs = min(adata_subset.shape[1], n_pcs)
75
80
  logger.info("Running PCA with n_pcs=%s", n_pcs)
76
- sc.pp.pca(adata_subset, layer=layer)
77
- logger.info("Running neighborhood graph")
78
- sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
81
+
82
+ if layer:
83
+ matrix = adata_subset.layers[layer]
84
+ else:
85
+ matrix = adata_subset.X
86
+
87
+ if sp.issparse(matrix):
88
+ logger.warning("Converting sparse matrix to dense for PCA.")
89
+ matrix = matrix.toarray()
90
+
91
+ matrix = np.asarray(matrix, dtype=float)
92
+ mean = matrix.mean(axis=0)
93
+ centered = matrix - mean
94
+
95
+ if centered.shape[0] == 0 or centered.shape[1] == 0:
96
+ raise ValueError("PCA requires a non-empty matrix.")
97
+
98
+ if n_pcs <= 0:
99
+ raise ValueError("n_pcs must be positive.")
100
+
101
+ if centered.shape[1] <= n_pcs:
102
+ n_pcs = centered.shape[1]
103
+
104
+ if centered.shape[0] < n_pcs:
105
+ n_pcs = centered.shape[0]
106
+
107
+ u, s, vt = spla.svd(centered, full_matrices=False)
108
+
109
+ u = u[:, :n_pcs]
110
+ s = s[:n_pcs]
111
+ vt = vt[:n_pcs]
112
+
113
+ adata_subset.obsm["X_pca"] = u * s
114
+ adata_subset.varm["PCs"] = vt.T
115
+
116
+ logger.info("Running neighborhood graph with pynndescent (n_neighbors=%s)", knn_neighbors)
117
+ n_neighbors = min(knn_neighbors, max(1, adata_subset.n_obs - 1))
118
+ nn_index = pynndescent.NNDescent(
119
+ adata_subset.obsm["X_pca"],
120
+ n_neighbors=n_neighbors,
121
+ metric="euclidean",
122
+ random_state=random_state,
123
+ n_jobs=threads,
124
+ )
125
+ knn_indices, knn_dists = nn_index.neighbor_graph
126
+
127
+ rows = np.repeat(np.arange(adata_subset.n_obs), n_neighbors)
128
+ cols = knn_indices.reshape(-1)
129
+ distances = sp.coo_matrix(
130
+ (knn_dists.reshape(-1), (rows, cols)),
131
+ shape=(adata_subset.n_obs, adata_subset.n_obs),
132
+ ).tocsr()
133
+ adata_subset.obsp["distances"] = distances
134
+
79
135
  logger.info("Running UMAP")
80
- sc.tl.umap(adata_subset)
136
+ umap_model = umap.UMAP(
137
+ n_neighbors=n_neighbors,
138
+ n_components=2,
139
+ metric="euclidean",
140
+ random_state=random_state,
141
+ )
142
+ adata_subset.obsm["X_umap"] = umap_model.fit_transform(adata_subset.obsm["X_pca"])
143
+
144
+ try:
145
+ from umap.umap_ import fuzzy_simplicial_set
146
+
147
+ fuzzy_result = fuzzy_simplicial_set(
148
+ adata_subset.obsm["X_pca"],
149
+ n_neighbors=n_neighbors,
150
+ random_state=random_state,
151
+ metric="euclidean",
152
+ knn_indices=knn_indices,
153
+ knn_dists=knn_dists,
154
+ )
155
+ connectivities = fuzzy_result[0] if isinstance(fuzzy_result, tuple) else fuzzy_result
156
+ except TypeError:
157
+ connectivities = umap_model.graph_
158
+
159
+ adata_subset.obsp["connectivities"] = connectivities
81
160
 
82
161
  # Step 4: Store results in original adata
83
162
  adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
84
163
  adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
85
164
  adata.obsp["distances"] = adata_subset.obsp["distances"]
86
165
  adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
87
- adata.uns["neighbors"] = adata_subset.uns["neighbors"]
166
+ adata.uns["neighbors"] = {
167
+ "params": {
168
+ "n_neighbors": knn_neighbors,
169
+ "method": "pynndescent",
170
+ "metric": "euclidean",
171
+ }
172
+ }
88
173
 
89
174
  # Fix varm["PCs"] shape mismatch
90
175
  pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
  from typing import TYPE_CHECKING, Sequence
5
5
 
6
6
  from smftools.logging_utils import get_logger
7
+ from smftools.optional_imports import require
7
8
 
8
9
  if TYPE_CHECKING:
9
10
  import anndata as ad
@@ -109,7 +110,12 @@ def cluster_adata_on_methylation(
109
110
  )
110
111
  elif method == "kmeans":
111
112
  try:
112
- from sklearn.cluster import KMeans
113
+ sklearn_cluster = require(
114
+ "sklearn.cluster",
115
+ extra="ml-base",
116
+ purpose="k-means clustering",
117
+ )
118
+ KMeans = sklearn_cluster.KMeans
113
119
 
114
120
  kmeans = KMeans(n_clusters=n_clusters)
115
121
  kmeans.fit(site_subset.layers[layer])
@@ -1,41 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  import warnings
5
+ from contextlib import contextmanager
6
+ from itertools import cycle
4
7
  from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
5
8
 
6
- if TYPE_CHECKING:
7
- import anndata as ad
8
-
9
- import matplotlib.pyplot as plt
10
9
  import numpy as np
11
10
  import pandas as pd
11
+ from scipy.stats import chi2_contingency
12
+ from tqdm import tqdm
12
13
 
13
- # optional imports
14
- try:
15
- from joblib import Parallel, delayed
16
-
17
- JOBLIB_AVAILABLE = True
18
- except Exception:
19
- JOBLIB_AVAILABLE = False
14
+ from smftools.optional_imports import require
20
15
 
21
- try:
22
- from scipy.stats import chi2_contingency
16
+ if TYPE_CHECKING:
17
+ import anndata as ad
23
18
 
24
- SCIPY_STATS_AVAILABLE = True
25
- except Exception:
26
- SCIPY_STATS_AVAILABLE = False
19
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
27
20
 
28
21
  # -----------------------------
29
22
  # Compute positionwise statistic (multi-method + simple site_types)
30
23
  # -----------------------------
31
- import os
32
- from contextlib import contextmanager
33
- from itertools import cycle
34
-
35
- import joblib
36
- from joblib import Parallel, cpu_count, delayed
37
- from scipy.stats import chi2_contingency
38
- from tqdm import tqdm
39
24
 
40
25
 
41
26
  # ------------------------- Utilities -------------------------
@@ -197,6 +182,8 @@ def calculate_relative_risk_on_activity(
197
182
  @contextmanager
198
183
  def tqdm_joblib(tqdm_object: tqdm):
199
184
  """Context manager to patch joblib to update a tqdm progress bar."""
185
+ joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
186
+
200
187
  old = joblib.parallel.BatchCompletionCallBack
201
188
 
202
189
  class TqdmBatchCompletionCallback(old): # type: ignore
@@ -315,6 +302,8 @@ def compute_positionwise_statistics(
315
302
  max_threads: Maximum number of threads.
316
303
  reverse_indices_on_store: Whether to reverse indices on output storage.
317
304
  """
305
+ joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
306
+
318
307
  if isinstance(methods, str):
319
308
  methods = [methods]
320
309
  methods = [m.lower() for m in methods]
@@ -325,7 +314,7 @@ def compute_positionwise_statistics(
325
314
 
326
315
  # workers
327
316
  if max_threads is None or max_threads <= 0:
328
- n_jobs = max(1, cpu_count() or 1)
317
+ n_jobs = max(1, joblib.cpu_count() or 1)
329
318
  else:
330
319
  n_jobs = max(1, int(max_threads))
331
320
 
@@ -439,13 +428,14 @@ def compute_positionwise_statistics(
439
428
  worker = _relative_risk_row_job
440
429
  out = np.full((n_pos, n_pos), np.nan, dtype=float)
441
430
  tasks = (
442
- delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos)
431
+ joblib.delayed(worker)(i, X_bin, min_count_for_pairwise)
432
+ for i in range(n_pos)
443
433
  )
444
434
  pbar_rows = tqdm(
445
435
  total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False
446
436
  )
447
437
  with tqdm_joblib(pbar_rows):
448
- results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
438
+ results = joblib.Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
449
439
  pbar_rows.close()
450
440
  for i, row in results:
451
441
  out[int(i), :] = row
@@ -0,0 +1,235 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import json
5
+ from typing import TYPE_CHECKING, Optional, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from smftools.logging_utils import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
18
+ """
19
+ Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
20
+ Safe w.r.t. contiguity/layout.
21
+ """
22
+ B = np.asarray(B, dtype=np.uint8)
23
+ packed_u8 = np.packbits(B, axis=1) # (n, ceil(w/8)) uint8
24
+
25
+ n, nb = packed_u8.shape
26
+ pad = (-nb) % 8
27
+ if pad:
28
+ packed_u8 = np.pad(packed_u8, ((0, 0), (0, pad)), mode="constant", constant_values=0)
29
+
30
+ packed_u8 = np.ascontiguousarray(packed_u8)
31
+
32
+ # group 8 bytes -> uint64
33
+ packed_u64 = packed_u8.reshape(n, -1, 8).view(np.uint64).reshape(n, -1)
34
+ return packed_u64
35
+
36
+
37
+ def _popcount_u64_matrix(A_u64: np.ndarray) -> np.ndarray:
38
+ """
39
+ Popcount for an array of uint64, vectorized and portable across NumPy versions.
40
+
41
+ Returns an integer array with the SAME SHAPE as A_u64.
42
+ """
43
+ A_u64 = np.ascontiguousarray(A_u64)
44
+ # View as bytes; IMPORTANT: reshape to add a trailing byte axis of length 8
45
+ b = A_u64.view(np.uint8).reshape(A_u64.shape + (8,))
46
+ # unpack bits within that byte axis -> (..., 64), then sum
47
+ return np.unpackbits(b, axis=-1).sum(axis=-1)
48
+
49
+
50
+ def rolling_window_nn_distance(
51
+ adata,
52
+ layer: Optional[str] = None,
53
+ window: int = 15,
54
+ step: int = 2,
55
+ min_overlap: int = 10,
56
+ return_fraction: bool = True,
57
+ block_rows: int = 256,
58
+ block_cols: int = 2048,
59
+ store_obsm: Optional[str] = "rolling_nn_dist",
60
+ ) -> Tuple[np.ndarray, np.ndarray]:
61
+ """
62
+ Rolling-window nearest-neighbor distance per read, overlap-aware.
63
+
64
+ Distance between reads i,j in a window:
65
+ - use only positions where BOTH are observed (non-NaN)
66
+ - require overlap >= min_overlap
67
+ - mismatch = count(x_i != x_j) over overlapped positions
68
+ - distance = mismatch/overlap (if return_fraction) else mismatch
69
+
70
+ Returns
71
+ -------
72
+ out : (n_obs, n_windows) float
73
+ Nearest-neighbor distance per read per window (NaN if no valid neighbor).
74
+ starts : (n_windows,) int
75
+ Window start indices in var-space.
76
+ """
77
+ X = adata.layers[layer] if layer is not None else adata.X
78
+ X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
79
+
80
+ n, p = X.shape
81
+ if window > p:
82
+ raise ValueError(f"window={window} is larger than n_vars={p}")
83
+ if window <= 0:
84
+ raise ValueError("window must be > 0")
85
+ if step <= 0:
86
+ raise ValueError("step must be > 0")
87
+ if min_overlap <= 0:
88
+ raise ValueError("min_overlap must be > 0")
89
+
90
+ starts = np.arange(0, p - window + 1, step, dtype=int)
91
+ nW = len(starts)
92
+ out = np.full((n, nW), np.nan, dtype=float)
93
+
94
+ for wi, s in enumerate(starts):
95
+ wX = X[:, s : s + window] # (n, window)
96
+
97
+ # observed mask; values as 0/1 where observed, 0 elsewhere
98
+ M = ~np.isnan(wX)
99
+ V = np.where(M, wX, 0).astype(np.float32)
100
+
101
+ # ensure binary 0/1
102
+ V = (V > 0).astype(np.uint8)
103
+
104
+ M64 = _pack_bool_to_u64(M)
105
+ V64 = _pack_bool_to_u64(V.astype(bool))
106
+
107
+ best = np.full(n, np.inf, dtype=float)
108
+
109
+ for i0 in range(0, n, block_rows):
110
+ i1 = min(n, i0 + block_rows)
111
+ Mi = M64[i0:i1] # (bi, nb)
112
+ Vi = V64[i0:i1]
113
+ bi = i1 - i0
114
+
115
+ local_best = np.full(bi, np.inf, dtype=float)
116
+
117
+ for j0 in range(0, n, block_cols):
118
+ j1 = min(n, j0 + block_cols)
119
+ Mj = M64[j0:j1] # (bj, nb)
120
+ Vj = V64[j0:j1]
121
+ bj = j1 - j0
122
+
123
+ overlap_counts = np.zeros((bi, bj), dtype=np.uint16)
124
+ mismatch_counts = np.zeros((bi, bj), dtype=np.uint16)
125
+
126
+ for k in range(Mi.shape[1]):
127
+ ob = (Mi[:, k][:, None] & Mj[:, k][None, :]).astype(np.uint64)
128
+ overlap_counts += _popcount_u64_matrix(ob).astype(np.uint16)
129
+
130
+ mb = ((Vi[:, k][:, None] ^ Vj[:, k][None, :]) & ob).astype(np.uint64)
131
+ mismatch_counts += _popcount_u64_matrix(mb).astype(np.uint16)
132
+
133
+ ok = overlap_counts >= min_overlap
134
+ if not np.any(ok):
135
+ continue
136
+
137
+ dist = np.full((bi, bj), np.inf, dtype=float)
138
+ if return_fraction:
139
+ dist[ok] = mismatch_counts[ok] / overlap_counts[ok]
140
+ else:
141
+ dist[ok] = mismatch_counts[ok].astype(float)
142
+
143
+ # exclude self comparisons (diagonal) when blocks overlap
144
+ if (i0 <= j1) and (j0 <= i1):
145
+ ii = np.arange(i0, i1)
146
+ jj = ii[(ii >= j0) & (ii < j1)]
147
+ if jj.size:
148
+ dist[(jj - i0), (jj - j0)] = np.inf
149
+
150
+ local_best = np.minimum(local_best, dist.min(axis=1))
151
+
152
+ best[i0:i1] = local_best
153
+
154
+ best[~np.isfinite(best)] = np.nan
155
+ out[:, wi] = best
156
+
157
+ if store_obsm is not None:
158
+ adata.obsm[store_obsm] = out
159
+ adata.uns[f"{store_obsm}_starts"] = starts
160
+ adata.uns[f"{store_obsm}_window"] = int(window)
161
+ adata.uns[f"{store_obsm}_step"] = int(step)
162
+ adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
163
+ adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
164
+ adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
165
+
166
+ return out, starts
167
+
168
+
169
+ def assign_rolling_nn_results(
170
+ parent_adata: "ad.AnnData",
171
+ subset_adata: "ad.AnnData",
172
+ values: np.ndarray,
173
+ starts: np.ndarray,
174
+ obsm_key: str,
175
+ window: int,
176
+ step: int,
177
+ min_overlap: int,
178
+ return_fraction: bool,
179
+ layer: Optional[str],
180
+ ) -> None:
181
+ """
182
+ Assign rolling NN results computed on a subset back onto a parent AnnData.
183
+
184
+ Parameters
185
+ ----------
186
+ parent_adata : AnnData
187
+ Parent AnnData that should store the combined results.
188
+ subset_adata : AnnData
189
+ Subset AnnData used to compute `values`.
190
+ values : np.ndarray
191
+ Rolling NN output with shape (n_subset_obs, n_windows).
192
+ starts : np.ndarray
193
+ Window start indices corresponding to `values`.
194
+ obsm_key : str
195
+ Key to store results under in parent_adata.obsm.
196
+ window : int
197
+ Rolling window size (stored in parent_adata.uns).
198
+ step : int
199
+ Rolling window step size (stored in parent_adata.uns).
200
+ min_overlap : int
201
+ Minimum overlap (stored in parent_adata.uns).
202
+ return_fraction : bool
203
+ Whether distances are fractional (stored in parent_adata.uns).
204
+ layer : str | None
205
+ Layer used for calculations (stored in parent_adata.uns).
206
+ """
207
+ n_obs = parent_adata.n_obs
208
+ n_windows = values.shape[1]
209
+
210
+ if obsm_key not in parent_adata.obsm:
211
+ parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
212
+ parent_adata.uns[f"{obsm_key}_starts"] = starts
213
+ parent_adata.uns[f"{obsm_key}_window"] = int(window)
214
+ parent_adata.uns[f"{obsm_key}_step"] = int(step)
215
+ parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
216
+ parent_adata.uns[f"{obsm_key}_return_fraction"] = bool(return_fraction)
217
+ parent_adata.uns[f"{obsm_key}_layer"] = layer if layer is not None else "X"
218
+ else:
219
+ existing = parent_adata.obsm[obsm_key]
220
+ if existing.shape[1] != n_windows:
221
+ raise ValueError(
222
+ f"Existing obsm[{obsm_key!r}] has {existing.shape[1]} windows; "
223
+ f"new values have {n_windows} windows."
224
+ )
225
+ existing_starts = parent_adata.uns.get(f"{obsm_key}_starts")
226
+ if existing_starts is not None and not np.array_equal(existing_starts, starts):
227
+ raise ValueError(
228
+ f"Existing obsm[{obsm_key!r}] has different window starts than new values."
229
+ )
230
+
231
+ parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
232
+ if (parent_indexer < 0).any():
233
+ raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
234
+
235
+ parent_adata.obsm[obsm_key][parent_indexer, :] = values
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Iterable, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def build_sequence_one_hot_and_mask(
18
+ encoded_sequences: np.ndarray,
19
+ *,
20
+ bases: Sequence[str] = ("A", "C", "G", "T"),
21
+ dtype: np.dtype | type[np.floating] = np.float32,
22
+ ) -> tuple[np.ndarray, np.ndarray]:
23
+ """Build one-hot encoded reads and a seen/unseen mask.
24
+
25
+ Args:
26
+ encoded_sequences: Integer-encoded sequences shaped (n_reads, seq_len).
27
+ bases: Bases to one-hot encode.
28
+ dtype: Output dtype for the one-hot tensor.
29
+
30
+ Returns:
31
+ Tuple of (one_hot_tensor, mask) where:
32
+ - one_hot_tensor: (n_reads, seq_len, n_bases)
33
+ - mask: (n_reads, seq_len) boolean array indicating seen bases.
34
+ """
35
+ encoded = np.asarray(encoded_sequences)
36
+ if encoded.ndim != 2:
37
+ raise ValueError(
38
+ f"encoded_sequences must be 2D with shape (n_reads, seq_len); got {encoded.shape}."
39
+ )
40
+
41
+ base_values = np.array(
42
+ [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in bases],
43
+ dtype=encoded.dtype,
44
+ )
45
+
46
+ if np.issubdtype(encoded.dtype, np.floating):
47
+ encoded = encoded.copy()
48
+ encoded[np.isnan(encoded)] = -1
49
+
50
+ mask = np.isin(encoded, base_values)
51
+ one_hot = np.zeros((*encoded.shape, len(base_values)), dtype=dtype)
52
+
53
+ for idx, base_value in enumerate(base_values):
54
+ one_hot[..., idx] = encoded == base_value
55
+
56
+ return one_hot, mask
57
+
58
+
59
+ def calculate_sequence_cp_decomposition(
60
+ adata: "ad.AnnData",
61
+ *,
62
+ layer: str,
63
+ rank: int = 5,
64
+ n_iter_max: int = 100,
65
+ random_state: int = 0,
66
+ overwrite: bool = True,
67
+ embedding_key: str = "X_cp_sequence",
68
+ components_key: str = "H_cp_sequence",
69
+ uns_key: str = "cp_sequence",
70
+ bases: Iterable[str] = ("A", "C", "G", "T"),
71
+ backend: str = "pytorch",
72
+ show_progress: bool = False,
73
+ init: str = "random",
74
+ ) -> "ad.AnnData":
75
+ """Compute CP decomposition on one-hot encoded sequence data with masking.
76
+
77
+ Args:
78
+ adata: AnnData object to update.
79
+ layer: Layer name containing integer-encoded sequences.
80
+ rank: CP rank.
81
+ n_iter_max: Maximum number of iterations for the solver.
82
+ random_state: Random seed for initialization.
83
+ overwrite: Whether to recompute if the embedding already exists.
84
+ embedding_key: Key for embedding in ``adata.obsm``.
85
+ components_key: Key for position factors in ``adata.varm``.
86
+ uns_key: Key for metadata stored in ``adata.uns``.
87
+ bases: Bases to one-hot encode (in order).
88
+ backend: Tensorly backend to use (``numpy`` or ``pytorch``).
89
+ show_progress: Whether to display progress during factorization if supported.
90
+
91
+ Returns:
92
+ Updated AnnData object containing the CP decomposition outputs.
93
+ """
94
+ if embedding_key in adata.obsm and components_key in adata.varm and not overwrite:
95
+ logger.info("CP embedding and components already present; skipping recomputation.")
96
+ return adata
97
+
98
+ if backend not in {"numpy", "pytorch"}:
99
+ raise ValueError(f"Unsupported backend '{backend}'. Use 'numpy' or 'pytorch'.")
100
+
101
+ tensorly = require("tensorly", extra="ml-base", purpose="CP decomposition")
102
+ from tensorly.decomposition import parafac
103
+
104
+ tensorly.set_backend(backend)
105
+
106
+ if layer not in adata.layers:
107
+ raise KeyError(f"Layer '{layer}' not found in adata.layers.")
108
+
109
+ one_hot, mask = build_sequence_one_hot_and_mask(adata.layers[layer], bases=tuple(bases))
110
+ mask_tensor = np.repeat(mask[:, :, None], one_hot.shape[2], axis=2)
111
+
112
+ device = "numpy"
113
+ if backend == "pytorch":
114
+ torch = require("torch", extra="ml-base", purpose="CP decomposition backend")
115
+ if torch.cuda.is_available():
116
+ device = torch.device("cuda")
117
+ elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
118
+ device = torch.device("mps")
119
+ else:
120
+ device = torch.device("cpu")
121
+
122
+ one_hot = torch.tensor(one_hot, dtype=torch.float32, device=device)
123
+ mask_tensor = torch.tensor(mask_tensor, dtype=torch.float32, device=device)
124
+
125
+ parafac_kwargs = {
126
+ "rank": rank,
127
+ "n_iter_max": n_iter_max,
128
+ "init": init,
129
+ "mask": mask_tensor,
130
+ "random_state": random_state,
131
+ }
132
+ import inspect
133
+
134
+ if "verbose" in inspect.signature(parafac).parameters:
135
+ parafac_kwargs["verbose"] = show_progress
136
+
137
+ cp = parafac(one_hot, **parafac_kwargs)
138
+
139
+ if backend == "pytorch":
140
+ weights = cp.weights.detach().cpu().numpy()
141
+ read_factors, position_factors, base_factors = [
142
+ factor.detach().cpu().numpy() for factor in cp.factors
143
+ ]
144
+ else:
145
+ weights = np.asarray(cp.weights)
146
+ read_factors, position_factors, base_factors = [np.asarray(f) for f in cp.factors]
147
+
148
+ adata.obsm[embedding_key] = read_factors
149
+ adata.varm[components_key] = position_factors
150
+ adata.uns[uns_key] = {
151
+ "rank": rank,
152
+ "n_iter_max": n_iter_max,
153
+ "random_state": random_state,
154
+ "layer": layer,
155
+ "components_key": components_key,
156
+ "weights": weights,
157
+ "base_factors": base_factors,
158
+ "base_labels": list(bases),
159
+ "backend": backend,
160
+ "device": str(device),
161
+ }
162
+
163
+ logger.info(
164
+ "Stored: adata.obsm['%s'], adata.varm['%s'], adata.uns['%s']",
165
+ embedding_key,
166
+ components_key,
167
+ uns_key,
168
+ )
169
+ return adata