smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +49 -7
- smftools/cli/hmm_adata.py +250 -32
- smftools/cli/latent_adata.py +773 -0
- smftools/cli/load_adata.py +78 -74
- smftools/cli/preprocess_adata.py +122 -58
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +74 -112
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +52 -4
- smftools/config/conversion.yaml +1 -1
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +85 -12
- smftools/config/experiment_config.py +146 -1
- smftools/constants.py +69 -0
- smftools/hmm/HMM.py +88 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +636 -175
- smftools/informatics/h5ad_functions.py +198 -2
- smftools/informatics/modkit_extract_to_adata.py +1007 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +26 -3
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +62 -1583
- smftools/plotting/hmm_plotting.py +1670 -8
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +4 -0
- smftools/preprocessing/append_base_context.py +18 -18
- smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +159 -99
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +10 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +130 -0
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +79 -80
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +872 -0
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +217 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,872 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from math import floor
|
|
4
|
+
from typing import TYPE_CHECKING, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from smftools.logging_utils import get_logger
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import anndata as ad
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def zero_pairs_to_dataframe(adata, zero_pairs_uns_key: str) -> "pd.DataFrame":
|
|
17
|
+
"""
|
|
18
|
+
Build a DataFrame of zero-Hamming pairs per window.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
adata: AnnData containing zero-pair window data in ``adata.uns``.
|
|
22
|
+
zero_pairs_uns_key: Key for zero-pair window data in ``adata.uns``.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
DataFrame with one row per zero-Hamming pair per window.
|
|
26
|
+
"""
|
|
27
|
+
import pandas as pd
|
|
28
|
+
|
|
29
|
+
if zero_pairs_uns_key not in adata.uns:
|
|
30
|
+
raise KeyError(f"Missing zero-pair data in adata.uns[{zero_pairs_uns_key!r}].")
|
|
31
|
+
|
|
32
|
+
zero_pairs_by_window = adata.uns[zero_pairs_uns_key]
|
|
33
|
+
starts = np.asarray(adata.uns.get(f"{zero_pairs_uns_key}_starts"))
|
|
34
|
+
window = int(adata.uns.get(f"{zero_pairs_uns_key}_window", 0))
|
|
35
|
+
if starts.size == 0 or window <= 0:
|
|
36
|
+
raise ValueError("Zero-pair metadata missing starts/window information.")
|
|
37
|
+
|
|
38
|
+
obs_names = np.asarray(adata.obs_names, dtype=object)
|
|
39
|
+
rows = []
|
|
40
|
+
for wi, pairs in enumerate(zero_pairs_by_window):
|
|
41
|
+
if pairs is None or len(pairs) == 0:
|
|
42
|
+
continue
|
|
43
|
+
start = int(starts[wi])
|
|
44
|
+
end = start + window
|
|
45
|
+
for read_i, read_j in pairs:
|
|
46
|
+
read_i = int(read_i)
|
|
47
|
+
read_j = int(read_j)
|
|
48
|
+
rows.append(
|
|
49
|
+
{
|
|
50
|
+
"window_index": wi,
|
|
51
|
+
"window_start": start,
|
|
52
|
+
"window_end": end,
|
|
53
|
+
"read_i": read_i,
|
|
54
|
+
"read_j": read_j,
|
|
55
|
+
"read_i_name": str(obs_names[read_i]),
|
|
56
|
+
"read_j_name": str(obs_names[read_j]),
|
|
57
|
+
}
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return pd.DataFrame(
|
|
61
|
+
rows,
|
|
62
|
+
columns=[
|
|
63
|
+
"window_index",
|
|
64
|
+
"window_start",
|
|
65
|
+
"window_end",
|
|
66
|
+
"read_i",
|
|
67
|
+
"read_j",
|
|
68
|
+
"read_i_name",
|
|
69
|
+
"read_j_name",
|
|
70
|
+
],
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def zero_hamming_segments_to_dataframe(
|
|
75
|
+
records: list[dict],
|
|
76
|
+
var_names: np.ndarray,
|
|
77
|
+
) -> "pd.DataFrame":
|
|
78
|
+
"""
|
|
79
|
+
Build a DataFrame of merged/refined zero-Hamming segments.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
records: Output records from ``annotate_zero_hamming_segments``.
|
|
83
|
+
var_names: AnnData var names for labeling segment coordinates.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
DataFrame with one row per zero-Hamming segment.
|
|
87
|
+
"""
|
|
88
|
+
import pandas as pd
|
|
89
|
+
|
|
90
|
+
var_names = np.asarray(var_names, dtype=object)
|
|
91
|
+
|
|
92
|
+
def _label_at(idx: int) -> Optional[str]:
|
|
93
|
+
if 0 <= idx < var_names.size:
|
|
94
|
+
return str(var_names[idx])
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
rows = []
|
|
98
|
+
for record in records:
|
|
99
|
+
read_i = int(record["read_i"])
|
|
100
|
+
read_j = int(record["read_j"])
|
|
101
|
+
read_i_name = record.get("read_i_name")
|
|
102
|
+
read_j_name = record.get("read_j_name")
|
|
103
|
+
for seg_start, seg_end in record.get("segments", []):
|
|
104
|
+
seg_start = int(seg_start)
|
|
105
|
+
seg_end = int(seg_end)
|
|
106
|
+
end_inclusive = max(seg_start, seg_end - 1)
|
|
107
|
+
rows.append(
|
|
108
|
+
{
|
|
109
|
+
"read_i": read_i,
|
|
110
|
+
"read_j": read_j,
|
|
111
|
+
"read_i_name": read_i_name,
|
|
112
|
+
"read_j_name": read_j_name,
|
|
113
|
+
"segment_start": seg_start,
|
|
114
|
+
"segment_end_exclusive": seg_end,
|
|
115
|
+
"segment_end_inclusive": end_inclusive,
|
|
116
|
+
"segment_start_label": _label_at(seg_start),
|
|
117
|
+
"segment_end_label": _label_at(end_inclusive),
|
|
118
|
+
}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return pd.DataFrame(
|
|
122
|
+
rows,
|
|
123
|
+
columns=[
|
|
124
|
+
"read_i",
|
|
125
|
+
"read_j",
|
|
126
|
+
"read_i_name",
|
|
127
|
+
"read_j_name",
|
|
128
|
+
"segment_start",
|
|
129
|
+
"segment_end_exclusive",
|
|
130
|
+
"segment_end_inclusive",
|
|
131
|
+
"segment_start_label",
|
|
132
|
+
"segment_end_label",
|
|
133
|
+
],
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _window_center_coordinates(adata, starts: np.ndarray, window: int) -> np.ndarray:
|
|
138
|
+
"""
|
|
139
|
+
Compute window center coordinates using AnnData var positions.
|
|
140
|
+
|
|
141
|
+
If coordinates are numeric, return the mean coordinate per window.
|
|
142
|
+
If not numeric, return the midpoint label for each window.
|
|
143
|
+
"""
|
|
144
|
+
coord_source = adata.var_names
|
|
145
|
+
|
|
146
|
+
coords = np.asarray(coord_source)
|
|
147
|
+
if coords.size == 0:
|
|
148
|
+
return np.array([], dtype=float)
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
coords_numeric = coords.astype(float)
|
|
152
|
+
return np.array(
|
|
153
|
+
[floor(np.nanmean(coords_numeric[s : s + window])) for s in starts], dtype=float
|
|
154
|
+
)
|
|
155
|
+
except Exception:
|
|
156
|
+
mid = np.clip(starts + (window // 2), 0, coords.size - 1)
|
|
157
|
+
return coords[mid]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
|
|
161
|
+
"""
|
|
162
|
+
Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
|
|
163
|
+
Safe w.r.t. contiguity/layout.
|
|
164
|
+
"""
|
|
165
|
+
B = np.asarray(B, dtype=np.uint8)
|
|
166
|
+
packed_u8 = np.packbits(B, axis=1) # (n, ceil(w/8)) uint8
|
|
167
|
+
|
|
168
|
+
n, nb = packed_u8.shape
|
|
169
|
+
pad = (-nb) % 8
|
|
170
|
+
if pad:
|
|
171
|
+
packed_u8 = np.pad(packed_u8, ((0, 0), (0, pad)), mode="constant", constant_values=0)
|
|
172
|
+
|
|
173
|
+
packed_u8 = np.ascontiguousarray(packed_u8)
|
|
174
|
+
|
|
175
|
+
# group 8 bytes -> uint64
|
|
176
|
+
packed_u64 = packed_u8.reshape(n, -1, 8).view(np.uint64).reshape(n, -1)
|
|
177
|
+
return packed_u64
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _popcount_u64_matrix(A_u64: np.ndarray) -> np.ndarray:
|
|
181
|
+
"""
|
|
182
|
+
Popcount for an array of uint64, vectorized and portable across NumPy versions.
|
|
183
|
+
|
|
184
|
+
Returns an integer array with the SAME SHAPE as A_u64.
|
|
185
|
+
"""
|
|
186
|
+
A_u64 = np.ascontiguousarray(A_u64)
|
|
187
|
+
# View as bytes; IMPORTANT: reshape to add a trailing byte axis of length 8
|
|
188
|
+
b = A_u64.view(np.uint8).reshape(A_u64.shape + (8,))
|
|
189
|
+
# unpack bits within that byte axis -> (..., 64), then sum
|
|
190
|
+
return np.unpackbits(b, axis=-1).sum(axis=-1)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def rolling_window_nn_distance(
|
|
194
|
+
adata,
|
|
195
|
+
layer: Optional[str] = None,
|
|
196
|
+
window: int = 15,
|
|
197
|
+
step: int = 2,
|
|
198
|
+
min_overlap: int = 10,
|
|
199
|
+
return_fraction: bool = True,
|
|
200
|
+
block_rows: int = 256,
|
|
201
|
+
block_cols: int = 2048,
|
|
202
|
+
store_obsm: Optional[str] = "rolling_nn_dist",
|
|
203
|
+
collect_zero_pairs: bool = False,
|
|
204
|
+
zero_pairs_uns_key: Optional[str] = None,
|
|
205
|
+
sample_labels: Optional[np.ndarray] = None,
|
|
206
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
207
|
+
"""
|
|
208
|
+
Rolling-window nearest-neighbor distance per read, overlap-aware.
|
|
209
|
+
|
|
210
|
+
Distance between reads i,j in a window:
|
|
211
|
+
- use only positions where BOTH are observed (non-NaN)
|
|
212
|
+
- require overlap >= min_overlap
|
|
213
|
+
- mismatch = count(x_i != x_j) over overlapped positions
|
|
214
|
+
- distance = mismatch/overlap (if return_fraction) else mismatch
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
out : (n_obs, n_windows) float
|
|
219
|
+
Nearest-neighbor distance per read per window (NaN if no valid neighbor).
|
|
220
|
+
starts : (n_windows,) int
|
|
221
|
+
Window start indices in var-space.
|
|
222
|
+
centers : (n_windows,) array-like
|
|
223
|
+
Window center coordinates derived from AnnData var positions (stored in ``.uns``).
|
|
224
|
+
"""
|
|
225
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
|
226
|
+
X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
|
|
227
|
+
|
|
228
|
+
n, p = X.shape
|
|
229
|
+
if window > p:
|
|
230
|
+
raise ValueError(f"window={window} is larger than n_vars={p}")
|
|
231
|
+
if window <= 0:
|
|
232
|
+
raise ValueError("window must be > 0")
|
|
233
|
+
if step <= 0:
|
|
234
|
+
raise ValueError("step must be > 0")
|
|
235
|
+
if min_overlap <= 0:
|
|
236
|
+
raise ValueError("min_overlap must be > 0")
|
|
237
|
+
|
|
238
|
+
starts = np.arange(0, p - window + 1, step, dtype=int)
|
|
239
|
+
nW = len(starts)
|
|
240
|
+
out = np.full((n, nW), np.nan, dtype=float)
|
|
241
|
+
|
|
242
|
+
zero_pairs_by_window = [] if collect_zero_pairs else None
|
|
243
|
+
|
|
244
|
+
for wi, s in enumerate(starts):
|
|
245
|
+
wX = X[:, s : s + window] # (n, window)
|
|
246
|
+
|
|
247
|
+
# observed mask; values as 0/1 where observed, 0 elsewhere
|
|
248
|
+
M = ~np.isnan(wX)
|
|
249
|
+
V = np.where(M, wX, 0).astype(np.float32)
|
|
250
|
+
|
|
251
|
+
# ensure binary 0/1
|
|
252
|
+
V = (V > 0).astype(np.uint8)
|
|
253
|
+
|
|
254
|
+
M64 = _pack_bool_to_u64(M)
|
|
255
|
+
V64 = _pack_bool_to_u64(V.astype(bool))
|
|
256
|
+
|
|
257
|
+
best = np.full(n, np.inf, dtype=float)
|
|
258
|
+
|
|
259
|
+
window_pairs = [] if collect_zero_pairs else None
|
|
260
|
+
|
|
261
|
+
for i0 in range(0, n, block_rows):
|
|
262
|
+
i1 = min(n, i0 + block_rows)
|
|
263
|
+
Mi = M64[i0:i1] # (bi, nb)
|
|
264
|
+
Vi = V64[i0:i1]
|
|
265
|
+
bi = i1 - i0
|
|
266
|
+
|
|
267
|
+
local_best = np.full(bi, np.inf, dtype=float)
|
|
268
|
+
|
|
269
|
+
for j0 in range(0, n, block_cols):
|
|
270
|
+
j1 = min(n, j0 + block_cols)
|
|
271
|
+
Mj = M64[j0:j1] # (bj, nb)
|
|
272
|
+
Vj = V64[j0:j1]
|
|
273
|
+
bj = j1 - j0
|
|
274
|
+
|
|
275
|
+
overlap_counts = np.zeros((bi, bj), dtype=np.uint16)
|
|
276
|
+
mismatch_counts = np.zeros((bi, bj), dtype=np.uint16)
|
|
277
|
+
|
|
278
|
+
for k in range(Mi.shape[1]):
|
|
279
|
+
ob = (Mi[:, k][:, None] & Mj[:, k][None, :]).astype(np.uint64)
|
|
280
|
+
overlap_counts += _popcount_u64_matrix(ob).astype(np.uint16)
|
|
281
|
+
|
|
282
|
+
mb = ((Vi[:, k][:, None] ^ Vj[:, k][None, :]) & ob).astype(np.uint64)
|
|
283
|
+
mismatch_counts += _popcount_u64_matrix(mb).astype(np.uint16)
|
|
284
|
+
|
|
285
|
+
ok = overlap_counts >= min_overlap
|
|
286
|
+
if not np.any(ok):
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
dist = np.full((bi, bj), np.inf, dtype=float)
|
|
290
|
+
if return_fraction:
|
|
291
|
+
dist[ok] = mismatch_counts[ok] / overlap_counts[ok]
|
|
292
|
+
else:
|
|
293
|
+
dist[ok] = mismatch_counts[ok].astype(float)
|
|
294
|
+
|
|
295
|
+
# exclude self comparisons (diagonal) when blocks overlap
|
|
296
|
+
if (i0 <= j1) and (j0 <= i1):
|
|
297
|
+
ii = np.arange(i0, i1)
|
|
298
|
+
jj = ii[(ii >= j0) & (ii < j1)]
|
|
299
|
+
if jj.size:
|
|
300
|
+
dist[(jj - i0), (jj - j0)] = np.inf
|
|
301
|
+
|
|
302
|
+
# exclude same-sample comparisons when sample_labels provided
|
|
303
|
+
if sample_labels is not None:
|
|
304
|
+
sl_i = sample_labels[i0:i1]
|
|
305
|
+
sl_j = sample_labels[j0:j1]
|
|
306
|
+
same_sample = sl_i[:, None] == sl_j[None, :]
|
|
307
|
+
dist[same_sample] = np.inf
|
|
308
|
+
|
|
309
|
+
local_best = np.minimum(local_best, dist.min(axis=1))
|
|
310
|
+
|
|
311
|
+
if collect_zero_pairs:
|
|
312
|
+
zero_mask = ok & (mismatch_counts == 0)
|
|
313
|
+
if np.any(zero_mask):
|
|
314
|
+
i_idx, j_idx = np.where(zero_mask)
|
|
315
|
+
gi = i0 + i_idx
|
|
316
|
+
gj = j0 + j_idx
|
|
317
|
+
keep = gi < gj
|
|
318
|
+
if sample_labels is not None:
|
|
319
|
+
keep = keep & (sample_labels[gi] != sample_labels[gj])
|
|
320
|
+
if np.any(keep):
|
|
321
|
+
window_pairs.append(np.stack([gi[keep], gj[keep]], axis=1))
|
|
322
|
+
|
|
323
|
+
best[i0:i1] = local_best
|
|
324
|
+
|
|
325
|
+
best[~np.isfinite(best)] = np.nan
|
|
326
|
+
out[:, wi] = best
|
|
327
|
+
if collect_zero_pairs:
|
|
328
|
+
if window_pairs:
|
|
329
|
+
zero_pairs_by_window.append(np.vstack(window_pairs))
|
|
330
|
+
else:
|
|
331
|
+
zero_pairs_by_window.append(np.empty((0, 2), dtype=int))
|
|
332
|
+
|
|
333
|
+
if store_obsm is not None:
|
|
334
|
+
adata.obsm[store_obsm] = out
|
|
335
|
+
adata.uns[f"{store_obsm}_starts"] = starts
|
|
336
|
+
adata.uns[f"{store_obsm}_centers"] = _window_center_coordinates(adata, starts, window)
|
|
337
|
+
adata.uns[f"{store_obsm}_window"] = int(window)
|
|
338
|
+
adata.uns[f"{store_obsm}_step"] = int(step)
|
|
339
|
+
adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
|
|
340
|
+
adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
|
|
341
|
+
adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
|
|
342
|
+
if collect_zero_pairs:
|
|
343
|
+
if zero_pairs_uns_key is None:
|
|
344
|
+
zero_pairs_uns_key = (
|
|
345
|
+
f"{store_obsm}_zero_pairs" if store_obsm is not None else "rolling_nn_zero_pairs"
|
|
346
|
+
)
|
|
347
|
+
adata.uns[zero_pairs_uns_key] = zero_pairs_by_window
|
|
348
|
+
adata.uns[f"{zero_pairs_uns_key}_starts"] = starts
|
|
349
|
+
adata.uns[f"{zero_pairs_uns_key}_window"] = int(window)
|
|
350
|
+
adata.uns[f"{zero_pairs_uns_key}_step"] = int(step)
|
|
351
|
+
adata.uns[f"{zero_pairs_uns_key}_min_overlap"] = int(min_overlap)
|
|
352
|
+
adata.uns[f"{zero_pairs_uns_key}_return_fraction"] = bool(return_fraction)
|
|
353
|
+
adata.uns[f"{zero_pairs_uns_key}_layer"] = layer if layer is not None else "X"
|
|
354
|
+
|
|
355
|
+
return out, starts
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def annotate_zero_hamming_segments(
|
|
359
|
+
adata,
|
|
360
|
+
zero_pairs_uns_key: Optional[str] = None,
|
|
361
|
+
output_uns_key: str = "zero_hamming_segments",
|
|
362
|
+
layer: Optional[str] = None,
|
|
363
|
+
min_overlap: Optional[int] = None,
|
|
364
|
+
refine_segments: bool = True,
|
|
365
|
+
max_nan_run: Optional[int] = None,
|
|
366
|
+
merge_gap: int = 0,
|
|
367
|
+
max_segments_per_read: Optional[int] = None,
|
|
368
|
+
max_segment_overlap: Optional[int] = None,
|
|
369
|
+
) -> list[dict]:
|
|
370
|
+
"""
|
|
371
|
+
Merge zero-Hamming windows into maximal segments and annotate onto AnnData.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
adata: AnnData containing zero-pair window data in ``.uns``.
|
|
375
|
+
zero_pairs_uns_key: Key for zero-pair window data in ``adata.uns``.
|
|
376
|
+
output_uns_key: Key to store merged/refined segments in ``adata.uns``.
|
|
377
|
+
layer: Layer to use for refinement (defaults to adata.X).
|
|
378
|
+
min_overlap: Minimum overlap required to keep a refined segment.
|
|
379
|
+
refine_segments: Whether to refine merged windows to maximal segments.
|
|
380
|
+
max_nan_run: Maximum consecutive NaN positions allowed when expanding segments.
|
|
381
|
+
If reached, expansion stops before the NaN run. Set to ``None`` to ignore NaNs.
|
|
382
|
+
merge_gap: Merge segments with gaps of at most this size (in positions).
|
|
383
|
+
max_segments_per_read: Maximum number of segments to retain per read pair.
|
|
384
|
+
max_segment_overlap: Maximum allowed overlap between retained segments (inclusive, in
|
|
385
|
+
var-index coordinates).
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
List of segment records stored in ``adata.uns[output_uns_key]``.
|
|
389
|
+
"""
|
|
390
|
+
if zero_pairs_uns_key is None:
|
|
391
|
+
candidate_keys = [key for key in adata.uns if key.endswith("_zero_pairs")]
|
|
392
|
+
if len(candidate_keys) == 1:
|
|
393
|
+
zero_pairs_uns_key = candidate_keys[0]
|
|
394
|
+
elif not candidate_keys:
|
|
395
|
+
raise KeyError("No zero-pair data found in adata.uns.")
|
|
396
|
+
else:
|
|
397
|
+
raise KeyError(
|
|
398
|
+
"Multiple zero-pair keys found in adata.uns; please specify zero_pairs_uns_key."
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
if zero_pairs_uns_key not in adata.uns:
|
|
402
|
+
raise KeyError(f"Missing zero-pair data in adata.uns[{zero_pairs_uns_key!r}].")
|
|
403
|
+
|
|
404
|
+
zero_pairs_by_window = adata.uns[zero_pairs_uns_key]
|
|
405
|
+
starts = np.asarray(adata.uns.get(f"{zero_pairs_uns_key}_starts"))
|
|
406
|
+
window = int(adata.uns.get(f"{zero_pairs_uns_key}_window", 0))
|
|
407
|
+
if starts.size == 0 or window <= 0:
|
|
408
|
+
raise ValueError("Zero-pair metadata missing starts/window information.")
|
|
409
|
+
|
|
410
|
+
if min_overlap is None:
|
|
411
|
+
min_overlap = int(adata.uns.get(f"{zero_pairs_uns_key}_min_overlap", 1))
|
|
412
|
+
|
|
413
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
|
414
|
+
X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
|
|
415
|
+
observed = ~np.isnan(X)
|
|
416
|
+
values = (np.where(observed, X, 0.0) > 0).astype(np.uint8)
|
|
417
|
+
max_nan_run = None if max_nan_run is None else max(int(max_nan_run), 1)
|
|
418
|
+
|
|
419
|
+
pair_segments: dict[tuple[int, int], list[tuple[int, int]]] = {}
|
|
420
|
+
for wi, pairs in enumerate(zero_pairs_by_window):
|
|
421
|
+
if pairs is None or len(pairs) == 0:
|
|
422
|
+
continue
|
|
423
|
+
start = int(starts[wi])
|
|
424
|
+
end = start + window
|
|
425
|
+
for i, j in pairs:
|
|
426
|
+
key = (int(i), int(j))
|
|
427
|
+
pair_segments.setdefault(key, []).append((start, end))
|
|
428
|
+
|
|
429
|
+
merge_gap = max(int(merge_gap), 0)
|
|
430
|
+
|
|
431
|
+
def _merge_segments(segments: list[tuple[int, int]]) -> list[tuple[int, int]]:
|
|
432
|
+
if not segments:
|
|
433
|
+
return []
|
|
434
|
+
segments = sorted(segments, key=lambda seg: seg[0])
|
|
435
|
+
merged = [segments[0]]
|
|
436
|
+
for seg_start, seg_end in segments[1:]:
|
|
437
|
+
last_start, last_end = merged[-1]
|
|
438
|
+
if seg_start <= last_end + merge_gap:
|
|
439
|
+
merged[-1] = (last_start, max(last_end, seg_end))
|
|
440
|
+
else:
|
|
441
|
+
merged.append((seg_start, seg_end))
|
|
442
|
+
return merged
|
|
443
|
+
|
|
444
|
+
def _refine_segment(
|
|
445
|
+
read_i: int,
|
|
446
|
+
read_j: int,
|
|
447
|
+
start: int,
|
|
448
|
+
end: int,
|
|
449
|
+
) -> Optional[tuple[int, int]]:
|
|
450
|
+
if not refine_segments:
|
|
451
|
+
return (start, end)
|
|
452
|
+
left = start
|
|
453
|
+
right = end
|
|
454
|
+
nan_run = 0
|
|
455
|
+
while left > 0:
|
|
456
|
+
idx = left - 1
|
|
457
|
+
both_observed = observed[read_i, idx] and observed[read_j, idx]
|
|
458
|
+
if both_observed:
|
|
459
|
+
nan_run = 0
|
|
460
|
+
if values[read_i, idx] != values[read_j, idx]:
|
|
461
|
+
break
|
|
462
|
+
else:
|
|
463
|
+
nan_run += 1
|
|
464
|
+
if max_nan_run is not None and nan_run >= max_nan_run:
|
|
465
|
+
break
|
|
466
|
+
left -= 1
|
|
467
|
+
n_vars = values.shape[1]
|
|
468
|
+
nan_run = 0
|
|
469
|
+
while right < n_vars:
|
|
470
|
+
idx = right
|
|
471
|
+
both_observed = observed[read_i, idx] and observed[read_j, idx]
|
|
472
|
+
if both_observed:
|
|
473
|
+
nan_run = 0
|
|
474
|
+
if values[read_i, idx] != values[read_j, idx]:
|
|
475
|
+
break
|
|
476
|
+
else:
|
|
477
|
+
nan_run += 1
|
|
478
|
+
if max_nan_run is not None and nan_run >= max_nan_run:
|
|
479
|
+
break
|
|
480
|
+
right += 1
|
|
481
|
+
overlap = np.sum(observed[read_i, left:right] & observed[read_j, left:right])
|
|
482
|
+
if overlap < min_overlap:
|
|
483
|
+
return None
|
|
484
|
+
return (left, right)
|
|
485
|
+
|
|
486
|
+
def _segment_length(segment: tuple[int, int]) -> int:
|
|
487
|
+
return int(segment[1]) - int(segment[0])
|
|
488
|
+
|
|
489
|
+
def _segment_overlap(first: tuple[int, int], second: tuple[int, int]) -> int:
|
|
490
|
+
return max(0, min(first[1], second[1]) - max(first[0], second[0]))
|
|
491
|
+
|
|
492
|
+
def _select_segments(segments: list[tuple[int, int]]) -> list[tuple[int, int]]:
|
|
493
|
+
if not segments:
|
|
494
|
+
return []
|
|
495
|
+
if max_segments_per_read is None and max_segment_overlap is None:
|
|
496
|
+
return segments
|
|
497
|
+
ordered = sorted(
|
|
498
|
+
segments,
|
|
499
|
+
key=lambda seg: (_segment_length(seg), -seg[0]),
|
|
500
|
+
reverse=True,
|
|
501
|
+
)
|
|
502
|
+
max_segments = len(ordered) if max_segments_per_read is None else max_segments_per_read
|
|
503
|
+
if max_segment_overlap is None:
|
|
504
|
+
return ordered[:max_segments]
|
|
505
|
+
selected: list[tuple[int, int]] = []
|
|
506
|
+
for segment in ordered:
|
|
507
|
+
if len(selected) >= max_segments:
|
|
508
|
+
break
|
|
509
|
+
if all(_segment_overlap(segment, other) <= max_segment_overlap for other in selected):
|
|
510
|
+
selected.append(segment)
|
|
511
|
+
return selected
|
|
512
|
+
|
|
513
|
+
records: list[dict] = []
|
|
514
|
+
obs_names = adata.obs_names
|
|
515
|
+
for (read_i, read_j), segments in pair_segments.items():
|
|
516
|
+
merged = _merge_segments(segments)
|
|
517
|
+
refined_segments = []
|
|
518
|
+
for seg_start, seg_end in merged:
|
|
519
|
+
refined = _refine_segment(read_i, read_j, seg_start, seg_end)
|
|
520
|
+
if refined is not None:
|
|
521
|
+
refined_segments.append(refined)
|
|
522
|
+
refined_segments = _select_segments(refined_segments)
|
|
523
|
+
if refined_segments:
|
|
524
|
+
records.append(
|
|
525
|
+
{
|
|
526
|
+
"read_i": read_i,
|
|
527
|
+
"read_j": read_j,
|
|
528
|
+
"read_i_name": str(obs_names[read_i]),
|
|
529
|
+
"read_j_name": str(obs_names[read_j]),
|
|
530
|
+
"segments": refined_segments,
|
|
531
|
+
}
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
adata.uns[output_uns_key] = records
|
|
535
|
+
return records
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def assign_per_read_segments_layer(
|
|
539
|
+
parent_adata: "ad.AnnData",
|
|
540
|
+
subset_adata: "ad.AnnData",
|
|
541
|
+
per_read_segments: "pd.DataFrame",
|
|
542
|
+
layer_key: str,
|
|
543
|
+
) -> None:
|
|
544
|
+
"""
|
|
545
|
+
Assign per-read segments into a summed span layer on a parent AnnData.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
parent_adata: AnnData that should receive the span layer.
|
|
549
|
+
subset_adata: AnnData used to compute per-read segments.
|
|
550
|
+
per_read_segments: DataFrame with ``read_id``, ``segment_start``, and
|
|
551
|
+
``segment_end_exclusive`` columns. If ``segment_start_label`` and
|
|
552
|
+
``segment_end_label`` are present and numeric, they are used to
|
|
553
|
+
map segments using label coordinates.
|
|
554
|
+
layer_key: Name of the layer to store in ``parent_adata.layers``.
|
|
555
|
+
"""
|
|
556
|
+
import pandas as pd
|
|
557
|
+
|
|
558
|
+
if per_read_segments.empty:
|
|
559
|
+
parent_adata.layers[layer_key] = np.zeros(
|
|
560
|
+
(parent_adata.n_obs, parent_adata.n_vars), dtype=np.uint16
|
|
561
|
+
)
|
|
562
|
+
return
|
|
563
|
+
required_cols = {"read_id", "segment_start", "segment_end_exclusive"}
|
|
564
|
+
missing = required_cols.difference(per_read_segments.columns)
|
|
565
|
+
if missing:
|
|
566
|
+
raise KeyError(f"per_read_segments missing required columns: {sorted(missing)}")
|
|
567
|
+
|
|
568
|
+
target_layer = np.zeros((parent_adata.n_obs, parent_adata.n_vars), dtype=np.uint16)
|
|
569
|
+
|
|
570
|
+
parent_obs_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
|
|
571
|
+
if (parent_obs_indexer < 0).any():
|
|
572
|
+
raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
|
|
573
|
+
parent_var_indexer = parent_adata.var_names.get_indexer(subset_adata.var_names)
|
|
574
|
+
if (parent_var_indexer < 0).any():
|
|
575
|
+
raise ValueError("Subset AnnData contains vars not present in parent AnnData.")
|
|
576
|
+
|
|
577
|
+
label_indexer = None
|
|
578
|
+
label_columns = {"segment_start_label", "segment_end_label"}
|
|
579
|
+
if label_columns.issubset(per_read_segments.columns):
|
|
580
|
+
try:
|
|
581
|
+
parent_label_values = [int(label) for label in parent_adata.var_names]
|
|
582
|
+
label_indexer = {label: idx for idx, label in enumerate(parent_label_values)}
|
|
583
|
+
except (TypeError, ValueError):
|
|
584
|
+
label_indexer = None
|
|
585
|
+
|
|
586
|
+
def _label_to_index(value: object) -> Optional[int]:
|
|
587
|
+
if label_indexer is None or value is None or pd.isna(value):
|
|
588
|
+
return None
|
|
589
|
+
try:
|
|
590
|
+
return label_indexer.get(int(value))
|
|
591
|
+
except (TypeError, ValueError):
|
|
592
|
+
return None
|
|
593
|
+
|
|
594
|
+
for row in per_read_segments.itertuples(index=False):
|
|
595
|
+
read_id = int(row.read_id)
|
|
596
|
+
seg_start = int(row.segment_start)
|
|
597
|
+
seg_end = int(row.segment_end_exclusive)
|
|
598
|
+
if seg_start >= seg_end:
|
|
599
|
+
continue
|
|
600
|
+
target_read = parent_obs_indexer[read_id]
|
|
601
|
+
if target_read < 0:
|
|
602
|
+
raise ValueError("Segment read_id not found in parent AnnData.")
|
|
603
|
+
|
|
604
|
+
label_start = _label_to_index(getattr(row, "segment_start_label", None))
|
|
605
|
+
label_end = _label_to_index(getattr(row, "segment_end_label", None))
|
|
606
|
+
if label_start is not None and label_end is not None:
|
|
607
|
+
parent_start = min(label_start, label_end)
|
|
608
|
+
parent_end = max(label_start, label_end)
|
|
609
|
+
else:
|
|
610
|
+
parent_positions = parent_var_indexer[seg_start:seg_end]
|
|
611
|
+
if parent_positions.size == 0:
|
|
612
|
+
continue
|
|
613
|
+
parent_start = int(parent_positions.min())
|
|
614
|
+
parent_end = int(parent_positions.max())
|
|
615
|
+
|
|
616
|
+
target_layer[target_read, parent_start : parent_end + 1] += 1
|
|
617
|
+
|
|
618
|
+
parent_adata.layers[layer_key] = target_layer
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def segments_to_per_read_dataframe(
|
|
622
|
+
records: list[dict],
|
|
623
|
+
var_names: np.ndarray,
|
|
624
|
+
) -> "pd.DataFrame":
|
|
625
|
+
"""
|
|
626
|
+
Build a per-read DataFrame of zero-Hamming segments.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
records: Output records from ``annotate_zero_hamming_segments``.
|
|
630
|
+
var_names: AnnData var names for labeling segment coordinates.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
DataFrame with one row per segment per read.
|
|
634
|
+
"""
|
|
635
|
+
import pandas as pd
|
|
636
|
+
|
|
637
|
+
var_names = np.asarray(var_names, dtype=object)
|
|
638
|
+
|
|
639
|
+
def _label_at(idx: int) -> Optional[str]:
|
|
640
|
+
if 0 <= idx < var_names.size:
|
|
641
|
+
return str(var_names[idx])
|
|
642
|
+
return None
|
|
643
|
+
|
|
644
|
+
rows = []
|
|
645
|
+
for record in records:
|
|
646
|
+
read_i = int(record["read_i"])
|
|
647
|
+
read_j = int(record["read_j"])
|
|
648
|
+
read_i_name = record.get("read_i_name")
|
|
649
|
+
read_j_name = record.get("read_j_name")
|
|
650
|
+
for seg_start, seg_end in record.get("segments", []):
|
|
651
|
+
seg_start = int(seg_start)
|
|
652
|
+
seg_end = int(seg_end)
|
|
653
|
+
end_inclusive = max(seg_start, seg_end - 1)
|
|
654
|
+
start_label = _label_at(seg_start)
|
|
655
|
+
end_label = _label_at(end_inclusive)
|
|
656
|
+
rows.append(
|
|
657
|
+
{
|
|
658
|
+
"read_id": read_i,
|
|
659
|
+
"partner_id": read_j,
|
|
660
|
+
"read_name": read_i_name,
|
|
661
|
+
"partner_name": read_j_name,
|
|
662
|
+
"segment_start": seg_start,
|
|
663
|
+
"segment_end_exclusive": seg_end,
|
|
664
|
+
"segment_end_inclusive": end_inclusive,
|
|
665
|
+
"segment_start_label": start_label,
|
|
666
|
+
"segment_end_label": end_label,
|
|
667
|
+
}
|
|
668
|
+
)
|
|
669
|
+
rows.append(
|
|
670
|
+
{
|
|
671
|
+
"read_id": read_j,
|
|
672
|
+
"partner_id": read_i,
|
|
673
|
+
"read_name": read_j_name,
|
|
674
|
+
"partner_name": read_i_name,
|
|
675
|
+
"segment_start": seg_start,
|
|
676
|
+
"segment_end_exclusive": seg_end,
|
|
677
|
+
"segment_end_inclusive": end_inclusive,
|
|
678
|
+
"segment_start_label": start_label,
|
|
679
|
+
"segment_end_label": end_label,
|
|
680
|
+
}
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
return pd.DataFrame(
|
|
684
|
+
rows,
|
|
685
|
+
columns=[
|
|
686
|
+
"read_id",
|
|
687
|
+
"partner_id",
|
|
688
|
+
"read_name",
|
|
689
|
+
"partner_name",
|
|
690
|
+
"segment_start",
|
|
691
|
+
"segment_end_exclusive",
|
|
692
|
+
"segment_end_inclusive",
|
|
693
|
+
"segment_start_label",
|
|
694
|
+
"segment_end_label",
|
|
695
|
+
],
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def select_top_segments_per_read(
|
|
700
|
+
records: list[dict],
|
|
701
|
+
var_names: np.ndarray,
|
|
702
|
+
max_segments_per_read: Optional[int] = None,
|
|
703
|
+
max_segment_overlap: Optional[int] = None,
|
|
704
|
+
min_span: Optional[float] = None,
|
|
705
|
+
) -> tuple["pd.DataFrame", "pd.DataFrame"]:
|
|
706
|
+
"""
|
|
707
|
+
Select top segments per read from distinct partner pairs.
|
|
708
|
+
|
|
709
|
+
Args:
|
|
710
|
+
records: Output records from ``annotate_zero_hamming_segments``.
|
|
711
|
+
var_names: AnnData var names for labeling segment coordinates.
|
|
712
|
+
max_segments_per_read: Maximum number of segments to keep per read.
|
|
713
|
+
max_segment_overlap: Maximum allowed overlap between kept segments.
|
|
714
|
+
min_span: Minimum span length to keep (var-name coordinate if numeric, else index span).
|
|
715
|
+
|
|
716
|
+
Returns:
|
|
717
|
+
Tuple of (raw per-read segments, filtered per-read segments).
|
|
718
|
+
"""
|
|
719
|
+
import pandas as pd
|
|
720
|
+
|
|
721
|
+
raw_df = segments_to_per_read_dataframe(records, var_names)
|
|
722
|
+
if raw_df.empty:
|
|
723
|
+
raw_df = raw_df.copy()
|
|
724
|
+
raw_df["segment_length_index"] = pd.Series(dtype=int)
|
|
725
|
+
raw_df["segment_length_label"] = pd.Series(dtype=float)
|
|
726
|
+
return raw_df, raw_df.copy()
|
|
727
|
+
|
|
728
|
+
def _span_length(row) -> float:
|
|
729
|
+
try:
|
|
730
|
+
start = float(row["segment_start_label"])
|
|
731
|
+
end = float(row["segment_end_label"])
|
|
732
|
+
return abs(end - start)
|
|
733
|
+
except (TypeError, ValueError):
|
|
734
|
+
return float(row["segment_end_exclusive"] - row["segment_start"])
|
|
735
|
+
|
|
736
|
+
raw_df = raw_df.copy()
|
|
737
|
+
raw_df["segment_length_index"] = (
|
|
738
|
+
raw_df["segment_end_exclusive"] - raw_df["segment_start"]
|
|
739
|
+
).astype(int)
|
|
740
|
+
raw_df["segment_length_label"] = raw_df.apply(_span_length, axis=1)
|
|
741
|
+
if min_span is not None:
|
|
742
|
+
raw_df = raw_df[raw_df["segment_length_label"] >= float(min_span)]
|
|
743
|
+
|
|
744
|
+
if raw_df.empty:
|
|
745
|
+
return raw_df, raw_df.copy()
|
|
746
|
+
|
|
747
|
+
def _segment_overlap(a, b) -> int:
|
|
748
|
+
return max(0, min(a[1], b[1]) - max(a[0], b[0]))
|
|
749
|
+
|
|
750
|
+
filtered_rows = []
|
|
751
|
+
max_segments = max_segments_per_read
|
|
752
|
+
for read_id, read_df in raw_df.groupby("read_id", sort=False):
|
|
753
|
+
per_partner = (
|
|
754
|
+
read_df.sort_values(
|
|
755
|
+
["segment_length_label", "segment_start"],
|
|
756
|
+
ascending=[False, True],
|
|
757
|
+
)
|
|
758
|
+
.groupby("partner_id", sort=False)
|
|
759
|
+
.head(1)
|
|
760
|
+
)
|
|
761
|
+
ordered = per_partner.sort_values(
|
|
762
|
+
["segment_length_label", "segment_start"],
|
|
763
|
+
ascending=[False, True],
|
|
764
|
+
).itertuples(index=False)
|
|
765
|
+
selected = []
|
|
766
|
+
for row in ordered:
|
|
767
|
+
if max_segments is not None and len(selected) >= max_segments:
|
|
768
|
+
break
|
|
769
|
+
seg = (row.segment_start, row.segment_end_exclusive)
|
|
770
|
+
if max_segment_overlap is not None:
|
|
771
|
+
if any(
|
|
772
|
+
_segment_overlap(seg, (s.segment_start, s.segment_end_exclusive))
|
|
773
|
+
> max_segment_overlap
|
|
774
|
+
for s in selected
|
|
775
|
+
):
|
|
776
|
+
continue
|
|
777
|
+
selected.append(row)
|
|
778
|
+
for row in selected:
|
|
779
|
+
filtered_rows.append(row._asdict())
|
|
780
|
+
|
|
781
|
+
filtered_df = pd.DataFrame(filtered_rows, columns=raw_df.columns)
|
|
782
|
+
if not filtered_df.empty:
|
|
783
|
+
filtered_df["selection_rank"] = (
|
|
784
|
+
filtered_df.sort_values(
|
|
785
|
+
["read_id", "segment_length_label", "segment_start"],
|
|
786
|
+
ascending=[True, False, True],
|
|
787
|
+
)
|
|
788
|
+
.groupby("read_id", sort=False)
|
|
789
|
+
.cumcount()
|
|
790
|
+
+ 1
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
return raw_df, filtered_df
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def assign_rolling_nn_results(
|
|
797
|
+
parent_adata: "ad.AnnData",
|
|
798
|
+
subset_adata: "ad.AnnData",
|
|
799
|
+
values: np.ndarray,
|
|
800
|
+
starts: np.ndarray,
|
|
801
|
+
obsm_key: str,
|
|
802
|
+
window: int,
|
|
803
|
+
step: int,
|
|
804
|
+
min_overlap: int,
|
|
805
|
+
return_fraction: bool,
|
|
806
|
+
layer: Optional[str],
|
|
807
|
+
) -> None:
|
|
808
|
+
"""
|
|
809
|
+
Assign rolling NN results computed on a subset back onto a parent AnnData.
|
|
810
|
+
|
|
811
|
+
Parameters
|
|
812
|
+
----------
|
|
813
|
+
parent_adata : AnnData
|
|
814
|
+
Parent AnnData that should store the combined results.
|
|
815
|
+
subset_adata : AnnData
|
|
816
|
+
Subset AnnData used to compute `values`.
|
|
817
|
+
values : np.ndarray
|
|
818
|
+
Rolling NN output with shape (n_subset_obs, n_windows).
|
|
819
|
+
starts : np.ndarray
|
|
820
|
+
Window start indices corresponding to `values`.
|
|
821
|
+
obsm_key : str
|
|
822
|
+
Key to store results under in parent_adata.obsm.
|
|
823
|
+
window : int
|
|
824
|
+
Rolling window size (stored in parent_adata.uns).
|
|
825
|
+
step : int
|
|
826
|
+
Rolling window step size (stored in parent_adata.uns).
|
|
827
|
+
min_overlap : int
|
|
828
|
+
Minimum overlap (stored in parent_adata.uns).
|
|
829
|
+
return_fraction : bool
|
|
830
|
+
Whether distances are fractional (stored in parent_adata.uns).
|
|
831
|
+
layer : str | None
|
|
832
|
+
Layer used for calculations (stored in parent_adata.uns).
|
|
833
|
+
"""
|
|
834
|
+
n_obs = parent_adata.n_obs
|
|
835
|
+
n_windows = values.shape[1]
|
|
836
|
+
|
|
837
|
+
if obsm_key not in parent_adata.obsm:
|
|
838
|
+
parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
|
|
839
|
+
parent_adata.uns[f"{obsm_key}_starts"] = starts
|
|
840
|
+
parent_adata.uns[f"{obsm_key}_centers"] = _window_center_coordinates(
|
|
841
|
+
subset_adata, starts, window
|
|
842
|
+
)
|
|
843
|
+
parent_adata.uns[f"{obsm_key}_window"] = int(window)
|
|
844
|
+
parent_adata.uns[f"{obsm_key}_step"] = int(step)
|
|
845
|
+
parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
|
|
846
|
+
parent_adata.uns[f"{obsm_key}_return_fraction"] = bool(return_fraction)
|
|
847
|
+
parent_adata.uns[f"{obsm_key}_layer"] = layer if layer is not None else "X"
|
|
848
|
+
else:
|
|
849
|
+
existing = parent_adata.obsm[obsm_key]
|
|
850
|
+
if existing.shape[1] != n_windows:
|
|
851
|
+
raise ValueError(
|
|
852
|
+
f"Existing obsm[{obsm_key!r}] has {existing.shape[1]} windows; "
|
|
853
|
+
f"new values have {n_windows} windows."
|
|
854
|
+
)
|
|
855
|
+
existing_starts = parent_adata.uns.get(f"{obsm_key}_starts")
|
|
856
|
+
if existing_starts is not None and not np.array_equal(existing_starts, starts):
|
|
857
|
+
raise ValueError(
|
|
858
|
+
f"Existing obsm[{obsm_key!r}] has different window starts than new values."
|
|
859
|
+
)
|
|
860
|
+
existing_centers = parent_adata.uns.get(f"{obsm_key}_centers")
|
|
861
|
+
if existing_centers is not None:
|
|
862
|
+
expected_centers = _window_center_coordinates(subset_adata, starts, window)
|
|
863
|
+
if not np.array_equal(existing_centers, expected_centers):
|
|
864
|
+
raise ValueError(
|
|
865
|
+
f"Existing obsm[{obsm_key!r}] has different window centers than new values."
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
|
|
869
|
+
if (parent_indexer < 0).any():
|
|
870
|
+
raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
|
|
871
|
+
|
|
872
|
+
parent_adata.obsm[obsm_key][parent_indexer, :] = values
|