smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,8 @@
1
1
  import os
2
- import numpy as np
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
5
2
 
6
- import os
3
+ import matplotlib.pyplot as plt
7
4
  import numpy as np
8
5
  import pandas as pd
9
- import matplotlib.pyplot as plt
10
6
 
11
7
 
12
8
  def plot_read_qc_histograms(
@@ -83,7 +79,11 @@ def plot_read_qc_histograms(
83
79
  for key in valid_keys:
84
80
  if not is_numeric[key]:
85
81
  continue
86
- s = pd.to_numeric(adata.obs[key], errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
82
+ s = (
83
+ pd.to_numeric(adata.obs[key], errors="coerce")
84
+ .replace([np.inf, -np.inf], np.nan)
85
+ .dropna()
86
+ )
87
87
  if s.size < min_non_nan:
88
88
  # still set something to avoid errors; just use min/max or (0,1)
89
89
  lo, hi = (0.0, 1.0) if s.size == 0 else (float(s.min()), float(s.max()))
@@ -99,6 +99,7 @@ def plot_read_qc_histograms(
99
99
  global_ranges[key] = (lo, hi)
100
100
 
101
101
  def _sanitize(name: str) -> str:
102
+ """Sanitize a string for use in filenames."""
102
103
  return "".join(c if c.isalnum() or c in "-._" else "_" for c in str(name))
103
104
 
104
105
  ncols = len(valid_keys)
@@ -107,17 +108,18 @@ def plot_read_qc_histograms(
107
108
  fig_h_unit = figsize_cell[1]
108
109
 
109
110
  for start in range(0, len(sample_levels), rows_per_fig):
110
- chunk = sample_levels[start:start + rows_per_fig]
111
+ chunk = sample_levels[start : start + rows_per_fig]
111
112
  nrows = len(chunk)
112
113
  fig, axes = plt.subplots(
113
- nrows=nrows, ncols=ncols,
114
+ nrows=nrows,
115
+ ncols=ncols,
114
116
  figsize=(fig_w, fig_h_unit * nrows),
115
117
  dpi=dpi,
116
118
  squeeze=False,
117
119
  )
118
120
 
119
121
  for r, sample_val in enumerate(chunk):
120
- row_mask = (adata.obs[sample_key].values == sample_val)
122
+ row_mask = adata.obs[sample_key].values == sample_val
121
123
  n_in_row = int(row_mask.sum())
122
124
 
123
125
  for c, key in enumerate(valid_keys):
@@ -125,7 +127,11 @@ def plot_read_qc_histograms(
125
127
  series = adata.obs.loc[row_mask, key]
126
128
 
127
129
  if is_numeric[key]:
128
- x = pd.to_numeric(series, errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
130
+ x = (
131
+ pd.to_numeric(series, errors="coerce")
132
+ .replace([np.inf, -np.inf], np.nan)
133
+ .dropna()
134
+ )
129
135
  if x.size < min_non_nan:
130
136
  ax.text(0.5, 0.5, f"n={x.size} (<{min_non_nan})", ha="center", va="center")
131
137
  else:
@@ -143,7 +149,9 @@ def plot_read_qc_histograms(
143
149
  else:
144
150
  vc = series.astype("category").value_counts(dropna=False)
145
151
  if vc.sum() < min_non_nan:
146
- ax.text(0.5, 0.5, f"n={vc.sum()} (<{min_non_nan})", ha="center", va="center")
152
+ ax.text(
153
+ 0.5, 0.5, f"n={vc.sum()} (<{min_non_nan})", ha="center", va="center"
154
+ )
147
155
  else:
148
156
  vc_top = vc.iloc[:topn_categories][::-1] # show top-N, reversed for barh
149
157
  ax.barh(vc_top.index.astype(str), vc_top.values)
@@ -267,4 +275,4 @@ def plot_read_qc_histograms(
267
275
  # fname = f"{key}_{sample_key}_{safe_group}.png" if sample_key else f"{key}.png"
268
276
  # fname = fname.replace("/", "_")
269
277
  # fig.savefig(os.path.join(outdir, fname))
270
- # plt.close(fig)
278
+ # plt.close(fig)
@@ -1,40 +1,38 @@
1
- from .add_read_length_and_mapping_qc import add_read_length_and_mapping_qc
2
1
  from .append_base_context import append_base_context
3
2
  from .append_binary_layer_by_base_context import append_binary_layer_by_base_context
4
- from .binarize_on_Youden import binarize_on_Youden
5
3
  from .binarize import binarize_adata
6
- from .calculate_complexity import calculate_complexity
4
+ from .binarize_on_Youden import binarize_on_Youden
7
5
  from .calculate_complexity_II import calculate_complexity_II
8
- from .calculate_read_modification_stats import calculate_read_modification_stats
9
6
  from .calculate_coverage import calculate_coverage
10
7
  from .calculate_position_Youden import calculate_position_Youden
11
8
  from .calculate_read_length_stats import calculate_read_length_stats
9
+ from .calculate_read_modification_stats import calculate_read_modification_stats
12
10
  from .clean_NaN import clean_NaN
13
11
  from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
14
- from .filter_reads_on_modification_thresholds import filter_reads_on_modification_thresholds
15
12
  from .filter_reads_on_length_quality_mapping import filter_reads_on_length_quality_mapping
13
+ from .filter_reads_on_modification_thresholds import filter_reads_on_modification_thresholds
14
+ from .flag_duplicate_reads import flag_duplicate_reads
16
15
  from .invert_adata import invert_adata
17
16
  from .load_sample_sheet import load_sample_sheet
18
- from .flag_duplicate_reads import flag_duplicate_reads
17
+ from .reindex_references_adata import reindex_references_adata
19
18
  from .subsample_adata import subsample_adata
20
19
 
21
20
  __all__ = [
22
- "add_read_length_and_mapping_qc",
23
21
  "append_base_context",
24
22
  "append_binary_layer_by_base_context",
25
23
  "binarize_on_Youden",
26
24
  "binarize_adata",
27
- "calculate_complexity",
25
+ "calculate_complexity_II",
28
26
  "calculate_read_modification_stats",
29
- "calculate_coverage",
27
+ "calculate_coverage",
30
28
  "calculate_position_Youden",
31
29
  "calculate_read_length_stats",
32
- "clean_NaN",
30
+ "clean_NaN",
33
31
  "filter_adata_by_nan_proportion",
34
32
  "filter_reads_on_modification_thresholds",
35
33
  "filter_reads_on_length_quality_mapping",
36
34
  "invert_adata",
37
35
  "load_sample_sheet",
38
36
  "flag_duplicate_reads",
39
- "subsample_adata"
40
- ]
37
+ "subsample_adata",
38
+ ]
@@ -1,27 +1,38 @@
1
- def append_base_context(adata,
2
- obs_column='Reference_strand',
3
- use_consensus=False,
4
- native=False,
5
- mod_target_bases=['GpC', 'CpG'],
6
- bypass=False,
7
- force_redo=False,
8
- uns_flag='base_context_added'
9
- ):
10
- """
11
- Adds nucleobase context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
12
-
13
- Parameters:
14
- adata (AnnData): The input adata object.
15
- obs_column (str): The observation column in which to stratify on. Default is 'Reference_strand', which should not be changed for most purposes.
16
- use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
17
- native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
18
- mod_target_bases (list): Base contexts that may be modified.
19
-
20
- Returns:
21
- None
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def append_base_context(
14
+ adata: "ad.AnnData",
15
+ ref_column: str = "Reference_strand",
16
+ use_consensus: bool = False,
17
+ native: bool = False,
18
+ mod_target_bases: list[str] = ["GpC", "CpG"],
19
+ bypass: bool = False,
20
+ force_redo: bool = False,
21
+ uns_flag: str = "append_base_context_performed",
22
+ ) -> None:
23
+ """Append base context annotations to ``adata``.
24
+
25
+ Args:
26
+ adata: AnnData object.
27
+ ref_column: Obs column used to stratify references.
28
+ use_consensus: Whether to use consensus sequences rather than FASTA references.
29
+ native: If ``True``, use native SMF assumptions; otherwise use conversion assumptions.
30
+ mod_target_bases: Base contexts that may be modified.
31
+ bypass: Whether to skip processing.
32
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
33
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
22
34
  """
23
35
  import numpy as np
24
- import anndata as ad
25
36
 
26
37
  # Only run if not already performed
27
38
  already = bool(adata.uns.get(uns_flag, False))
@@ -29,94 +40,118 @@ def append_base_context(adata,
29
40
  # QC already performed; nothing to do
30
41
  return
31
42
 
32
- print('Adding base context based on reference FASTA sequence for sample')
33
- categories = adata.obs[obs_column].cat.categories
43
+ logger.info("Adding base context based on reference FASTA sequence for sample")
44
+ references = adata.obs[ref_column].cat.categories
34
45
  site_types = []
35
-
36
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
37
- site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'C_site']
38
-
39
- if 'A' in mod_target_bases:
40
- site_types += ['A_site']
41
-
42
- for cat in categories:
46
+
47
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
48
+ site_types += ["GpC_site", "CpG_site", "ambiguous_GpC_CpG_site", "other_C_site", "C_site"]
49
+
50
+ if "A" in mod_target_bases:
51
+ site_types += ["A_site"]
52
+
53
+ for ref in references:
43
54
  # Assess if the strand is the top or bottom strand converted
44
- if 'top' in cat:
45
- strand = 'top'
46
- elif 'bottom' in cat:
47
- strand = 'bottom'
55
+ if "top" in ref:
56
+ strand = "top"
57
+ elif "bottom" in ref:
58
+ strand = "bottom"
48
59
 
49
60
  if native:
50
- basename = cat.split(f"_{strand}")[0]
61
+ basename = ref.split(f"_{strand}")[0]
51
62
  if use_consensus:
52
- sequence = adata.uns[f'{basename}_consensus_sequence']
63
+ sequence = adata.uns[f"{basename}_consensus_sequence"]
53
64
  else:
54
65
  # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
55
- sequence = adata.uns[f'{basename}_FASTA_sequence']
66
+ sequence = adata.uns[f"{basename}_FASTA_sequence"]
56
67
  else:
57
- basename = cat.split(f"_{strand}")[0]
68
+ basename = ref.split(f"_{strand}")[0]
58
69
  if use_consensus:
59
- sequence = adata.uns[f'{basename}_consensus_sequence']
70
+ sequence = adata.uns[f"{basename}_consensus_sequence"]
60
71
  else:
61
72
  # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
62
- sequence = adata.uns[f'{basename}_FASTA_sequence']
63
- # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
73
+ sequence = adata.uns[f"{basename}_FASTA_sequence"]
74
+
75
+ # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
64
76
  boolean_dict = {}
65
77
  for site_type in site_types:
66
- boolean_dict[f'{cat}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
78
+ boolean_dict[f"{ref}_{site_type}"] = np.full(len(sequence), False, dtype=bool)
67
79
 
68
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
69
- if strand == 'top':
80
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
81
+ if strand == "top":
70
82
  # Iterate through the sequence and apply the criteria
71
83
  for i in range(1, len(sequence) - 1):
72
- if sequence[i] == 'C':
73
- boolean_dict[f'{cat}_C_site'][i] = True
74
- if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
75
- boolean_dict[f'{cat}_GpC_site'][i] = True
76
- elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
77
- boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
78
- elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
79
- boolean_dict[f'{cat}_CpG_site'][i] = True
80
- elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
81
- boolean_dict[f'{cat}_other_C_site'][i] = True
82
- elif strand == 'bottom':
84
+ if sequence[i] == "C":
85
+ boolean_dict[f"{ref}_C_site"][i] = True
86
+ if sequence[i - 1] == "G" and sequence[i + 1] != "G":
87
+ boolean_dict[f"{ref}_GpC_site"][i] = True
88
+ elif sequence[i - 1] == "G" and sequence[i + 1] == "G":
89
+ boolean_dict[f"{ref}_ambiguous_GpC_CpG_site"][i] = True
90
+ elif sequence[i - 1] != "G" and sequence[i + 1] == "G":
91
+ boolean_dict[f"{ref}_CpG_site"][i] = True
92
+ elif sequence[i - 1] != "G" and sequence[i + 1] != "G":
93
+ boolean_dict[f"{ref}_other_C_site"][i] = True
94
+ elif strand == "bottom":
83
95
  # Iterate through the sequence and apply the criteria
84
96
  for i in range(1, len(sequence) - 1):
85
- if sequence[i] == 'G':
86
- boolean_dict[f'{cat}_C_site'][i] = True
87
- if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
88
- boolean_dict[f'{cat}_GpC_site'][i] = True
89
- elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
90
- boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
91
- elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
92
- boolean_dict[f'{cat}_CpG_site'][i] = True
93
- elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
94
- boolean_dict[f'{cat}_other_C_site'][i] = True
97
+ if sequence[i] == "G":
98
+ boolean_dict[f"{ref}_C_site"][i] = True
99
+ if sequence[i + 1] == "C" and sequence[i - 1] != "C":
100
+ boolean_dict[f"{ref}_GpC_site"][i] = True
101
+ elif sequence[i - 1] == "C" and sequence[i + 1] == "C":
102
+ boolean_dict[f"{ref}_ambiguous_GpC_CpG_site"][i] = True
103
+ elif sequence[i - 1] == "C" and sequence[i + 1] != "C":
104
+ boolean_dict[f"{ref}_CpG_site"][i] = True
105
+ elif sequence[i - 1] != "C" and sequence[i + 1] != "C":
106
+ boolean_dict[f"{ref}_other_C_site"][i] = True
95
107
  else:
96
- print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
108
+ logger.error(
109
+ "Top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name."
110
+ )
97
111
 
98
- if 'A' in mod_target_bases:
99
- if strand == 'top':
112
+ if "A" in mod_target_bases:
113
+ if strand == "top":
100
114
  # Iterate through the sequence and apply the criteria
101
115
  for i in range(1, len(sequence) - 1):
102
- if sequence[i] == 'A':
103
- boolean_dict[f'{cat}_A_site'][i] = True
104
- elif strand == 'bottom':
116
+ if sequence[i] == "A":
117
+ boolean_dict[f"{ref}_A_site"][i] = True
118
+ elif strand == "bottom":
105
119
  # Iterate through the sequence and apply the criteria
106
120
  for i in range(1, len(sequence) - 1):
107
- if sequence[i] == 'T':
108
- boolean_dict[f'{cat}_A_site'][i] = True
121
+ if sequence[i] == "T":
122
+ boolean_dict[f"{ref}_A_site"][i] = True
109
123
  else:
110
- print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
124
+ logger.error(
125
+ "Top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name."
126
+ )
111
127
 
112
128
  for site_type in site_types:
113
- adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
129
+ # Site context annotations for each reference
130
+ adata.var[f"{ref}_{site_type}"] = boolean_dict[f"{ref}_{site_type}"].astype(bool)
131
+ # Restrict the site type labels to only be in positions that occur at a high enough frequency in the dataset
132
+ if adata.uns.get("calculate_coverage_performed", False):
133
+ adata.var[f"{ref}_{site_type}_valid_coverage"] = (
134
+ (adata.var[f"{ref}_{site_type}"]) & (adata.var[f"position_in_{ref}"])
135
+ )
136
+ if native:
137
+ adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
138
+ :, adata.var[f"{ref}_{site_type}_valid_coverage"]
139
+ ].layers["binarized_methylation"]
140
+ else:
141
+ adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
142
+ :, adata.var[f"{ref}_{site_type}_valid_coverage"]
143
+ ].X
144
+ else:
145
+ pass
146
+
114
147
  if native:
115
- adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].layers['binarized_methylation']
148
+ adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
149
+ "binarized_methylation"
150
+ ]
116
151
  else:
117
- adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
152
+ adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
118
153
 
119
154
  # mark as done
120
155
  adata.uns[uns_flag] = True
121
156
 
122
- return None
157
+ return None
@@ -1,33 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
1
5
  import numpy as np
2
6
  import scipy.sparse as sp
3
7
 
8
+ from smftools.logging_utils import get_logger
9
+
10
+ if TYPE_CHECKING:
11
+ import anndata as ad
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
4
16
  def append_binary_layer_by_base_context(
5
- adata,
17
+ adata: "ad.AnnData",
6
18
  reference_column: str,
7
19
  smf_modality: str = "conversion",
8
20
  verbose: bool = True,
9
- uns_flag: str = "binary_layers_by_base_context_added",
21
+ uns_flag: str = "append_binary_layer_by_base_context_performed",
10
22
  bypass: bool = False,
11
- force_redo: bool = False
12
- ):
13
- """
14
- Build per-reference C/G-site masked layers:
15
- - GpC_site_binary
16
- - CpG_site_binary
17
- - GpC_CpG_combined_site_binary (numeric sum where present; NaN where neither present)
18
- - C_site_binary
19
- - other_C_site_binary
20
-
21
- Behavior:
22
- - If X is sparse it will be converted to dense for these layers (keeps original adata.X untouched).
23
- - Missing var columns are warned about but do not crash.
24
- - Masked positions are filled with np.nan to make masked vs unmasked explicit.
25
- - Requires append_base_context to be run first
23
+ force_redo: bool = False,
24
+ from_valid_sites_only: bool = False,
25
+ valid_site_col_suffix: str = "_valid_coverage",
26
+ ) -> "ad.AnnData":
27
+ """Build per-reference masked layers for base-context sites.
28
+
29
+ Args:
30
+ adata: AnnData object to annotate.
31
+ reference_column: Obs column containing reference identifiers.
32
+ smf_modality: SMF modality identifier.
33
+ verbose: Whether to log layer summary information.
34
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
35
+ bypass: Whether to skip processing.
36
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
37
+ from_valid_sites_only: Whether to use valid-coverage site masks only.
38
+ valid_site_col_suffix: Suffix for valid-coverage site columns.
39
+
40
+ Returns:
41
+ anndata.AnnData: AnnData object with new masked layers.
26
42
  """
43
+ if not from_valid_sites_only:
44
+ valid_site_col_suffix = ""
27
45
 
28
46
  # Only run if not already performed
29
47
  already = bool(adata.uns.get(uns_flag, False))
30
- if (already and not force_redo) or bypass or ("base_context_added" not in adata.uns):
48
+ if (already and not force_redo) or bypass or ("append_base_context_performed" not in adata.uns):
31
49
  # QC already performed; nothing to do
32
50
  return adata
33
51
 
@@ -46,17 +64,25 @@ def append_binary_layer_by_base_context(
46
64
 
47
65
  # expected per-reference var column names
48
66
  references = adata.obs[reference_column].astype("category").cat.categories
49
- reference_to_gpc_column = {ref: f"{ref}_GpC_site" for ref in references}
50
- reference_to_cpg_column = {ref: f"{ref}_CpG_site" for ref in references}
51
- reference_to_c_column = {ref: f"{ref}_C_site" for ref in references}
52
- reference_to_other_c_column = {ref: f"{ref}_other_C_site" for ref in references}
67
+ reference_to_gpc_column = {ref: f"{ref}_GpC_site{valid_site_col_suffix}" for ref in references}
68
+ reference_to_cpg_column = {ref: f"{ref}_CpG_site{valid_site_col_suffix}" for ref in references}
69
+ reference_to_c_column = {ref: f"{ref}_C_site{valid_site_col_suffix}" for ref in references}
70
+ reference_to_other_c_column = {
71
+ ref: f"{ref}_other_C_site{valid_site_col_suffix}" for ref in references
72
+ }
73
+ reference_to_a_column = {ref: f"{ref}_A_site{valid_site_col_suffix}" for ref in references}
53
74
 
54
75
  # verify var columns exist and build boolean masks per ref (len = n_vars)
55
76
  n_obs, n_vars = adata.shape
77
+
56
78
  def _col_mask_or_warn(colname):
79
+ """Return a boolean mask for a var column, or all-False if missing."""
57
80
  if colname not in adata.var.columns:
58
81
  if verbose:
59
- print(f"Warning: var column '{colname}' not found; treating as all-False mask.")
82
+ logger.warning(
83
+ "Var column '%s' not found; treating as all-False mask.",
84
+ colname,
85
+ )
60
86
  return np.zeros(n_vars, dtype=bool)
61
87
  vals = adata.var[colname].values
62
88
  # coerce truthiness
@@ -67,14 +93,17 @@ def append_binary_layer_by_base_context(
67
93
 
68
94
  gpc_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_gpc_column.items()}
69
95
  cpg_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_cpg_column.items()}
70
- c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_c_column.items()}
71
- other_c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_other_c_column.items()}
96
+ c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_c_column.items()}
97
+ other_c_var_masks = {
98
+ ref: _col_mask_or_warn(col) for ref, col in reference_to_other_c_column.items()
99
+ }
100
+ a_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_a_column.items()}
72
101
 
73
102
  # prepare X as dense float32 for layer filling (we leave adata.X untouched)
74
103
  X = adata.X
75
104
  if sp.issparse(X):
76
105
  if verbose:
77
- print("Converting sparse X to dense array for layer construction (temporary).")
106
+ logger.info("Converting sparse X to dense array for layer construction (temporary).")
78
107
  X = X.toarray()
79
108
  X = np.asarray(X, dtype=np.float32)
80
109
 
@@ -83,11 +112,12 @@ def append_binary_layer_by_base_context(
83
112
  masked_cpg = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
84
113
  masked_any_c = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
85
114
  masked_other_c = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
115
+ masked_a = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
86
116
 
87
117
  # fill row-blocks per reference (this avoids creating a full row×var boolean mask)
88
118
  obs_ref_series = adata.obs[reference_column]
89
119
  for ref in references:
90
- rows_mask = (obs_ref_series.values == ref)
120
+ rows_mask = obs_ref_series.values == ref
91
121
  if not rows_mask.any():
92
122
  continue
93
123
  row_idx = np.nonzero(rows_mask)[0] # integer indices of rows for this ref
@@ -95,8 +125,9 @@ def append_binary_layer_by_base_context(
95
125
  # column masks for this ref
96
126
  gpc_cols = gpc_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
97
127
  cpg_cols = cpg_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
98
- c_cols = c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
128
+ c_cols = c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
99
129
  other_c_cols = other_c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
130
+ a_cols = a_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
100
131
 
101
132
  if gpc_cols.any():
102
133
  # assign only the submatrix (rows x selected cols)
@@ -107,6 +138,8 @@ def append_binary_layer_by_base_context(
107
138
  masked_any_c[np.ix_(row_idx, c_cols)] = X[np.ix_(row_idx, c_cols)]
108
139
  if other_c_cols.any():
109
140
  masked_other_c[np.ix_(row_idx, other_c_cols)] = X[np.ix_(row_idx, other_c_cols)]
141
+ if a_cols.any():
142
+ masked_a[np.ix_(row_idx, other_c_cols)] = X[np.ix_(row_idx, other_c_cols)]
110
143
 
111
144
  # Build combined layer:
112
145
  # - numeric_sum: sum where either exists, NaN where neither exists
@@ -121,21 +154,26 @@ def append_binary_layer_by_base_context(
121
154
  # combined_bool = (~gpc_nan & (masked_gpc != 0)) | (~cpg_nan & (masked_cpg != 0))
122
155
  # combined_layer = combined_bool.astype(np.float32)
123
156
 
124
- adata.layers['GpC_site_binary'] = masked_gpc
125
- adata.layers['CpG_site_binary'] = masked_cpg
126
- adata.layers['GpC_CpG_combined_site_binary'] = combined_sum
127
- adata.layers['C_site_binary'] = masked_any_c
128
- adata.layers['other_C_site_binary'] = masked_other_c
157
+ adata.layers["GpC_site_binary"] = masked_gpc
158
+ adata.layers["CpG_site_binary"] = masked_cpg
159
+ adata.layers["GpC_CpG_combined_site_binary"] = combined_sum
160
+ adata.layers["C_site_binary"] = masked_any_c
161
+ adata.layers["other_C_site_binary"] = masked_other_c
162
+ adata.layers["A_site_binary"] = masked_a
129
163
 
130
164
  if verbose:
165
+
131
166
  def _filled_positions(arr):
167
+ """Count the number of non-NaN positions in an array."""
132
168
  return int(np.sum(~np.isnan(arr)))
133
- print("Layer build summary (non-NaN cell counts):")
134
- print(f" GpC: {_filled_positions(masked_gpc)}")
135
- print(f" CpG: {_filled_positions(masked_cpg)}")
136
- print(f" GpC+CpG combined: {_filled_positions(combined_sum)}")
137
- print(f" C: {_filled_positions(masked_any_c)}")
138
- print(f" other_C: {_filled_positions(masked_other_c)}")
169
+
170
+ logger.info("Layer build summary (non-NaN cell counts):")
171
+ logger.info(" GpC: %s", _filled_positions(masked_gpc))
172
+ logger.info(" CpG: %s", _filled_positions(masked_cpg))
173
+ logger.info(" GpC+CpG combined: %s", _filled_positions(combined_sum))
174
+ logger.info(" C: %s", _filled_positions(masked_any_c))
175
+ logger.info(" other_C: %s", _filled_positions(masked_other_c))
176
+ logger.info(" A: %s", _filled_positions(masked_a))
139
177
 
140
178
  # mark as done
141
179
  adata.uns[uns_flag] = True
@@ -21,9 +21,11 @@ def calculate_complexity(adata, output_directory='', obs_column='Reference', sam
21
21
  from scipy.optimize import curve_fit
22
22
 
23
23
  def lander_waterman(x, C0):
24
+ """Lander-Waterman curve for complexity estimation."""
24
25
  return C0 * (1 - np.exp(-x / C0))
25
26
 
26
27
  def count_unique_reads(reads, depth):
28
+ """Count unique reads in a subsample of the given depth."""
27
29
  subsample = np.random.choice(reads, depth, replace=False)
28
30
  return len(np.unique(subsample))
29
31
 
@@ -69,4 +71,4 @@ def calculate_complexity(adata, output_directory='', obs_column='Reference', sam
69
71
  plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
70
72
  plt.close()
71
73
  else:
72
- plt.show()
74
+ plt.show()
@@ -322,12 +322,14 @@ def min_non_diagonal(matrix):
322
322
  min_values.append(np.min(row))
323
323
  return min_values
324
324
 
325
- def lander_waterman(x, C0):
326
- return C0 * (1 - np.exp(-x / C0))
325
+ def lander_waterman(x, C0):
326
+ """Lander-Waterman curve for complexity estimation."""
327
+ return C0 * (1 - np.exp(-x / C0))
327
328
 
328
- def count_unique_reads(reads, depth):
329
- subsample = np.random.choice(reads, depth, replace=False)
330
- return len(np.unique(subsample))
329
+ def count_unique_reads(reads, depth):
330
+ """Count unique reads in a subsample of the given depth."""
331
+ subsample = np.random.choice(reads, depth, replace=False)
332
+ return len(np.unique(subsample))
331
333
 
332
334
  def mark_duplicates(adata, layers, obs_column='Reference', sample_col='Sample_names'):
333
335
  """
@@ -611,4 +613,4 @@ def binarize_on_Youden(adata, obs_column='Reference'):
611
613
  # Pull back the new binarized layers into the original adata object
612
614
  adata.layers['binarized_methylation'] = temp_adata.layers['binarized_methylation']
613
615
 
614
- ######################################################################################################
616
+ ######################################################################################################
@@ -1,9 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
1
5
  import numpy as np
2
6
 
3
- def binarize_adata(adata, source="X", target_layer="binary", threshold=0.8):
4
- """
5
- Binarize a dense matrix and preserve NaN.
6
- source: "X" or layer name
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+
11
+ def binarize_adata(
12
+ adata: "ad.AnnData",
13
+ source: str = "X",
14
+ target_layer: str = "binary",
15
+ threshold: float = 0.8,
16
+ ) -> None:
17
+ """Binarize a dense matrix and preserve NaNs.
18
+
19
+ Args:
20
+ adata: AnnData object with input matrix or layer.
21
+ source: ``"X"`` to use the main matrix or a layer name.
22
+ target_layer: Layer name to store the binarized values.
23
+ threshold: Threshold above which values are set to 1.
7
24
  """
8
25
  X = adata.X if source == "X" else adata.layers[source]
9
26