smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +18 -2
  4. smftools/cli/hmm_adata.py +18 -1
  5. smftools/cli/latent_adata.py +522 -67
  6. smftools/cli/load_adata.py +2 -2
  7. smftools/cli/preprocess_adata.py +32 -93
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +23 -109
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +41 -5
  12. smftools/config/conversion.yaml +0 -10
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +49 -13
  15. smftools/config/experiment_config.py +96 -3
  16. smftools/constants.py +4 -0
  17. smftools/hmm/call_hmm_peaks.py +1 -1
  18. smftools/informatics/binarize_converted_base_identities.py +2 -89
  19. smftools/informatics/converted_BAM_to_adata.py +53 -13
  20. smftools/informatics/h5ad_functions.py +83 -0
  21. smftools/informatics/modkit_extract_to_adata.py +4 -0
  22. smftools/plotting/__init__.py +26 -12
  23. smftools/plotting/autocorrelation_plotting.py +22 -4
  24. smftools/plotting/chimeric_plotting.py +1893 -0
  25. smftools/plotting/classifiers.py +28 -14
  26. smftools/plotting/general_plotting.py +58 -3362
  27. smftools/plotting/hmm_plotting.py +1586 -2
  28. smftools/plotting/latent_plotting.py +804 -0
  29. smftools/plotting/plotting_utils.py +243 -0
  30. smftools/plotting/position_stats.py +16 -8
  31. smftools/plotting/preprocess_plotting.py +281 -0
  32. smftools/plotting/qc_plotting.py +8 -3
  33. smftools/plotting/spatial_plotting.py +1134 -0
  34. smftools/plotting/variant_plotting.py +1231 -0
  35. smftools/preprocessing/__init__.py +3 -0
  36. smftools/preprocessing/append_base_context.py +1 -1
  37. smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
  38. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  39. smftools/preprocessing/append_variant_call_layer.py +480 -0
  40. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  41. smftools/preprocessing/invert_adata.py +1 -0
  42. smftools/readwrite.py +109 -85
  43. smftools/tools/__init__.py +6 -0
  44. smftools/tools/calculate_knn.py +121 -0
  45. smftools/tools/calculate_nmf.py +18 -7
  46. smftools/tools/calculate_pca.py +180 -0
  47. smftools/tools/calculate_umap.py +70 -154
  48. smftools/tools/position_stats.py +4 -4
  49. smftools/tools/rolling_nn_distance.py +640 -3
  50. smftools/tools/sequence_alignment.py +140 -0
  51. smftools/tools/tensor_factorization.py +52 -4
  52. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
  53. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
  54. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  55. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  56. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1563 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ import anndata as ad
8
+
9
+ from smftools.constants import CHIMERIC_DIR, LOGGING_DIR
10
+ from smftools.logging_utils import get_logger, setup_logging
11
+
12
+ logger = get_logger(__name__)
13
+
14
+ ZERO_HAMMING_DISTANCE_SPANS = "zero_hamming_distance_spans"
15
+
16
+
17
+ def _max_positive_span_length(delta_row: "np.ndarray") -> int:
18
+ """Return the max contiguous run length where delta span values are > 0."""
19
+ import numpy as np
20
+
21
+ values = np.asarray(delta_row)
22
+ if values.ndim != 1 or values.size == 0:
23
+ return 0
24
+
25
+ positive_mask = values > 0
26
+ if not np.any(positive_mask):
27
+ return 0
28
+
29
+ transitions = np.diff(positive_mask.astype(np.int8))
30
+ starts = np.flatnonzero(transitions == 1) + 1
31
+ ends = np.flatnonzero(transitions == -1) + 1
32
+
33
+ if positive_mask[0]:
34
+ starts = np.r_[0, starts]
35
+ if positive_mask[-1]:
36
+ ends = np.r_[ends, positive_mask.size]
37
+
38
+ return int(np.max(ends - starts))
39
+
40
+
41
+ def _compute_chimeric_by_mod_hamming_distance(
42
+ delta_layer: "np.ndarray",
43
+ span_threshold: int,
44
+ ) -> "np.ndarray":
45
+ """Flag reads with any delta-hamming span strictly larger than ``span_threshold``."""
46
+ import numpy as np
47
+
48
+ delta_values = np.asarray(delta_layer)
49
+ if delta_values.ndim != 2:
50
+ raise ValueError("delta_layer must be a 2D array with shape (n_obs, n_vars).")
51
+
52
+ flags = np.zeros(delta_values.shape[0], dtype=bool)
53
+ for obs_idx, row in enumerate(delta_values):
54
+ flags[obs_idx] = _max_positive_span_length(row) > span_threshold
55
+ return flags
56
+
57
+
58
+ def _build_top_segments_obs_tuples(
59
+ read_df: "pd.DataFrame",
60
+ obs_names: "pd.Index",
61
+ ) -> list[tuple[int, int, str]]:
62
+ """
63
+ Build per-read top-segment tuples with integer spans and partner names.
64
+
65
+ Args:
66
+ read_df: DataFrame for a single read containing segment and partner fields.
67
+ obs_names: AnnData obs names used to resolve partner indices.
68
+
69
+ Returns:
70
+ List of ``(segment_start, segment_end_exclusive, partner_name)`` tuples.
71
+ """
72
+ import pandas as pd
73
+
74
+ tuples: list[tuple[int, int, str]] = []
75
+ for row in read_df.itertuples(index=False):
76
+ start_val = int(row.segment_start_label)
77
+ end_val = int(row.segment_end_label)
78
+ partner_name = row.partner_name
79
+ if partner_name is None or pd.isna(partner_name):
80
+ partner_id = int(row.partner_id)
81
+ if 0 <= partner_id < len(obs_names):
82
+ partner_name = str(obs_names[partner_id])
83
+ else:
84
+ partner_name = str(partner_id)
85
+ tuples.append((start_val, end_val, str(partner_name)))
86
+ return tuples
87
+
88
+
89
+ def _build_zero_hamming_span_layer_from_obs(
90
+ adata: ad.AnnData,
91
+ obs_key: str,
92
+ layer_key: str,
93
+ ) -> None:
94
+ """
95
+ Populate a count-based span layer from per-read obs tuples.
96
+
97
+ Args:
98
+ adata: AnnData to receive/update the layer.
99
+ obs_key: obs column containing ``(start_label, end_label, partner)`` tuples.
100
+ layer_key: Name of the layer to create/update.
101
+ """
102
+ import numpy as np
103
+ import pandas as pd
104
+
105
+ if obs_key not in adata.obs:
106
+ return
107
+
108
+ try:
109
+ label_indexer = {int(label): idx for idx, label in enumerate(adata.var_names)}
110
+ except (TypeError, ValueError):
111
+ logger.warning(
112
+ "Unable to build span layer %s: adata.var_names are not numeric labels.",
113
+ layer_key,
114
+ )
115
+ return
116
+
117
+ if layer_key in adata.layers:
118
+ target_layer = np.asarray(adata.layers[layer_key])
119
+ if target_layer.shape != (adata.n_obs, adata.n_vars):
120
+ target_layer = np.zeros((adata.n_obs, adata.n_vars), dtype=np.uint16)
121
+ else:
122
+ target_layer = np.zeros((adata.n_obs, adata.n_vars), dtype=np.uint16)
123
+
124
+ for obs_idx, spans in enumerate(adata.obs[obs_key].tolist()):
125
+ if not isinstance(spans, list):
126
+ continue
127
+ for span in spans:
128
+ if span is None or len(span) < 2:
129
+ continue
130
+ start_val, end_val = span[0], span[1]
131
+ if start_val is None or end_val is None:
132
+ continue
133
+ if pd.isna(start_val) or pd.isna(end_val):
134
+ continue
135
+ try:
136
+ start_label = int(start_val)
137
+ end_label = int(end_val)
138
+ except (TypeError, ValueError):
139
+ continue
140
+ if start_label not in label_indexer or end_label not in label_indexer:
141
+ continue
142
+ start_idx = label_indexer[start_label]
143
+ end_idx = label_indexer[end_label]
144
+ if start_idx > end_idx:
145
+ start_idx, end_idx = end_idx, start_idx
146
+ target_layer[obs_idx, start_idx : end_idx + 1] = 1
147
+
148
+ adata.layers[layer_key] = target_layer
149
+
150
+
151
+ def chimeric_adata(
152
+ config_path: str,
153
+ ) -> Tuple[Optional[ad.AnnData], Optional[Path]]:
154
+ """
155
+ CLI-facing wrapper for chimeric analyses.
156
+
157
+ Called by: `smftools chimeric <config_path>`
158
+
159
+ Responsibilities:
160
+ - Ensure a usable AnnData exists.
161
+ - Determine which AnnData stages exist.
162
+ - Call `chimeric_adata_core(...)` when actual work is needed.
163
+ """
164
+ from ..readwrite import safe_read_h5ad
165
+ from .helpers import get_adata_paths, load_experiment_config
166
+
167
+ # 1) Ensure config + basic paths via load_adata
168
+ cfg = load_experiment_config(config_path)
169
+
170
+ paths = get_adata_paths(cfg)
171
+
172
+ pp_path = paths.pp
173
+ pp_dedup_path = paths.pp_dedup
174
+ spatial_path = paths.spatial
175
+ chimeric_path = paths.chimeric
176
+ variant_path = paths.variant
177
+ hmm_path = paths.hmm
178
+ latent_path = paths.latent
179
+
180
+ # Stage-skipping logic
181
+ if not getattr(cfg, "force_redo_chimeric_analyses", False):
182
+ if chimeric_path.exists():
183
+ logger.info(f"Chimeric AnnData found: {chimeric_path}\nSkipping smftools chimeric")
184
+ return None, spatial_path
185
+
186
+ # Helper to load from disk, reusing loaded_adata if it matches
187
+ def _load(path: Path):
188
+ adata, _ = safe_read_h5ad(path)
189
+ return adata
190
+
191
+ # 3) Decide which AnnData to use as the *starting point* for analyses
192
+ if hmm_path.exists():
193
+ start_adata = _load(hmm_path)
194
+ source_path = hmm_path
195
+ elif latent_path.exists():
196
+ start_adata = _load(latent_path)
197
+ source_path = latent_path
198
+ elif spatial_path.exists():
199
+ start_adata = _load(spatial_path)
200
+ source_path = spatial_path
201
+ elif chimeric_path.exists():
202
+ start_adata = _load(chimeric_path)
203
+ source_path = chimeric_path
204
+ elif variant_path.exists():
205
+ start_adata = _load(variant_path)
206
+ source_path = variant_path
207
+ elif pp_dedup_path.exists():
208
+ start_adata = _load(pp_dedup_path)
209
+ source_path = pp_dedup_path
210
+ elif pp_path.exists():
211
+ start_adata = _load(pp_path)
212
+ source_path = pp_path
213
+ else:
214
+ logger.warning(
215
+ "No suitable AnnData found for chimeric analyses (need at least preprocessed)."
216
+ )
217
+ return None, None
218
+
219
+ # 4) Run the core
220
+ adata_chimeric, chimeric_path = chimeric_adata_core(
221
+ adata=start_adata,
222
+ cfg=cfg,
223
+ paths=paths,
224
+ source_adata_path=source_path,
225
+ config_path=config_path,
226
+ )
227
+
228
+ return adata_chimeric, chimeric_path
229
+
230
+
231
+ def chimeric_adata_core(
232
+ adata: ad.AnnData,
233
+ cfg,
234
+ paths: AdataPaths,
235
+ source_adata_path: Optional[Path] = None,
236
+ config_path: Optional[str] = None,
237
+ ) -> Tuple[ad.AnnData, Path]:
238
+ """
239
+ Core chimeric analysis pipeline.
240
+
241
+ Assumes:
242
+ - `cfg` is the ExperimentConfig.
243
+
244
+ Does:
245
+ -
246
+ - Save AnnData.
247
+ """
248
+ import os
249
+ import warnings
250
+ from datetime import datetime
251
+ from pathlib import Path
252
+
253
+ import numpy as np
254
+ import pandas as pd
255
+
256
+ from ..metadata import record_smftools_metadata
257
+ from ..plotting import (
258
+ plot_delta_hamming_summary,
259
+ plot_rolling_nn_and_layer,
260
+ plot_rolling_nn_and_two_layers,
261
+ plot_segment_length_histogram,
262
+ plot_span_length_distributions,
263
+ plot_zero_hamming_pair_counts,
264
+ plot_zero_hamming_span_and_layer,
265
+ )
266
+ from ..preprocessing import (
267
+ load_sample_sheet,
268
+ )
269
+ from ..readwrite import make_dirs
270
+ from ..tools import (
271
+ annotate_zero_hamming_segments,
272
+ rolling_window_nn_distance,
273
+ select_top_segments_per_read,
274
+ )
275
+ from ..tools.rolling_nn_distance import (
276
+ assign_rolling_nn_results,
277
+ zero_hamming_segments_to_dataframe,
278
+ )
279
+ from .helpers import write_gz_h5ad
280
+
281
+ # -----------------------------
282
+ # General setup
283
+ # -----------------------------
284
+ date_str = datetime.today().strftime("%y%m%d")
285
+ now = datetime.now()
286
+ time_str = now.strftime("%H%M%S")
287
+ log_level = getattr(logging, cfg.log_level.upper(), logging.INFO)
288
+
289
+ output_directory = Path(cfg.output_directory)
290
+ chimeric_directory = output_directory / CHIMERIC_DIR
291
+ logging_directory = chimeric_directory / LOGGING_DIR
292
+
293
+ make_dirs([output_directory, chimeric_directory])
294
+
295
+ if cfg.emit_log_file:
296
+ log_file = logging_directory / f"{date_str}_{time_str}_log.log"
297
+ make_dirs([logging_directory])
298
+ else:
299
+ log_file = None
300
+
301
+ setup_logging(level=log_level, log_file=log_file, reconfigure=log_file is not None)
302
+
303
+ smf_modality = cfg.smf_modality
304
+ if smf_modality == "conversion":
305
+ deaminase = False
306
+ else:
307
+ deaminase = True
308
+ if smf_modality == "direct":
309
+ rolling_nn_layer = "binarized_methylation"
310
+ else:
311
+ rolling_nn_layer = None
312
+
313
+ # -----------------------------
314
+ # Optional sample sheet metadata
315
+ # -----------------------------
316
+ if getattr(cfg, "sample_sheet_path", None):
317
+ load_sample_sheet(
318
+ adata,
319
+ cfg.sample_sheet_path,
320
+ mapping_key_column=cfg.sample_sheet_mapping_column,
321
+ as_category=True,
322
+ force_reload=cfg.force_reload_sample_sheet,
323
+ )
324
+
325
+ references = adata.obs[cfg.reference_column].cat.categories
326
+
327
+ # Auto-detect variant call layer from a prior variant_adata run
328
+ variant_call_layer_name = None
329
+ variant_seq1_label = "seq1"
330
+ variant_seq2_label = "seq2"
331
+ seq1_col, seq2_col = getattr(cfg, "references_to_align_for_variant_annotation", [None, None])
332
+ if seq1_col and seq2_col:
333
+ candidate = f"{seq1_col}__{seq2_col}_variant_call"
334
+ if candidate in adata.layers:
335
+ variant_call_layer_name = candidate
336
+ suffix = "_strand_FASTA_base"
337
+ variant_seq1_label = seq1_col[: -len(suffix)] if seq1_col.endswith(suffix) else seq1_col
338
+ variant_seq2_label = seq2_col[: -len(suffix)] if seq2_col.endswith(suffix) else seq2_col
339
+ logger.info(
340
+ "Detected variant call layer '%s'; will overlay on span clustermaps.",
341
+ variant_call_layer_name,
342
+ )
343
+
344
+ # ============================================================
345
+ # 1) Rolling NN distances + layer clustermaps
346
+ # ============================================================
347
+ rolling_nn_dir = chimeric_directory / "01_rolling_nn_clustermaps"
348
+
349
+ if rolling_nn_dir.is_dir() and not getattr(cfg, "force_redo_chimeric_analyses", False):
350
+ logger.debug(f"{rolling_nn_dir} already exists. Skipping rolling NN distance plots.")
351
+ else:
352
+ make_dirs([rolling_nn_dir])
353
+ samples = (
354
+ adata.obs[cfg.sample_name_col_for_plotting].astype("category").cat.categories.tolist()
355
+ )
356
+ references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
357
+
358
+ for reference in references:
359
+ for sample in samples:
360
+ mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
361
+ adata.obs[cfg.reference_column] == reference
362
+ )
363
+ if not mask.any():
364
+ continue
365
+
366
+ subset = adata[mask]
367
+ position_col = f"position_in_{reference}"
368
+ site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
369
+ missing_cols = [
370
+ col for col in [position_col, *site_cols] if col not in adata.var.columns
371
+ ]
372
+ if missing_cols:
373
+ raise KeyError(
374
+ f"Required site mask columns missing in adata.var: {missing_cols}"
375
+ )
376
+ mod_site_mask = adata.var[site_cols].fillna(False).any(axis=1)
377
+ site_mask = mod_site_mask & adata.var[position_col].fillna(False)
378
+ subset = subset[:, site_mask].copy()
379
+ try:
380
+ rolling_values, rolling_starts = rolling_window_nn_distance(
381
+ subset,
382
+ layer=rolling_nn_layer,
383
+ window=cfg.rolling_nn_window,
384
+ step=cfg.rolling_nn_step,
385
+ min_overlap=cfg.rolling_nn_min_overlap,
386
+ return_fraction=cfg.rolling_nn_return_fraction,
387
+ store_obsm=cfg.rolling_nn_obsm_key,
388
+ collect_zero_pairs=True,
389
+ )
390
+ except Exception as exc:
391
+ logger.warning(
392
+ "Rolling NN distance computation failed for sample=%s ref=%s: %s",
393
+ sample,
394
+ reference,
395
+ exc,
396
+ )
397
+ continue
398
+
399
+ safe_sample = str(sample).replace(os.sep, "_")
400
+ safe_ref = str(reference).replace(os.sep, "_")
401
+ map_key = f"{safe_sample}__{safe_ref}"
402
+ parent_obsm_key = f"{cfg.rolling_nn_obsm_key}__{safe_ref}"
403
+ try:
404
+ assign_rolling_nn_results(
405
+ adata,
406
+ subset,
407
+ rolling_values,
408
+ rolling_starts,
409
+ obsm_key=parent_obsm_key,
410
+ window=cfg.rolling_nn_window,
411
+ step=cfg.rolling_nn_step,
412
+ min_overlap=cfg.rolling_nn_min_overlap,
413
+ return_fraction=cfg.rolling_nn_return_fraction,
414
+ layer=rolling_nn_layer,
415
+ )
416
+ except Exception as exc:
417
+ logger.warning(
418
+ "Failed to merge rolling NN results for sample=%s ref=%s: %s",
419
+ sample,
420
+ reference,
421
+ exc,
422
+ )
423
+ resolved_zero_pairs_key = f"{cfg.rolling_nn_obsm_key}_zero_pairs"
424
+ parent_zero_pairs_key = f"{parent_obsm_key}__zero_pairs"
425
+ zero_pairs_data = subset.uns.get(resolved_zero_pairs_key)
426
+ rolling_zero_pairs_out_dir = rolling_nn_dir / "01_rolling_nn_zero_pairs"
427
+ if zero_pairs_data is not None:
428
+ adata.uns[parent_zero_pairs_key] = zero_pairs_data
429
+ for suffix in (
430
+ "starts",
431
+ "window",
432
+ "step",
433
+ "min_overlap",
434
+ "return_fraction",
435
+ "layer",
436
+ ):
437
+ value = subset.uns.get(f"{resolved_zero_pairs_key}_{suffix}")
438
+ if value is not None:
439
+ adata.uns[f"{parent_zero_pairs_key}_{suffix}"] = value
440
+ adata.uns.setdefault(
441
+ f"{cfg.rolling_nn_obsm_key}_zero_pairs_map", {}
442
+ ).setdefault(map_key, {})["zero_pairs_key"] = parent_zero_pairs_key
443
+ else:
444
+ logger.warning(
445
+ "Zero-pair data missing for sample=%s ref=%s (key=%s).",
446
+ sample,
447
+ reference,
448
+ resolved_zero_pairs_key,
449
+ )
450
+ try:
451
+ segments_uns_key = f"{parent_obsm_key}__zero_hamming_segments"
452
+ segment_records = annotate_zero_hamming_segments(
453
+ subset,
454
+ zero_pairs_uns_key=resolved_zero_pairs_key,
455
+ output_uns_key=segments_uns_key,
456
+ layer=rolling_nn_layer,
457
+ min_overlap=cfg.rolling_nn_min_overlap,
458
+ refine_segments=getattr(cfg, "rolling_nn_zero_pairs_refine", True),
459
+ max_nan_run=getattr(cfg, "rolling_nn_zero_pairs_max_nan_run", None),
460
+ merge_gap=getattr(cfg, "rolling_nn_zero_pairs_merge_gap", 0),
461
+ max_segments_per_read=getattr(
462
+ cfg, "rolling_nn_zero_pairs_max_segments_per_read", None
463
+ ),
464
+ max_segment_overlap=getattr(cfg, "rolling_nn_zero_pairs_max_overlap", None),
465
+ )
466
+ adata.uns.setdefault(
467
+ f"{cfg.rolling_nn_obsm_key}_zero_pairs_map", {}
468
+ ).setdefault(map_key, {}).update({"segments_key": segments_uns_key})
469
+ if getattr(cfg, "rolling_nn_write_zero_pairs_csvs", True):
470
+ try:
471
+ make_dirs([rolling_zero_pairs_out_dir])
472
+ segments_df = zero_hamming_segments_to_dataframe(
473
+ segment_records, subset.var_names.to_numpy()
474
+ )
475
+ segments_df.to_csv(
476
+ rolling_zero_pairs_out_dir
477
+ / f"{safe_sample}__{safe_ref}__zero_pairs_segments.csv",
478
+ index=False,
479
+ )
480
+ except Exception as exc:
481
+ logger.warning(
482
+ "Failed to write zero-pair segments CSV for sample=%s ref=%s: %s",
483
+ sample,
484
+ reference,
485
+ exc,
486
+ )
487
+ top_segments_per_read = getattr(
488
+ cfg, "rolling_nn_zero_pairs_top_segments_per_read", None
489
+ )
490
+ if top_segments_per_read is not None:
491
+ raw_df, filtered_df = select_top_segments_per_read(
492
+ segment_records,
493
+ subset.var_names.to_numpy(),
494
+ max_segments_per_read=top_segments_per_read,
495
+ max_segment_overlap=getattr(
496
+ cfg, "rolling_nn_zero_pairs_top_segments_max_overlap", None
497
+ ),
498
+ min_span=getattr(
499
+ cfg, "rolling_nn_zero_pairs_top_segments_min_span", None
500
+ ),
501
+ )
502
+ per_read_layer_key = ZERO_HAMMING_DISTANCE_SPANS
503
+ per_read_obs_key = f"{parent_obsm_key}__top_segments"
504
+ if per_read_obs_key in adata.obs:
505
+ per_read_obs_series = adata.obs[per_read_obs_key].copy()
506
+ per_read_obs_series = per_read_obs_series.apply(
507
+ lambda value: value if isinstance(value, list) else []
508
+ )
509
+ else:
510
+ per_read_obs_series = pd.Series(
511
+ [list() for _ in range(adata.n_obs)],
512
+ index=adata.obs_names,
513
+ dtype=object,
514
+ )
515
+ if not filtered_df.empty:
516
+ for read_id, read_df in filtered_df.groupby("read_id", sort=False):
517
+ read_index = int(read_id)
518
+ if read_index < 0 or read_index >= subset.n_obs:
519
+ continue
520
+ tuples = _build_top_segments_obs_tuples(
521
+ read_df,
522
+ subset.obs_names,
523
+ )
524
+ per_read_obs_series.at[subset.obs_names[read_index]] = tuples
525
+ adata.obs[per_read_obs_key] = per_read_obs_series
526
+ _build_zero_hamming_span_layer_from_obs(
527
+ adata=adata,
528
+ obs_key=per_read_obs_key,
529
+ layer_key=per_read_layer_key,
530
+ )
531
+ adata.uns.setdefault(
532
+ f"{cfg.rolling_nn_obsm_key}_zero_pairs_map", {}
533
+ ).setdefault(map_key, {}).update(
534
+ {
535
+ "per_read_layer_key": per_read_layer_key,
536
+ "per_read_obs_key": per_read_obs_key,
537
+ }
538
+ )
539
+ if getattr(cfg, "rolling_nn_zero_pairs_top_segments_write_csvs", True):
540
+ try:
541
+ make_dirs([rolling_zero_pairs_out_dir])
542
+ filtered_df.to_csv(
543
+ rolling_zero_pairs_out_dir
544
+ / f"{safe_sample}__{safe_ref}__zero_pairs_top_segments_per_read.csv",
545
+ index=False,
546
+ )
547
+ except Exception as exc:
548
+ logger.warning(
549
+ "Failed to write top segments CSV for sample=%s ref=%s: %s",
550
+ sample,
551
+ reference,
552
+ exc,
553
+ )
554
+ histogram_dir = rolling_zero_pairs_out_dir / "segment_histograms"
555
+ try:
556
+ make_dirs([histogram_dir])
557
+ raw_lengths = raw_df["segment_length_label"].to_numpy()
558
+ filtered_lengths = filtered_df["segment_length_label"].to_numpy()
559
+ hist_title = f"{sample} {reference} (n={subset.n_obs})"
560
+ plot_segment_length_histogram(
561
+ raw_lengths,
562
+ filtered_lengths,
563
+ bins=getattr(
564
+ cfg,
565
+ "rolling_nn_zero_pairs_segment_histogram_bins",
566
+ 30,
567
+ ),
568
+ title=hist_title,
569
+ density=True,
570
+ save_name=histogram_dir
571
+ / f"{safe_sample}__{safe_ref}__segment_lengths.png",
572
+ )
573
+ except Exception as exc:
574
+ logger.warning(
575
+ "Failed to plot segment length histogram for sample=%s ref=%s: %s",
576
+ sample,
577
+ reference,
578
+ exc,
579
+ )
580
+ except Exception as exc:
581
+ logger.warning(
582
+ "Failed to annotate zero-pair segments for sample=%s ref=%s: %s",
583
+ sample,
584
+ reference,
585
+ exc,
586
+ )
587
+ adata.uns.setdefault(f"{cfg.rolling_nn_obsm_key}_reference_map", {})[reference] = (
588
+ parent_obsm_key
589
+ )
590
+ out_png = rolling_nn_dir / f"{safe_sample}__{safe_ref}.png"
591
+ title = f"{sample} {reference} (n={subset.n_obs}) | window={cfg.rolling_nn_window}"
592
+ try:
593
+ plot_rolling_nn_and_layer(
594
+ subset,
595
+ obsm_key=cfg.rolling_nn_obsm_key,
596
+ layer_key=cfg.rolling_nn_plot_layer,
597
+ fill_nn_with_colmax=False,
598
+ drop_all_nan_windows=False,
599
+ max_nan_fraction=cfg.position_max_nan_threshold,
600
+ var_valid_fraction_col=f"{reference}_valid_fraction",
601
+ title=title,
602
+ save_name=out_png,
603
+ )
604
+ except Exception as exc:
605
+ logger.warning(
606
+ "Failed rolling NN plot for sample=%s ref=%s: %s",
607
+ sample,
608
+ reference,
609
+ exc,
610
+ )
611
+
612
+ # ============================================================
613
+ # 2) Zero-Hamming span clustermaps
614
+ # ============================================================
615
+ zero_hamming_dir = chimeric_directory / "02_zero_hamming_span_clustermaps"
616
+
617
+ if zero_hamming_dir.is_dir():
618
+ logger.debug(f"{zero_hamming_dir} already exists. Skipping zero-Hamming plots.")
619
+ else:
620
+ zero_pairs_map = adata.uns.get(f"{cfg.rolling_nn_obsm_key}_zero_pairs_map", {})
621
+ if zero_pairs_map:
622
+ make_dirs([zero_hamming_dir])
623
+ samples = (
624
+ adata.obs[cfg.sample_name_col_for_plotting]
625
+ .astype("category")
626
+ .cat.categories.tolist()
627
+ )
628
+ references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
629
+ for reference in references:
630
+ for sample in samples:
631
+ mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
632
+ adata.obs[cfg.reference_column] == reference
633
+ )
634
+ if not mask.any():
635
+ continue
636
+
637
+ safe_sample = str(sample).replace(os.sep, "_")
638
+ safe_ref = str(reference).replace(os.sep, "_")
639
+ map_key = f"{safe_sample}__{safe_ref}"
640
+ map_entry = zero_pairs_map.get(map_key)
641
+ if not map_entry:
642
+ continue
643
+
644
+ layer_key = map_entry.get("per_read_layer_key")
645
+ if not layer_key or layer_key not in adata.layers:
646
+ logger.warning(
647
+ "Zero-Hamming span layer %s missing for sample=%s ref=%s.",
648
+ layer_key,
649
+ sample,
650
+ reference,
651
+ )
652
+ continue
653
+
654
+ subset = adata[mask]
655
+ position_col = f"position_in_{reference}"
656
+ site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
657
+ missing_cols = [
658
+ col for col in [position_col, *site_cols] if col not in adata.var.columns
659
+ ]
660
+ if missing_cols:
661
+ raise KeyError(
662
+ f"Required site mask columns missing in adata.var: {missing_cols}"
663
+ )
664
+ mod_site_mask = adata.var[site_cols].fillna(False).any(axis=1)
665
+ site_mask = mod_site_mask & adata.var[position_col].fillna(False)
666
+ # Build variant call DataFrame before column filtering
667
+ _variant_call_df = None
668
+ if variant_call_layer_name and variant_call_layer_name in adata.layers:
669
+ _vc = adata[mask].layers[variant_call_layer_name]
670
+ _vc = _vc.toarray() if hasattr(_vc, "toarray") else np.asarray(_vc)
671
+ _variant_call_df = pd.DataFrame(
672
+ _vc,
673
+ index=adata[mask].obs_names.astype(str),
674
+ columns=adata.var_names,
675
+ )
676
+
677
+ subset = subset[:, site_mask].copy()
678
+ title = f"{sample} {reference} (n={subset.n_obs})"
679
+ out_png = zero_hamming_dir / f"{safe_sample}__{safe_ref}.png"
680
+ try:
681
+ plot_zero_hamming_span_and_layer(
682
+ subset,
683
+ span_layer_key=layer_key,
684
+ layer_key=cfg.rolling_nn_plot_layer,
685
+ max_nan_fraction=cfg.position_max_nan_threshold,
686
+ var_valid_fraction_col=f"{reference}_valid_fraction",
687
+ variant_call_data=_variant_call_df,
688
+ seq1_label=variant_seq1_label,
689
+ seq2_label=variant_seq2_label,
690
+ ref1_marker_color=getattr(cfg, "variant_overlay_seq1_color", "white"),
691
+ ref2_marker_color=getattr(cfg, "variant_overlay_seq2_color", "black"),
692
+ variant_marker_size=getattr(cfg, "variant_overlay_marker_size", 4.0),
693
+ title=title,
694
+ save_name=out_png,
695
+ )
696
+ except Exception as exc:
697
+ logger.warning(
698
+ "Failed zero-Hamming span plot for sample=%s ref=%s: %s",
699
+ sample,
700
+ reference,
701
+ exc,
702
+ )
703
+ else:
704
+ logger.debug("No zero-pair map found; skipping zero-Hamming span clustermaps.")
705
+
706
+ # ============================================================
707
+ # 3) Rolling NN + two-layer clustermaps
708
+ # ============================================================
709
+ rolling_nn_layers_dir = chimeric_directory / "03_rolling_nn_two_layer_clustermaps"
710
+ zero_pairs_map = adata.uns.get(f"{cfg.rolling_nn_obsm_key}_zero_pairs_map", {})
711
+
712
+ if rolling_nn_layers_dir.is_dir() and not getattr(cfg, "force_redo_chimeric_analyses", False):
713
+ logger.debug(
714
+ "%s already exists. Skipping rolling NN two-layer clustermaps.",
715
+ rolling_nn_layers_dir,
716
+ )
717
+ else:
718
+ plot_layers = list(getattr(cfg, "rolling_nn_plot_layers", []) or [])
719
+ if len(plot_layers) != 2:
720
+ logger.warning(
721
+ "rolling_nn_plot_layers should list exactly two layers; got %s. Skipping.",
722
+ plot_layers,
723
+ )
724
+ else:
725
+ make_dirs([rolling_nn_layers_dir])
726
+ samples = (
727
+ adata.obs[cfg.sample_name_col_for_plotting]
728
+ .astype("category")
729
+ .cat.categories.tolist()
730
+ )
731
+ references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
732
+ for reference in references:
733
+ for sample in samples:
734
+ mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & (
735
+ adata.obs[cfg.reference_column] == reference
736
+ )
737
+ if not mask.any():
738
+ continue
739
+
740
+ safe_sample = str(sample).replace(os.sep, "_")
741
+ safe_ref = str(reference).replace(os.sep, "_")
742
+ parent_obsm_key = f"{cfg.rolling_nn_obsm_key}__{safe_ref}"
743
+ map_key = f"{safe_sample}__{safe_ref}"
744
+
745
+ subset = adata[mask]
746
+ position_col = f"position_in_{reference}"
747
+ site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
748
+ missing_cols = [
749
+ col for col in [position_col, *site_cols] if col not in adata.var.columns
750
+ ]
751
+ if missing_cols:
752
+ raise KeyError(
753
+ f"Required site mask columns missing in adata.var: {missing_cols}"
754
+ )
755
+ mod_site_mask = adata.var[site_cols].fillna(False).any(axis=1)
756
+ site_mask = mod_site_mask & adata.var[position_col].fillna(False)
757
+ subset = subset[:, site_mask].copy()
758
+
759
+ if (
760
+ parent_obsm_key not in subset.obsm
761
+ and cfg.rolling_nn_obsm_key not in subset.obsm
762
+ ):
763
+ logger.warning(
764
+ "Rolling NN results missing for sample=%s ref=%s (key=%s).",
765
+ sample,
766
+ reference,
767
+ parent_obsm_key,
768
+ )
769
+ continue
770
+ plot_layers_resolved = list(plot_layers)
771
+ map_entry = zero_pairs_map.get(map_key, {})
772
+ zero_hamming_layer_key = map_entry.get("per_read_layer_key")
773
+ if zero_hamming_layer_key and len(plot_layers_resolved) == 2:
774
+ plot_layers_resolved[1] = zero_hamming_layer_key
775
+ elif (
776
+ ZERO_HAMMING_DISTANCE_SPANS in subset.layers
777
+ and len(plot_layers_resolved) == 2
778
+ ):
779
+ plot_layers_resolved[1] = ZERO_HAMMING_DISTANCE_SPANS
780
+
781
+ if (
782
+ cfg.rolling_nn_obsm_key not in subset.obsm
783
+ and parent_obsm_key in subset.obsm
784
+ ):
785
+ subset.obsm[cfg.rolling_nn_obsm_key] = subset.obsm[parent_obsm_key]
786
+ for suffix in (
787
+ "starts",
788
+ "centers",
789
+ "window",
790
+ "step",
791
+ "min_overlap",
792
+ "return_fraction",
793
+ "layer",
794
+ ):
795
+ parent_key = f"{parent_obsm_key}_{suffix}"
796
+ if parent_key in adata.uns:
797
+ subset.uns.setdefault(
798
+ f"{cfg.rolling_nn_obsm_key}_{suffix}", adata.uns[parent_key]
799
+ )
800
+
801
+ missing_layers = [
802
+ layer_key
803
+ for layer_key in plot_layers_resolved
804
+ if layer_key not in subset.layers
805
+ ]
806
+ if missing_layers:
807
+ logger.warning(
808
+ "Layer(s) %s missing for sample=%s ref=%s.",
809
+ missing_layers,
810
+ sample,
811
+ reference,
812
+ )
813
+ continue
814
+
815
+ out_png = rolling_nn_layers_dir / f"{safe_sample}__{safe_ref}.png"
816
+ title = (
817
+ f"{sample} {reference} (n={subset.n_obs}) | window={cfg.rolling_nn_window}"
818
+ )
819
+ try:
820
+ plot_rolling_nn_and_two_layers(
821
+ subset,
822
+ obsm_key=cfg.rolling_nn_obsm_key,
823
+ layer_keys=plot_layers_resolved,
824
+ fill_nn_with_colmax=False,
825
+ drop_all_nan_windows=False,
826
+ max_nan_fraction=cfg.position_max_nan_threshold,
827
+ var_valid_fraction_col=f"{reference}_valid_fraction",
828
+ title=title,
829
+ save_name=out_png,
830
+ )
831
+ except Exception as exc:
832
+ logger.warning(
833
+ "Failed rolling NN two-layer plot for sample=%s ref=%s: %s",
834
+ sample,
835
+ reference,
836
+ exc,
837
+ )
838
+
839
+ # ============================================================
840
+ # Cross-sample rolling NN analysis
841
+ # ============================================================
842
+ if getattr(cfg, "cross_sample_analysis", False):
843
+ CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS = "cross_sample_zero_hamming_distance_spans"
844
+ cross_nn_dir = chimeric_directory / "cross_sample_01_rolling_nn_clustermaps"
845
+ cross_zh_dir = chimeric_directory / "cross_sample_02_zero_hamming_span_clustermaps"
846
+ cross_two_dir = chimeric_directory / "cross_sample_03_rolling_nn_two_layer_clustermaps"
847
+
848
+ if cross_nn_dir.is_dir() and not getattr(cfg, "force_redo_chimeric_analyses", False):
849
+ logger.debug("Cross-sample dirs exist. Skipping cross-sample analysis.")
850
+ else:
851
+ make_dirs([cross_nn_dir, cross_zh_dir, cross_two_dir])
852
+ samples = (
853
+ adata.obs[cfg.sample_name_col_for_plotting]
854
+ .astype("category")
855
+ .cat.categories.tolist()
856
+ )
857
+ references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
858
+ rng = np.random.RandomState(getattr(cfg, "cross_sample_random_seed", 42))
859
+
860
+ for reference in references:
861
+ ref_mask = adata.obs[cfg.reference_column] == reference
862
+ position_col = f"position_in_{reference}"
863
+ site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
864
+ missing_cols = [
865
+ col for col in [position_col, *site_cols] if col not in adata.var.columns
866
+ ]
867
+ if missing_cols:
868
+ logger.warning(
869
+ "Cross-sample: missing var columns %s for ref=%s, skipping.",
870
+ missing_cols,
871
+ reference,
872
+ )
873
+ continue
874
+ mod_site_mask = adata.var[site_cols].fillna(False).any(axis=1)
875
+ site_mask = mod_site_mask & adata.var[position_col].fillna(False)
876
+
877
+ for sample in samples:
878
+ sample_mask = (adata.obs[cfg.sample_name_col_for_plotting] == sample) & ref_mask
879
+ if not sample_mask.any():
880
+ continue
881
+
882
+ # Build cross-sample pool
883
+ grouping_col = getattr(cfg, "cross_sample_grouping_col", None)
884
+ if grouping_col and grouping_col in adata.obs.columns:
885
+ sample_group_val = adata.obs.loc[sample_mask, grouping_col].iloc[0]
886
+ pool_mask = ref_mask & (adata.obs[grouping_col] == sample_group_val)
887
+ else:
888
+ pool_mask = ref_mask
889
+
890
+ other_mask = pool_mask & ~sample_mask
891
+ if not other_mask.any():
892
+ logger.debug(
893
+ "Cross-sample: no other-sample reads for sample=%s ref=%s.",
894
+ sample,
895
+ reference,
896
+ )
897
+ continue
898
+
899
+ n_sample = int(sample_mask.sum())
900
+ n_other = int(other_mask.sum())
901
+ n_use = min(n_sample, n_other)
902
+
903
+ other_indices = np.where(other_mask.values)[0]
904
+ if n_other > n_use:
905
+ other_indices = rng.choice(other_indices, size=n_use, replace=False)
906
+
907
+ sample_indices = np.where(sample_mask.values)[0]
908
+ combined_indices = np.concatenate([sample_indices, other_indices])
909
+ cross_subset = adata[combined_indices][:, site_mask].copy()
910
+
911
+ # Build sample_labels: 0 = current sample, 1 = other
912
+ cross_labels = np.zeros(len(combined_indices), dtype=np.int32)
913
+ cross_labels[len(sample_indices) :] = 1
914
+
915
+ cross_obsm_key = "cross_sample_rolling_nn_dist"
916
+ try:
917
+ rolling_values, rolling_starts = rolling_window_nn_distance(
918
+ cross_subset,
919
+ layer=rolling_nn_layer,
920
+ window=cfg.rolling_nn_window,
921
+ step=cfg.rolling_nn_step,
922
+ min_overlap=cfg.rolling_nn_min_overlap,
923
+ return_fraction=cfg.rolling_nn_return_fraction,
924
+ store_obsm=cross_obsm_key,
925
+ collect_zero_pairs=True,
926
+ sample_labels=cross_labels,
927
+ )
928
+ except Exception as exc:
929
+ logger.warning(
930
+ "Cross-sample rolling NN failed for sample=%s ref=%s: %s",
931
+ sample,
932
+ reference,
933
+ exc,
934
+ )
935
+ continue
936
+
937
+ safe_sample = str(sample).replace(os.sep, "_")
938
+ safe_ref = str(reference).replace(os.sep, "_")
939
+
940
+ # Assign results back to adata for sample reads only
941
+ parent_obsm_key = f"cross_sample_rolling_nn_dist__{safe_ref}"
942
+ sample_rolling = rolling_values[: len(sample_indices)]
943
+ try:
944
+ assign_rolling_nn_results(
945
+ adata,
946
+ cross_subset[: len(sample_indices)],
947
+ sample_rolling,
948
+ rolling_starts,
949
+ obsm_key=parent_obsm_key,
950
+ window=cfg.rolling_nn_window,
951
+ step=cfg.rolling_nn_step,
952
+ min_overlap=cfg.rolling_nn_min_overlap,
953
+ return_fraction=cfg.rolling_nn_return_fraction,
954
+ layer=rolling_nn_layer,
955
+ )
956
+ except Exception as exc:
957
+ logger.warning(
958
+ "Failed to merge cross-sample rolling NN for sample=%s ref=%s: %s",
959
+ sample,
960
+ reference,
961
+ exc,
962
+ )
963
+
964
+ # Zero-pair segments
965
+ resolved_zero_pairs_key = f"{cross_obsm_key}_zero_pairs"
966
+ zero_pairs_data = cross_subset.uns.get(resolved_zero_pairs_key)
967
+ if zero_pairs_data is not None:
968
+ try:
969
+ segments_uns_key = f"{parent_obsm_key}__zero_hamming_segments"
970
+ segment_records = annotate_zero_hamming_segments(
971
+ cross_subset,
972
+ zero_pairs_uns_key=resolved_zero_pairs_key,
973
+ output_uns_key=segments_uns_key,
974
+ layer=rolling_nn_layer,
975
+ min_overlap=cfg.rolling_nn_min_overlap,
976
+ refine_segments=getattr(cfg, "rolling_nn_zero_pairs_refine", True),
977
+ max_nan_run=getattr(cfg, "rolling_nn_zero_pairs_max_nan_run", None),
978
+ merge_gap=getattr(cfg, "rolling_nn_zero_pairs_merge_gap", 0),
979
+ max_segments_per_read=getattr(
980
+ cfg, "rolling_nn_zero_pairs_max_segments_per_read", None
981
+ ),
982
+ max_segment_overlap=getattr(
983
+ cfg, "rolling_nn_zero_pairs_max_overlap", None
984
+ ),
985
+ )
986
+
987
+ top_segments_per_read = getattr(
988
+ cfg, "rolling_nn_zero_pairs_top_segments_per_read", None
989
+ )
990
+ if top_segments_per_read is not None:
991
+ raw_df, filtered_df = select_top_segments_per_read(
992
+ segment_records,
993
+ cross_subset.var_names.to_numpy(),
994
+ max_segments_per_read=top_segments_per_read,
995
+ max_segment_overlap=getattr(
996
+ cfg, "rolling_nn_zero_pairs_top_segments_max_overlap", None
997
+ ),
998
+ min_span=getattr(
999
+ cfg, "rolling_nn_zero_pairs_top_segments_min_span", None
1000
+ ),
1001
+ )
1002
+ per_read_layer_key = CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS
1003
+ per_read_obs_key = f"{parent_obsm_key}__top_segments"
1004
+ if per_read_obs_key in adata.obs:
1005
+ per_read_obs_series = adata.obs[per_read_obs_key].copy()
1006
+ per_read_obs_series = per_read_obs_series.apply(
1007
+ lambda value: value if isinstance(value, list) else []
1008
+ )
1009
+ else:
1010
+ per_read_obs_series = pd.Series(
1011
+ [list() for _ in range(adata.n_obs)],
1012
+ index=adata.obs_names,
1013
+ dtype=object,
1014
+ )
1015
+ if not filtered_df.empty:
1016
+ for read_id, read_df in filtered_df.groupby(
1017
+ "read_id", sort=False
1018
+ ):
1019
+ read_index = int(read_id)
1020
+ if read_index < 0 or read_index >= cross_subset.n_obs:
1021
+ continue
1022
+ # Only assign for sample reads
1023
+ if read_index >= len(sample_indices):
1024
+ continue
1025
+ tuples = _build_top_segments_obs_tuples(
1026
+ read_df,
1027
+ cross_subset.obs_names,
1028
+ )
1029
+ per_read_obs_series.at[
1030
+ cross_subset.obs_names[read_index]
1031
+ ] = tuples
1032
+ adata.obs[per_read_obs_key] = per_read_obs_series
1033
+ _build_zero_hamming_span_layer_from_obs(
1034
+ adata=adata,
1035
+ obs_key=per_read_obs_key,
1036
+ layer_key=per_read_layer_key,
1037
+ )
1038
+ except Exception as exc:
1039
+ logger.warning(
1040
+ "Cross-sample zero-pair segments failed for sample=%s ref=%s: %s",
1041
+ sample,
1042
+ reference,
1043
+ exc,
1044
+ )
1045
+
1046
+ # Build variant call DataFrame before column filtering
1047
+ _cross_variant_call_df = None
1048
+ if variant_call_layer_name and variant_call_layer_name in adata.layers:
1049
+ _vc = adata[sample_mask].layers[variant_call_layer_name]
1050
+ _vc = _vc.toarray() if hasattr(_vc, "toarray") else np.asarray(_vc)
1051
+ _cross_variant_call_df = pd.DataFrame(
1052
+ _vc,
1053
+ index=adata[sample_mask].obs_names.astype(str),
1054
+ columns=adata.var_names,
1055
+ )
1056
+
1057
+ # --- Plots ---
1058
+ # Use the sample-only subset for plotting
1059
+ plot_subset = adata[sample_mask][:, site_mask].copy()
1060
+
1061
+ # Copy cross-sample obsm into plot_subset
1062
+ if parent_obsm_key in adata.obsm:
1063
+ plot_subset.obsm[cfg.rolling_nn_obsm_key] = adata[sample_mask].obsm.get(
1064
+ parent_obsm_key
1065
+ )
1066
+ for suffix in (
1067
+ "starts",
1068
+ "centers",
1069
+ "window",
1070
+ "step",
1071
+ "min_overlap",
1072
+ "return_fraction",
1073
+ "layer",
1074
+ ):
1075
+ parent_key = f"{parent_obsm_key}_{suffix}"
1076
+ if parent_key in adata.uns:
1077
+ plot_subset.uns[f"{cfg.rolling_nn_obsm_key}_{suffix}"] = adata.uns[
1078
+ parent_key
1079
+ ]
1080
+
1081
+ if grouping_col and grouping_col in adata.obs.columns:
1082
+ cross_pool_desc = f"cross-sample ({grouping_col}={sample_group_val})"
1083
+ else:
1084
+ cross_pool_desc = "cross-sample (all samples)"
1085
+ title = (
1086
+ f"{sample} {reference} (n={n_sample})"
1087
+ f" | {cross_pool_desc}"
1088
+ f" | subsample={len(other_indices)}"
1089
+ f" | window={cfg.rolling_nn_window}"
1090
+ )
1091
+
1092
+ # Plot 1: rolling NN clustermap
1093
+ try:
1094
+ out_png = cross_nn_dir / f"{safe_sample}__{safe_ref}.png"
1095
+ plot_rolling_nn_and_layer(
1096
+ plot_subset,
1097
+ obsm_key=cfg.rolling_nn_obsm_key,
1098
+ layer_key=cfg.rolling_nn_plot_layer,
1099
+ fill_nn_with_colmax=False,
1100
+ drop_all_nan_windows=False,
1101
+ max_nan_fraction=cfg.position_max_nan_threshold,
1102
+ var_valid_fraction_col=f"{reference}_valid_fraction",
1103
+ title=title,
1104
+ save_name=out_png,
1105
+ )
1106
+ except Exception as exc:
1107
+ logger.warning(
1108
+ "Cross-sample rolling NN plot failed for sample=%s ref=%s: %s",
1109
+ sample,
1110
+ reference,
1111
+ exc,
1112
+ )
1113
+
1114
+ # Plot 2: zero-hamming span clustermap
1115
+ if CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS in adata.layers:
1116
+ try:
1117
+ out_png = cross_zh_dir / f"{safe_sample}__{safe_ref}.png"
1118
+ plot_zero_hamming_span_and_layer(
1119
+ plot_subset,
1120
+ span_layer_key=CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS,
1121
+ layer_key=cfg.rolling_nn_plot_layer,
1122
+ max_nan_fraction=cfg.position_max_nan_threshold,
1123
+ var_valid_fraction_col=f"{reference}_valid_fraction",
1124
+ variant_call_data=_cross_variant_call_df,
1125
+ seq1_label=variant_seq1_label,
1126
+ seq2_label=variant_seq2_label,
1127
+ ref1_marker_color=getattr(
1128
+ cfg, "variant_overlay_seq1_color", "white"
1129
+ ),
1130
+ ref2_marker_color=getattr(
1131
+ cfg, "variant_overlay_seq2_color", "black"
1132
+ ),
1133
+ variant_marker_size=getattr(
1134
+ cfg, "variant_overlay_marker_size", 4.0
1135
+ ),
1136
+ title=title,
1137
+ save_name=out_png,
1138
+ )
1139
+ except Exception as exc:
1140
+ logger.warning(
1141
+ "Cross-sample zero-Hamming span plot failed for sample=%s ref=%s: %s",
1142
+ sample,
1143
+ reference,
1144
+ exc,
1145
+ )
1146
+
1147
+ # Plot 3: two-layer clustermap
1148
+ plot_layers = list(getattr(cfg, "rolling_nn_plot_layers", []) or [])
1149
+ if len(plot_layers) == 2:
1150
+ plot_layers_resolved = list(plot_layers)
1151
+ if CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS in plot_subset.layers:
1152
+ plot_layers_resolved[1] = CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS
1153
+ missing_layers = [
1154
+ lk for lk in plot_layers_resolved if lk not in plot_subset.layers
1155
+ ]
1156
+ if not missing_layers:
1157
+ try:
1158
+ out_png = cross_two_dir / f"{safe_sample}__{safe_ref}.png"
1159
+ plot_rolling_nn_and_two_layers(
1160
+ plot_subset,
1161
+ obsm_key=cfg.rolling_nn_obsm_key,
1162
+ layer_keys=plot_layers_resolved,
1163
+ fill_nn_with_colmax=False,
1164
+ drop_all_nan_windows=False,
1165
+ max_nan_fraction=cfg.position_max_nan_threshold,
1166
+ var_valid_fraction_col=f"{reference}_valid_fraction",
1167
+ title=title,
1168
+ save_name=out_png,
1169
+ )
1170
+ except Exception as exc:
1171
+ logger.warning(
1172
+ "Cross-sample two-layer plot failed for sample=%s ref=%s: %s",
1173
+ sample,
1174
+ reference,
1175
+ exc,
1176
+ )
1177
+
1178
+ # ============================================================
1179
+ # Delta: within-sample minus cross-sample hamming spans (clamped >= 0)
1180
+ # ============================================================
1181
+ if getattr(cfg, "cross_sample_analysis", False):
1182
+ DELTA_ZERO_HAMMING_DISTANCE_SPANS = "delta_zero_hamming_distance_spans"
1183
+ delta_summary_dir = chimeric_directory / "delta_hamming_summary"
1184
+
1185
+ if delta_summary_dir.is_dir() and not getattr(cfg, "force_redo_chimeric_analyses", False):
1186
+ logger.debug("Delta summary dir exists. Skipping delta analysis.")
1187
+ else:
1188
+ make_dirs([delta_summary_dir])
1189
+ samples = (
1190
+ adata.obs[cfg.sample_name_col_for_plotting]
1191
+ .astype("category")
1192
+ .cat.categories.tolist()
1193
+ )
1194
+ references = adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
1195
+
1196
+ # Build delta layer: within - cross, clamped at 0
1197
+ if (
1198
+ ZERO_HAMMING_DISTANCE_SPANS in adata.layers
1199
+ and CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS in adata.layers
1200
+ ):
1201
+ within_layer = np.asarray(
1202
+ adata.layers[ZERO_HAMMING_DISTANCE_SPANS], dtype=np.float64
1203
+ )
1204
+ cross_layer = np.asarray(
1205
+ adata.layers[CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS], dtype=np.float64
1206
+ )
1207
+ delta_layer = np.clip(within_layer - cross_layer, 0, None)
1208
+ adata.layers[DELTA_ZERO_HAMMING_DISTANCE_SPANS] = delta_layer
1209
+ threshold = getattr(cfg, "delta_hamming_chimeric_span_threshold", 200)
1210
+ try:
1211
+ threshold = int(threshold)
1212
+ except (TypeError, ValueError):
1213
+ logger.warning(
1214
+ "Invalid delta_hamming_chimeric_span_threshold=%s; using default 200.",
1215
+ threshold,
1216
+ )
1217
+ threshold = 200
1218
+ if threshold < 0:
1219
+ logger.warning(
1220
+ "delta_hamming_chimeric_span_threshold=%s is negative; clamping to 0.",
1221
+ threshold,
1222
+ )
1223
+ threshold = 0
1224
+ adata.obs["chimeric_by_mod_hamming_distance"] = (
1225
+ _compute_chimeric_by_mod_hamming_distance(delta_layer, threshold)
1226
+ )
1227
+ else:
1228
+ logger.warning(
1229
+ "Cannot compute delta: missing %s or %s layer.",
1230
+ ZERO_HAMMING_DISTANCE_SPANS,
1231
+ CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS,
1232
+ )
1233
+ adata.obs["chimeric_by_mod_hamming_distance"] = False
1234
+
1235
+ if DELTA_ZERO_HAMMING_DISTANCE_SPANS in adata.layers:
1236
+ for reference in references:
1237
+ ref_mask = adata.obs[cfg.reference_column] == reference
1238
+ position_col = f"position_in_{reference}"
1239
+ site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
1240
+ missing_cols = [
1241
+ col for col in [position_col, *site_cols] if col not in adata.var.columns
1242
+ ]
1243
+ if missing_cols:
1244
+ continue
1245
+ mod_site_mask = adata.var[site_cols].fillna(False).any(axis=1)
1246
+ site_mask = mod_site_mask & adata.var[position_col].fillna(False)
1247
+
1248
+ for sample in samples:
1249
+ sample_mask = (
1250
+ adata.obs[cfg.sample_name_col_for_plotting] == sample
1251
+ ) & ref_mask
1252
+ if not sample_mask.any():
1253
+ continue
1254
+
1255
+ safe_sample = str(sample).replace(os.sep, "_")
1256
+ safe_ref = str(reference).replace(os.sep, "_")
1257
+ within_obsm_key = f"{cfg.rolling_nn_obsm_key}__{safe_ref}"
1258
+ cross_obsm_key = f"cross_sample_rolling_nn_dist__{safe_ref}"
1259
+
1260
+ plot_subset = adata[sample_mask][:, site_mask].copy()
1261
+
1262
+ # Wire self NN obsm
1263
+ self_nn_key = "self_rolling_nn_dist"
1264
+ if within_obsm_key in plot_subset.obsm:
1265
+ plot_subset.obsm[self_nn_key] = plot_subset.obsm[within_obsm_key]
1266
+ elif cfg.rolling_nn_obsm_key in plot_subset.obsm:
1267
+ plot_subset.obsm[self_nn_key] = plot_subset.obsm[
1268
+ cfg.rolling_nn_obsm_key
1269
+ ]
1270
+ else:
1271
+ logger.debug(
1272
+ "Delta: missing self NN obsm for sample=%s ref=%s.",
1273
+ sample,
1274
+ reference,
1275
+ )
1276
+ continue
1277
+
1278
+ # Wire cross NN obsm
1279
+ cross_nn_key = "cross_rolling_nn_dist"
1280
+ if cross_obsm_key in plot_subset.obsm:
1281
+ plot_subset.obsm[cross_nn_key] = plot_subset.obsm[cross_obsm_key]
1282
+ else:
1283
+ logger.debug(
1284
+ "Delta: missing cross NN obsm for sample=%s ref=%s.",
1285
+ sample,
1286
+ reference,
1287
+ )
1288
+ continue
1289
+
1290
+ # Copy uns metadata for both NN keys
1291
+ for src_obsm, dst_obsm in (
1292
+ (within_obsm_key, self_nn_key),
1293
+ (cross_obsm_key, cross_nn_key),
1294
+ ):
1295
+ for suffix in (
1296
+ "starts",
1297
+ "centers",
1298
+ "window",
1299
+ "step",
1300
+ "min_overlap",
1301
+ "return_fraction",
1302
+ "layer",
1303
+ ):
1304
+ src_k = f"{src_obsm}_{suffix}"
1305
+ if src_k in adata.uns:
1306
+ plot_subset.uns[f"{dst_obsm}_{suffix}"] = adata.uns[src_k]
1307
+
1308
+ # Check required span layers
1309
+ required_layers = [
1310
+ ZERO_HAMMING_DISTANCE_SPANS,
1311
+ CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS,
1312
+ DELTA_ZERO_HAMMING_DISTANCE_SPANS,
1313
+ ]
1314
+ missing_layers = [
1315
+ lk for lk in required_layers if lk not in plot_subset.layers
1316
+ ]
1317
+ if missing_layers:
1318
+ logger.debug(
1319
+ "Delta: missing layers %s for sample=%s ref=%s.",
1320
+ missing_layers,
1321
+ sample,
1322
+ reference,
1323
+ )
1324
+ continue
1325
+
1326
+ title = (
1327
+ f"{sample} {reference}"
1328
+ f" (n={int(sample_mask.sum())})"
1329
+ f" | window={cfg.rolling_nn_window}"
1330
+ )
1331
+ out_png = delta_summary_dir / f"{safe_sample}__{safe_ref}.png"
1332
+ try:
1333
+ plot_delta_hamming_summary(
1334
+ plot_subset,
1335
+ self_obsm_key=self_nn_key,
1336
+ cross_obsm_key=cross_nn_key,
1337
+ layer_key=cfg.rolling_nn_plot_layer,
1338
+ self_span_layer_key=ZERO_HAMMING_DISTANCE_SPANS,
1339
+ cross_span_layer_key=CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS,
1340
+ delta_span_layer_key=DELTA_ZERO_HAMMING_DISTANCE_SPANS,
1341
+ fill_nn_with_colmax=False,
1342
+ drop_all_nan_windows=False,
1343
+ max_nan_fraction=cfg.position_max_nan_threshold,
1344
+ var_valid_fraction_col=f"{reference}_valid_fraction",
1345
+ title=title,
1346
+ save_name=out_png,
1347
+ )
1348
+ except Exception as exc:
1349
+ logger.warning(
1350
+ "Delta hamming summary plot failed for sample=%s ref=%s: %s",
1351
+ sample,
1352
+ reference,
1353
+ exc,
1354
+ )
1355
+
1356
+ # ============================================================
1357
+ # Hamming span trio (self, cross, delta) — no column subsetting
1358
+ # ============================================================
1359
+ if getattr(cfg, "cross_sample_analysis", False):
1360
+ span_trio_dir = chimeric_directory / "hamming_span_trio"
1361
+
1362
+ if span_trio_dir.is_dir() and not getattr(cfg, "force_redo_chimeric_analyses", False):
1363
+ logger.debug("Hamming span trio dir exists. Skipping.")
1364
+ else:
1365
+ _self_key = ZERO_HAMMING_DISTANCE_SPANS
1366
+ _cross_key = CROSS_SAMPLE_ZERO_HAMMING_DISTANCE_SPANS
1367
+ _delta_key = DELTA_ZERO_HAMMING_DISTANCE_SPANS
1368
+ has_layers = (
1369
+ _self_key in adata.layers
1370
+ and _cross_key in adata.layers
1371
+ and _delta_key in adata.layers
1372
+ )
1373
+ if has_layers:
1374
+ from smftools.plotting import plot_hamming_span_trio
1375
+
1376
+ make_dirs([span_trio_dir])
1377
+ samples = (
1378
+ adata.obs[cfg.sample_name_col_for_plotting]
1379
+ .astype("category")
1380
+ .cat.categories.tolist()
1381
+ )
1382
+ references = (
1383
+ adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
1384
+ )
1385
+
1386
+ for reference in references:
1387
+ ref_mask = adata.obs[cfg.reference_column] == reference
1388
+ position_col = f"position_in_{reference}"
1389
+ if position_col not in adata.var.columns:
1390
+ continue
1391
+ pos_mask = adata.var[position_col].fillna(False).astype(bool)
1392
+
1393
+ for sample in samples:
1394
+ sample_mask = (
1395
+ adata.obs[cfg.sample_name_col_for_plotting] == sample
1396
+ ) & ref_mask
1397
+ if not sample_mask.any():
1398
+ continue
1399
+
1400
+ # Build variant call DataFrame (full width, no subsetting)
1401
+ _variant_call_df = None
1402
+ if variant_call_layer_name and variant_call_layer_name in adata.layers:
1403
+ _vc = adata[sample_mask].layers[variant_call_layer_name]
1404
+ _vc = _vc.toarray() if hasattr(_vc, "toarray") else np.asarray(_vc)
1405
+ _variant_call_df = pd.DataFrame(
1406
+ _vc,
1407
+ index=adata[sample_mask].obs_names.astype(str),
1408
+ columns=adata.var_names,
1409
+ )
1410
+
1411
+ plot_subset = adata[sample_mask][:, pos_mask].copy()
1412
+
1413
+ safe_sample = str(sample).replace(os.sep, "_")
1414
+ safe_ref = str(reference).replace(os.sep, "_")
1415
+ n_reads = int(sample_mask.sum())
1416
+ trio_title = f"{sample} {reference} (n={n_reads})"
1417
+ out_png = span_trio_dir / f"{safe_sample}__{safe_ref}.png"
1418
+ try:
1419
+ plot_hamming_span_trio(
1420
+ plot_subset,
1421
+ self_span_layer_key=_self_key,
1422
+ cross_span_layer_key=_cross_key,
1423
+ delta_span_layer_key=_delta_key,
1424
+ variant_call_data=_variant_call_df,
1425
+ seq1_label=variant_seq1_label,
1426
+ seq2_label=variant_seq2_label,
1427
+ ref1_marker_color=getattr(
1428
+ cfg, "variant_overlay_seq1_color", "white"
1429
+ ),
1430
+ ref2_marker_color=getattr(
1431
+ cfg, "variant_overlay_seq2_color", "black"
1432
+ ),
1433
+ variant_marker_size=getattr(
1434
+ cfg, "variant_overlay_marker_size", 4.0
1435
+ ),
1436
+ title=trio_title,
1437
+ save_name=out_png,
1438
+ )
1439
+ except Exception as exc:
1440
+ logger.warning(
1441
+ "Hamming span trio plot failed for sample=%s ref=%s: %s",
1442
+ sample,
1443
+ reference,
1444
+ exc,
1445
+ )
1446
+
1447
+ # ============================================================
1448
+ # Span length distribution histograms
1449
+ # ============================================================
1450
+ if getattr(cfg, "cross_sample_analysis", False):
1451
+ span_hist_dir = chimeric_directory / "span_length_distributions"
1452
+ if span_hist_dir.is_dir() and not getattr(cfg, "force_redo_chimeric_analyses", False):
1453
+ logger.debug("Span length distribution dir exists. Skipping.")
1454
+ else:
1455
+ _self_key = ZERO_HAMMING_DISTANCE_SPANS
1456
+ _cross_key = "cross_sample_zero_hamming_distance_spans"
1457
+ _delta_key = "delta_zero_hamming_distance_spans"
1458
+ has_layers = (
1459
+ _self_key in adata.layers
1460
+ and _cross_key in adata.layers
1461
+ and _delta_key in adata.layers
1462
+ )
1463
+ if has_layers:
1464
+ make_dirs([span_hist_dir])
1465
+ samples = (
1466
+ adata.obs[cfg.sample_name_col_for_plotting]
1467
+ .astype("category")
1468
+ .cat.categories.tolist()
1469
+ )
1470
+ references = (
1471
+ adata.obs[cfg.reference_column].astype("category").cat.categories.tolist()
1472
+ )
1473
+ for reference in references:
1474
+ ref_mask = adata.obs[cfg.reference_column] == reference
1475
+ position_col = f"position_in_{reference}"
1476
+ site_cols = [f"{reference}_{st}_site" for st in cfg.rolling_nn_site_types]
1477
+ missing_cols = [
1478
+ col for col in [position_col, *site_cols] if col not in adata.var.columns
1479
+ ]
1480
+ if missing_cols:
1481
+ continue
1482
+ mod_site_mask = adata.var[site_cols].fillna(False).any(axis=1)
1483
+ site_mask = mod_site_mask & adata.var[position_col].fillna(False)
1484
+
1485
+ for sample in samples:
1486
+ sample_mask = (
1487
+ adata.obs[cfg.sample_name_col_for_plotting] == sample
1488
+ ) & ref_mask
1489
+ if not sample_mask.any():
1490
+ continue
1491
+
1492
+ safe_sample = str(sample).replace(os.sep, "_")
1493
+ safe_ref = str(reference).replace(os.sep, "_")
1494
+ plot_subset = adata[sample_mask][:, site_mask].copy()
1495
+
1496
+ title = f"{sample} {reference} (n={int(sample_mask.sum())})"
1497
+ out_png = span_hist_dir / f"{safe_sample}__{safe_ref}.png"
1498
+ try:
1499
+ plot_span_length_distributions(
1500
+ plot_subset,
1501
+ self_span_layer_key=_self_key,
1502
+ cross_span_layer_key=_cross_key,
1503
+ delta_span_layer_key=_delta_key,
1504
+ bins=getattr(
1505
+ cfg,
1506
+ "rolling_nn_zero_pairs_segment_histogram_bins",
1507
+ 30,
1508
+ ),
1509
+ title=title,
1510
+ save_name=out_png,
1511
+ )
1512
+ except Exception as exc:
1513
+ logger.warning(
1514
+ "Span length distribution plot failed for sample=%s ref=%s: %s",
1515
+ sample,
1516
+ reference,
1517
+ exc,
1518
+ )
1519
+ else:
1520
+ logger.debug("Span length distribution: missing required layers, skipping.")
1521
+
1522
+ # ============================================================
1523
+ # 4) Save AnnData
1524
+ # ============================================================
1525
+ zero_pairs_map_key = f"{cfg.rolling_nn_obsm_key}_zero_pairs_map"
1526
+ zero_pairs_map = adata.uns.get(zero_pairs_map_key, {})
1527
+ if not getattr(cfg, "rolling_nn_zero_pairs_keep_uns", True):
1528
+ for entry in zero_pairs_map.values():
1529
+ zero_pairs_key = entry.get("zero_pairs_key")
1530
+ if zero_pairs_key and zero_pairs_key in adata.uns:
1531
+ del adata.uns[zero_pairs_key]
1532
+ for suffix in (
1533
+ "starts",
1534
+ "window",
1535
+ "step",
1536
+ "min_overlap",
1537
+ "return_fraction",
1538
+ "layer",
1539
+ ):
1540
+ meta_key = f"{zero_pairs_key}_{suffix}"
1541
+ if meta_key in adata.uns:
1542
+ del adata.uns[meta_key]
1543
+ if zero_pairs_map_key in adata.uns:
1544
+ del adata.uns[zero_pairs_map_key]
1545
+ if not getattr(cfg, "rolling_nn_zero_pairs_segments_keep_uns", True):
1546
+ for entry in zero_pairs_map.values():
1547
+ segments_key = entry.get("segments_key")
1548
+ if segments_key and segments_key in adata.uns:
1549
+ del adata.uns[segments_key]
1550
+
1551
+ if not paths.chimeric.exists():
1552
+ logger.info("Saving chimeric analyzed AnnData")
1553
+ record_smftools_metadata(
1554
+ adata,
1555
+ step_name="chimeric",
1556
+ cfg=cfg,
1557
+ config_path=config_path,
1558
+ input_paths=[source_adata_path] if source_adata_path else None,
1559
+ output_path=paths.chimeric,
1560
+ )
1561
+ write_gz_h5ad(adata, paths.chimeric)
1562
+
1563
+ return adata, paths.chimeric