smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,13 @@
1
- import os
2
- import numpy as np
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
5
2
 
6
3
  import os
4
+
7
5
  import numpy as np
8
6
  import pandas as pd
9
- import matplotlib.pyplot as plt
7
+
8
+ from smftools.optional_imports import require
9
+
10
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="QC plots")
10
11
 
11
12
 
12
13
  def plot_read_qc_histograms(
@@ -83,7 +84,11 @@ def plot_read_qc_histograms(
83
84
  for key in valid_keys:
84
85
  if not is_numeric[key]:
85
86
  continue
86
- s = pd.to_numeric(adata.obs[key], errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
87
+ s = (
88
+ pd.to_numeric(adata.obs[key], errors="coerce")
89
+ .replace([np.inf, -np.inf], np.nan)
90
+ .dropna()
91
+ )
87
92
  if s.size < min_non_nan:
88
93
  # still set something to avoid errors; just use min/max or (0,1)
89
94
  lo, hi = (0.0, 1.0) if s.size == 0 else (float(s.min()), float(s.max()))
@@ -99,6 +104,7 @@ def plot_read_qc_histograms(
99
104
  global_ranges[key] = (lo, hi)
100
105
 
101
106
  def _sanitize(name: str) -> str:
107
+ """Sanitize a string for use in filenames."""
102
108
  return "".join(c if c.isalnum() or c in "-._" else "_" for c in str(name))
103
109
 
104
110
  ncols = len(valid_keys)
@@ -107,17 +113,18 @@ def plot_read_qc_histograms(
107
113
  fig_h_unit = figsize_cell[1]
108
114
 
109
115
  for start in range(0, len(sample_levels), rows_per_fig):
110
- chunk = sample_levels[start:start + rows_per_fig]
116
+ chunk = sample_levels[start : start + rows_per_fig]
111
117
  nrows = len(chunk)
112
118
  fig, axes = plt.subplots(
113
- nrows=nrows, ncols=ncols,
119
+ nrows=nrows,
120
+ ncols=ncols,
114
121
  figsize=(fig_w, fig_h_unit * nrows),
115
122
  dpi=dpi,
116
123
  squeeze=False,
117
124
  )
118
125
 
119
126
  for r, sample_val in enumerate(chunk):
120
- row_mask = (adata.obs[sample_key].values == sample_val)
127
+ row_mask = adata.obs[sample_key].values == sample_val
121
128
  n_in_row = int(row_mask.sum())
122
129
 
123
130
  for c, key in enumerate(valid_keys):
@@ -125,7 +132,11 @@ def plot_read_qc_histograms(
125
132
  series = adata.obs.loc[row_mask, key]
126
133
 
127
134
  if is_numeric[key]:
128
- x = pd.to_numeric(series, errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
135
+ x = (
136
+ pd.to_numeric(series, errors="coerce")
137
+ .replace([np.inf, -np.inf], np.nan)
138
+ .dropna()
139
+ )
129
140
  if x.size < min_non_nan:
130
141
  ax.text(0.5, 0.5, f"n={x.size} (<{min_non_nan})", ha="center", va="center")
131
142
  else:
@@ -143,7 +154,9 @@ def plot_read_qc_histograms(
143
154
  else:
144
155
  vc = series.astype("category").value_counts(dropna=False)
145
156
  if vc.sum() < min_non_nan:
146
- ax.text(0.5, 0.5, f"n={vc.sum()} (<{min_non_nan})", ha="center", va="center")
157
+ ax.text(
158
+ 0.5, 0.5, f"n={vc.sum()} (<{min_non_nan})", ha="center", va="center"
159
+ )
147
160
  else:
148
161
  vc_top = vc.iloc[:topn_categories][::-1] # show top-N, reversed for barh
149
162
  ax.barh(vc_top.index.astype(str), vc_top.values)
@@ -267,4 +280,4 @@ def plot_read_qc_histograms(
267
280
  # fname = f"{key}_{sample_key}_{safe_group}.png" if sample_key else f"{key}.png"
268
281
  # fname = fname.replace("/", "_")
269
282
  # fig.savefig(os.path.join(outdir, fname))
270
- # plt.close(fig)
283
+ # plt.close(fig)
@@ -1,38 +1,36 @@
1
- from .append_base_context import append_base_context
2
- from .append_binary_layer_by_base_context import append_binary_layer_by_base_context
3
- from .binarize_on_Youden import binarize_on_Youden
4
- from .binarize import binarize_adata
5
- from .calculate_complexity_II import calculate_complexity_II
6
- from .calculate_read_modification_stats import calculate_read_modification_stats
7
- from .calculate_coverage import calculate_coverage
8
- from .calculate_position_Youden import calculate_position_Youden
9
- from .calculate_read_length_stats import calculate_read_length_stats
10
- from .clean_NaN import clean_NaN
11
- from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
12
- from .filter_reads_on_modification_thresholds import filter_reads_on_modification_thresholds
13
- from .filter_reads_on_length_quality_mapping import filter_reads_on_length_quality_mapping
14
- from .invert_adata import invert_adata
15
- from .load_sample_sheet import load_sample_sheet
16
- from .flag_duplicate_reads import flag_duplicate_reads
17
- from .reindex_references_adata import reindex_references_adata
18
- from .subsample_adata import subsample_adata
1
+ from __future__ import annotations
19
2
 
20
- __all__ = [
21
- "append_base_context",
22
- "append_binary_layer_by_base_context",
23
- "binarize_on_Youden",
24
- "binarize_adata",
25
- "calculate_complexity_II",
26
- "calculate_read_modification_stats",
27
- "calculate_coverage",
28
- "calculate_position_Youden",
29
- "calculate_read_length_stats",
30
- "clean_NaN",
31
- "filter_adata_by_nan_proportion",
32
- "filter_reads_on_modification_thresholds",
33
- "filter_reads_on_length_quality_mapping",
34
- "invert_adata",
35
- "load_sample_sheet",
36
- "flag_duplicate_reads",
37
- "subsample_adata"
38
- ]
3
+ from importlib import import_module
4
+
5
+ _LAZY_ATTRS = {
6
+ "append_base_context": "smftools.preprocessing.append_base_context",
7
+ "append_binary_layer_by_base_context": "smftools.preprocessing.append_binary_layer_by_base_context",
8
+ "binarize_adata": "smftools.preprocessing.binarize",
9
+ "binarize_on_Youden": "smftools.preprocessing.binarize_on_Youden",
10
+ "calculate_complexity_II": "smftools.preprocessing.calculate_complexity_II",
11
+ "calculate_coverage": "smftools.preprocessing.calculate_coverage",
12
+ "calculate_position_Youden": "smftools.preprocessing.calculate_position_Youden",
13
+ "calculate_read_length_stats": "smftools.preprocessing.calculate_read_length_stats",
14
+ "calculate_read_modification_stats": "smftools.preprocessing.calculate_read_modification_stats",
15
+ "clean_NaN": "smftools.preprocessing.clean_NaN",
16
+ "filter_adata_by_nan_proportion": "smftools.preprocessing.filter_adata_by_nan_proportion",
17
+ "filter_reads_on_length_quality_mapping": "smftools.preprocessing.filter_reads_on_length_quality_mapping",
18
+ "filter_reads_on_modification_thresholds": "smftools.preprocessing.filter_reads_on_modification_thresholds",
19
+ "flag_duplicate_reads": "smftools.preprocessing.flag_duplicate_reads",
20
+ "invert_adata": "smftools.preprocessing.invert_adata",
21
+ "load_sample_sheet": "smftools.preprocessing.load_sample_sheet",
22
+ "reindex_references_adata": "smftools.preprocessing.reindex_references_adata",
23
+ "subsample_adata": "smftools.preprocessing.subsample_adata",
24
+ }
25
+
26
+
27
+ def __getattr__(name: str):
28
+ if name in _LAZY_ATTRS:
29
+ module = import_module(_LAZY_ATTRS[name])
30
+ attr = getattr(module, name)
31
+ globals()[name] = attr
32
+ return attr
33
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
34
+
35
+
36
+ __all__ = list(_LAZY_ATTRS.keys())
@@ -1,28 +1,38 @@
1
- def append_base_context(adata,
2
- ref_column='Reference_strand',
3
- use_consensus=False,
4
- native=False,
5
- mod_target_bases=['GpC', 'CpG'],
6
- bypass=False,
7
- force_redo=False,
8
- uns_flag='append_base_context_performed'
9
- ):
10
- """
11
- Adds nucleobase context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
12
- This needs to be performed prior to AnnData inversion step.
13
-
14
- Parameters:
15
- adata (AnnData): The input adata object.
16
- ref_column (str): The observation column in which to stratify on. Default is 'Reference_strand', which should not be changed for most purposes.
17
- use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
18
- native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
19
- mod_target_bases (list): Base contexts that may be modified.
20
-
21
- Returns:
22
- None
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def append_base_context(
14
+ adata: "ad.AnnData",
15
+ ref_column: str = "Reference_strand",
16
+ use_consensus: bool = False,
17
+ native: bool = False,
18
+ mod_target_bases: list[str] = ["GpC", "CpG"],
19
+ bypass: bool = False,
20
+ force_redo: bool = False,
21
+ uns_flag: str = "append_base_context_performed",
22
+ ) -> None:
23
+ """Append base context annotations to ``adata``.
24
+
25
+ Args:
26
+ adata: AnnData object.
27
+ ref_column: Obs column used to stratify references.
28
+ use_consensus: Whether to use consensus sequences rather than FASTA references.
29
+ native: If ``True``, use native SMF assumptions; otherwise use conversion assumptions.
30
+ mod_target_bases: Base contexts that may be modified.
31
+ bypass: Whether to skip processing.
32
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
33
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
23
34
  """
24
35
  import numpy as np
25
- import anndata as ad
26
36
 
27
37
  # Only run if not already performed
28
38
  already = bool(adata.uns.get(uns_flag, False))
@@ -30,102 +40,118 @@ def append_base_context(adata,
30
40
  # QC already performed; nothing to do
31
41
  return
32
42
 
33
- print('Adding base context based on reference FASTA sequence for sample')
43
+ logger.info("Adding base context based on reference FASTA sequence for sample")
34
44
  references = adata.obs[ref_column].cat.categories
35
45
  site_types = []
36
-
37
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
38
- site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'C_site']
39
-
40
- if 'A' in mod_target_bases:
41
- site_types += ['A_site']
46
+
47
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
48
+ site_types += ["GpC_site", "CpG_site", "ambiguous_GpC_CpG_site", "other_C_site", "C_site"]
49
+
50
+ if "A" in mod_target_bases:
51
+ site_types += ["A_site"]
42
52
 
43
53
  for ref in references:
44
54
  # Assess if the strand is the top or bottom strand converted
45
- if 'top' in ref:
46
- strand = 'top'
47
- elif 'bottom' in ref:
48
- strand = 'bottom'
55
+ if "top" in ref:
56
+ strand = "top"
57
+ elif "bottom" in ref:
58
+ strand = "bottom"
49
59
 
50
60
  if native:
51
61
  basename = ref.split(f"_{strand}")[0]
52
62
  if use_consensus:
53
- sequence = adata.uns[f'{basename}_consensus_sequence']
63
+ sequence = adata.uns[f"{basename}_consensus_sequence"]
54
64
  else:
55
65
  # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
56
- sequence = adata.uns[f'{basename}_FASTA_sequence']
66
+ sequence = adata.uns[f"{basename}_FASTA_sequence"]
57
67
  else:
58
68
  basename = ref.split(f"_{strand}")[0]
59
69
  if use_consensus:
60
- sequence = adata.uns[f'{basename}_consensus_sequence']
70
+ sequence = adata.uns[f"{basename}_consensus_sequence"]
61
71
  else:
62
72
  # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
63
- sequence = adata.uns[f'{basename}_FASTA_sequence']
73
+ sequence = adata.uns[f"{basename}_FASTA_sequence"]
64
74
 
65
- # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
75
+ # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
66
76
  boolean_dict = {}
67
77
  for site_type in site_types:
68
- boolean_dict[f'{ref}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
78
+ boolean_dict[f"{ref}_{site_type}"] = np.full(len(sequence), False, dtype=bool)
69
79
 
70
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
71
- if strand == 'top':
80
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
81
+ if strand == "top":
72
82
  # Iterate through the sequence and apply the criteria
73
83
  for i in range(1, len(sequence) - 1):
74
- if sequence[i] == 'C':
75
- boolean_dict[f'{ref}_C_site'][i] = True
76
- if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
77
- boolean_dict[f'{ref}_GpC_site'][i] = True
78
- elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
79
- boolean_dict[f'{ref}_ambiguous_GpC_CpG_site'][i] = True
80
- elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
81
- boolean_dict[f'{ref}_CpG_site'][i] = True
82
- elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
83
- boolean_dict[f'{ref}_other_C_site'][i] = True
84
- elif strand == 'bottom':
84
+ if sequence[i] == "C":
85
+ boolean_dict[f"{ref}_C_site"][i] = True
86
+ if sequence[i - 1] == "G" and sequence[i + 1] != "G":
87
+ boolean_dict[f"{ref}_GpC_site"][i] = True
88
+ elif sequence[i - 1] == "G" and sequence[i + 1] == "G":
89
+ boolean_dict[f"{ref}_ambiguous_GpC_CpG_site"][i] = True
90
+ elif sequence[i - 1] != "G" and sequence[i + 1] == "G":
91
+ boolean_dict[f"{ref}_CpG_site"][i] = True
92
+ elif sequence[i - 1] != "G" and sequence[i + 1] != "G":
93
+ boolean_dict[f"{ref}_other_C_site"][i] = True
94
+ elif strand == "bottom":
85
95
  # Iterate through the sequence and apply the criteria
86
96
  for i in range(1, len(sequence) - 1):
87
- if sequence[i] == 'G':
88
- boolean_dict[f'{ref}_C_site'][i] = True
89
- if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
90
- boolean_dict[f'{ref}_GpC_site'][i] = True
91
- elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
92
- boolean_dict[f'{ref}_ambiguous_GpC_CpG_site'][i] = True
93
- elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
94
- boolean_dict[f'{ref}_CpG_site'][i] = True
95
- elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
96
- boolean_dict[f'{ref}_other_C_site'][i] = True
97
+ if sequence[i] == "G":
98
+ boolean_dict[f"{ref}_C_site"][i] = True
99
+ if sequence[i + 1] == "C" and sequence[i - 1] != "C":
100
+ boolean_dict[f"{ref}_GpC_site"][i] = True
101
+ elif sequence[i - 1] == "C" and sequence[i + 1] == "C":
102
+ boolean_dict[f"{ref}_ambiguous_GpC_CpG_site"][i] = True
103
+ elif sequence[i - 1] == "C" and sequence[i + 1] != "C":
104
+ boolean_dict[f"{ref}_CpG_site"][i] = True
105
+ elif sequence[i - 1] != "C" and sequence[i + 1] != "C":
106
+ boolean_dict[f"{ref}_other_C_site"][i] = True
97
107
  else:
98
- print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
108
+ logger.error(
109
+ "Top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name."
110
+ )
99
111
 
100
- if 'A' in mod_target_bases:
101
- if strand == 'top':
112
+ if "A" in mod_target_bases:
113
+ if strand == "top":
102
114
  # Iterate through the sequence and apply the criteria
103
115
  for i in range(1, len(sequence) - 1):
104
- if sequence[i] == 'A':
105
- boolean_dict[f'{ref}_A_site'][i] = True
106
- elif strand == 'bottom':
116
+ if sequence[i] == "A":
117
+ boolean_dict[f"{ref}_A_site"][i] = True
118
+ elif strand == "bottom":
107
119
  # Iterate through the sequence and apply the criteria
108
120
  for i in range(1, len(sequence) - 1):
109
- if sequence[i] == 'T':
110
- boolean_dict[f'{ref}_A_site'][i] = True
121
+ if sequence[i] == "T":
122
+ boolean_dict[f"{ref}_A_site"][i] = True
111
123
  else:
112
- print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
124
+ logger.error(
125
+ "Top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name."
126
+ )
113
127
 
114
128
  for site_type in site_types:
115
129
  # Site context annotations for each reference
116
- adata.var[f'{ref}_{site_type}'] = boolean_dict[f'{ref}_{site_type}'].astype(bool)
130
+ adata.var[f"{ref}_{site_type}"] = boolean_dict[f"{ref}_{site_type}"].astype(bool)
117
131
  # Restrict the site type labels to only be in positions that occur at a high enough frequency in the dataset
118
- if adata.uns["calculate_coverage_performed"] == True:
119
- adata.var[f'{ref}_{site_type}'] = (adata.var[f'{ref}_{site_type}']) & (adata.var[f'position_in_{ref}'])
132
+ if adata.uns.get("calculate_coverage_performed", False):
133
+ adata.var[f"{ref}_{site_type}_valid_coverage"] = (
134
+ (adata.var[f"{ref}_{site_type}"]) & (adata.var[f"position_in_{ref}"])
135
+ )
136
+ if native:
137
+ adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
138
+ :, adata.var[f"{ref}_{site_type}_valid_coverage"]
139
+ ].layers["binarized_methylation"]
140
+ else:
141
+ adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
142
+ :, adata.var[f"{ref}_{site_type}_valid_coverage"]
143
+ ].X
120
144
  else:
121
145
  pass
122
-
146
+
123
147
  if native:
124
- adata.obsm[f'{ref}_{site_type}'] = adata[:, adata.var[f'{ref}_{site_type}'] == True].layers['binarized_methylation']
148
+ adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
149
+ "binarized_methylation"
150
+ ]
125
151
  else:
126
- adata.obsm[f'{ref}_{site_type}'] = adata[:, adata.var[f'{ref}_{site_type}'] == True].X
152
+ adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
127
153
 
128
154
  # mark as done
129
155
  adata.uns[uns_flag] = True
130
156
 
131
- return None
157
+ return None
@@ -1,29 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
1
5
  import numpy as np
2
6
  import scipy.sparse as sp
3
7
 
8
+ from smftools.logging_utils import get_logger
9
+
10
+ if TYPE_CHECKING:
11
+ import anndata as ad
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
4
16
  def append_binary_layer_by_base_context(
5
- adata,
17
+ adata: "ad.AnnData",
6
18
  reference_column: str,
7
19
  smf_modality: str = "conversion",
8
20
  verbose: bool = True,
9
21
  uns_flag: str = "append_binary_layer_by_base_context_performed",
10
22
  bypass: bool = False,
11
- force_redo: bool = False
12
- ):
13
- """
14
- Build per-reference C/G-site masked layers:
15
- - GpC_site_binary
16
- - CpG_site_binary
17
- - GpC_CpG_combined_site_binary (numeric sum where present; NaN where neither present)
18
- - C_site_binary
19
- - other_C_site_binary
20
-
21
- Behavior:
22
- - If X is sparse it will be converted to dense for these layers (keeps original adata.X untouched).
23
- - Missing var columns are warned about but do not crash.
24
- - Masked positions are filled with np.nan to make masked vs unmasked explicit.
25
- - Requires append_base_context to be run first
23
+ force_redo: bool = False,
24
+ from_valid_sites_only: bool = False,
25
+ valid_site_col_suffix: str = "_valid_coverage",
26
+ ) -> "ad.AnnData":
27
+ """Build per-reference masked layers for base-context sites.
28
+
29
+ Args:
30
+ adata: AnnData object to annotate.
31
+ reference_column: Obs column containing reference identifiers.
32
+ smf_modality: SMF modality identifier.
33
+ verbose: Whether to log layer summary information.
34
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
35
+ bypass: Whether to skip processing.
36
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
37
+ from_valid_sites_only: Whether to use valid-coverage site masks only.
38
+ valid_site_col_suffix: Suffix for valid-coverage site columns.
39
+
40
+ Returns:
41
+ anndata.AnnData: AnnData object with new masked layers.
26
42
  """
43
+ if not from_valid_sites_only:
44
+ valid_site_col_suffix = ""
27
45
 
28
46
  # Only run if not already performed
29
47
  already = bool(adata.uns.get(uns_flag, False))
@@ -46,17 +64,25 @@ def append_binary_layer_by_base_context(
46
64
 
47
65
  # expected per-reference var column names
48
66
  references = adata.obs[reference_column].astype("category").cat.categories
49
- reference_to_gpc_column = {ref: f"{ref}_GpC_site" for ref in references}
50
- reference_to_cpg_column = {ref: f"{ref}_CpG_site" for ref in references}
51
- reference_to_c_column = {ref: f"{ref}_C_site" for ref in references}
52
- reference_to_other_c_column = {ref: f"{ref}_other_C_site" for ref in references}
67
+ reference_to_gpc_column = {ref: f"{ref}_GpC_site{valid_site_col_suffix}" for ref in references}
68
+ reference_to_cpg_column = {ref: f"{ref}_CpG_site{valid_site_col_suffix}" for ref in references}
69
+ reference_to_c_column = {ref: f"{ref}_C_site{valid_site_col_suffix}" for ref in references}
70
+ reference_to_other_c_column = {
71
+ ref: f"{ref}_other_C_site{valid_site_col_suffix}" for ref in references
72
+ }
73
+ reference_to_a_column = {ref: f"{ref}_A_site{valid_site_col_suffix}" for ref in references}
53
74
 
54
75
  # verify var columns exist and build boolean masks per ref (len = n_vars)
55
76
  n_obs, n_vars = adata.shape
77
+
56
78
  def _col_mask_or_warn(colname):
79
+ """Return a boolean mask for a var column, or all-False if missing."""
57
80
  if colname not in adata.var.columns:
58
81
  if verbose:
59
- print(f"Warning: var column '{colname}' not found; treating as all-False mask.")
82
+ logger.warning(
83
+ "Var column '%s' not found; treating as all-False mask.",
84
+ colname,
85
+ )
60
86
  return np.zeros(n_vars, dtype=bool)
61
87
  vals = adata.var[colname].values
62
88
  # coerce truthiness
@@ -67,14 +93,17 @@ def append_binary_layer_by_base_context(
67
93
 
68
94
  gpc_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_gpc_column.items()}
69
95
  cpg_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_cpg_column.items()}
70
- c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_c_column.items()}
71
- other_c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_other_c_column.items()}
96
+ c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_c_column.items()}
97
+ other_c_var_masks = {
98
+ ref: _col_mask_or_warn(col) for ref, col in reference_to_other_c_column.items()
99
+ }
100
+ a_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_a_column.items()}
72
101
 
73
102
  # prepare X as dense float32 for layer filling (we leave adata.X untouched)
74
103
  X = adata.X
75
104
  if sp.issparse(X):
76
105
  if verbose:
77
- print("Converting sparse X to dense array for layer construction (temporary).")
106
+ logger.info("Converting sparse X to dense array for layer construction (temporary).")
78
107
  X = X.toarray()
79
108
  X = np.asarray(X, dtype=np.float32)
80
109
 
@@ -83,11 +112,12 @@ def append_binary_layer_by_base_context(
83
112
  masked_cpg = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
84
113
  masked_any_c = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
85
114
  masked_other_c = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
115
+ masked_a = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
86
116
 
87
117
  # fill row-blocks per reference (this avoids creating a full row×var boolean mask)
88
118
  obs_ref_series = adata.obs[reference_column]
89
119
  for ref in references:
90
- rows_mask = (obs_ref_series.values == ref)
120
+ rows_mask = obs_ref_series.values == ref
91
121
  if not rows_mask.any():
92
122
  continue
93
123
  row_idx = np.nonzero(rows_mask)[0] # integer indices of rows for this ref
@@ -95,8 +125,9 @@ def append_binary_layer_by_base_context(
95
125
  # column masks for this ref
96
126
  gpc_cols = gpc_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
97
127
  cpg_cols = cpg_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
98
- c_cols = c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
128
+ c_cols = c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
99
129
  other_c_cols = other_c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
130
+ a_cols = a_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
100
131
 
101
132
  if gpc_cols.any():
102
133
  # assign only the submatrix (rows x selected cols)
@@ -107,6 +138,8 @@ def append_binary_layer_by_base_context(
107
138
  masked_any_c[np.ix_(row_idx, c_cols)] = X[np.ix_(row_idx, c_cols)]
108
139
  if other_c_cols.any():
109
140
  masked_other_c[np.ix_(row_idx, other_c_cols)] = X[np.ix_(row_idx, other_c_cols)]
141
+ if a_cols.any():
142
+ masked_a[np.ix_(row_idx, other_c_cols)] = X[np.ix_(row_idx, other_c_cols)]
110
143
 
111
144
  # Build combined layer:
112
145
  # - numeric_sum: sum where either exists, NaN where neither exists
@@ -121,21 +154,26 @@ def append_binary_layer_by_base_context(
121
154
  # combined_bool = (~gpc_nan & (masked_gpc != 0)) | (~cpg_nan & (masked_cpg != 0))
122
155
  # combined_layer = combined_bool.astype(np.float32)
123
156
 
124
- adata.layers['GpC_site_binary'] = masked_gpc
125
- adata.layers['CpG_site_binary'] = masked_cpg
126
- adata.layers['GpC_CpG_combined_site_binary'] = combined_sum
127
- adata.layers['C_site_binary'] = masked_any_c
128
- adata.layers['other_C_site_binary'] = masked_other_c
157
+ adata.layers["GpC_site_binary"] = masked_gpc
158
+ adata.layers["CpG_site_binary"] = masked_cpg
159
+ adata.layers["GpC_CpG_combined_site_binary"] = combined_sum
160
+ adata.layers["C_site_binary"] = masked_any_c
161
+ adata.layers["other_C_site_binary"] = masked_other_c
162
+ adata.layers["A_site_binary"] = masked_a
129
163
 
130
164
  if verbose:
165
+
131
166
  def _filled_positions(arr):
167
+ """Count the number of non-NaN positions in an array."""
132
168
  return int(np.sum(~np.isnan(arr)))
133
- print("Layer build summary (non-NaN cell counts):")
134
- print(f" GpC: {_filled_positions(masked_gpc)}")
135
- print(f" CpG: {_filled_positions(masked_cpg)}")
136
- print(f" GpC+CpG combined: {_filled_positions(combined_sum)}")
137
- print(f" C: {_filled_positions(masked_any_c)}")
138
- print(f" other_C: {_filled_positions(masked_other_c)}")
169
+
170
+ logger.info("Layer build summary (non-NaN cell counts):")
171
+ logger.info(" GpC: %s", _filled_positions(masked_gpc))
172
+ logger.info(" CpG: %s", _filled_positions(masked_cpg))
173
+ logger.info(" GpC+CpG combined: %s", _filled_positions(combined_sum))
174
+ logger.info(" C: %s", _filled_positions(masked_any_c))
175
+ logger.info(" other_C: %s", _filled_positions(masked_other_c))
176
+ logger.info(" A: %s", _filled_positions(masked_a))
139
177
 
140
178
  # mark as done
141
179
  adata.uns[uns_flag] = True
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import scipy.sparse as sp
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## calculate_complexity
2
4
 
3
5
  def calculate_complexity(adata, output_directory='', obs_column='Reference', sample_col='Sample_names', plot=True, save_plot=False):
@@ -21,9 +23,11 @@ def calculate_complexity(adata, output_directory='', obs_column='Reference', sam
21
23
  from scipy.optimize import curve_fit
22
24
 
23
25
  def lander_waterman(x, C0):
26
+ """Lander-Waterman curve for complexity estimation."""
24
27
  return C0 * (1 - np.exp(-x / C0))
25
28
 
26
29
  def count_unique_reads(reads, depth):
30
+ """Count unique reads in a subsample of the given depth."""
27
31
  subsample = np.random.choice(reads, depth, replace=False)
28
32
  return len(np.unique(subsample))
29
33
 
@@ -69,4 +73,4 @@ def calculate_complexity(adata, output_directory='', obs_column='Reference', sam
69
73
  plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
70
74
  plt.close()
71
75
  else:
72
- plt.show()
76
+ plt.show()
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## mark_duplicates
2
4
 
3
5
  def mark_duplicates(adata, layers, obs_column='Reference', sample_col='Sample_names', method='N_masked_distances', distance_thresholds={}):