smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +18 -2
  4. smftools/cli/hmm_adata.py +18 -1
  5. smftools/cli/latent_adata.py +522 -67
  6. smftools/cli/load_adata.py +2 -2
  7. smftools/cli/preprocess_adata.py +32 -93
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +23 -109
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +41 -5
  12. smftools/config/conversion.yaml +0 -10
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +49 -13
  15. smftools/config/experiment_config.py +96 -3
  16. smftools/constants.py +4 -0
  17. smftools/hmm/call_hmm_peaks.py +1 -1
  18. smftools/informatics/binarize_converted_base_identities.py +2 -89
  19. smftools/informatics/converted_BAM_to_adata.py +53 -13
  20. smftools/informatics/h5ad_functions.py +83 -0
  21. smftools/informatics/modkit_extract_to_adata.py +4 -0
  22. smftools/plotting/__init__.py +26 -12
  23. smftools/plotting/autocorrelation_plotting.py +22 -4
  24. smftools/plotting/chimeric_plotting.py +1893 -0
  25. smftools/plotting/classifiers.py +28 -14
  26. smftools/plotting/general_plotting.py +58 -3362
  27. smftools/plotting/hmm_plotting.py +1586 -2
  28. smftools/plotting/latent_plotting.py +804 -0
  29. smftools/plotting/plotting_utils.py +243 -0
  30. smftools/plotting/position_stats.py +16 -8
  31. smftools/plotting/preprocess_plotting.py +281 -0
  32. smftools/plotting/qc_plotting.py +8 -3
  33. smftools/plotting/spatial_plotting.py +1134 -0
  34. smftools/plotting/variant_plotting.py +1231 -0
  35. smftools/preprocessing/__init__.py +3 -0
  36. smftools/preprocessing/append_base_context.py +1 -1
  37. smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
  38. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  39. smftools/preprocessing/append_variant_call_layer.py +480 -0
  40. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  41. smftools/preprocessing/invert_adata.py +1 -0
  42. smftools/readwrite.py +109 -85
  43. smftools/tools/__init__.py +6 -0
  44. smftools/tools/calculate_knn.py +121 -0
  45. smftools/tools/calculate_nmf.py +18 -7
  46. smftools/tools/calculate_pca.py +180 -0
  47. smftools/tools/calculate_umap.py +70 -154
  48. smftools/tools/position_stats.py +4 -4
  49. smftools/tools/rolling_nn_distance.py +640 -3
  50. smftools/tools/sequence_alignment.py +140 -0
  51. smftools/tools/tensor_factorization.py +52 -4
  52. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
  53. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
  54. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  55. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  56. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- import ast
4
- import json
5
- from typing import TYPE_CHECKING, Optional, Sequence, Tuple
3
+ from math import floor
4
+ from typing import TYPE_CHECKING, Optional, Tuple
6
5
 
7
6
  import numpy as np
8
7
 
@@ -14,6 +13,150 @@ if TYPE_CHECKING:
14
13
  logger = get_logger(__name__)
15
14
 
16
15
 
16
+ def zero_pairs_to_dataframe(adata, zero_pairs_uns_key: str) -> "pd.DataFrame":
17
+ """
18
+ Build a DataFrame of zero-Hamming pairs per window.
19
+
20
+ Args:
21
+ adata: AnnData containing zero-pair window data in ``adata.uns``.
22
+ zero_pairs_uns_key: Key for zero-pair window data in ``adata.uns``.
23
+
24
+ Returns:
25
+ DataFrame with one row per zero-Hamming pair per window.
26
+ """
27
+ import pandas as pd
28
+
29
+ if zero_pairs_uns_key not in adata.uns:
30
+ raise KeyError(f"Missing zero-pair data in adata.uns[{zero_pairs_uns_key!r}].")
31
+
32
+ zero_pairs_by_window = adata.uns[zero_pairs_uns_key]
33
+ starts = np.asarray(adata.uns.get(f"{zero_pairs_uns_key}_starts"))
34
+ window = int(adata.uns.get(f"{zero_pairs_uns_key}_window", 0))
35
+ if starts.size == 0 or window <= 0:
36
+ raise ValueError("Zero-pair metadata missing starts/window information.")
37
+
38
+ obs_names = np.asarray(adata.obs_names, dtype=object)
39
+ rows = []
40
+ for wi, pairs in enumerate(zero_pairs_by_window):
41
+ if pairs is None or len(pairs) == 0:
42
+ continue
43
+ start = int(starts[wi])
44
+ end = start + window
45
+ for read_i, read_j in pairs:
46
+ read_i = int(read_i)
47
+ read_j = int(read_j)
48
+ rows.append(
49
+ {
50
+ "window_index": wi,
51
+ "window_start": start,
52
+ "window_end": end,
53
+ "read_i": read_i,
54
+ "read_j": read_j,
55
+ "read_i_name": str(obs_names[read_i]),
56
+ "read_j_name": str(obs_names[read_j]),
57
+ }
58
+ )
59
+
60
+ return pd.DataFrame(
61
+ rows,
62
+ columns=[
63
+ "window_index",
64
+ "window_start",
65
+ "window_end",
66
+ "read_i",
67
+ "read_j",
68
+ "read_i_name",
69
+ "read_j_name",
70
+ ],
71
+ )
72
+
73
+
74
+ def zero_hamming_segments_to_dataframe(
75
+ records: list[dict],
76
+ var_names: np.ndarray,
77
+ ) -> "pd.DataFrame":
78
+ """
79
+ Build a DataFrame of merged/refined zero-Hamming segments.
80
+
81
+ Args:
82
+ records: Output records from ``annotate_zero_hamming_segments``.
83
+ var_names: AnnData var names for labeling segment coordinates.
84
+
85
+ Returns:
86
+ DataFrame with one row per zero-Hamming segment.
87
+ """
88
+ import pandas as pd
89
+
90
+ var_names = np.asarray(var_names, dtype=object)
91
+
92
+ def _label_at(idx: int) -> Optional[str]:
93
+ if 0 <= idx < var_names.size:
94
+ return str(var_names[idx])
95
+ return None
96
+
97
+ rows = []
98
+ for record in records:
99
+ read_i = int(record["read_i"])
100
+ read_j = int(record["read_j"])
101
+ read_i_name = record.get("read_i_name")
102
+ read_j_name = record.get("read_j_name")
103
+ for seg_start, seg_end in record.get("segments", []):
104
+ seg_start = int(seg_start)
105
+ seg_end = int(seg_end)
106
+ end_inclusive = max(seg_start, seg_end - 1)
107
+ rows.append(
108
+ {
109
+ "read_i": read_i,
110
+ "read_j": read_j,
111
+ "read_i_name": read_i_name,
112
+ "read_j_name": read_j_name,
113
+ "segment_start": seg_start,
114
+ "segment_end_exclusive": seg_end,
115
+ "segment_end_inclusive": end_inclusive,
116
+ "segment_start_label": _label_at(seg_start),
117
+ "segment_end_label": _label_at(end_inclusive),
118
+ }
119
+ )
120
+
121
+ return pd.DataFrame(
122
+ rows,
123
+ columns=[
124
+ "read_i",
125
+ "read_j",
126
+ "read_i_name",
127
+ "read_j_name",
128
+ "segment_start",
129
+ "segment_end_exclusive",
130
+ "segment_end_inclusive",
131
+ "segment_start_label",
132
+ "segment_end_label",
133
+ ],
134
+ )
135
+
136
+
137
+ def _window_center_coordinates(adata, starts: np.ndarray, window: int) -> np.ndarray:
138
+ """
139
+ Compute window center coordinates using AnnData var positions.
140
+
141
+ If coordinates are numeric, return the mean coordinate per window.
142
+ If not numeric, return the midpoint label for each window.
143
+ """
144
+ coord_source = adata.var_names
145
+
146
+ coords = np.asarray(coord_source)
147
+ if coords.size == 0:
148
+ return np.array([], dtype=float)
149
+
150
+ try:
151
+ coords_numeric = coords.astype(float)
152
+ return np.array(
153
+ [floor(np.nanmean(coords_numeric[s : s + window])) for s in starts], dtype=float
154
+ )
155
+ except Exception:
156
+ mid = np.clip(starts + (window // 2), 0, coords.size - 1)
157
+ return coords[mid]
158
+
159
+
17
160
  def _pack_bool_to_u64(B: np.ndarray) -> np.ndarray:
18
161
  """
19
162
  Pack a boolean (or 0/1) matrix (n, w) into uint64 blocks (n, ceil(w/64)).
@@ -57,6 +200,9 @@ def rolling_window_nn_distance(
57
200
  block_rows: int = 256,
58
201
  block_cols: int = 2048,
59
202
  store_obsm: Optional[str] = "rolling_nn_dist",
203
+ collect_zero_pairs: bool = False,
204
+ zero_pairs_uns_key: Optional[str] = None,
205
+ sample_labels: Optional[np.ndarray] = None,
60
206
  ) -> Tuple[np.ndarray, np.ndarray]:
61
207
  """
62
208
  Rolling-window nearest-neighbor distance per read, overlap-aware.
@@ -73,6 +219,8 @@ def rolling_window_nn_distance(
73
219
  Nearest-neighbor distance per read per window (NaN if no valid neighbor).
74
220
  starts : (n_windows,) int
75
221
  Window start indices in var-space.
222
+ centers : (n_windows,) array-like
223
+ Window center coordinates derived from AnnData var positions (stored in ``.uns``).
76
224
  """
77
225
  X = adata.layers[layer] if layer is not None else adata.X
78
226
  X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
@@ -91,6 +239,8 @@ def rolling_window_nn_distance(
91
239
  nW = len(starts)
92
240
  out = np.full((n, nW), np.nan, dtype=float)
93
241
 
242
+ zero_pairs_by_window = [] if collect_zero_pairs else None
243
+
94
244
  for wi, s in enumerate(starts):
95
245
  wX = X[:, s : s + window] # (n, window)
96
246
 
@@ -106,6 +256,8 @@ def rolling_window_nn_distance(
106
256
 
107
257
  best = np.full(n, np.inf, dtype=float)
108
258
 
259
+ window_pairs = [] if collect_zero_pairs else None
260
+
109
261
  for i0 in range(0, n, block_rows):
110
262
  i1 = min(n, i0 + block_rows)
111
263
  Mi = M64[i0:i1] # (bi, nb)
@@ -147,25 +299,500 @@ def rolling_window_nn_distance(
147
299
  if jj.size:
148
300
  dist[(jj - i0), (jj - j0)] = np.inf
149
301
 
302
+ # exclude same-sample comparisons when sample_labels provided
303
+ if sample_labels is not None:
304
+ sl_i = sample_labels[i0:i1]
305
+ sl_j = sample_labels[j0:j1]
306
+ same_sample = sl_i[:, None] == sl_j[None, :]
307
+ dist[same_sample] = np.inf
308
+
150
309
  local_best = np.minimum(local_best, dist.min(axis=1))
151
310
 
311
+ if collect_zero_pairs:
312
+ zero_mask = ok & (mismatch_counts == 0)
313
+ if np.any(zero_mask):
314
+ i_idx, j_idx = np.where(zero_mask)
315
+ gi = i0 + i_idx
316
+ gj = j0 + j_idx
317
+ keep = gi < gj
318
+ if sample_labels is not None:
319
+ keep = keep & (sample_labels[gi] != sample_labels[gj])
320
+ if np.any(keep):
321
+ window_pairs.append(np.stack([gi[keep], gj[keep]], axis=1))
322
+
152
323
  best[i0:i1] = local_best
153
324
 
154
325
  best[~np.isfinite(best)] = np.nan
155
326
  out[:, wi] = best
327
+ if collect_zero_pairs:
328
+ if window_pairs:
329
+ zero_pairs_by_window.append(np.vstack(window_pairs))
330
+ else:
331
+ zero_pairs_by_window.append(np.empty((0, 2), dtype=int))
156
332
 
157
333
  if store_obsm is not None:
158
334
  adata.obsm[store_obsm] = out
159
335
  adata.uns[f"{store_obsm}_starts"] = starts
336
+ adata.uns[f"{store_obsm}_centers"] = _window_center_coordinates(adata, starts, window)
160
337
  adata.uns[f"{store_obsm}_window"] = int(window)
161
338
  adata.uns[f"{store_obsm}_step"] = int(step)
162
339
  adata.uns[f"{store_obsm}_min_overlap"] = int(min_overlap)
163
340
  adata.uns[f"{store_obsm}_return_fraction"] = bool(return_fraction)
164
341
  adata.uns[f"{store_obsm}_layer"] = layer if layer is not None else "X"
342
+ if collect_zero_pairs:
343
+ if zero_pairs_uns_key is None:
344
+ zero_pairs_uns_key = (
345
+ f"{store_obsm}_zero_pairs" if store_obsm is not None else "rolling_nn_zero_pairs"
346
+ )
347
+ adata.uns[zero_pairs_uns_key] = zero_pairs_by_window
348
+ adata.uns[f"{zero_pairs_uns_key}_starts"] = starts
349
+ adata.uns[f"{zero_pairs_uns_key}_window"] = int(window)
350
+ adata.uns[f"{zero_pairs_uns_key}_step"] = int(step)
351
+ adata.uns[f"{zero_pairs_uns_key}_min_overlap"] = int(min_overlap)
352
+ adata.uns[f"{zero_pairs_uns_key}_return_fraction"] = bool(return_fraction)
353
+ adata.uns[f"{zero_pairs_uns_key}_layer"] = layer if layer is not None else "X"
165
354
 
166
355
  return out, starts
167
356
 
168
357
 
358
+ def annotate_zero_hamming_segments(
359
+ adata,
360
+ zero_pairs_uns_key: Optional[str] = None,
361
+ output_uns_key: str = "zero_hamming_segments",
362
+ layer: Optional[str] = None,
363
+ min_overlap: Optional[int] = None,
364
+ refine_segments: bool = True,
365
+ max_nan_run: Optional[int] = None,
366
+ merge_gap: int = 0,
367
+ max_segments_per_read: Optional[int] = None,
368
+ max_segment_overlap: Optional[int] = None,
369
+ ) -> list[dict]:
370
+ """
371
+ Merge zero-Hamming windows into maximal segments and annotate onto AnnData.
372
+
373
+ Args:
374
+ adata: AnnData containing zero-pair window data in ``.uns``.
375
+ zero_pairs_uns_key: Key for zero-pair window data in ``adata.uns``.
376
+ output_uns_key: Key to store merged/refined segments in ``adata.uns``.
377
+ layer: Layer to use for refinement (defaults to adata.X).
378
+ min_overlap: Minimum overlap required to keep a refined segment.
379
+ refine_segments: Whether to refine merged windows to maximal segments.
380
+ max_nan_run: Maximum consecutive NaN positions allowed when expanding segments.
381
+ If reached, expansion stops before the NaN run. Set to ``None`` to ignore NaNs.
382
+ merge_gap: Merge segments with gaps of at most this size (in positions).
383
+ max_segments_per_read: Maximum number of segments to retain per read pair.
384
+ max_segment_overlap: Maximum allowed overlap between retained segments (inclusive, in
385
+ var-index coordinates).
386
+
387
+ Returns:
388
+ List of segment records stored in ``adata.uns[output_uns_key]``.
389
+ """
390
+ if zero_pairs_uns_key is None:
391
+ candidate_keys = [key for key in adata.uns if key.endswith("_zero_pairs")]
392
+ if len(candidate_keys) == 1:
393
+ zero_pairs_uns_key = candidate_keys[0]
394
+ elif not candidate_keys:
395
+ raise KeyError("No zero-pair data found in adata.uns.")
396
+ else:
397
+ raise KeyError(
398
+ "Multiple zero-pair keys found in adata.uns; please specify zero_pairs_uns_key."
399
+ )
400
+
401
+ if zero_pairs_uns_key not in adata.uns:
402
+ raise KeyError(f"Missing zero-pair data in adata.uns[{zero_pairs_uns_key!r}].")
403
+
404
+ zero_pairs_by_window = adata.uns[zero_pairs_uns_key]
405
+ starts = np.asarray(adata.uns.get(f"{zero_pairs_uns_key}_starts"))
406
+ window = int(adata.uns.get(f"{zero_pairs_uns_key}_window", 0))
407
+ if starts.size == 0 or window <= 0:
408
+ raise ValueError("Zero-pair metadata missing starts/window information.")
409
+
410
+ if min_overlap is None:
411
+ min_overlap = int(adata.uns.get(f"{zero_pairs_uns_key}_min_overlap", 1))
412
+
413
+ X = adata.layers[layer] if layer is not None else adata.X
414
+ X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
415
+ observed = ~np.isnan(X)
416
+ values = (np.where(observed, X, 0.0) > 0).astype(np.uint8)
417
+ max_nan_run = None if max_nan_run is None else max(int(max_nan_run), 1)
418
+
419
+ pair_segments: dict[tuple[int, int], list[tuple[int, int]]] = {}
420
+ for wi, pairs in enumerate(zero_pairs_by_window):
421
+ if pairs is None or len(pairs) == 0:
422
+ continue
423
+ start = int(starts[wi])
424
+ end = start + window
425
+ for i, j in pairs:
426
+ key = (int(i), int(j))
427
+ pair_segments.setdefault(key, []).append((start, end))
428
+
429
+ merge_gap = max(int(merge_gap), 0)
430
+
431
+ def _merge_segments(segments: list[tuple[int, int]]) -> list[tuple[int, int]]:
432
+ if not segments:
433
+ return []
434
+ segments = sorted(segments, key=lambda seg: seg[0])
435
+ merged = [segments[0]]
436
+ for seg_start, seg_end in segments[1:]:
437
+ last_start, last_end = merged[-1]
438
+ if seg_start <= last_end + merge_gap:
439
+ merged[-1] = (last_start, max(last_end, seg_end))
440
+ else:
441
+ merged.append((seg_start, seg_end))
442
+ return merged
443
+
444
+ def _refine_segment(
445
+ read_i: int,
446
+ read_j: int,
447
+ start: int,
448
+ end: int,
449
+ ) -> Optional[tuple[int, int]]:
450
+ if not refine_segments:
451
+ return (start, end)
452
+ left = start
453
+ right = end
454
+ nan_run = 0
455
+ while left > 0:
456
+ idx = left - 1
457
+ both_observed = observed[read_i, idx] and observed[read_j, idx]
458
+ if both_observed:
459
+ nan_run = 0
460
+ if values[read_i, idx] != values[read_j, idx]:
461
+ break
462
+ else:
463
+ nan_run += 1
464
+ if max_nan_run is not None and nan_run >= max_nan_run:
465
+ break
466
+ left -= 1
467
+ n_vars = values.shape[1]
468
+ nan_run = 0
469
+ while right < n_vars:
470
+ idx = right
471
+ both_observed = observed[read_i, idx] and observed[read_j, idx]
472
+ if both_observed:
473
+ nan_run = 0
474
+ if values[read_i, idx] != values[read_j, idx]:
475
+ break
476
+ else:
477
+ nan_run += 1
478
+ if max_nan_run is not None and nan_run >= max_nan_run:
479
+ break
480
+ right += 1
481
+ overlap = np.sum(observed[read_i, left:right] & observed[read_j, left:right])
482
+ if overlap < min_overlap:
483
+ return None
484
+ return (left, right)
485
+
486
+ def _segment_length(segment: tuple[int, int]) -> int:
487
+ return int(segment[1]) - int(segment[0])
488
+
489
+ def _segment_overlap(first: tuple[int, int], second: tuple[int, int]) -> int:
490
+ return max(0, min(first[1], second[1]) - max(first[0], second[0]))
491
+
492
+ def _select_segments(segments: list[tuple[int, int]]) -> list[tuple[int, int]]:
493
+ if not segments:
494
+ return []
495
+ if max_segments_per_read is None and max_segment_overlap is None:
496
+ return segments
497
+ ordered = sorted(
498
+ segments,
499
+ key=lambda seg: (_segment_length(seg), -seg[0]),
500
+ reverse=True,
501
+ )
502
+ max_segments = len(ordered) if max_segments_per_read is None else max_segments_per_read
503
+ if max_segment_overlap is None:
504
+ return ordered[:max_segments]
505
+ selected: list[tuple[int, int]] = []
506
+ for segment in ordered:
507
+ if len(selected) >= max_segments:
508
+ break
509
+ if all(_segment_overlap(segment, other) <= max_segment_overlap for other in selected):
510
+ selected.append(segment)
511
+ return selected
512
+
513
+ records: list[dict] = []
514
+ obs_names = adata.obs_names
515
+ for (read_i, read_j), segments in pair_segments.items():
516
+ merged = _merge_segments(segments)
517
+ refined_segments = []
518
+ for seg_start, seg_end in merged:
519
+ refined = _refine_segment(read_i, read_j, seg_start, seg_end)
520
+ if refined is not None:
521
+ refined_segments.append(refined)
522
+ refined_segments = _select_segments(refined_segments)
523
+ if refined_segments:
524
+ records.append(
525
+ {
526
+ "read_i": read_i,
527
+ "read_j": read_j,
528
+ "read_i_name": str(obs_names[read_i]),
529
+ "read_j_name": str(obs_names[read_j]),
530
+ "segments": refined_segments,
531
+ }
532
+ )
533
+
534
+ adata.uns[output_uns_key] = records
535
+ return records
536
+
537
+
538
+ def assign_per_read_segments_layer(
539
+ parent_adata: "ad.AnnData",
540
+ subset_adata: "ad.AnnData",
541
+ per_read_segments: "pd.DataFrame",
542
+ layer_key: str,
543
+ ) -> None:
544
+ """
545
+ Assign per-read segments into a summed span layer on a parent AnnData.
546
+
547
+ Args:
548
+ parent_adata: AnnData that should receive the span layer.
549
+ subset_adata: AnnData used to compute per-read segments.
550
+ per_read_segments: DataFrame with ``read_id``, ``segment_start``, and
551
+ ``segment_end_exclusive`` columns. If ``segment_start_label`` and
552
+ ``segment_end_label`` are present and numeric, they are used to
553
+ map segments using label coordinates.
554
+ layer_key: Name of the layer to store in ``parent_adata.layers``.
555
+ """
556
+ import pandas as pd
557
+
558
+ if per_read_segments.empty:
559
+ parent_adata.layers[layer_key] = np.zeros(
560
+ (parent_adata.n_obs, parent_adata.n_vars), dtype=np.uint16
561
+ )
562
+ return
563
+ required_cols = {"read_id", "segment_start", "segment_end_exclusive"}
564
+ missing = required_cols.difference(per_read_segments.columns)
565
+ if missing:
566
+ raise KeyError(f"per_read_segments missing required columns: {sorted(missing)}")
567
+
568
+ target_layer = np.zeros((parent_adata.n_obs, parent_adata.n_vars), dtype=np.uint16)
569
+
570
+ parent_obs_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
571
+ if (parent_obs_indexer < 0).any():
572
+ raise ValueError("Subset AnnData contains obs not present in parent AnnData.")
573
+ parent_var_indexer = parent_adata.var_names.get_indexer(subset_adata.var_names)
574
+ if (parent_var_indexer < 0).any():
575
+ raise ValueError("Subset AnnData contains vars not present in parent AnnData.")
576
+
577
+ label_indexer = None
578
+ label_columns = {"segment_start_label", "segment_end_label"}
579
+ if label_columns.issubset(per_read_segments.columns):
580
+ try:
581
+ parent_label_values = [int(label) for label in parent_adata.var_names]
582
+ label_indexer = {label: idx for idx, label in enumerate(parent_label_values)}
583
+ except (TypeError, ValueError):
584
+ label_indexer = None
585
+
586
+ def _label_to_index(value: object) -> Optional[int]:
587
+ if label_indexer is None or value is None or pd.isna(value):
588
+ return None
589
+ try:
590
+ return label_indexer.get(int(value))
591
+ except (TypeError, ValueError):
592
+ return None
593
+
594
+ for row in per_read_segments.itertuples(index=False):
595
+ read_id = int(row.read_id)
596
+ seg_start = int(row.segment_start)
597
+ seg_end = int(row.segment_end_exclusive)
598
+ if seg_start >= seg_end:
599
+ continue
600
+ target_read = parent_obs_indexer[read_id]
601
+ if target_read < 0:
602
+ raise ValueError("Segment read_id not found in parent AnnData.")
603
+
604
+ label_start = _label_to_index(getattr(row, "segment_start_label", None))
605
+ label_end = _label_to_index(getattr(row, "segment_end_label", None))
606
+ if label_start is not None and label_end is not None:
607
+ parent_start = min(label_start, label_end)
608
+ parent_end = max(label_start, label_end)
609
+ else:
610
+ parent_positions = parent_var_indexer[seg_start:seg_end]
611
+ if parent_positions.size == 0:
612
+ continue
613
+ parent_start = int(parent_positions.min())
614
+ parent_end = int(parent_positions.max())
615
+
616
+ target_layer[target_read, parent_start : parent_end + 1] += 1
617
+
618
+ parent_adata.layers[layer_key] = target_layer
619
+
620
+
621
+ def segments_to_per_read_dataframe(
622
+ records: list[dict],
623
+ var_names: np.ndarray,
624
+ ) -> "pd.DataFrame":
625
+ """
626
+ Build a per-read DataFrame of zero-Hamming segments.
627
+
628
+ Args:
629
+ records: Output records from ``annotate_zero_hamming_segments``.
630
+ var_names: AnnData var names for labeling segment coordinates.
631
+
632
+ Returns:
633
+ DataFrame with one row per segment per read.
634
+ """
635
+ import pandas as pd
636
+
637
+ var_names = np.asarray(var_names, dtype=object)
638
+
639
+ def _label_at(idx: int) -> Optional[str]:
640
+ if 0 <= idx < var_names.size:
641
+ return str(var_names[idx])
642
+ return None
643
+
644
+ rows = []
645
+ for record in records:
646
+ read_i = int(record["read_i"])
647
+ read_j = int(record["read_j"])
648
+ read_i_name = record.get("read_i_name")
649
+ read_j_name = record.get("read_j_name")
650
+ for seg_start, seg_end in record.get("segments", []):
651
+ seg_start = int(seg_start)
652
+ seg_end = int(seg_end)
653
+ end_inclusive = max(seg_start, seg_end - 1)
654
+ start_label = _label_at(seg_start)
655
+ end_label = _label_at(end_inclusive)
656
+ rows.append(
657
+ {
658
+ "read_id": read_i,
659
+ "partner_id": read_j,
660
+ "read_name": read_i_name,
661
+ "partner_name": read_j_name,
662
+ "segment_start": seg_start,
663
+ "segment_end_exclusive": seg_end,
664
+ "segment_end_inclusive": end_inclusive,
665
+ "segment_start_label": start_label,
666
+ "segment_end_label": end_label,
667
+ }
668
+ )
669
+ rows.append(
670
+ {
671
+ "read_id": read_j,
672
+ "partner_id": read_i,
673
+ "read_name": read_j_name,
674
+ "partner_name": read_i_name,
675
+ "segment_start": seg_start,
676
+ "segment_end_exclusive": seg_end,
677
+ "segment_end_inclusive": end_inclusive,
678
+ "segment_start_label": start_label,
679
+ "segment_end_label": end_label,
680
+ }
681
+ )
682
+
683
+ return pd.DataFrame(
684
+ rows,
685
+ columns=[
686
+ "read_id",
687
+ "partner_id",
688
+ "read_name",
689
+ "partner_name",
690
+ "segment_start",
691
+ "segment_end_exclusive",
692
+ "segment_end_inclusive",
693
+ "segment_start_label",
694
+ "segment_end_label",
695
+ ],
696
+ )
697
+
698
+
699
+ def select_top_segments_per_read(
700
+ records: list[dict],
701
+ var_names: np.ndarray,
702
+ max_segments_per_read: Optional[int] = None,
703
+ max_segment_overlap: Optional[int] = None,
704
+ min_span: Optional[float] = None,
705
+ ) -> tuple["pd.DataFrame", "pd.DataFrame"]:
706
+ """
707
+ Select top segments per read from distinct partner pairs.
708
+
709
+ Args:
710
+ records: Output records from ``annotate_zero_hamming_segments``.
711
+ var_names: AnnData var names for labeling segment coordinates.
712
+ max_segments_per_read: Maximum number of segments to keep per read.
713
+ max_segment_overlap: Maximum allowed overlap between kept segments.
714
+ min_span: Minimum span length to keep (var-name coordinate if numeric, else index span).
715
+
716
+ Returns:
717
+ Tuple of (raw per-read segments, filtered per-read segments).
718
+ """
719
+ import pandas as pd
720
+
721
+ raw_df = segments_to_per_read_dataframe(records, var_names)
722
+ if raw_df.empty:
723
+ raw_df = raw_df.copy()
724
+ raw_df["segment_length_index"] = pd.Series(dtype=int)
725
+ raw_df["segment_length_label"] = pd.Series(dtype=float)
726
+ return raw_df, raw_df.copy()
727
+
728
+ def _span_length(row) -> float:
729
+ try:
730
+ start = float(row["segment_start_label"])
731
+ end = float(row["segment_end_label"])
732
+ return abs(end - start)
733
+ except (TypeError, ValueError):
734
+ return float(row["segment_end_exclusive"] - row["segment_start"])
735
+
736
+ raw_df = raw_df.copy()
737
+ raw_df["segment_length_index"] = (
738
+ raw_df["segment_end_exclusive"] - raw_df["segment_start"]
739
+ ).astype(int)
740
+ raw_df["segment_length_label"] = raw_df.apply(_span_length, axis=1)
741
+ if min_span is not None:
742
+ raw_df = raw_df[raw_df["segment_length_label"] >= float(min_span)]
743
+
744
+ if raw_df.empty:
745
+ return raw_df, raw_df.copy()
746
+
747
+ def _segment_overlap(a, b) -> int:
748
+ return max(0, min(a[1], b[1]) - max(a[0], b[0]))
749
+
750
+ filtered_rows = []
751
+ max_segments = max_segments_per_read
752
+ for read_id, read_df in raw_df.groupby("read_id", sort=False):
753
+ per_partner = (
754
+ read_df.sort_values(
755
+ ["segment_length_label", "segment_start"],
756
+ ascending=[False, True],
757
+ )
758
+ .groupby("partner_id", sort=False)
759
+ .head(1)
760
+ )
761
+ ordered = per_partner.sort_values(
762
+ ["segment_length_label", "segment_start"],
763
+ ascending=[False, True],
764
+ ).itertuples(index=False)
765
+ selected = []
766
+ for row in ordered:
767
+ if max_segments is not None and len(selected) >= max_segments:
768
+ break
769
+ seg = (row.segment_start, row.segment_end_exclusive)
770
+ if max_segment_overlap is not None:
771
+ if any(
772
+ _segment_overlap(seg, (s.segment_start, s.segment_end_exclusive))
773
+ > max_segment_overlap
774
+ for s in selected
775
+ ):
776
+ continue
777
+ selected.append(row)
778
+ for row in selected:
779
+ filtered_rows.append(row._asdict())
780
+
781
+ filtered_df = pd.DataFrame(filtered_rows, columns=raw_df.columns)
782
+ if not filtered_df.empty:
783
+ filtered_df["selection_rank"] = (
784
+ filtered_df.sort_values(
785
+ ["read_id", "segment_length_label", "segment_start"],
786
+ ascending=[True, False, True],
787
+ )
788
+ .groupby("read_id", sort=False)
789
+ .cumcount()
790
+ + 1
791
+ )
792
+
793
+ return raw_df, filtered_df
794
+
795
+
169
796
  def assign_rolling_nn_results(
170
797
  parent_adata: "ad.AnnData",
171
798
  subset_adata: "ad.AnnData",
@@ -210,6 +837,9 @@ def assign_rolling_nn_results(
210
837
  if obsm_key not in parent_adata.obsm:
211
838
  parent_adata.obsm[obsm_key] = np.full((n_obs, n_windows), np.nan, dtype=float)
212
839
  parent_adata.uns[f"{obsm_key}_starts"] = starts
840
+ parent_adata.uns[f"{obsm_key}_centers"] = _window_center_coordinates(
841
+ subset_adata, starts, window
842
+ )
213
843
  parent_adata.uns[f"{obsm_key}_window"] = int(window)
214
844
  parent_adata.uns[f"{obsm_key}_step"] = int(step)
215
845
  parent_adata.uns[f"{obsm_key}_min_overlap"] = int(min_overlap)
@@ -227,6 +857,13 @@ def assign_rolling_nn_results(
227
857
  raise ValueError(
228
858
  f"Existing obsm[{obsm_key!r}] has different window starts than new values."
229
859
  )
860
+ existing_centers = parent_adata.uns.get(f"{obsm_key}_centers")
861
+ if existing_centers is not None:
862
+ expected_centers = _window_center_coordinates(subset_adata, starts, window)
863
+ if not np.array_equal(existing_centers, expected_centers):
864
+ raise ValueError(
865
+ f"Existing obsm[{obsm_key!r}] has different window centers than new values."
866
+ )
230
867
 
231
868
  parent_indexer = parent_adata.obs_names.get_indexer(subset_adata.obs_names)
232
869
  if (parent_indexer < 0).any():