smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,872 @@
1
+ from __future__ import annotations
2
+
3
+ from math import floor
4
+ from typing import TYPE_CHECKING, Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from smftools.logging_utils import get_logger
9
+
10
+ if TYPE_CHECKING:
11
+ import anndata as ad
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def zero_pairs_to_dataframe(adata, zero_pairs_uns_key: str) -> "pd.DataFrame":
17
+ """
18
+ Build a DataFrame of zero-Hamming pairs per window.
19
+
20
+ Args:
21
+ adata: AnnData containing zero-pair window data in ``adata.uns``.
22
+ zero_pairs_uns_key: Key for zero-pair window data in ``adata.uns``.
23
+
24
+ Returns:
25
+ DataFrame with one row per zero-Hamming pair per window.
26
+ """
27
+ import pandas as pd
28
+
29
+ if zero_pairs_uns_key not in adata.uns:
30
+ raise KeyError(f"Missing zero-pair data in adata.uns[{zero_pairs_uns_key!r}].")
31
+
32
+ zero_pairs_by_window = adata.uns[zero_pairs_uns_key]
33
+ starts = np.asarray(adata.uns.get(f"{zero_pairs_uns_key}_starts"))
34
+ window = int(adata.uns.get(f"{zero_pairs_uns_key}_window", 0))
35
+ if starts.size == 0 or window <= 0:
36
+ raise ValueError("Zero-pair metadata missing starts/window information.")
37
+
38
+ obs_names = np.asarray(adata.obs_names, dtype=object)
39
+ rows = []
40
+ for wi, pairs in enumerate(zero_pairs_by_window):
41
+ if pairs is None or len(pairs) == 0:
42
+ continue
43
+ start = int(starts[wi])
44
+ end = start + window
45
+ for read_i, read_j in pairs:
46
+ read_i = int(read_i)
47
+ read_j = int(read_j)
48
+ rows.append(
49
+ {
50
+ "window_index": wi,
51
+ "window_start": start,
52
+ "window_end": end,
53
+ "read_i": read_i,
54
+ "read_j": read_j,
55
+ "read_i_name": str(obs_names[read_i]),
56
+ "read_j_name": str(obs_names[read_j]),
57
+ }
58
+ )
59
+
60
+ return pd.DataFrame(
61
+ rows,
62
+ columns=[
63
+ "window_index",
64
+ "window_start",
65
+ "window_end",
66
+ "read_i",
67
+ "read_j",
68
+ "read_i_name",
69
+ "read_j_name",
70
+ ],
71
+ )
72
+
73
+
74
+ def zero_hamming_segments_to_dataframe(
75
+ records: list[dict],
76
+ var_names: np.ndarray,
77
+ ) -> "pd.DataFrame":
78
+ """
79
+ Build a DataFrame of merged/refined zero-Hamming segments.
80
+
81
+ Args:
82
+ records: Output records from ``annotate_zero_hamming_segments``.
83
+ var_names: AnnData var names for labeling segment coordinates.
84
+
85
+ Returns:
86
+ DataFrame with one row per zero-Hamming segment.
87
+ """
88
+ import pandas as pd
89
+
90
+ var_names = np.asarray(var_names, dtype=object)
91
+
92
+ def _label_at(idx: int) -> Optional[str]:
93
+ if 0 <= idx < var_names.size:
94
+ return str(var_names[idx])
95
+ return None
96
+
97
+ rows = []
98
+ for record in records:
99
+ read_i = int(record["read_i"])
100
+ read_j = int(record["read_j"])
101
+ read_i_name = record.get("read_i_name")
102
+ read_j_name = record.get("read_j_name")
103
+ for seg_start, seg_end in record.get("segments", []):
104
+ seg_start = int(seg_start)
105
+ seg_end = int(seg_end)
106
+ end_inclusive = max(seg_start, seg_end - 1)
107
+ rows.append(
108
+ {
109
+ "read_i": read_i,
110
+ "read_j": read_j,
111
+ "read_i_name": read_i_name,
112
+ "read_j_name": read_j_name,
113
+ "segment_start": seg_start,
114
+ "segment_end_exclusive": seg_end,
115
+ "segment_end_inclusive": end_inclusive,
116
+ "segment_start_label": _label_at(seg_start),
117
+ "segment_end_label": _label_at(end_inclusive),
118
+ }
119
+ )
120
+
121
+ return pd.DataFrame(
122
+ rows,
123
+ columns=[
124
+ "read_i",
125
+ "read_j",
126
+ "read_i_name",
127
+ "read_j_name",
128
+ "segment_start",
129
+ "segment_end_exclusive",
130
+ "segment_end_inclusive",
131
+ "segment_start_label",
132
+ "segment_end_label",
133
+ ],
134
+ )
135
+
136
+
137
+ def _window_center_coordinates(adata, starts: np.ndarray, window: int) -> np.ndarray:
138
+ """
139
+ Compute window center coordinates using AnnData var positions.
140
+
141
+ If coordinates are numeric, return the mean coordinate per window.
142
+ If not numeric, return the midpoint label for each window.
143
+ """
144
+ coord_source = adata.var_names
145
+
146
+ coords = np.asarray(coord_source)
147
+ if coords.size == 0:
148
+ return np.array([], dtype=float)
149
+
150
+ try:
151
+ coords_numeric = coords.astype(float)
152
+ return np.array(
153
+ [floor(np.nanmean(coords_numeric[s : s + window])) for s in starts], dtype=float
154
+ )
155
+ except Exception:
156
+ mid = np.clip(starts + (window // 2), 0, coords.size - 1)
157
+ return coords[mid]
158
+
159
+
160
+ def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
161
+ """
162
+ Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
163
+ Safe w.r.t. contiguity/layout.
164
+ """
165
+ B = np.asarray(B, dtype=np.uint8)
166
+ packed_u8 = np.packbits(B, axis=1) # (n, ceil(w/8)) uint8
167
+
168
+ n, nb = packed_u8.shape
169
+ pad = (-nb) % 8
170
+ if pad:
171
+ packed_u8 = np.pad(packed_u8, ((0, 0), (0, pad)), mode="constant", constant_values=0)
172
+
173
+ packed_u8 = np.ascontiguousarray(packed_u8)
174
+
175
+ # group 8 bytes -> uint64
176
+ packed_u64 = packed_u8.reshape(n, -1, 8).view(np.uint64).reshape(n, -1)
177
+ return packed_u64
178
+
179
+
180
+ def _popcount_u64_matrix(A_u64: np.ndarray) -> np.ndarray:
181
+ """
182
+ Popcount for an array of uint64, vectorized and portable across NumPy versions.
183
+
184
+ Returns an integer array with the SAME SHAPE as A_u64.
185
+ """
186
+ A_u64 = np.ascontiguousarray(A_u64)
187
+ # View as bytes; IMPORTANT: reshape to add a trailing byte axis of length 8
188
+ b = A_u64.view(np.uint8).reshape(A_u64.shape + (8,))
189
+ # unpack bits within that byte axis -> (..., 64), then sum
190
+ return np.unpackbits(b, axis=-1).sum(axis=-1)
191
+
192
+
193
+ def rolling_window_nn_distance(
194
+ adata,
195
+ layer: Optional[str] = None,
196
+ window: int = 15,
197
+ step: int = 2,
198
+ min_overlap: int = 10,
199
+ return_fraction: bool = True,
200
+ block_rows: int = 256,
201
+ block_cols: int = 2048,
202
+ store_obsm: Optional[str] = "rolling_nn_dist",
203
+ collect_zero_pairs: bool = False,
204
+ zero_pairs_uns_key: Optional[str] = None,
205
+ sample_labels: Optional[np.ndarray] = None,
206
+ ) -> Tuple[np.ndarray, np.ndarray]:
207
+ """
208
+ Rolling-window nearest-neighbor distance per read, overlap-aware.
209
+
210
+ Distance between reads i,j in a window:
211
+ - use only positions where BOTH are observed (non-NaN)
212
+ - require overlap >= min_overlap
213
+ - mismatch = count(x_i != x_j) over overlapped positions
214
+ - distance = mismatch/overlap (if return_fraction) else mismatch
215
+
216
+ Returns
217
+ -------
218
+ out : (n_obs, n_windows) float
219
+ Nearest-neighbor distance per read per window (NaN if no valid neighbor).
220
+ starts : (n_windows,) int
221
+ Window start indices in var-space.
222
+ centers : (n_windows,) array-like
223
+ Window center coordinates derived from AnnData var positions (stored in ``.uns``).
224
+ """
225
+ X = adata.layers[layer] if layer is not None else adata.X
226
+ X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
227
+
228
+ n, p = X.shape
229
+ if window > p:
230
+ raise ValueError(f"window={window} is larger than n_vars={p}")
231
+ if window <= 0:
232
+ raise ValueError("window must be > 0")
233
+ if step <= 0:
234
+ raise ValueError("step must be > 0")
235
+ if min_overlap <= 0:
236
+ raise ValueError("min_overlap must be > 0")
237
+
238
+ starts = np.arange(0, p - window + 1, step, dtype=int)
239
+ nW = len(starts)
240
+ out = np.full((n, nW), np.nan, dtype=float)
241
+
242
+ zero_pairs_by_window = [] if collect_zero_pairs else None
243
+
244
+ for wi, s in enumerate(starts):
245
+ wX = X[:, s : s + window] # (n, window)
246
+
247
+ # observed mask; values as 0/1 where observed, 0 elsewhere
248
+ M = ~np.isnan(wX)
249
+ V = np.where(M, wX, 0).astype(np.float32)
250
+
251
+ # ensure binary 0/1
252
+ V = (V > 0).astype(np.uint8)
253
+
254
+ M64 = _pack_bool_to_u64(M)
255
+ V64 = _pack_bool_to_u64(V.astype(bool))
256
+
257
+ best = np.full(n, np.inf, dtype=float)
258
+
259
+ window_pairs = [] if collect_zero_pairs else None
260
+
261
+ for i0 in range(0, n, block_rows):
262
+ i1 = min(n, i0 + block_rows)
263
+ Mi = M64[i0:i1] # (bi, nb)
264
+ Vi = V64[i0:i1]
265
+ bi = i1 - i0
266
+
267
+ local_best = np.full(bi, np.inf, dtype=float)
268
+
269
+ for j0 in range(0, n, block_cols):
270
+ j1 = min(n, j0 + block_cols)
271
+ Mj = M64[j0:j1] # (bj, nb)
272
+ Vj = V64[j0:j1]
273
+ bj = j1 - j0
274
+
275
+ overlap_counts = np.zeros((bi, bj), dtype=np.uint16)
276
+ mismatch_counts = np.zeros((bi, bj), dtype=np.uint16)
277
+
278
+ for k in range(Mi.shape[1]):
279
+ ob = (Mi[:, k][:, None] & Mj[:, k][None, :]).astype(np.uint64)
280
+ overlap_counts += _popcount_u64_matrix(ob).astype(np.uint16)
281
+
282
+ mb = ((Vi[:, k][:, None] ^ Vj[:, k][None, :]) & ob).astype(np.uint64)
283
+ mismatch_counts += _popcount_u64_matrix(mb).astype(np.uint16)
284
+
285
+ ok = overlap_counts >= min_overlap
286
+ if not np.any(ok):
287
+ continue
288
+
289
+ dist = np.full((bi, bj), np.inf, dtype=float)
290
+ if return_fraction:
291
+ dist[ok] = mismatch_counts[ok] / overlap_counts[ok]
292
+ else:
293
+ dist[ok] = mismatch_counts[ok].astype(float)
294
+
295
+ # exclude self comparisons (diagonal) when blocks overlap
296
+ if (i0 <= j1) and (j0 <= i1):
297
+ ii = np.arange(i0, i1)
298
+ jj = ii[(ii >= j0) & (ii < j1)]
299
+ if jj.size:
300
+ dist[(jj - i0), (jj - j0)] = np.inf
301
+
302
+ # exclude same-sample comparisons when sample_labels provided
303
+ if sample_labels is not None:
304
+ sl_i = sample_labels[i0:i1]
305
+ sl_j = sample_labels[j0:j1]
306
+ same_sample = sl_i[:, None] == sl_j[None, :]
307
+ dist[same_sample] = np.inf
308
+
309
+ local_best = np.minimum(local_best, dist.min(axis=1))
310
+
311
+ if collect_zero_pairs:
312
+ zero_mask = ok & (mismatch_counts == 0)
313
+ if np.any(zero_mask):
314
+ i_idx, j_idx = np.where(zero_mask)
315
+ gi = i0 + i_idx
316
+ gj = j0 + j_idx
317
+ keep = gi < gj
318
+ if sample_labels is not None:
319
+ keep = keep & (sample_labels[gi] != sample_labels[gj])
320
+ if np.any(keep):
321
+ window_pairs.append(np.stack([gi[keep], gj[keep]], axis=1))
322
+
323
+ best[i0:i1] = local_best
324
+
325
+ best[~np.isfinite(best)] = np.nan
326
+ out[:, wi] = best
327
+ if collect_zero_pairs:
328
+ if window_pairs:
329
+ zero_pairs_by_window.append(np.vstack(window_pairs))
330
+ else:
331
+ zero_pairs_by_window.append(np.empty((0, 2), dtype=int))
332
+
333
+ if store_obsm is not None:
334
+ adata.obsm[store_obsm] = out
335
+ adata.uns[f"{store_obsm}_starts"] = starts
336
+ adata.uns[f"{store_obsm}_centers"] = _window_center_coordinates(adata, starts, window)
337
+ adata.uns[f"{store_obsm}_window"] = int(window)
338
+ adata.uns[f"{store_obsm}_step"] = int(step)
339
+ adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
340
+ adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
341
+ adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
342
+ if collect_zero_pairs:
343
+ if zero_pairs_uns_key is None:
344
+ zero_pairs_uns_key = (
345
+ f"{store_obsm}_zero_pairs" if store_obsm is not None else "rolling_nn_zero_pairs"
346
+ )
347
+ adata.uns[zero_pairs_uns_key] = zero_pairs_by_window
348
+ adata.uns[f"{zero_pairs_uns_key}_starts"] = starts
349
+ adata.uns[f"{zero_pairs_uns_key}_window"] = int(window)
350
+ adata.uns[f"{zero_pairs_uns_key}_step"] = int(step)
351
+ adata.uns[f"{zero_pairs_uns_key}_min_overlap"] = int(min_overlap)
352
+ adata.uns[f"{zero_pairs_uns_key}_return_fraction"] = bool(return_fraction)
353
+ adata.uns[f"{zero_pairs_uns_key}_layer"] = layer if layer is not None else "X"
354
+
355
+ return out, starts
356
+
357
+
358
+ def annotate_zero_hamming_segments(
359
+ adata,
360
+ zero_pairs_uns_key: Optional[str] = None,
361
+ output_uns_key: str = "zero_hamming_segments",
362
+ layer: Optional[str] = None,
363
+ min_overlap: Optional[int] = None,
364
+ refine_segments: bool = True,
365
+ max_nan_run: Optional[int] = None,
366
+ merge_gap: int = 0,
367
+ max_segments_per_read: Optional[int] = None,
368
+ max_segment_overlap: Optional[int] = None,
369
+ ) -> list[dict]:
370
+ """
371
+ Merge zero-Hamming windows into maximal segments and annotate onto AnnData.
372
+
373
+ Args:
374
+ adata: AnnData containing zero-pair window data in ``.uns``.
375
+ zero_pairs_uns_key: Key for zero-pair window data in ``adata.uns``.
376
+ output_uns_key: Key to store merged/refined segments in ``adata.uns``.
377
+ layer: Layer to use for refinement (defaults to adata.X).
378
+ min_overlap: Minimum overlap required to keep a refined segment.
379
+ refine_segments: Whether to refine merged windows to maximal segments.
380
+ max_nan_run: Maximum consecutive NaN positions allowed when expanding segments.
381
+ If reached, expansion stops before the NaN run. Set to ``None`` to ignore NaNs.
382
+ merge_gap: Merge segments with gaps of at most this size (in positions).
383
+ max_segments_per_read: Maximum number of segments to retain per read pair.
384
+ max_segment_overlap: Maximum allowed overlap between retained segments (inclusive, in
385
+ var-index coordinates).
386
+
387
+ Returns:
388
+ List of segment records stored in ``adata.uns[output_uns_key]``.
389
+ """
390
+ if zero_pairs_uns_key is None:
391
+ candidate_keys = [key for key in adata.uns if key.endswith("_zero_pairs")]
392
+ if len(candidate_keys) == 1:
393
+ zero_pairs_uns_key = candidate_keys[0]
394
+ elif not candidate_keys:
395
+ raise KeyError("No zero-pair data found in adata.uns.")
396
+ else:
397
+ raise KeyError(
398
+ "Multiple zero-pair keys found in adata.uns; please specify zero_pairs_uns_key."
399
+ )
400
+
401
+ if zero_pairs_uns_key not in adata.uns:
402
+ raise KeyError(f"Missing zero-pair data in adata.uns[{zero_pairs_uns_key!r}].")
403
+
404
+ zero_pairs_by_window = adata.uns[zero_pairs_uns_key]
405
+ starts = np.asarray(adata.uns.get(f"{zero_pairs_uns_key}_starts"))
406
+ window = int(adata.uns.get(f"{zero_pairs_uns_key}_window", 0))
407
+ if starts.size == 0 or window <= 0:
408
+ raise ValueError("Zero-pair metadata missing starts/window information.")
409
+
410
+ if min_overlap is None:
411
+ min_overlap = int(adata.uns.get(f"{zero_pairs_uns_key}_min_overlap", 1))
412
+
413
+ X = adata.layers[layer] if layer is not None else adata.X
414
+ X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
415
+ observed = ~np.isnan(X)
416
+ values = (np.where(observed, X, 0.0) > 0).astype(np.uint8)
417
+ max_nan_run = None if max_nan_run is None else max(int(max_nan_run), 1)
418
+
419
+ pair_segments: dict[tuple[int, int], list[tuple[int, int]]] = {}
420
+ for wi, pairs in enumerate(zero_pairs_by_window):
421
+ if pairs is None or len(pairs) == 0:
422
+ continue
423
+ start = int(starts[wi])
424
+ end = start + window
425
+ for i, j in pairs:
426
+ key = (int(i), int(j))
427
+ pair_segments.setdefault(key, []).append((start, end))
428
+
429
+ merge_gap = max(int(merge_gap), 0)
430
+
431
+ def _merge_segments(segments: list[tuple[int, int]]) -> list[tuple[int, int]]:
432
+ if not segments:
433
+ return []
434
+ segments = sorted(segments, key=lambda seg: seg[0])
435
+ merged = [segments[0]]
436
+ for seg_start, seg_end in segments[1:]:
437
+ last_start, last_end = merged[-1]
438
+ if seg_start <= last_end + merge_gap:
439
+ merged[-1] = (last_start, max(last_end, seg_end))
440
+ else:
441
+ merged.append((seg_start, seg_end))
442
+ return merged
443
+
444
+ def _refine_segment(
445
+ read_i: int,
446
+ read_j: int,
447
+ start: int,
448
+ end: int,
449
+ ) -> Optional[tuple[int, int]]:
450
+ if not refine_segments:
451
+ return (start, end)
452
+ left = start
453
+ right = end
454
+ nan_run = 0
455
+ while left > 0:
456
+ idx = left - 1
457
+ both_observed = observed[read_i, idx] and observed[read_j, idx]
458
+ if both_observed:
459
+ nan_run = 0
460
+ if values[read_i, idx] != values[read_j, idx]:
461
+ break
462
+ else:
463
+ nan_run += 1
464
+ if max_nan_run is not None and nan_run >= max_nan_run:
465
+ break
466
+ left -= 1
467
+ n_vars = values.shape[1]
468
+ nan_run = 0
469
+ while right < n_vars:
470
+ idx = right
471
+ both_observed = observed[read_i, idx] and observed[read_j, idx]
472
+ if both_observed:
473
+ nan_run = 0
474
+ if values[read_i, idx] != values[read_j, idx]:
475
+ break
476
+ else:
477
+ nan_run += 1
478
+ if max_nan_run is not None and nan_run >= max_nan_run:
479
+ break
480
+ right += 1
481
+ overlap = np.sum(observed[read_i, left:right] & observed[read_j, left:right])
482
+ if overlap < min_overlap:
483
+ return None
484
+ return (left, right)
485
+
486
+ def _segment_length(segment: tuple[int, int]) -> int:
487
+ return int(segment[1]) - int(segment[0])
488
+
489
+ def _segment_overlap(first: tuple[int, int], second: tuple[int, int]) -> int:
490
+ return max(0, min(first[1], second[1]) - max(first[0], second[0]))
491
+
492
+ def _select_segments(segments: list[tuple[int, int]]) -> list[tuple[int, int]]:
493
+ if not segments:
494
+ return []
495
+ if max_segments_per_read is None and max_segment_overlap is None:
496
+ return segments
497
+ ordered = sorted(
498
+ segments,
499
+ key=lambda seg: (_segment_length(seg), -seg[0]),
500
+ reverse=True,
501
+ )
502
+ max_segments = len(ordered) if max_segments_per_read is None else max_segments_per_read
503
+ if max_segment_overlap is None:
504
+ return ordered[:max_segments]
505
+ selected: list[tuple[int, int]] = []
506
+ for segment in ordered:
507
+ if len(selected) >= max_segments:
508
+ break
509
+ if all(_segment_overlap(segment, other) <= max_segment_overlap for other in selected):
510
+ selected.append(segment)
511
+ return selected
512
+
513
+ records: list[dict] = []
514
+ obs_names = adata.obs_names
515
+ for (read_i, read_j), segments in pair_segments.items():
516
+ merged = _merge_segments(segments)
517
+ refined_segments = []
518
+ for seg_start, seg_end in merged:
519
+ refined = _refine_segment(read_i, read_j, seg_start, seg_end)
520
+ if refined is not None:
521
+ refined_segments.append(refined)
522
+ refined_segments = _select_segments(refined_segments)
523
+ if refined_segments:
524
+ records.append(
525
+ {
526
+ "read_i": read_i,
527
+ "read_j": read_j,
528
+ "read_i_name": str(obs_names[read_i]),
529
+ "read_j_name": str(obs_names[read_j]),
530
+ "segments": refined_segments,
531
+ }
532
+ )
533
+
534
+ adata.uns[output_uns_key] = records
535
+ return records
536
+
537
+
538
+ def assign_per_read_segments_layer(
539
+ parent_adata: "ad.AnnData",
540
+ subset_adata: "ad.AnnData",
541
+ per_read_segments: "pd.DataFrame",
542
+ layer_key: str,
543
+ ) -> None:
544
+ """
545
+ Assign per-read segments into a summed span layer on a parent AnnData.
546
+
547
+ Args:
548
+ parent_adata: AnnData that should receive the span layer.
549
+ subset_adata: AnnData used to compute per-read segments.
550
+ per_read_segments: DataFrame with ``read_id``, ``segment_start``, and
551
+ ``segment_end_exclusive`` columns. If ``segment_start_label`` and
552
+ ``segment_end_label`` are present and numeric, they are used to
553
+ map segments using label coordinates.
554
+ layer_key: Name of the layer to store in ``parent_adata.layers``.
555
+ """
556
+ import pandas as pd
557
+
558
+ if per_read_segments.empty:
559
+ parent_adata.layers[layer_key] = np.zeros(
560
+ (parent_adata.n_obs, parent_adata.n_vars), dtype=np.uint16
561
+ )
562
+ return
563
+ required_cols = {"read_id", "segment_start", "segment_end_exclusive"}
564
+ missing = required_cols.difference(per_read_segments.columns)
565
+ if missing:
566
+ raise KeyError(f"per_read_segments missing required columns: {sorted(missing)}")
567
+
568
+ target_layer = np.zeros((parent_adata.n_obs, parent_adata.n_vars), dtype=np.uint16)
569
+
570
+ parent_obs_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
571
+ if (parent_obs_indexer < 0).any():
572
+ raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
573
+ parent_var_indexer = parent_adata.var_names.get_indexer(subset_adata.var_names)
574
+ if (parent_var_indexer < 0).any():
575
+ raise ValueError("Subset AnnData contains vars not present in parent AnnData.")
576
+
577
+ label_indexer = None
578
+ label_columns = {"segment_start_label", "segment_end_label"}
579
+ if label_columns.issubset(per_read_segments.columns):
580
+ try:
581
+ parent_label_values = [int(label) for label in parent_adata.var_names]
582
+ label_indexer = {label: idx for idx, label in enumerate(parent_label_values)}
583
+ except (TypeError, ValueError):
584
+ label_indexer = None
585
+
586
+ def _label_to_index(value: object) -> Optional[int]:
587
+ if label_indexer is None or value is None or pd.isna(value):
588
+ return None
589
+ try:
590
+ return label_indexer.get(int(value))
591
+ except (TypeError, ValueError):
592
+ return None
593
+
594
+ for row in per_read_segments.itertuples(index=False):
595
+ read_id = int(row.read_id)
596
+ seg_start = int(row.segment_start)
597
+ seg_end = int(row.segment_end_exclusive)
598
+ if seg_start >= seg_end:
599
+ continue
600
+ target_read = parent_obs_indexer[read_id]
601
+ if target_read < 0:
602
+ raise ValueError("Segment read_id not found in parent AnnData.")
603
+
604
+ label_start = _label_to_index(getattr(row, "segment_start_label", None))
605
+ label_end = _label_to_index(getattr(row, "segment_end_label", None))
606
+ if label_start is not None and label_end is not None:
607
+ parent_start = min(label_start, label_end)
608
+ parent_end = max(label_start, label_end)
609
+ else:
610
+ parent_positions = parent_var_indexer[seg_start:seg_end]
611
+ if parent_positions.size == 0:
612
+ continue
613
+ parent_start = int(parent_positions.min())
614
+ parent_end = int(parent_positions.max())
615
+
616
+ target_layer[target_read, parent_start : parent_end + 1] += 1
617
+
618
+ parent_adata.layers[layer_key] = target_layer
619
+
620
+
621
+ def segments_to_per_read_dataframe(
622
+ records: list[dict],
623
+ var_names: np.ndarray,
624
+ ) -> "pd.DataFrame":
625
+ """
626
+ Build a per-read DataFrame of zero-Hamming segments.
627
+
628
+ Args:
629
+ records: Output records from ``annotate_zero_hamming_segments``.
630
+ var_names: AnnData var names for labeling segment coordinates.
631
+
632
+ Returns:
633
+ DataFrame with one row per segment per read.
634
+ """
635
+ import pandas as pd
636
+
637
+ var_names = np.asarray(var_names, dtype=object)
638
+
639
+ def _label_at(idx: int) -> Optional[str]:
640
+ if 0 <= idx < var_names.size:
641
+ return str(var_names[idx])
642
+ return None
643
+
644
+ rows = []
645
+ for record in records:
646
+ read_i = int(record["read_i"])
647
+ read_j = int(record["read_j"])
648
+ read_i_name = record.get("read_i_name")
649
+ read_j_name = record.get("read_j_name")
650
+ for seg_start, seg_end in record.get("segments", []):
651
+ seg_start = int(seg_start)
652
+ seg_end = int(seg_end)
653
+ end_inclusive = max(seg_start, seg_end - 1)
654
+ start_label = _label_at(seg_start)
655
+ end_label = _label_at(end_inclusive)
656
+ rows.append(
657
+ {
658
+ "read_id": read_i,
659
+ "partner_id": read_j,
660
+ "read_name": read_i_name,
661
+ "partner_name": read_j_name,
662
+ "segment_start": seg_start,
663
+ "segment_end_exclusive": seg_end,
664
+ "segment_end_inclusive": end_inclusive,
665
+ "segment_start_label": start_label,
666
+ "segment_end_label": end_label,
667
+ }
668
+ )
669
+ rows.append(
670
+ {
671
+ "read_id": read_j,
672
+ "partner_id": read_i,
673
+ "read_name": read_j_name,
674
+ "partner_name": read_i_name,
675
+ "segment_start": seg_start,
676
+ "segment_end_exclusive": seg_end,
677
+ "segment_end_inclusive": end_inclusive,
678
+ "segment_start_label": start_label,
679
+ "segment_end_label": end_label,
680
+ }
681
+ )
682
+
683
+ return pd.DataFrame(
684
+ rows,
685
+ columns=[
686
+ "read_id",
687
+ "partner_id",
688
+ "read_name",
689
+ "partner_name",
690
+ "segment_start",
691
+ "segment_end_exclusive",
692
+ "segment_end_inclusive",
693
+ "segment_start_label",
694
+ "segment_end_label",
695
+ ],
696
+ )
697
+
698
+
699
+ def select_top_segments_per_read(
700
+ records: list[dict],
701
+ var_names: np.ndarray,
702
+ max_segments_per_read: Optional[int] = None,
703
+ max_segment_overlap: Optional[int] = None,
704
+ min_span: Optional[float] = None,
705
+ ) -> tuple["pd.DataFrame", "pd.DataFrame"]:
706
+ """
707
+ Select top segments per read from distinct partner pairs.
708
+
709
+ Args:
710
+ records: Output records from ``annotate_zero_hamming_segments``.
711
+ var_names: AnnData var names for labeling segment coordinates.
712
+ max_segments_per_read: Maximum number of segments to keep per read.
713
+ max_segment_overlap: Maximum allowed overlap between kept segments.
714
+ min_span: Minimum span length to keep (var-name coordinate if numeric, else index span).
715
+
716
+ Returns:
717
+ Tuple of (raw per-read segments, filtered per-read segments).
718
+ """
719
+ import pandas as pd
720
+
721
+ raw_df = segments_to_per_read_dataframe(records, var_names)
722
+ if raw_df.empty:
723
+ raw_df = raw_df.copy()
724
+ raw_df["segment_length_index"] = pd.Series(dtype=int)
725
+ raw_df["segment_length_label"] = pd.Series(dtype=float)
726
+ return raw_df, raw_df.copy()
727
+
728
+ def _span_length(row) -> float:
729
+ try:
730
+ start = float(row["segment_start_label"])
731
+ end = float(row["segment_end_label"])
732
+ return abs(end - start)
733
+ except (TypeError, ValueError):
734
+ return float(row["segment_end_exclusive"] - row["segment_start"])
735
+
736
+ raw_df = raw_df.copy()
737
+ raw_df["segment_length_index"] = (
738
+ raw_df["segment_end_exclusive"] - raw_df["segment_start"]
739
+ ).astype(int)
740
+ raw_df["segment_length_label"] = raw_df.apply(_span_length, axis=1)
741
+ if min_span is not None:
742
+ raw_df = raw_df[raw_df["segment_length_label"] >= float(min_span)]
743
+
744
+ if raw_df.empty:
745
+ return raw_df, raw_df.copy()
746
+
747
+ def _segment_overlap(a, b) -> int:
748
+ return max(0, min(a[1], b[1]) - max(a[0], b[0]))
749
+
750
+ filtered_rows = []
751
+ max_segments = max_segments_per_read
752
+ for read_id, read_df in raw_df.groupby("read_id", sort=False):
753
+ per_partner = (
754
+ read_df.sort_values(
755
+ ["segment_length_label", "segment_start"],
756
+ ascending=[False, True],
757
+ )
758
+ .groupby("partner_id", sort=False)
759
+ .head(1)
760
+ )
761
+ ordered = per_partner.sort_values(
762
+ ["segment_length_label", "segment_start"],
763
+ ascending=[False, True],
764
+ ).itertuples(index=False)
765
+ selected = []
766
+ for row in ordered:
767
+ if max_segments is not None and len(selected) >= max_segments:
768
+ break
769
+ seg = (row.segment_start, row.segment_end_exclusive)
770
+ if max_segment_overlap is not None:
771
+ if any(
772
+ _segment_overlap(seg, (s.segment_start, s.segment_end_exclusive))
773
+ > max_segment_overlap
774
+ for s in selected
775
+ ):
776
+ continue
777
+ selected.append(row)
778
+ for row in selected:
779
+ filtered_rows.append(row._asdict())
780
+
781
+ filtered_df = pd.DataFrame(filtered_rows, columns=raw_df.columns)
782
+ if not filtered_df.empty:
783
+ filtered_df["selection_rank"] = (
784
+ filtered_df.sort_values(
785
+ ["read_id", "segment_length_label", "segment_start"],
786
+ ascending=[True, False, True],
787
+ )
788
+ .groupby("read_id", sort=False)
789
+ .cumcount()
790
+ + 1
791
+ )
792
+
793
+ return raw_df, filtered_df
794
+
795
+
796
+ def assign_rolling_nn_results(
797
+ parent_adata: "ad.AnnData",
798
+ subset_adata: "ad.AnnData",
799
+ values: np.ndarray,
800
+ starts: np.ndarray,
801
+ obsm_key: str,
802
+ window: int,
803
+ step: int,
804
+ min_overlap: int,
805
+ return_fraction: bool,
806
+ layer: Optional[str],
807
+ ) -> None:
808
+ """
809
+ Assign rolling NN results computed on a subset back onto a parent AnnData.
810
+
811
+ Parameters
812
+ ----------
813
+ parent_adata : AnnData
814
+ Parent AnnData that should store the combined results.
815
+ subset_adata : AnnData
816
+ Subset AnnData used to compute `values`.
817
+ values : np.ndarray
818
+ Rolling NN output with shape (n_subset_obs, n_windows).
819
+ starts : np.ndarray
820
+ Window start indices corresponding to `values`.
821
+ obsm_key : str
822
+ Key to store results under in parent_adata.obsm.
823
+ window : int
824
+ Rolling window size (stored in parent_adata.uns).
825
+ step : int
826
+ Rolling window step size (stored in parent_adata.uns).
827
+ min_overlap : int
828
+ Minimum overlap (stored in parent_adata.uns).
829
+ return_fraction : bool
830
+ Whether distances are fractional (stored in parent_adata.uns).
831
+ layer : str | None
832
+ Layer used for calculations (stored in parent_adata.uns).
833
+ """
834
+ n_obs = parent_adata.n_obs
835
+ n_windows = values.shape[1]
836
+
837
+ if obsm_key not in parent_adata.obsm:
838
+ parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
839
+ parent_adata.uns[f"{obsm_key}_starts"] = starts
840
+ parent_adata.uns[f"{obsm_key}_centers"] = _window_center_coordinates(
841
+ subset_adata, starts, window
842
+ )
843
+ parent_adata.uns[f"{obsm_key}_window"] = int(window)
844
+ parent_adata.uns[f"{obsm_key}_step"] = int(step)
845
+ parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
846
+ parent_adata.uns[f"{obsm_key}_return_fraction"] = bool(return_fraction)
847
+ parent_adata.uns[f"{obsm_key}_layer"] = layer if layer is not None else "X"
848
+ else:
849
+ existing = parent_adata.obsm[obsm_key]
850
+ if existing.shape[1] != n_windows:
851
+ raise ValueError(
852
+ f"Existing obsm[{obsm_key!r}] has {existing.shape[1]} windows; "
853
+ f"new values have {n_windows} windows."
854
+ )
855
+ existing_starts = parent_adata.uns.get(f"{obsm_key}_starts")
856
+ if existing_starts is not None and not np.array_equal(existing_starts, starts):
857
+ raise ValueError(
858
+ f"Existing obsm[{obsm_key!r}] has different window starts than new values."
859
+ )
860
+ existing_centers = parent_adata.uns.get(f"{obsm_key}_centers")
861
+ if existing_centers is not None:
862
+ expected_centers = _window_center_coordinates(subset_adata, starts, window)
863
+ if not np.array_equal(existing_centers, expected_centers):
864
+ raise ValueError(
865
+ f"Existing obsm[{obsm_key!r}] has different window centers than new values."
866
+ )
867
+
868
+ parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
869
+ if (parent_indexer < 0).any():
870
+ raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
871
+
872
+ parent_adata.obsm[obsm_key][parent_indexer, :] = values