smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,270 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+
6
+ import os
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+ def plot_read_qc_histograms(
13
+ adata,
14
+ outdir,
15
+ obs_keys,
16
+ sample_key,
17
+ bins=60,
18
+ clip_quantiles=(0.0, 0.995),
19
+ min_non_nan=10,
20
+ rows_per_fig=6,
21
+ topn_categories=15,
22
+ figsize_cell=(3.6, 2.6),
23
+ dpi=150,
24
+ ):
25
+ """
26
+ Plot a grid of QC histograms: rows = samples (from `sample_key`), columns = `obs_keys`.
27
+
28
+ Numeric columns -> histogram per sample.
29
+ Categorical columns -> bar chart of top categories per sample.
30
+
31
+ Saves paginated PNGs to `outdir`.
32
+
33
+ Parameters
34
+ ----------
35
+ adata : AnnData
36
+ outdir : str
37
+ obs_keys : list[str]
38
+ sample_key : str
39
+ Column in adata.obs defining rows (samples/barcodes).
40
+ bins : int
41
+ Histogram bins for numeric metrics.
42
+ clip_quantiles : tuple or None
43
+ Clip numeric data globally per metric for consistent axes, e.g. (0.0, 0.995).
44
+ min_non_nan : int
45
+ Minimum finite values to plot a panel.
46
+ rows_per_fig : int
47
+ Number of samples per page.
48
+ topn_categories : int
49
+ For categorical metrics, show top-N categories (per sample).
50
+ figsize_cell : (float, float)
51
+ Size of each subplot cell (width, height).
52
+ dpi : int
53
+ Figure resolution.
54
+ """
55
+ os.makedirs(outdir, exist_ok=True)
56
+
57
+ if sample_key not in adata.obs.columns:
58
+ raise KeyError(f"'{sample_key}' not found in adata.obs")
59
+
60
+ # Ensure sample_key is categorical for stable ordering
61
+ samples = adata.obs[sample_key]
62
+ if not pd.api.types.is_categorical_dtype(samples):
63
+ samples = samples.astype("category")
64
+ sample_levels = list(samples.cat.categories)
65
+
66
+ # Validate keys, and classify numeric vs categorical
67
+ valid_keys = []
68
+ is_numeric = {}
69
+ for key in obs_keys:
70
+ if key not in adata.obs.columns:
71
+ print(f"[WARN] '{key}' not found in obs; skipping.")
72
+ continue
73
+ s = adata.obs[key]
74
+ num = pd.api.types.is_numeric_dtype(s)
75
+ valid_keys.append(key)
76
+ is_numeric[key] = num
77
+ if not valid_keys:
78
+ print("[plot_read_qc_grid] No valid obs_keys to plot.")
79
+ return
80
+
81
+ # Precompute global numeric ranges (after clipping) so rows share x-axis per column
82
+ global_ranges = {}
83
+ for key in valid_keys:
84
+ if not is_numeric[key]:
85
+ continue
86
+ s = pd.to_numeric(adata.obs[key], errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
87
+ if s.size < min_non_nan:
88
+ # still set something to avoid errors; just use min/max or (0,1)
89
+ lo, hi = (0.0, 1.0) if s.size == 0 else (float(s.min()), float(s.max()))
90
+ else:
91
+ if clip_quantiles:
92
+ qlo = s.quantile(clip_quantiles[0]) if clip_quantiles[0] is not None else s.min()
93
+ qhi = s.quantile(clip_quantiles[1]) if clip_quantiles[1] is not None else s.max()
94
+ lo, hi = float(qlo), float(qhi)
95
+ if not (np.isfinite(lo) and np.isfinite(hi) and hi > lo):
96
+ lo, hi = float(s.min()), float(s.max())
97
+ else:
98
+ lo, hi = float(s.min()), float(s.max())
99
+ global_ranges[key] = (lo, hi)
100
+
101
+ def _sanitize(name: str) -> str:
102
+ return "".join(c if c.isalnum() or c in "-._" else "_" for c in str(name))
103
+
104
+ ncols = len(valid_keys)
105
+ fig_w = figsize_cell[0] * ncols
106
+ # rows per page is rows_per_fig; figure height scales accordingly
107
+ fig_h_unit = figsize_cell[1]
108
+
109
+ for start in range(0, len(sample_levels), rows_per_fig):
110
+ chunk = sample_levels[start:start + rows_per_fig]
111
+ nrows = len(chunk)
112
+ fig, axes = plt.subplots(
113
+ nrows=nrows, ncols=ncols,
114
+ figsize=(fig_w, fig_h_unit * nrows),
115
+ dpi=dpi,
116
+ squeeze=False,
117
+ )
118
+
119
+ for r, sample_val in enumerate(chunk):
120
+ row_mask = (adata.obs[sample_key].values == sample_val)
121
+ n_in_row = int(row_mask.sum())
122
+
123
+ for c, key in enumerate(valid_keys):
124
+ ax = axes[r, c]
125
+ series = adata.obs.loc[row_mask, key]
126
+
127
+ if is_numeric[key]:
128
+ x = pd.to_numeric(series, errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
129
+ if x.size < min_non_nan:
130
+ ax.text(0.5, 0.5, f"n={x.size} (<{min_non_nan})", ha="center", va="center")
131
+ else:
132
+ # clip to global range for consistent axes
133
+ lo, hi = global_ranges[key]
134
+ x = x.clip(lo, hi)
135
+ ax.hist(x.values, bins=bins, range=(lo, hi), edgecolor="black", alpha=0.7)
136
+ ax.set_xlim(lo, hi)
137
+ if r == 0:
138
+ ax.set_title(key)
139
+ if c == 0:
140
+ ax.set_ylabel(f"{sample_val}\n(n={n_in_row})")
141
+ ax.grid(alpha=0.25)
142
+ ax.set_xlabel("") # keep uncluttered; x-limit conveys scale
143
+ else:
144
+ vc = series.astype("category").value_counts(dropna=False)
145
+ if vc.sum() < min_non_nan:
146
+ ax.text(0.5, 0.5, f"n={vc.sum()} (<{min_non_nan})", ha="center", va="center")
147
+ else:
148
+ vc_top = vc.iloc[:topn_categories][::-1] # show top-N, reversed for barh
149
+ ax.barh(vc_top.index.astype(str), vc_top.values)
150
+ ax.invert_yaxis()
151
+ if r == 0:
152
+ ax.set_title(f"{key} (cat)")
153
+ if c == 0:
154
+ ax.set_ylabel(f"{sample_val}\n(n={n_in_row})")
155
+ ax.grid(alpha=0.25)
156
+ # trim labels to reduce clutter
157
+ if vc.sum() >= min_non_nan:
158
+ ax.tick_params(axis="y", labelsize=8)
159
+
160
+ plt.tight_layout()
161
+ page = start // rows_per_fig + 1
162
+ out_png = os.path.join(outdir, f"qc_grid_{_sanitize(sample_key)}_page{page}.png")
163
+ plt.savefig(out_png, bbox_inches="tight")
164
+ plt.close(fig)
165
+
166
+
167
+ # def plot_read_qc_histograms(
168
+ # adata,
169
+ # outdir,
170
+ # obs_keys,
171
+ # sample_key=None,
172
+ # *,
173
+ # bins=100,
174
+ # clip_quantiles=(0.0, 0.995),
175
+ # min_non_nan=10,
176
+ # figsize=(6, 4),
177
+ # dpi=150
178
+ # ):
179
+ # """
180
+ # Plots histograms for given obs_keys, optionally grouped by sample_key.
181
+
182
+ # Parameters
183
+ # ----------
184
+ # adata : AnnData
185
+ # AnnData object.
186
+ # outdir : str
187
+ # Output directory for PNG files.
188
+ # obs_keys : list[str]
189
+ # List of obs columns to plot.
190
+ # sample_key : str or None
191
+ # Column in adata.obs to group by (e.g., 'Barcode').
192
+ # If None, plots are for the full dataset only.
193
+ # bins : int
194
+ # Number of histogram bins for numeric data.
195
+ # clip_quantiles : tuple or None
196
+ # (low_q, high_q) to clip extreme values for plotting.
197
+ # min_non_nan : int
198
+ # Minimum number of finite values to plot.
199
+ # figsize : tuple
200
+ # Figure size.
201
+ # dpi : int
202
+ # Figure resolution.
203
+ # """
204
+ # os.makedirs(outdir, exist_ok=True)
205
+
206
+ # # Define grouping
207
+ # if sample_key and sample_key in adata.obs.columns:
208
+ # groups = adata.obs.groupby(sample_key)
209
+ # else:
210
+ # groups = [(None, adata.obs)] # single group
211
+
212
+ # for group_name, group_df in groups:
213
+ # # For each metric
214
+ # for key in obs_keys:
215
+ # if key not in group_df.columns:
216
+ # print(f"[WARN] '{key}' not found in obs; skipping.")
217
+ # continue
218
+
219
+ # series = group_df[key]
220
+
221
+ # # Numeric columns
222
+ # if pd.api.types.is_numeric_dtype(series):
223
+ # x = pd.to_numeric(series, errors="coerce").replace([np.inf, -np.inf], np.nan).dropna()
224
+ # if len(x) < min_non_nan:
225
+ # continue
226
+
227
+ # # Clip for better visualization
228
+ # if clip_quantiles:
229
+ # lo = x.quantile(clip_quantiles[0]) if clip_quantiles[0] is not None else x.min()
230
+ # hi = x.quantile(clip_quantiles[1]) if clip_quantiles[1] is not None else x.max()
231
+ # if np.isfinite(lo) and np.isfinite(hi) and hi > lo:
232
+ # x = x.clip(lo, hi)
233
+
234
+ # fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
235
+ # ax.hist(x, bins=bins, edgecolor="black", alpha=0.7)
236
+ # ax.set_xlabel(key)
237
+ # ax.set_ylabel("Count")
238
+
239
+ # title = f"{key}" if group_name is None else f"{key} — {sample_key}={group_name}"
240
+ # ax.set_title(title)
241
+
242
+ # plt.tight_layout()
243
+
244
+ # # Save PNG
245
+ # safe_group = "all" if group_name is None else str(group_name)
246
+ # fname = f"{key}_{sample_key}_{safe_group}.png" if sample_key else f"{key}.png"
247
+ # fname = fname.replace("/", "_")
248
+ # fig.savefig(os.path.join(outdir, fname))
249
+ # plt.close(fig)
250
+
251
+ # else:
252
+ # # Categorical columns
253
+ # vc = series.astype("category").value_counts(dropna=False)
254
+ # if vc.sum() < min_non_nan:
255
+ # continue
256
+
257
+ # fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
258
+ # vc.plot(kind="barh", ax=ax)
259
+ # ax.set_xlabel("Count")
260
+
261
+ # title = f"{key} (categorical)" if group_name is None else f"{key} — {sample_key}={group_name}"
262
+ # ax.set_title(title)
263
+
264
+ # plt.tight_layout()
265
+
266
+ # safe_group = "all" if group_name is None else str(group_name)
267
+ # fname = f"{key}_{sample_key}_{safe_group}.png" if sample_key else f"{key}.png"
268
+ # fname = fname.replace("/", "_")
269
+ # fig.savefig(os.path.join(outdir, fname))
270
+ # plt.close(fig)
@@ -0,0 +1,38 @@
1
+ from .add_read_length_and_mapping_qc import add_read_length_and_mapping_qc
2
+ from .append_base_context import append_base_context
3
+ from .append_binary_layer_by_base_context import append_binary_layer_by_base_context
4
+ from .binarize_on_Youden import binarize_on_Youden
5
+ from .calculate_complexity import calculate_complexity
6
+ from .calculate_complexity_II import calculate_complexity_II
7
+ from .calculate_read_modification_stats import calculate_read_modification_stats
8
+ from .calculate_coverage import calculate_coverage
9
+ from .calculate_position_Youden import calculate_position_Youden
10
+ from .calculate_read_length_stats import calculate_read_length_stats
11
+ from .clean_NaN import clean_NaN
12
+ from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
13
+ from .filter_reads_on_modification_thresholds import filter_reads_on_modification_thresholds
14
+ from .filter_reads_on_length_quality_mapping import filter_reads_on_length_quality_mapping
15
+ from .invert_adata import invert_adata
16
+ from .load_sample_sheet import load_sample_sheet
17
+ from .flag_duplicate_reads import flag_duplicate_reads
18
+ from .subsample_adata import subsample_adata
19
+
20
+ __all__ = [
21
+ "add_read_length_and_mapping_qc",
22
+ "append_base_context",
23
+ "append_binary_layer_by_base_context",
24
+ "binarize_on_Youden",
25
+ "calculate_complexity",
26
+ "calculate_read_modification_stats",
27
+ "calculate_coverage",
28
+ "calculate_position_Youden",
29
+ "calculate_read_length_stats",
30
+ "clean_NaN",
31
+ "filter_adata_by_nan_proportion",
32
+ "filter_reads_on_modification_thresholds",
33
+ "filter_reads_on_length_quality_mapping",
34
+ "invert_adata",
35
+ "load_sample_sheet",
36
+ "flag_duplicate_reads",
37
+ "subsample_adata"
38
+ ]
@@ -0,0 +1,129 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import scipy.sparse as sp
4
+ from typing import Optional, List, Dict, Union
5
+
6
+ def add_read_length_and_mapping_qc(
7
+ adata,
8
+ bam_files: Optional[List[str]] = None,
9
+ read_metrics: Optional[Dict[str, Union[list, tuple]]] = None,
10
+ uns_flag: str = "read_lenth_and_mapping_qc_performed",
11
+ extract_read_features_from_bam_callable = None,
12
+ bypass: bool = False,
13
+ force_redo: bool = True
14
+ ):
15
+ """
16
+ Populate adata.obs with read/mapping QC columns.
17
+
18
+ Parameters
19
+ ----------
20
+ adata
21
+ AnnData to annotate (modified in-place).
22
+ bam_files
23
+ Optional list of BAM files to extract metrics from. Ignored if read_metrics supplied.
24
+ read_metrics
25
+ Optional dict mapping obs_name -> [read_length, read_quality, reference_length, mapped_length, mapping_quality]
26
+ If provided, this will be used directly and bam_files will be ignored.
27
+ uns_flag
28
+ key in final_adata.uns used to record that QC was performed (kept the name with original misspelling).
29
+ extract_read_features_from_bam_callable
30
+ Optional callable(bam_path) -> dict mapping read_name -> list/tuple of metrics.
31
+ If not provided and bam_files is given, function will attempt to call `extract_read_features_from_bam`
32
+ from the global namespace (your existing helper).
33
+ Returns
34
+ -------
35
+ None (mutates final_adata in-place)
36
+ """
37
+
38
+ # Only run if not already performed
39
+ already = bool(adata.uns.get(uns_flag, False))
40
+ if (already and not force_redo) or bypass:
41
+ # QC already performed; nothing to do
42
+ return
43
+
44
+ # Build read_metrics dict either from provided arg or by extracting from bam files
45
+ if read_metrics is None:
46
+ read_metrics = {}
47
+ if bam_files:
48
+ extractor = extract_read_features_from_bam_callable or globals().get("extract_read_features_from_bam")
49
+ if extractor is None:
50
+ raise ValueError("No `read_metrics` provided and `extract_read_features_from_bam` not found.")
51
+ for bam in bam_files:
52
+ bam_read_metrics = extractor(bam)
53
+ if not isinstance(bam_read_metrics, dict):
54
+ raise ValueError(f"extract_read_features_from_bam returned non-dict for {bam}")
55
+ read_metrics.update(bam_read_metrics)
56
+ else:
57
+ # nothing to do
58
+ read_metrics = {}
59
+
60
+ # Convert read_metrics dict -> DataFrame (rows = read id)
61
+ # Values may be lists/tuples or scalars; prefer lists/tuples with 5 entries.
62
+ if len(read_metrics) == 0:
63
+ # fill with NaNs
64
+ n = adata.n_obs
65
+ adata.obs['read_length'] = np.full(n, np.nan)
66
+ adata.obs['mapped_length'] = np.full(n, np.nan)
67
+ adata.obs['reference_length'] = np.full(n, np.nan)
68
+ adata.obs['read_quality'] = np.full(n, np.nan)
69
+ adata.obs['mapping_quality'] = np.full(n, np.nan)
70
+ else:
71
+ # Build DF robustly
72
+ # Convert values to lists where possible, else to [val, val, val...]
73
+ max_cols = 5
74
+ rows = {}
75
+ for k, v in read_metrics.items():
76
+ if isinstance(v, (list, tuple, np.ndarray)):
77
+ vals = list(v)
78
+ else:
79
+ # scalar -> replicate into 5 columns to preserve original behavior
80
+ vals = [v] * max_cols
81
+ # Ensure length >= 5
82
+ if len(vals) < max_cols:
83
+ vals = vals + [np.nan] * (max_cols - len(vals))
84
+ rows[k] = vals[:max_cols]
85
+
86
+ df = pd.DataFrame.from_dict(rows, orient='index', columns=[
87
+ 'read_length', 'read_quality', 'reference_length', 'mapped_length', 'mapping_quality'
88
+ ])
89
+
90
+ # Reindex to final_adata.obs_names so order matches adata
91
+ # If obs_names are not present as keys in df, the results will be NaN
92
+ df_reindexed = df.reindex(adata.obs_names).astype(float)
93
+
94
+ adata.obs['read_length'] = df_reindexed['read_length'].values
95
+ adata.obs['mapped_length'] = df_reindexed['mapped_length'].values
96
+ adata.obs['reference_length'] = df_reindexed['reference_length'].values
97
+ adata.obs['read_quality'] = df_reindexed['read_quality'].values
98
+ adata.obs['mapping_quality'] = df_reindexed['mapping_quality'].values
99
+
100
+ # Compute ratio columns safely (avoid divide-by-zero and preserve NaN)
101
+ # read_length_to_reference_length_ratio
102
+ rl = pd.to_numeric(adata.obs['read_length'], errors='coerce').to_numpy(dtype=float)
103
+ ref_len = pd.to_numeric(adata.obs['reference_length'], errors='coerce').to_numpy(dtype=float)
104
+ mapped_len = pd.to_numeric(adata.obs['mapped_length'], errors='coerce').to_numpy(dtype=float)
105
+
106
+ # safe divisions: use np.where to avoid warnings and replace inf with nan
107
+ with np.errstate(divide='ignore', invalid='ignore'):
108
+ rl_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), rl / ref_len, np.nan)
109
+ mapped_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), mapped_len / ref_len, np.nan)
110
+ mapped_to_read = np.where((rl != 0) & np.isfinite(rl), mapped_len / rl, np.nan)
111
+
112
+ adata.obs['read_length_to_reference_length_ratio'] = rl_to_ref
113
+ adata.obs['mapped_length_to_reference_length_ratio'] = mapped_to_ref
114
+ adata.obs['mapped_length_to_read_length_ratio'] = mapped_to_read
115
+
116
+ # Add read level raw modification signal: sum over X rows
117
+ X = adata.X
118
+ if sp.issparse(X):
119
+ # sum returns (n_obs, 1) sparse matrix; convert to 1d array
120
+ raw_sig = np.asarray(X.sum(axis=1)).ravel()
121
+ else:
122
+ raw_sig = np.asarray(X.sum(axis=1)).ravel()
123
+
124
+ adata.obs['Raw_modification_signal'] = raw_sig
125
+
126
+ # mark as done
127
+ adata.uns[uns_flag] = True
128
+
129
+ return None
@@ -0,0 +1,122 @@
1
+ def append_base_context(adata,
2
+ obs_column='Reference_strand',
3
+ use_consensus=False,
4
+ native=False,
5
+ mod_target_bases=['GpC', 'CpG'],
6
+ bypass=False,
7
+ force_redo=False,
8
+ uns_flag='base_context_added'
9
+ ):
10
+ """
11
+ Adds nucleobase context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
12
+
13
+ Parameters:
14
+ adata (AnnData): The input adata object.
15
+ obs_column (str): The observation column in which to stratify on. Default is 'Reference_strand', which should not be changed for most purposes.
16
+ use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
17
+ native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
18
+ mod_target_bases (list): Base contexts that may be modified.
19
+
20
+ Returns:
21
+ None
22
+ """
23
+ import numpy as np
24
+ import anndata as ad
25
+
26
+ # Only run if not already performed
27
+ already = bool(adata.uns.get(uns_flag, False))
28
+ if (already and not force_redo) or bypass:
29
+ # QC already performed; nothing to do
30
+ return
31
+
32
+ print('Adding base context based on reference FASTA sequence for sample')
33
+ categories = adata.obs[obs_column].cat.categories
34
+ site_types = []
35
+
36
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
37
+ site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'any_C_site']
38
+
39
+ if 'A' in mod_target_bases:
40
+ site_types += ['A_site']
41
+
42
+ for cat in categories:
43
+ # Assess if the strand is the top or bottom strand converted
44
+ if 'top' in cat:
45
+ strand = 'top'
46
+ elif 'bottom' in cat:
47
+ strand = 'bottom'
48
+
49
+ if native:
50
+ basename = cat.split(f"_{strand}")[0]
51
+ if use_consensus:
52
+ sequence = adata.uns[f'{basename}_consensus_sequence']
53
+ else:
54
+ # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
55
+ sequence = adata.uns[f'{basename}_FASTA_sequence']
56
+ else:
57
+ basename = cat.split(f"_{strand}")[0]
58
+ if use_consensus:
59
+ sequence = adata.uns[f'{basename}_consensus_sequence']
60
+ else:
61
+ # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
62
+ sequence = adata.uns[f'{basename}_FASTA_sequence']
63
+ # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
64
+ boolean_dict = {}
65
+ for site_type in site_types:
66
+ boolean_dict[f'{cat}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
67
+
68
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
69
+ if strand == 'top':
70
+ # Iterate through the sequence and apply the criteria
71
+ for i in range(1, len(sequence) - 1):
72
+ if sequence[i] == 'C':
73
+ boolean_dict[f'{cat}_any_C_site'][i] = True
74
+ if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
75
+ boolean_dict[f'{cat}_GpC_site'][i] = True
76
+ elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
77
+ boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
78
+ elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
79
+ boolean_dict[f'{cat}_CpG_site'][i] = True
80
+ elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
81
+ boolean_dict[f'{cat}_other_C_site'][i] = True
82
+ elif strand == 'bottom':
83
+ # Iterate through the sequence and apply the criteria
84
+ for i in range(1, len(sequence) - 1):
85
+ if sequence[i] == 'G':
86
+ boolean_dict[f'{cat}_any_C_site'][i] = True
87
+ if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
88
+ boolean_dict[f'{cat}_GpC_site'][i] = True
89
+ elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
90
+ boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
91
+ elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
92
+ boolean_dict[f'{cat}_CpG_site'][i] = True
93
+ elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
94
+ boolean_dict[f'{cat}_other_C_site'][i] = True
95
+ else:
96
+ print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
97
+
98
+ if 'A' in mod_target_bases:
99
+ if strand == 'top':
100
+ # Iterate through the sequence and apply the criteria
101
+ for i in range(1, len(sequence) - 1):
102
+ if sequence[i] == 'A':
103
+ boolean_dict[f'{cat}_A_site'][i] = True
104
+ elif strand == 'bottom':
105
+ # Iterate through the sequence and apply the criteria
106
+ for i in range(1, len(sequence) - 1):
107
+ if sequence[i] == 'T':
108
+ boolean_dict[f'{cat}_A_site'][i] = True
109
+ else:
110
+ print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
111
+
112
+ for site_type in site_types:
113
+ adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
114
+ if native:
115
+ adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].layers['binarized_methylation']
116
+ else:
117
+ adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
118
+
119
+ # mark as done
120
+ adata.uns[uns_flag] = True
121
+
122
+ return None