smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +34 -6
  7. smftools/cli/hmm_adata.py +239 -33
  8. smftools/cli/latent_adata.py +318 -0
  9. smftools/cli/load_adata.py +167 -131
  10. smftools/cli/preprocess_adata.py +180 -53
  11. smftools/cli/spatial_adata.py +152 -100
  12. smftools/cli_entry.py +38 -1
  13. smftools/config/__init__.py +2 -0
  14. smftools/config/conversion.yaml +11 -1
  15. smftools/config/default.yaml +42 -2
  16. smftools/config/experiment_config.py +59 -1
  17. smftools/constants.py +65 -0
  18. smftools/datasets/__init__.py +2 -0
  19. smftools/hmm/HMM.py +97 -3
  20. smftools/hmm/__init__.py +24 -13
  21. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  22. smftools/hmm/archived/calculate_distances.py +2 -0
  23. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  24. smftools/hmm/archived/train_hmm.py +2 -0
  25. smftools/hmm/call_hmm_peaks.py +5 -2
  26. smftools/hmm/display_hmm.py +4 -1
  27. smftools/hmm/hmm_readwrite.py +7 -2
  28. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  29. smftools/informatics/__init__.py +59 -34
  30. smftools/informatics/archived/bam_conversion.py +2 -0
  31. smftools/informatics/archived/bam_direct.py +2 -0
  32. smftools/informatics/archived/basecall_pod5s.py +2 -0
  33. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  34. smftools/informatics/archived/conversion_smf.py +2 -0
  35. smftools/informatics/archived/deaminase_smf.py +1 -0
  36. smftools/informatics/archived/direct_smf.py +2 -0
  37. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  38. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  39. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  40. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  41. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  42. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  43. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  44. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  45. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  48. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  49. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  50. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  52. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  53. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  54. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  55. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  56. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  57. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  58. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  59. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  60. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  61. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  62. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  63. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  64. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  65. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  66. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  67. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  68. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  69. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  70. smftools/informatics/archived/subsample_pod5.py +2 -0
  71. smftools/informatics/bam_functions.py +1093 -176
  72. smftools/informatics/basecalling.py +2 -0
  73. smftools/informatics/bed_functions.py +271 -61
  74. smftools/informatics/binarize_converted_base_identities.py +3 -0
  75. smftools/informatics/complement_base_list.py +2 -0
  76. smftools/informatics/converted_BAM_to_adata.py +641 -176
  77. smftools/informatics/fasta_functions.py +94 -10
  78. smftools/informatics/h5ad_functions.py +123 -4
  79. smftools/informatics/modkit_extract_to_adata.py +1019 -431
  80. smftools/informatics/modkit_functions.py +2 -0
  81. smftools/informatics/ohe.py +2 -0
  82. smftools/informatics/pod5_functions.py +3 -2
  83. smftools/informatics/sequence_encoding.py +72 -0
  84. smftools/logging_utils.py +21 -2
  85. smftools/machine_learning/__init__.py +22 -6
  86. smftools/machine_learning/data/__init__.py +2 -0
  87. smftools/machine_learning/data/anndata_data_module.py +18 -4
  88. smftools/machine_learning/data/preprocessing.py +2 -0
  89. smftools/machine_learning/evaluation/__init__.py +2 -0
  90. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  91. smftools/machine_learning/evaluation/evaluators.py +14 -9
  92. smftools/machine_learning/inference/__init__.py +2 -0
  93. smftools/machine_learning/inference/inference_utils.py +2 -0
  94. smftools/machine_learning/inference/lightning_inference.py +6 -1
  95. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  96. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  97. smftools/machine_learning/models/__init__.py +2 -0
  98. smftools/machine_learning/models/base.py +7 -2
  99. smftools/machine_learning/models/cnn.py +7 -2
  100. smftools/machine_learning/models/lightning_base.py +16 -11
  101. smftools/machine_learning/models/mlp.py +5 -1
  102. smftools/machine_learning/models/positional.py +7 -2
  103. smftools/machine_learning/models/rnn.py +5 -1
  104. smftools/machine_learning/models/sklearn_models.py +14 -9
  105. smftools/machine_learning/models/transformer.py +7 -2
  106. smftools/machine_learning/models/wrappers.py +6 -2
  107. smftools/machine_learning/training/__init__.py +2 -0
  108. smftools/machine_learning/training/train_lightning_model.py +13 -3
  109. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  110. smftools/machine_learning/utils/__init__.py +2 -0
  111. smftools/machine_learning/utils/device.py +5 -1
  112. smftools/machine_learning/utils/grl.py +5 -1
  113. smftools/metadata.py +1 -1
  114. smftools/optional_imports.py +31 -0
  115. smftools/plotting/__init__.py +41 -31
  116. smftools/plotting/autocorrelation_plotting.py +9 -5
  117. smftools/plotting/classifiers.py +16 -4
  118. smftools/plotting/general_plotting.py +2415 -629
  119. smftools/plotting/hmm_plotting.py +97 -9
  120. smftools/plotting/position_stats.py +15 -7
  121. smftools/plotting/qc_plotting.py +6 -1
  122. smftools/preprocessing/__init__.py +36 -37
  123. smftools/preprocessing/append_base_context.py +17 -17
  124. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  125. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  126. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  127. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  128. smftools/preprocessing/archived/preprocessing.py +2 -0
  129. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  130. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  131. smftools/preprocessing/calculate_complexity_II.py +4 -1
  132. smftools/preprocessing/calculate_consensus.py +1 -1
  133. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  134. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  135. smftools/preprocessing/calculate_position_Youden.py +9 -2
  136. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  137. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  138. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  139. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  140. smftools/preprocessing/make_dirs.py +2 -1
  141. smftools/preprocessing/min_non_diagonal.py +2 -0
  142. smftools/preprocessing/recipes.py +2 -0
  143. smftools/readwrite.py +53 -17
  144. smftools/schema/anndata_schema_v1.yaml +15 -1
  145. smftools/tools/__init__.py +30 -18
  146. smftools/tools/archived/apply_hmm.py +2 -0
  147. smftools/tools/archived/classifiers.py +2 -0
  148. smftools/tools/archived/classify_methylated_features.py +2 -0
  149. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  150. smftools/tools/archived/subset_adata_v1.py +2 -0
  151. smftools/tools/archived/subset_adata_v2.py +2 -0
  152. smftools/tools/calculate_leiden.py +57 -0
  153. smftools/tools/calculate_nmf.py +119 -0
  154. smftools/tools/calculate_umap.py +93 -8
  155. smftools/tools/cluster_adata_on_methylation.py +7 -1
  156. smftools/tools/position_stats.py +17 -27
  157. smftools/tools/rolling_nn_distance.py +235 -0
  158. smftools/tools/tensor_factorization.py +169 -0
  159. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
  160. smftools-0.3.1.dist-info/RECORD +189 -0
  161. smftools-0.2.5.dist-info/RECORD +0 -181
  162. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  163. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  164. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,21 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
- from typing import Optional, Tuple, Union
4
+ from typing import Optional, Sequence, Tuple, Union
3
5
 
4
- import matplotlib.pyplot as plt
5
6
  import numpy as np
6
- from matplotlib.backends.backend_pdf import PdfPages
7
+ import pandas as pd
8
+
9
+ from smftools.optional_imports import require
10
+
11
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
12
+ mpl_colors = require("matplotlib.colors", extra="plotting", purpose="HMM plots")
13
+ pdf_backend = require(
14
+ "matplotlib.backends.backend_pdf",
15
+ extra="plotting",
16
+ purpose="PDF output",
17
+ )
18
+ PdfPages = pdf_backend.PdfPages
7
19
 
8
20
 
9
21
  def plot_hmm_size_contours(
@@ -22,6 +34,9 @@ def plot_hmm_size_contours(
22
34
  dpi: int = 150,
23
35
  vmin: Optional[float] = None,
24
36
  vmax: Optional[float] = None,
37
+ feature_ranges: Optional[Sequence[Tuple[int, int, str]]] = None,
38
+ zero_color: str = "#f5f1e8",
39
+ nan_color: str = "#E6E6E6",
25
40
  # ---------------- smoothing params ----------------
26
41
  smoothing_sigma: Optional[Union[float, Tuple[float, float]]] = None,
27
42
  normalize_after_smoothing: bool = True,
@@ -30,6 +45,9 @@ def plot_hmm_size_contours(
30
45
  """
31
46
  Create contour/pcolormesh plots of P(length | position) using a length-encoded HMM layer.
32
47
  Optional Gaussian smoothing applied to the 2D probability grid before plotting.
48
+ When feature_ranges is provided, each length row is assigned a base color based
49
+ on the matching (min_len, max_len) range and the probability value modulates
50
+ the color intensity.
33
51
 
34
52
  smoothing_sigma: None or 0 -> no smoothing.
35
53
  float -> same sigma applied to (length_axis, position_axis)
@@ -38,6 +56,51 @@ def plot_hmm_size_contours(
38
56
 
39
57
  Other args are the same as prior function.
40
58
  """
59
+ feature_ranges = tuple(feature_ranges or ())
60
+
61
+ def _resolve_length_color(length: int, fallback: str) -> Tuple[float, float, float, float]:
62
+ for min_len, max_len, color in feature_ranges:
63
+ if min_len <= length <= max_len:
64
+ return mpl_colors.to_rgba(color)
65
+ return mpl_colors.to_rgba(fallback)
66
+
67
+ def _build_length_facecolors(
68
+ Z_values: np.ndarray,
69
+ lengths: np.ndarray,
70
+ fallback_color: str,
71
+ *,
72
+ vmin_local: Optional[float],
73
+ vmax_local: Optional[float],
74
+ ) -> np.ndarray:
75
+ zero_rgba = np.array(mpl_colors.to_rgba(zero_color))
76
+ nan_rgba = np.array(mpl_colors.to_rgba(nan_color))
77
+ base_colors = np.array(
78
+ [_resolve_length_color(int(length), fallback_color) for length in lengths],
79
+ dtype=float,
80
+ )
81
+ base_colors[:, 3] = 1.0
82
+
83
+ scale = np.array(Z_values, copy=True, dtype=float)
84
+ finite_mask = np.isfinite(scale)
85
+ if not finite_mask.any():
86
+ facecolors = np.zeros(scale.shape + (4,), dtype=float)
87
+ facecolors[:] = nan_rgba
88
+ return facecolors.reshape(-1, 4)
89
+
90
+ vmin_use = np.nanmin(scale) if vmin_local is None else vmin_local
91
+ vmax_use = np.nanmax(scale) if vmax_local is None else vmax_local
92
+ denom = vmax_use - vmin_use
93
+ if denom <= 0:
94
+ norm = np.zeros_like(scale)
95
+ else:
96
+ norm = (scale - vmin_use) / denom
97
+ norm = np.clip(norm, 0, 1)
98
+
99
+ row_colors = base_colors[:, None, :]
100
+ facecolors = zero_rgba + norm[..., None] * (row_colors - zero_rgba)
101
+ facecolors[..., 3] = 1.0
102
+ facecolors[~finite_mask] = nan_rgba
103
+ return facecolors.reshape(-1, 4)
41
104
 
42
105
  # --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
43
106
  def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
@@ -140,7 +203,8 @@ def plot_hmm_size_contours(
140
203
  figs = []
141
204
 
142
205
  # decide global max length to allocate y axis (cap to avoid huge memory)
143
- observed_max_len = int(np.max(full_layer)) if full_layer.size > 0 else 0
206
+ finite_lengths = full_layer[np.isfinite(full_layer) & (full_layer > 0)]
207
+ observed_max_len = int(np.nanmax(finite_lengths)) if finite_lengths.size > 0 else 0
144
208
  if max_length_cap is None:
145
209
  max_len = observed_max_len
146
210
  else:
@@ -195,10 +259,15 @@ def plot_hmm_size_contours(
195
259
  ax.text(0.5, 0.5, "no data", ha="center", va="center")
196
260
  ax.set_title(f"{sample} / {ref}")
197
261
  continue
262
+ valid_lengths = sub[np.isfinite(sub) & (sub > 0)]
263
+ if valid_lengths.size == 0:
264
+ ax.text(0.5, 0.5, "no data", ha="center", va="center")
265
+ ax.set_title(f"{sample} / {ref}")
266
+ continue
198
267
 
199
268
  # compute counts per length per position
200
269
  n_positions = sub.shape[1]
201
- max_len_local = int(sub.max()) if sub.size > 0 else 0
270
+ max_len_local = int(valid_lengths.max()) if valid_lengths.size > 0 else 0
202
271
  max_len_here = min(max_len, max_len_local)
203
272
 
204
273
  lengths_range = np.arange(1, max_len_here + 1, dtype=int)
@@ -209,7 +278,7 @@ def plot_hmm_size_contours(
209
278
  # fill Z by efficient bincount across columns
210
279
  for j in range(n_positions):
211
280
  col_vals = sub[:, j]
212
- pos_vals = col_vals[col_vals > 0].astype(int)
281
+ pos_vals = col_vals[np.isfinite(col_vals) & (col_vals > 0)].astype(int)
213
282
  if pos_vals.size == 0:
214
283
  continue
215
284
  clipped = np.clip(pos_vals, 1, max_len_here)
@@ -248,9 +317,28 @@ def plot_hmm_size_contours(
248
317
  dy = 1.0
249
318
  y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
250
319
 
251
- pcm = ax.pcolormesh(
252
- x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
253
- )
320
+ if feature_ranges:
321
+ fallback_color = mpl_colors.to_rgba(plt.get_cmap(cmap)(1.0))
322
+ facecolors = _build_length_facecolors(
323
+ Z_plot,
324
+ lengths_range,
325
+ fallback_color,
326
+ vmin_local=vmin,
327
+ vmax_local=vmax,
328
+ )
329
+ pcm = ax.pcolormesh(
330
+ x_edges,
331
+ y_edges,
332
+ Z_plot,
333
+ shading="auto",
334
+ vmin=vmin,
335
+ vmax=vmax,
336
+ facecolors=facecolors,
337
+ )
338
+ else:
339
+ pcm = ax.pcolormesh(
340
+ x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
341
+ )
254
342
  ax.set_title(f"{sample} / {ref}")
255
343
  ax.set_ylabel("length")
256
344
  if i_row == rows_on_page - 1:
@@ -1,3 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+
1
6
  def plot_volcano_relative_risk(
2
7
  results_dict,
3
8
  save_path=None,
@@ -22,7 +27,7 @@ def plot_volcano_relative_risk(
22
27
  """
23
28
  import os
24
29
 
25
- import matplotlib.pyplot as plt
30
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
26
31
 
27
32
  for ref, group_results in results_dict.items():
28
33
  for group_label, (results_df, _) in group_results.items():
@@ -124,7 +129,7 @@ def plot_bar_relative_risk(
124
129
  """
125
130
  import os
126
131
 
127
- import matplotlib.pyplot as plt
132
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
128
133
 
129
134
  for ref, group_data in results_dict.items():
130
135
  for group_label, (df, _) in group_data.items():
@@ -229,10 +234,11 @@ def plot_positionwise_matrix(
229
234
  """
230
235
  import os
231
236
 
232
- import matplotlib.pyplot as plt
233
237
  import numpy as np
234
238
  import pandas as pd
235
- import seaborn as sns
239
+
240
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
241
+ sns = require("seaborn", extra="plotting", purpose="position stats plots")
236
242
 
237
243
  def find_closest_index(index, target):
238
244
  """Find the index value closest to a target value."""
@@ -408,12 +414,14 @@ def plot_positionwise_matrix_grid(
408
414
  """
409
415
  import os
410
416
 
411
- import matplotlib.pyplot as plt
412
417
  import numpy as np
413
418
  import pandas as pd
414
- import seaborn as sns
415
419
  from joblib import Parallel, delayed
416
- from matplotlib.gridspec import GridSpec
420
+
421
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
422
+ sns = require("seaborn", extra="plotting", purpose="position stats plots")
423
+ grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="position stats plots")
424
+ GridSpec = grid_spec.GridSpec
417
425
 
418
426
  matrices = adata.uns[key]
419
427
  group_labels = list(matrices.keys())
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
 
3
- import matplotlib.pyplot as plt
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
 
8
+ from smftools.optional_imports import require
9
+
10
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="QC plots")
11
+
7
12
 
8
13
  def plot_read_qc_histograms(
9
14
  adata,
@@ -1,38 +1,37 @@
1
- from .append_base_context import append_base_context
2
- from .append_binary_layer_by_base_context import append_binary_layer_by_base_context
3
- from .binarize import binarize_adata
4
- from .binarize_on_Youden import binarize_on_Youden
5
- from .calculate_complexity_II import calculate_complexity_II
6
- from .calculate_coverage import calculate_coverage
7
- from .calculate_position_Youden import calculate_position_Youden
8
- from .calculate_read_length_stats import calculate_read_length_stats
9
- from .calculate_read_modification_stats import calculate_read_modification_stats
10
- from .clean_NaN import clean_NaN
11
- from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
12
- from .filter_reads_on_length_quality_mapping import filter_reads_on_length_quality_mapping
13
- from .filter_reads_on_modification_thresholds import filter_reads_on_modification_thresholds
14
- from .flag_duplicate_reads import flag_duplicate_reads
15
- from .invert_adata import invert_adata
16
- from .load_sample_sheet import load_sample_sheet
17
- from .reindex_references_adata import reindex_references_adata
18
- from .subsample_adata import subsample_adata
1
+ from __future__ import annotations
19
2
 
20
- __all__ = [
21
- "append_base_context",
22
- "append_binary_layer_by_base_context",
23
- "binarize_on_Youden",
24
- "binarize_adata",
25
- "calculate_complexity_II",
26
- "calculate_read_modification_stats",
27
- "calculate_coverage",
28
- "calculate_position_Youden",
29
- "calculate_read_length_stats",
30
- "clean_NaN",
31
- "filter_adata_by_nan_proportion",
32
- "filter_reads_on_modification_thresholds",
33
- "filter_reads_on_length_quality_mapping",
34
- "invert_adata",
35
- "load_sample_sheet",
36
- "flag_duplicate_reads",
37
- "subsample_adata",
38
- ]
3
+ from importlib import import_module
4
+
5
+ _LAZY_ATTRS = {
6
+ "append_base_context": "smftools.preprocessing.append_base_context",
7
+ "append_binary_layer_by_base_context": "smftools.preprocessing.append_binary_layer_by_base_context",
8
+ "append_mismatch_frequency_sites": "smftools.preprocessing.append_mismatch_frequency_sites",
9
+ "binarize_adata": "smftools.preprocessing.binarize",
10
+ "binarize_on_Youden": "smftools.preprocessing.binarize_on_Youden",
11
+ "calculate_complexity_II": "smftools.preprocessing.calculate_complexity_II",
12
+ "calculate_coverage": "smftools.preprocessing.calculate_coverage",
13
+ "calculate_position_Youden": "smftools.preprocessing.calculate_position_Youden",
14
+ "calculate_read_length_stats": "smftools.preprocessing.calculate_read_length_stats",
15
+ "calculate_read_modification_stats": "smftools.preprocessing.calculate_read_modification_stats",
16
+ "clean_NaN": "smftools.preprocessing.clean_NaN",
17
+ "filter_adata_by_nan_proportion": "smftools.preprocessing.filter_adata_by_nan_proportion",
18
+ "filter_reads_on_length_quality_mapping": "smftools.preprocessing.filter_reads_on_length_quality_mapping",
19
+ "filter_reads_on_modification_thresholds": "smftools.preprocessing.filter_reads_on_modification_thresholds",
20
+ "flag_duplicate_reads": "smftools.preprocessing.flag_duplicate_reads",
21
+ "invert_adata": "smftools.preprocessing.invert_adata",
22
+ "load_sample_sheet": "smftools.preprocessing.load_sample_sheet",
23
+ "reindex_references_adata": "smftools.preprocessing.reindex_references_adata",
24
+ "subsample_adata": "smftools.preprocessing.subsample_adata",
25
+ }
26
+
27
+
28
+ def __getattr__(name: str):
29
+ if name in _LAZY_ATTRS:
30
+ module = import_module(_LAZY_ATTRS[name])
31
+ attr = getattr(module, name)
32
+ globals()[name] = attr
33
+ return attr
34
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
35
+
36
+
37
+ __all__ = list(_LAZY_ATTRS.keys())
@@ -133,23 +133,23 @@ def append_base_context(
133
133
  adata.var[f"{ref}_{site_type}_valid_coverage"] = (
134
134
  (adata.var[f"{ref}_{site_type}"]) & (adata.var[f"position_in_{ref}"])
135
135
  )
136
- if native:
137
- adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
138
- :, adata.var[f"{ref}_{site_type}_valid_coverage"]
139
- ].layers["binarized_methylation"]
140
- else:
141
- adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
142
- :, adata.var[f"{ref}_{site_type}_valid_coverage"]
143
- ].X
144
- else:
145
- pass
146
-
147
- if native:
148
- adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
149
- "binarized_methylation"
150
- ]
151
- else:
152
- adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
136
+ # if native:
137
+ # adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
138
+ # :, adata.var[f"{ref}_{site_type}_valid_coverage"]
139
+ # ].layers["binarized_methylation"]
140
+ # else:
141
+ # adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
142
+ # :, adata.var[f"{ref}_{site_type}_valid_coverage"]
143
+ # ].X
144
+ # else:
145
+ # pass
146
+
147
+ # if native:
148
+ # adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
149
+ # "binarized_methylation"
150
+ # ]
151
+ # else:
152
+ # adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
153
153
 
154
154
  # mark as done
155
155
  adata.uns[uns_flag] = True
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Iterable, Sequence
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
9
+ from smftools.logging_utils import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def append_mismatch_frequency_sites(
18
+ adata: "ad.AnnData",
19
+ ref_column: str = "Reference_strand",
20
+ mismatch_layer: str = "mismatch_integer_encoding",
21
+ read_span_layer: str = "read_span_mask",
22
+ mismatch_frequency_range: Sequence[float] | None = (0.05, 0.95),
23
+ uns_flag: str = "append_mismatch_frequency_sites_performed",
24
+ force_redo: bool = False,
25
+ bypass: bool = False,
26
+ ) -> None:
27
+ """Append mismatch frequency metadata and variable-site flags per reference.
28
+
29
+ Args:
30
+ adata: AnnData object.
31
+ ref_column: Obs column defining reference categories.
32
+ mismatch_layer: Layer containing mismatch integer encodings.
33
+ read_span_layer: Layer containing read span masks (1=covered, 0=not covered).
34
+ mismatch_frequency_range: Lower/upper bounds (inclusive) for variable site flagging.
35
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
36
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
37
+ bypass: Whether to skip running this step.
38
+ """
39
+ if bypass:
40
+ return
41
+
42
+ already = bool(adata.uns.get(uns_flag, False))
43
+ if already and not force_redo:
44
+ return
45
+
46
+ if mismatch_layer not in adata.layers:
47
+ logger.debug(
48
+ "Mismatch layer '%s' not found; skipping mismatch frequency step.", mismatch_layer
49
+ )
50
+ return
51
+
52
+ mismatch_map = adata.uns.get("mismatch_integer_encoding_map", {})
53
+ if not mismatch_map:
54
+ logger.debug("Mismatch encoding map not found; skipping mismatch frequency step.")
55
+ return
56
+
57
+ n_value = mismatch_map.get("N", MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["N"])
58
+ pad_value = mismatch_map.get("PAD", MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["PAD"])
59
+
60
+ base_int_to_label = {
61
+ int(value): str(base)
62
+ for base, value in mismatch_map.items()
63
+ if base not in {"N", "PAD"} and isinstance(value, (int, np.integer))
64
+ }
65
+ if not base_int_to_label:
66
+ logger.debug("Mismatch encoding map missing base labels; skipping mismatch frequency step.")
67
+ return
68
+
69
+ has_span_mask = read_span_layer in adata.layers
70
+ if not has_span_mask:
71
+ logger.debug(
72
+ "Read span mask '%s' not found; mismatch frequencies will be computed over all reads.",
73
+ read_span_layer,
74
+ )
75
+
76
+ references = adata.obs[ref_column].cat.categories
77
+ n_vars = adata.shape[1]
78
+
79
+ if mismatch_frequency_range is None:
80
+ mismatch_frequency_range = (0.0, 1.0)
81
+
82
+ lower_bound, upper_bound = mismatch_frequency_range
83
+
84
+ for ref in references:
85
+ ref_mask = adata.obs[ref_column] == ref
86
+ ref_position_mask = adata.var.get(f"position_in_{ref}")
87
+ if ref_position_mask is None:
88
+ ref_position_mask = pd.Series(np.ones(n_vars, dtype=bool), index=adata.var.index)
89
+ else:
90
+ ref_position_mask = ref_position_mask.astype(bool)
91
+
92
+ frequency_values = np.full(n_vars, np.nan, dtype=float)
93
+ variable_flags = np.zeros(n_vars, dtype=bool)
94
+ mismatch_base_frequencies: list[list[tuple[str, float]]] = [[] for _ in range(n_vars)]
95
+
96
+ if ref_mask.sum() == 0:
97
+ adata.var[f"{ref}_mismatch_frequency"] = pd.Series(
98
+ frequency_values, index=adata.var.index
99
+ )
100
+ adata.var[f"{ref}_variable_sequence_site"] = pd.Series(
101
+ variable_flags, index=adata.var.index
102
+ )
103
+ adata.var[f"{ref}_mismatch_base_frequencies"] = pd.Series(
104
+ mismatch_base_frequencies, index=adata.var.index
105
+ )
106
+ continue
107
+
108
+ mismatch_matrix = np.asarray(adata.layers[mismatch_layer][ref_mask])
109
+ if has_span_mask:
110
+ span_matrix = np.asarray(adata.layers[read_span_layer][ref_mask])
111
+ coverage_mask = span_matrix > 0
112
+ coverage_counts = coverage_mask.sum(axis=0).astype(float)
113
+ else:
114
+ coverage_mask = np.ones_like(mismatch_matrix, dtype=bool)
115
+ coverage_counts = np.full(n_vars, ref_mask.sum(), dtype=float)
116
+
117
+ mismatch_mask = (~np.isin(mismatch_matrix, [n_value, pad_value])) & coverage_mask
118
+ mismatch_counts = mismatch_mask.sum(axis=0)
119
+
120
+ frequency_values = np.divide(
121
+ mismatch_counts,
122
+ coverage_counts,
123
+ out=np.full(n_vars, np.nan, dtype=float),
124
+ where=coverage_counts > 0,
125
+ )
126
+ frequency_values = np.where(ref_position_mask.values, frequency_values, np.nan)
127
+
128
+ variable_flags = (
129
+ (frequency_values >= lower_bound)
130
+ & (frequency_values <= upper_bound)
131
+ & ref_position_mask.values
132
+ )
133
+
134
+ base_counts_by_int: dict[int, np.ndarray] = {}
135
+ for base_int in base_int_to_label:
136
+ base_counts_by_int[base_int] = ((mismatch_matrix == base_int) & coverage_mask).sum(
137
+ axis=0
138
+ )
139
+
140
+ for idx in range(n_vars):
141
+ if not ref_position_mask.iloc[idx] or coverage_counts[idx] == 0:
142
+ continue
143
+ base_freqs: list[tuple[str, float]] = []
144
+ for base_int, base_label in base_int_to_label.items():
145
+ count = base_counts_by_int[base_int][idx]
146
+ if count > 0:
147
+ base_freqs.append((base_label, float(count / coverage_counts[idx])))
148
+ mismatch_base_frequencies[idx] = base_freqs
149
+
150
+ adata.var[f"{ref}_mismatch_frequency"] = pd.Series(frequency_values, index=adata.var.index)
151
+ adata.var[f"{ref}_variable_sequence_site"] = pd.Series(
152
+ variable_flags, index=adata.var.index
153
+ )
154
+ adata.var[f"{ref}_mismatch_base_frequencies"] = pd.Series(
155
+ mismatch_base_frequencies, index=adata.var.index
156
+ )
157
+
158
+ adata.uns[uns_flag] = True
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import scipy.sparse as sp
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## calculate_complexity
2
4
 
3
5
  def calculate_complexity(adata, output_directory='', obs_column='Reference', sample_col='Sample_names', plot=True, save_plot=False):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## mark_duplicates
2
4
 
3
5
  def mark_duplicates(adata, layers, obs_column='Reference', sample_col='Sample_names', method='N_masked_distances', distance_thresholds={}):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## preprocessing
2
4
  from .. import readwrite
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # remove_duplicates
2
4
 
3
5
  def remove_duplicates(adata):
@@ -1,5 +1,6 @@
1
- ## binary_layers_to_ohe
1
+ from __future__ import annotations
2
2
 
3
+ ## binary_layers_to_ohe
3
4
  from smftools.logging_utils import get_logger
4
5
 
5
6
  logger = get_logger(__name__)
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  from pathlib import Path
4
4
  from typing import TYPE_CHECKING, Optional
5
5
 
6
+ from smftools.optional_imports import require
7
+
6
8
  if TYPE_CHECKING:
7
9
  import anndata as ad
8
10
 
@@ -46,11 +48,12 @@ def calculate_complexity_II(
46
48
  """
47
49
  import os
48
50
 
49
- import matplotlib.pyplot as plt
50
51
  import numpy as np
51
52
  import pandas as pd
52
53
  from scipy.optimize import curve_fit
53
54
 
55
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="complexity plots")
56
+
54
57
  # early exits
55
58
  already = bool(adata.uns.get(uns_flag, False))
56
59
  if already and not force_redo:
@@ -53,4 +53,4 @@ def calculate_consensus(
53
53
  else:
54
54
  adata.var[f"{reference}_consensus_across_samples"] = consensus_sequence_list
55
55
 
56
- adata.uns[f"{reference}_consensus_sequence"] = consensus_sequence_list
56
+ adata.uns[f"{reference}_consensus_sequence"] = str(consensus_sequence_list)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # calculate_pairwise_differences
2
4
 
3
5
 
@@ -1,5 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  ## calculate_pairwise_hamming_distances
2
4
 
5
+
3
6
  ## Conversion SMF Specific
4
7
  def calculate_pairwise_hamming_distances(arrays):
5
8
  """
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import TYPE_CHECKING
7
7
 
8
8
  from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  import anndata as ad
@@ -40,9 +41,15 @@ def calculate_position_Youden(
40
41
  save: Whether to save ROC plots to disk.
41
42
  output_directory: Output directory for ROC plots.
42
43
  """
43
- import matplotlib.pyplot as plt
44
44
  import numpy as np
45
- from sklearn.metrics import roc_curve
45
+
46
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="Youden ROC plots")
47
+ sklearn_metrics = require(
48
+ "sklearn.metrics",
49
+ extra="ml-base",
50
+ purpose="Youden ROC curve calculation",
51
+ )
52
+ roc_curve = sklearn_metrics.roc_curve
46
53
 
47
54
  control_samples = [positive_control_sample, negative_control_sample]
48
55
  references = adata.obs[ref_column].cat.categories
@@ -20,6 +20,7 @@ def calculate_read_modification_stats(
20
20
  force_redo: bool = False,
21
21
  valid_sites_only: bool = False,
22
22
  valid_site_suffix: str = "_valid_coverage",
23
+ smf_modality: str = "conversion",
23
24
  ) -> None:
24
25
  """Add methylation/deamination statistics for each read.
25
26
 
@@ -80,8 +81,12 @@ def calculate_read_modification_stats(
80
81
  for ref in references:
81
82
  ref_subset = adata[adata.obs[reference_column] == ref]
82
83
  for site_type in site_types:
84
+ site_subset = ref_subset[:, ref_subset.var[f"{ref}_{site_type}{valid_site_suffix}"]]
83
85
  logger.info("Iterating over %s_%s", ref, site_type)
84
- observation_matrix = ref_subset.obsm[f"{ref}_{site_type}{valid_site_suffix}"]
86
+ if smf_modality == "native":
87
+ observation_matrix = site_subset.layers["binarized_methylation"]
88
+ else:
89
+ observation_matrix = site_subset.X
85
90
  total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
86
91
  total_positions_in_reference = observation_matrix.shape[1]
87
92
  fraction_valid_positions_in_read_vs_ref = (
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Optional, Sequence, Union
2
4
 
3
5
  import anndata as ad
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import gc
2
4
  from typing import List, Optional, Sequence
3
5