smftools 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/helpers.py +48 -0
  3. smftools/cli/hmm_adata.py +168 -145
  4. smftools/cli/load_adata.py +155 -95
  5. smftools/cli/preprocess_adata.py +222 -130
  6. smftools/cli/spatial_adata.py +441 -308
  7. smftools/cli_entry.py +4 -5
  8. smftools/config/conversion.yaml +12 -5
  9. smftools/config/deaminase.yaml +11 -9
  10. smftools/config/default.yaml +123 -19
  11. smftools/config/direct.yaml +3 -0
  12. smftools/config/experiment_config.py +120 -19
  13. smftools/hmm/HMM.py +12 -1
  14. smftools/hmm/__init__.py +0 -6
  15. smftools/hmm/archived/call_hmm_peaks.py +106 -0
  16. smftools/hmm/call_hmm_peaks.py +318 -90
  17. smftools/informatics/bam_functions.py +28 -29
  18. smftools/informatics/h5ad_functions.py +1 -1
  19. smftools/plotting/general_plotting.py +97 -51
  20. smftools/plotting/position_stats.py +3 -3
  21. smftools/preprocessing/__init__.py +2 -4
  22. smftools/preprocessing/append_base_context.py +34 -25
  23. smftools/preprocessing/append_binary_layer_by_base_context.py +2 -2
  24. smftools/preprocessing/binarize_on_Youden.py +10 -8
  25. smftools/preprocessing/calculate_complexity_II.py +1 -1
  26. smftools/preprocessing/calculate_coverage.py +16 -13
  27. smftools/preprocessing/calculate_position_Youden.py +41 -25
  28. smftools/preprocessing/calculate_read_modification_stats.py +1 -1
  29. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +1 -1
  30. smftools/preprocessing/filter_reads_on_modification_thresholds.py +1 -1
  31. smftools/preprocessing/flag_duplicate_reads.py +1 -1
  32. smftools/preprocessing/invert_adata.py +1 -1
  33. smftools/preprocessing/load_sample_sheet.py +1 -1
  34. smftools/preprocessing/reindex_references_adata.py +37 -0
  35. smftools/readwrite.py +94 -0
  36. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/METADATA +18 -12
  37. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/RECORD +46 -43
  38. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  39. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  40. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  41. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  42. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archives/add_read_length_and_mapping_qc.py} +0 -0
  43. /smftools/preprocessing/{calculate_complexity.py → archives/calculate_complexity.py} +0 -0
  44. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/WHEEL +0 -0
  45. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/entry_points.txt +0 -0
  46. {smftools-0.2.3.dist-info → smftools-0.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -70,24 +70,15 @@ def _index_bam_with_pysam(bam_path: Union[str, Path], threads: Optional[int] = N
70
70
 
71
71
  def align_and_sort_BAM(fasta,
72
72
  input,
73
- bam_suffix='.bam',
74
- output_directory='aligned_outputs',
75
- make_bigwigs=False,
76
- threads=None,
77
- aligner='minimap2',
78
- aligner_args=['-a', '-x', 'map-ont', '--MD', '-Y', '-y', '-N', '5', '--secondary=no']):
73
+ cfg,
74
+ ):
79
75
  """
80
76
  A wrapper for running dorado aligner and samtools functions
81
77
 
82
78
  Parameters:
83
79
  fasta (str): File path to the reference genome to align to.
84
80
  input (str): File path to the basecalled file to align. Works for .bam and .fastq files
85
- bam_suffix (str): The suffix to use for the BAM file.
86
- output_directory (str): A file path to the directory to output all the analyses.
87
- make_bigwigs (bool): Whether to make bigwigs
88
- threads (int): Number of additional threads to use
89
- aligner (str): Aligner to use. minimap2 and dorado options
90
- aligner_args (list): list of optional parameters to use for the alignment
81
+ cfg: The configuration object
91
82
 
92
83
  Returns:
93
84
  None
@@ -97,40 +88,48 @@ def align_and_sort_BAM(fasta,
97
88
  input_suffix = input.suffix
98
89
  input_as_fastq = input.with_name(input.stem + '.fastq')
99
90
 
100
- output_path_minus_suffix = output_directory / input.stem
91
+ output_path_minus_suffix = cfg.output_directory / input.stem
101
92
 
102
93
  aligned_BAM = output_path_minus_suffix.with_name(output_path_minus_suffix.stem + "_aligned")
103
- aligned_output = aligned_BAM.with_suffix(bam_suffix)
94
+ aligned_output = aligned_BAM.with_suffix(cfg.bam_suffix)
104
95
  aligned_sorted_BAM =aligned_BAM.with_name(aligned_BAM.stem + "_sorted")
105
- aligned_sorted_output = aligned_sorted_BAM.with_suffix(bam_suffix)
96
+ aligned_sorted_output = aligned_sorted_BAM.with_suffix(cfg.bam_suffix)
106
97
 
107
- if threads:
108
- threads = str(threads)
98
+ if cfg.threads:
99
+ threads = str(cfg.threads)
109
100
  else:
110
- pass
101
+ threads = None
111
102
 
112
- if aligner == 'minimap2':
113
- print(f"Converting BAM to FASTQ: {input}")
114
- _bam_to_fastq_with_pysam(input, input_as_fastq)
115
- print(f"Aligning FASTQ to Reference: {input_as_fastq}")
103
+ if cfg.aligner == 'minimap2':
104
+ if not cfg.align_from_bam:
105
+ print(f"Converting BAM to FASTQ: {input}")
106
+ _bam_to_fastq_with_pysam(input, input_as_fastq)
107
+ print(f"Aligning FASTQ to Reference: {input_as_fastq}")
108
+ mm_input = input_as_fastq
109
+ else:
110
+ print(f"Aligning BAM to Reference: {input}")
111
+ mm_input = input
112
+
116
113
  if threads:
117
- minimap_command = ['minimap2'] + aligner_args + ['-t', threads, str(fasta), str(input_as_fastq)]
114
+ minimap_command = ['minimap2'] + cfg.aligner_args + ['-t', threads, str(fasta), str(mm_input)]
118
115
  else:
119
- minimap_command = ['minimap2'] + aligner_args + [str(fasta), str(input_as_fastq)]
116
+ minimap_command = ['minimap2'] + cfg.aligner_args + [str(fasta), str(mm_input)]
120
117
  subprocess.run(minimap_command, stdout=open(aligned_output, "wb"))
121
- os.remove(input_as_fastq)
122
118
 
123
- elif aligner == 'dorado':
119
+ if not cfg.align_from_bam:
120
+ os.remove(input_as_fastq)
121
+
122
+ elif cfg.aligner == 'dorado':
124
123
  # Run dorado aligner
125
124
  print(f"Aligning BAM to Reference: {input}")
126
125
  if threads:
127
- alignment_command = ["dorado", "aligner", "-t", threads] + aligner_args + [str(fasta), str(input)]
126
+ alignment_command = ["dorado", "aligner", "-t", threads] + cfg.aligner_args + [str(fasta), str(input)]
128
127
  else:
129
- alignment_command = ["dorado", "aligner"] + aligner_args + [str(fasta), str(input)]
128
+ alignment_command = ["dorado", "aligner"] + cfg.aligner_args + [str(fasta), str(input)]
130
129
  subprocess.run(alignment_command, stdout=open(aligned_output, "wb"))
131
130
 
132
131
  else:
133
- print(f'Aligner not recognized: {aligner}. Choose from minimap2 and dorado')
132
+ print(f'Aligner not recognized: {cfg.aligner}. Choose from minimap2 and dorado')
134
133
  return
135
134
 
136
135
  # --- Sort & Index with pysam ---
@@ -75,7 +75,7 @@ def add_read_length_and_mapping_qc(
75
75
  adata,
76
76
  bam_files: Optional[List[str]] = None,
77
77
  read_metrics: Optional[Dict[str, Union[list, tuple]]] = None,
78
- uns_flag: str = "read_lenth_and_mapping_qc_performed",
78
+ uns_flag: str = "add_read_length_and_mapping_qc_performed",
79
79
  extract_read_features_from_bam_callable = None,
80
80
  bypass: bool = False,
81
81
  force_redo: bool = True
@@ -9,9 +9,62 @@ import os
9
9
  import math
10
10
  import pandas as pd
11
11
 
12
- from typing import Optional, Mapping, Sequence, Any, Dict, List
12
+ from typing import Optional, Mapping, Sequence, Any, Dict, List, Tuple
13
13
  from pathlib import Path
14
14
 
15
+ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
16
+ """
17
+ Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
18
+ Always includes 0 and n_positions-1 when possible.
19
+ """
20
+ n_ticks = int(max(2, n_ticks))
21
+ if n_positions <= n_ticks:
22
+ return np.arange(n_positions)
23
+
24
+ # linspace gives fixed count
25
+ pos = np.linspace(0, n_positions - 1, n_ticks)
26
+ return np.unique(np.round(pos).astype(int))
27
+
28
+ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
29
+ """
30
+ Select tick labels for the heatmap axis.
31
+
32
+ Parameters
33
+ ----------
34
+ subset : AnnData view
35
+ The per-bin subset of the AnnData.
36
+ sites : np.ndarray[int]
37
+ Indices of the subset.var positions to annotate.
38
+ reference : str
39
+ Reference name (e.g., '6B6_top').
40
+ index_col_suffix : None or str
41
+ If None → use subset.var_names
42
+ Else → use subset.var[f"{reference}_{index_col_suffix}"]
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray[str]
47
+ The labels to use for tick positions.
48
+ """
49
+ if sites.size == 0:
50
+ return np.array([])
51
+
52
+ # Default behavior: use var_names
53
+ if index_col_suffix is None:
54
+ return subset.var_names[sites].astype(str)
55
+
56
+ # Otherwise: use a computed column adata.var[f"{reference}_{suffix}"]
57
+ colname = f"{reference}_{index_col_suffix}"
58
+
59
+ if colname not in subset.var:
60
+ raise KeyError(
61
+ f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
62
+ f"but it is not present in adata.var."
63
+ )
64
+
65
+ labels = subset.var[colname].astype(str).values
66
+ return labels[sites]
67
+
15
68
  def normalized_mean(matrix: np.ndarray) -> np.ndarray:
16
69
  mean = np.nanmean(matrix, axis=0)
17
70
  denom = (mean.max() - mean.min()) + 1e-9
@@ -266,7 +319,6 @@ def clean_barplot(ax, mean_values, title):
266
319
  # traceback.print_exc()
267
320
  # continue
268
321
 
269
-
270
322
  def combined_hmm_raw_clustermap(
271
323
  adata,
272
324
  sample_col: str = "Sample_Names",
@@ -276,13 +328,13 @@ def combined_hmm_raw_clustermap(
276
328
 
277
329
  layer_gpc: str = "nan0_0minus1",
278
330
  layer_cpg: str = "nan0_0minus1",
279
- layer_any_c: str = "nan0_0minus1",
331
+ layer_c: str = "nan0_0minus1",
280
332
  layer_a: str = "nan0_0minus1",
281
333
 
282
334
  cmap_hmm: str = "tab10",
283
335
  cmap_gpc: str = "coolwarm",
284
336
  cmap_cpg: str = "viridis",
285
- cmap_any_c: str = "coolwarm",
337
+ cmap_c: str = "coolwarm",
286
338
  cmap_a: str = "coolwarm",
287
339
 
288
340
  min_quality: int = 20,
@@ -305,15 +357,17 @@ def combined_hmm_raw_clustermap(
305
357
  n_xticks_gpc: int = 8,
306
358
  n_xticks_cpg: int = 8,
307
359
  n_xticks_a: int = 8,
360
+
361
+ index_col_suffix: str | None = None,
308
362
  ):
309
363
  """
310
364
  Makes a multi-panel clustermap per (sample, reference):
311
- HMM panel (always) + optional raw panels for any_C, GpC, CpG, and A sites.
365
+ HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
312
366
 
313
367
  Panels are added only if the corresponding site mask exists AND has >0 sites.
314
368
 
315
369
  sort_by options:
316
- 'gpc', 'cpg', 'any_c', 'any_a', 'gpc_cpg', 'none', or 'obs:<col>'
370
+ 'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
317
371
  """
318
372
  def pick_xticks(labels: np.ndarray, n_ticks: int):
319
373
  if labels.size == 0:
@@ -363,18 +417,19 @@ def combined_hmm_raw_clustermap(
363
417
  return np.where(subset.var[k].values)[0]
364
418
  return np.array([], dtype=int)
365
419
 
366
- gpc_sites = _sites(f"{ref}_GpC_site")
367
- cpg_sites = _sites(f"{ref}_CpG_site")
420
+ gpc_sites = _sites(f"{ref}_GpC_site")
421
+ cpg_sites = _sites(f"{ref}_CpG_site")
368
422
  any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
369
423
  any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
370
424
 
371
- def _labels(sites):
372
- return subset.var_names[sites].astype(int) if sites.size else np.array([])
373
-
374
- gpc_labels = _labels(gpc_sites)
375
- cpg_labels = _labels(cpg_sites)
376
- any_c_labels = _labels(any_c_sites)
377
- any_a_labels = _labels(any_a_sites)
425
+ # ---- labels via _select_labels ----
426
+ # HMM uses *all* columns
427
+ hmm_sites = np.arange(subset.n_vars, dtype=int)
428
+ hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
429
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
430
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
431
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
432
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
378
433
 
379
434
  # storage
380
435
  stacked_hmm = []
@@ -411,17 +466,21 @@ def combined_hmm_raw_clustermap(
411
466
  linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
412
467
  order = sch.leaves_list(linkage)
413
468
 
414
- elif sort_by == "any_c" and any_c_sites.size:
415
- linkage = sch.linkage(sb[:, any_c_sites].layers[layer_any_c], method="ward")
469
+ elif sort_by == "c" and any_c_sites.size:
470
+ linkage = sch.linkage(sb[:, any_c_sites].layers[layer_c], method="ward")
416
471
  order = sch.leaves_list(linkage)
417
472
 
418
- elif sort_by == "any_a" and any_a_sites.size:
473
+ elif sort_by == "a" and any_a_sites.size:
419
474
  linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
420
475
  order = sch.leaves_list(linkage)
421
476
 
422
477
  elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
423
478
  linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
424
479
  order = sch.leaves_list(linkage)
480
+
481
+ elif sort_by == "hmm" and hmm_sites.size:
482
+ linkage = sch.linkage(sb[:, hmm_sites].layers[hmm_feature_layer], method="ward")
483
+ order = sch.leaves_list(linkage)
425
484
 
426
485
  else:
427
486
  order = np.arange(n)
@@ -431,7 +490,7 @@ def combined_hmm_raw_clustermap(
431
490
  # ---- collect matrices ----
432
491
  stacked_hmm.append(sb.layers[hmm_feature_layer])
433
492
  if any_c_sites.size:
434
- stacked_any_c.append(sb[:, any_c_sites].layers[layer_any_c])
493
+ stacked_any_c.append(sb[:, any_c_sites].layers[layer_c])
435
494
  if gpc_sites.size:
436
495
  stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
437
496
  if cpg_sites.size:
@@ -449,12 +508,12 @@ def combined_hmm_raw_clustermap(
449
508
  mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
450
509
 
451
510
  panels = [
452
- ("HMM", hmm_matrix, subset.var_names.astype(int), cmap_hmm, mean_hmm, n_xticks_hmm),
511
+ (f"HMM - {hmm_feature_layer}", hmm_matrix, hmm_labels, cmap_hmm, mean_hmm, n_xticks_hmm),
453
512
  ]
454
513
 
455
514
  if stacked_any_c:
456
515
  m = np.vstack(stacked_any_c)
457
- panels.append(("any_C", m, any_c_labels, cmap_any_c, methylation_fraction(m), n_xticks_any_c))
516
+ panels.append(("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c))
458
517
 
459
518
  if stacked_gpc:
460
519
  m = np.vstack(stacked_gpc)
@@ -777,29 +836,16 @@ def combined_hmm_raw_clustermap(
777
836
  # traceback.print_exc()
778
837
  # continue
779
838
 
780
- def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
781
- """
782
- Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
783
- Always includes 0 and n_positions-1 when possible.
784
- """
785
- n_ticks = int(max(2, n_ticks))
786
- if n_positions <= n_ticks:
787
- return np.arange(n_positions)
788
-
789
- # linspace gives fixed count
790
- pos = np.linspace(0, n_positions - 1, n_ticks)
791
- return np.unique(np.round(pos).astype(int))
792
-
793
839
  def combined_raw_clustermap(
794
840
  adata,
795
841
  sample_col: str = "Sample_Names",
796
842
  reference_col: str = "Reference_strand",
797
843
  mod_target_bases: Sequence[str] = ("GpC", "CpG"),
798
- layer_any_c: str = "nan0_0minus1",
844
+ layer_c: str = "nan0_0minus1",
799
845
  layer_gpc: str = "nan0_0minus1",
800
846
  layer_cpg: str = "nan0_0minus1",
801
847
  layer_a: str = "nan0_0minus1",
802
- cmap_any_c: str = "coolwarm",
848
+ cmap_c: str = "coolwarm",
803
849
  cmap_gpc: str = "coolwarm",
804
850
  cmap_cpg: str = "viridis",
805
851
  cmap_a: str = "coolwarm",
@@ -809,20 +855,20 @@ def combined_raw_clustermap(
809
855
  min_position_valid_fraction: float = 0.5,
810
856
  sample_mapping: Optional[Mapping[str, str]] = None,
811
857
  save_path: str | Path | None = None,
812
- sort_by: str = "gpc", # 'gpc','cpg','any_c','gpc_cpg','any_a','none','obs:<col>'
858
+ sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
813
859
  bins: Optional[Dict[str, Any]] = None,
814
860
  deaminase: bool = False,
815
861
  min_signal: float = 0,
816
- # NEW tick controls
817
862
  n_xticks_any_c: int = 10,
818
863
  n_xticks_gpc: int = 10,
819
864
  n_xticks_cpg: int = 10,
820
865
  n_xticks_any_a: int = 10,
821
866
  xtick_rotation: int = 90,
822
867
  xtick_fontsize: int = 9,
868
+ index_col_suffix: str | None = None,
823
869
  ):
824
870
  """
825
- Plot stacked heatmaps + per-position mean barplots for any_C, GpC, CpG, and optional A.
871
+ Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
826
872
 
827
873
  Key fixes vs old version:
828
874
  - order computed ONCE per bin, applied to all matrices
@@ -898,14 +944,14 @@ def combined_raw_clustermap(
898
944
 
899
945
  num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
900
946
 
901
- any_c_labels = subset.var_names[any_c_sites].astype(str)
902
- gpc_labels = subset.var_names[gpc_sites].astype(str)
903
- cpg_labels = subset.var_names[cpg_sites].astype(str)
947
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
948
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
949
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
904
950
 
905
951
  if include_any_a:
906
952
  any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
907
953
  num_any_a = len(any_a_sites)
908
- any_a_labels = subset.var_names[any_a_sites].astype(str)
954
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
909
955
 
910
956
  stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
911
957
  row_labels, bin_labels, bin_boundaries = [], [], []
@@ -939,15 +985,15 @@ def combined_raw_clustermap(
939
985
  linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
940
986
  order = sch.leaves_list(linkage)
941
987
 
942
- elif sort_by == "any_c" and num_any_c > 0:
943
- linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
988
+ elif sort_by == "c" and num_any_c > 0:
989
+ linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_c], method="ward")
944
990
  order = sch.leaves_list(linkage)
945
991
 
946
992
  elif sort_by == "gpc_cpg":
947
993
  linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
948
994
  order = sch.leaves_list(linkage)
949
995
 
950
- elif sort_by == "any_a" and num_any_a > 0:
996
+ elif sort_by == "a" and num_any_a > 0:
951
997
  linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
952
998
  order = sch.leaves_list(linkage)
953
999
 
@@ -961,7 +1007,7 @@ def combined_raw_clustermap(
961
1007
 
962
1008
  # stack consistently
963
1009
  if include_any_c and num_any_c > 0:
964
- stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_any_c])
1010
+ stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_c])
965
1011
  if include_any_c and num_gpc > 0:
966
1012
  stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
967
1013
  if include_any_c and num_cpg > 0:
@@ -990,11 +1036,11 @@ def combined_raw_clustermap(
990
1036
 
991
1037
  if any_c_matrix.size:
992
1038
  blocks.append(dict(
993
- name="any_c",
1039
+ name="c",
994
1040
  matrix=any_c_matrix,
995
1041
  mean=mean_any_c,
996
1042
  labels=any_c_labels,
997
- cmap=cmap_any_c,
1043
+ cmap=cmap_c,
998
1044
  n_xticks=n_xticks_any_c,
999
1045
  title="any C site Modification Signal"
1000
1046
  ))
@@ -1024,7 +1070,7 @@ def combined_raw_clustermap(
1024
1070
  mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1025
1071
  if any_a_matrix.size:
1026
1072
  blocks.append(dict(
1027
- name="any_a",
1073
+ name="a",
1028
1074
  matrix=any_a_matrix,
1029
1075
  mean=mean_any_a,
1030
1076
  labels=any_a_labels,
@@ -1141,7 +1187,7 @@ def plot_hmm_layers_rolling_by_sample_ref(
1141
1187
  output_dir: Optional[str] = None,
1142
1188
  save: bool = True,
1143
1189
  show_raw: bool = False,
1144
- cmap: str = "tab10",
1190
+ cmap: str = "tab20",
1145
1191
  use_var_coords: bool = True,
1146
1192
  ):
1147
1193
  """
@@ -90,7 +90,7 @@ def plot_volcano_relative_risk(
90
90
  safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
91
91
  out_file = os.path.join(save_path, f"{safe_name}.png")
92
92
  plt.savefig(out_file, dpi=300)
93
- print(f"📁 Saved: {out_file}")
93
+ print(f"Saved: {out_file}")
94
94
 
95
95
  plt.show()
96
96
 
@@ -449,7 +449,7 @@ def plot_positionwise_matrix_grid(
449
449
  os.makedirs(save_path, exist_ok=True)
450
450
  fname = outer_label.replace("_", "").replace("=", "") + ".png"
451
451
  plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
452
- print(f"Saved {fname}")
452
+ print(f"Saved {fname}")
453
453
 
454
454
  plt.close(fig)
455
455
 
@@ -459,4 +459,4 @@ def plot_positionwise_matrix_grid(
459
459
  for outer_label in parsed['outer'].unique():
460
460
  plot_one_grid(outer_label)
461
461
 
462
- print("Finished plotting all grids.")
462
+ print("Finished plotting all grids.")
@@ -1,9 +1,7 @@
1
- from .add_read_length_and_mapping_qc import add_read_length_and_mapping_qc
2
1
  from .append_base_context import append_base_context
3
2
  from .append_binary_layer_by_base_context import append_binary_layer_by_base_context
4
3
  from .binarize_on_Youden import binarize_on_Youden
5
4
  from .binarize import binarize_adata
6
- from .calculate_complexity import calculate_complexity
7
5
  from .calculate_complexity_II import calculate_complexity_II
8
6
  from .calculate_read_modification_stats import calculate_read_modification_stats
9
7
  from .calculate_coverage import calculate_coverage
@@ -16,15 +14,15 @@ from .filter_reads_on_length_quality_mapping import filter_reads_on_length_quali
16
14
  from .invert_adata import invert_adata
17
15
  from .load_sample_sheet import load_sample_sheet
18
16
  from .flag_duplicate_reads import flag_duplicate_reads
17
+ from .reindex_references_adata import reindex_references_adata
19
18
  from .subsample_adata import subsample_adata
20
19
 
21
20
  __all__ = [
22
- "add_read_length_and_mapping_qc",
23
21
  "append_base_context",
24
22
  "append_binary_layer_by_base_context",
25
23
  "binarize_on_Youden",
26
24
  "binarize_adata",
27
- "calculate_complexity",
25
+ "calculate_complexity_II",
28
26
  "calculate_read_modification_stats",
29
27
  "calculate_coverage",
30
28
  "calculate_position_Youden",
@@ -1,18 +1,19 @@
1
1
  def append_base_context(adata,
2
- obs_column='Reference_strand',
2
+ ref_column='Reference_strand',
3
3
  use_consensus=False,
4
4
  native=False,
5
5
  mod_target_bases=['GpC', 'CpG'],
6
6
  bypass=False,
7
7
  force_redo=False,
8
- uns_flag='base_context_added'
8
+ uns_flag='append_base_context_performed'
9
9
  ):
10
10
  """
11
11
  Adds nucleobase context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
12
+ This needs to be performed prior to AnnData inversion step.
12
13
 
13
14
  Parameters:
14
15
  adata (AnnData): The input adata object.
15
- obs_column (str): The observation column in which to stratify on. Default is 'Reference_strand', which should not be changed for most purposes.
16
+ ref_column (str): The observation column in which to stratify on. Default is 'Reference_strand', which should not be changed for most purposes.
16
17
  use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
17
18
  native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
18
19
  mod_target_bases (list): Base contexts that may be modified.
@@ -30,7 +31,7 @@ def append_base_context(adata,
30
31
  return
31
32
 
32
33
  print('Adding base context based on reference FASTA sequence for sample')
33
- categories = adata.obs[obs_column].cat.categories
34
+ references = adata.obs[ref_column].cat.categories
34
35
  site_types = []
35
36
 
36
37
  if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
@@ -39,59 +40,60 @@ def append_base_context(adata,
39
40
  if 'A' in mod_target_bases:
40
41
  site_types += ['A_site']
41
42
 
42
- for cat in categories:
43
+ for ref in references:
43
44
  # Assess if the strand is the top or bottom strand converted
44
- if 'top' in cat:
45
+ if 'top' in ref:
45
46
  strand = 'top'
46
- elif 'bottom' in cat:
47
+ elif 'bottom' in ref:
47
48
  strand = 'bottom'
48
49
 
49
50
  if native:
50
- basename = cat.split(f"_{strand}")[0]
51
+ basename = ref.split(f"_{strand}")[0]
51
52
  if use_consensus:
52
53
  sequence = adata.uns[f'{basename}_consensus_sequence']
53
54
  else:
54
55
  # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
55
56
  sequence = adata.uns[f'{basename}_FASTA_sequence']
56
57
  else:
57
- basename = cat.split(f"_{strand}")[0]
58
+ basename = ref.split(f"_{strand}")[0]
58
59
  if use_consensus:
59
60
  sequence = adata.uns[f'{basename}_consensus_sequence']
60
61
  else:
61
62
  # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
62
63
  sequence = adata.uns[f'{basename}_FASTA_sequence']
64
+
63
65
  # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
64
66
  boolean_dict = {}
65
67
  for site_type in site_types:
66
- boolean_dict[f'{cat}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
68
+ boolean_dict[f'{ref}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
67
69
 
68
70
  if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
69
71
  if strand == 'top':
70
72
  # Iterate through the sequence and apply the criteria
71
73
  for i in range(1, len(sequence) - 1):
72
74
  if sequence[i] == 'C':
73
- boolean_dict[f'{cat}_C_site'][i] = True
75
+ boolean_dict[f'{ref}_C_site'][i] = True
74
76
  if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
75
- boolean_dict[f'{cat}_GpC_site'][i] = True
77
+ boolean_dict[f'{ref}_GpC_site'][i] = True
76
78
  elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
77
- boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
79
+ boolean_dict[f'{ref}_ambiguous_GpC_CpG_site'][i] = True
78
80
  elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
79
- boolean_dict[f'{cat}_CpG_site'][i] = True
81
+ boolean_dict[f'{ref}_CpG_site'][i] = True
80
82
  elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
81
- boolean_dict[f'{cat}_other_C_site'][i] = True
83
+ boolean_dict[f'{ref}_other_C_site'][i] = True
82
84
  elif strand == 'bottom':
83
85
  # Iterate through the sequence and apply the criteria
84
86
  for i in range(1, len(sequence) - 1):
85
87
  if sequence[i] == 'G':
86
- boolean_dict[f'{cat}_C_site'][i] = True
88
+ boolean_dict[f'{ref}_C_site'][i] = True
87
89
  if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
88
- boolean_dict[f'{cat}_GpC_site'][i] = True
90
+ boolean_dict[f'{ref}_GpC_site'][i] = True
89
91
  elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
90
- boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
92
+ boolean_dict[f'{ref}_ambiguous_GpC_CpG_site'][i] = True
91
93
  elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
92
- boolean_dict[f'{cat}_CpG_site'][i] = True
94
+ boolean_dict[f'{ref}_CpG_site'][i] = True
93
95
  elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
94
- boolean_dict[f'{cat}_other_C_site'][i] = True
96
+ boolean_dict[f'{ref}_other_C_site'][i] = True
95
97
  else:
96
98
  print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
97
99
 
@@ -100,21 +102,28 @@ def append_base_context(adata,
100
102
  # Iterate through the sequence and apply the criteria
101
103
  for i in range(1, len(sequence) - 1):
102
104
  if sequence[i] == 'A':
103
- boolean_dict[f'{cat}_A_site'][i] = True
105
+ boolean_dict[f'{ref}_A_site'][i] = True
104
106
  elif strand == 'bottom':
105
107
  # Iterate through the sequence and apply the criteria
106
108
  for i in range(1, len(sequence) - 1):
107
109
  if sequence[i] == 'T':
108
- boolean_dict[f'{cat}_A_site'][i] = True
110
+ boolean_dict[f'{ref}_A_site'][i] = True
109
111
  else:
110
112
  print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
111
113
 
112
114
  for site_type in site_types:
113
- adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
115
+ # Site context annotations for each reference
116
+ adata.var[f'{ref}_{site_type}'] = boolean_dict[f'{ref}_{site_type}'].astype(bool)
117
+ # Restrict the site type labels to only be in positions that occur at a high enough frequency in the dataset
118
+ if adata.uns["calculate_coverage_performed"] == True:
119
+ adata.var[f'{ref}_{site_type}'] = (adata.var[f'{ref}_{site_type}']) & (adata.var[f'position_in_{ref}'])
120
+ else:
121
+ pass
122
+
114
123
  if native:
115
- adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].layers['binarized_methylation']
124
+ adata.obsm[f'{ref}_{site_type}'] = adata[:, adata.var[f'{ref}_{site_type}'] == True].layers['binarized_methylation']
116
125
  else:
117
- adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
126
+ adata.obsm[f'{ref}_{site_type}'] = adata[:, adata.var[f'{ref}_{site_type}'] == True].X
118
127
 
119
128
  # mark as done
120
129
  adata.uns[uns_flag] = True
@@ -6,7 +6,7 @@ def append_binary_layer_by_base_context(
6
6
  reference_column: str,
7
7
  smf_modality: str = "conversion",
8
8
  verbose: bool = True,
9
- uns_flag: str = "binary_layers_by_base_context_added",
9
+ uns_flag: str = "append_binary_layer_by_base_context_performed",
10
10
  bypass: bool = False,
11
11
  force_redo: bool = False
12
12
  ):
@@ -27,7 +27,7 @@ def append_binary_layer_by_base_context(
27
27
 
28
28
  # Only run if not already performed
29
29
  already = bool(adata.uns.get(uns_flag, False))
30
- if (already and not force_redo) or bypass or ("base_context_added" not in adata.uns):
30
+ if (already and not force_redo) or bypass or ("append_base_context_performed" not in adata.uns):
31
31
  # QC already performed; nothing to do
32
32
  return adata
33
33