smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,20 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
- from typing import List, Optional, Tuple, Union
4
+ from typing import Optional, Tuple, Union
5
+
3
6
  import numpy as np
4
- import matplotlib.pyplot as plt
5
- from matplotlib.backends.backend_pdf import PdfPages
7
+
8
+ from smftools.optional_imports import require
9
+
10
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
11
+ pdf_backend = require(
12
+ "matplotlib.backends.backend_pdf",
13
+ extra="plotting",
14
+ purpose="PDF output",
15
+ )
16
+ PdfPages = pdf_backend.PdfPages
17
+
6
18
 
7
19
  def plot_hmm_size_contours(
8
20
  adata,
@@ -36,32 +48,41 @@ def plot_hmm_size_contours(
36
48
 
37
49
  Other args are the same as prior function.
38
50
  """
51
+
39
52
  # --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
40
53
  def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
54
+ """Build a normalized 1D Gaussian kernel."""
41
55
  if sigma <= 0 or sigma is None:
42
56
  return np.array([1.0], dtype=float)
43
57
  # choose kernel size = odd ~ 6*sigma (covers +/-3 sigma)
44
58
  radius = max(1, int(math.ceil(3.0 * float(sigma))))
45
59
  xs = np.arange(-radius, radius + 1, dtype=float)
46
- k = np.exp(-(xs ** 2) / (2.0 * sigma ** 2))
60
+ k = np.exp(-(xs**2) / (2.0 * sigma**2))
47
61
  k_sum = k.sum()
48
62
  if k_sum <= eps:
49
63
  k = np.array([1.0], dtype=float)
50
64
  k_sum = 1.0
51
65
  return k / k_sum
52
66
 
53
- def _smooth_with_numpy_separable(Z: np.ndarray, sigma_len: float, sigma_pos: float) -> np.ndarray:
67
+ def _smooth_with_numpy_separable(
68
+ Z: np.ndarray, sigma_len: float, sigma_pos: float
69
+ ) -> np.ndarray:
70
+ """Apply separable Gaussian smoothing with NumPy."""
54
71
  # Z shape: (n_lengths, n_positions)
55
72
  out = Z.copy()
56
73
  # smooth along length axis (axis=0)
57
74
  if sigma_len and sigma_len > 0:
58
75
  k_len = _gaussian_1d_kernel(sigma_len)
59
76
  # convolve each column
60
- out = np.apply_along_axis(lambda col: np.convolve(col, k_len, mode="same"), axis=0, arr=out)
77
+ out = np.apply_along_axis(
78
+ lambda col: np.convolve(col, k_len, mode="same"), axis=0, arr=out
79
+ )
61
80
  # smooth along position axis (axis=1)
62
81
  if sigma_pos and sigma_pos > 0:
63
82
  k_pos = _gaussian_1d_kernel(sigma_pos)
64
- out = np.apply_along_axis(lambda row: np.convolve(row, k_pos, mode="same"), axis=1, arr=out)
83
+ out = np.apply_along_axis(
84
+ lambda row: np.convolve(row, k_pos, mode="same"), axis=1, arr=out
85
+ )
65
86
  return out
66
87
 
67
88
  # prefer scipy.ndimage if available (faster and better boundary handling)
@@ -69,11 +90,13 @@ def plot_hmm_size_contours(
69
90
  if use_scipy_if_available:
70
91
  try:
71
92
  from scipy.ndimage import gaussian_filter as _scipy_gaussian_filter
93
+
72
94
  _have_scipy = True
73
95
  except Exception:
74
96
  _have_scipy = False
75
97
 
76
98
  def _smooth_Z(Z: np.ndarray, sigma_len: float, sigma_pos: float) -> np.ndarray:
99
+ """Smooth a matrix using scipy if available or NumPy fallback."""
77
100
  if (sigma_len is None or sigma_len == 0) and (sigma_pos is None or sigma_pos == 0):
78
101
  return Z
79
102
  if _have_scipy:
@@ -84,8 +107,16 @@ def plot_hmm_size_contours(
84
107
  return _smooth_with_numpy_separable(Z, float(sigma_len or 0.0), float(sigma_pos or 0.0))
85
108
 
86
109
  # --- gather unique ordered labels ---
87
- samples = list(adata.obs[sample_col].cat.categories) if getattr(adata.obs[sample_col], "dtype", None) == "category" else list(pd.Categorical(adata.obs[sample_col]).categories)
88
- refs = list(adata.obs[ref_obs_col].cat.categories) if getattr(adata.obs[ref_obs_col], "dtype", None) == "category" else list(pd.Categorical(adata.obs[ref_obs_col]).categories)
110
+ samples = (
111
+ list(adata.obs[sample_col].cat.categories)
112
+ if getattr(adata.obs[sample_col], "dtype", None) == "category"
113
+ else list(pd.Categorical(adata.obs[sample_col]).categories)
114
+ )
115
+ refs = (
116
+ list(adata.obs[ref_obs_col].cat.categories)
117
+ if getattr(adata.obs[ref_obs_col], "dtype", None) == "category"
118
+ else list(pd.Categorical(adata.obs[ref_obs_col]).categories)
119
+ )
89
120
 
90
121
  n_samples = len(samples)
91
122
  n_refs = len(refs)
@@ -102,6 +133,7 @@ def plot_hmm_size_contours(
102
133
 
103
134
  # helper to get dense layer array for subset
104
135
  def _get_layer_array(layer):
136
+ """Convert a layer to a dense NumPy array."""
105
137
  arr = layer
106
138
  # sparse -> toarray
107
139
  if hasattr(arr, "toarray"):
@@ -146,7 +178,7 @@ def plot_hmm_size_contours(
146
178
  fig_w = n_refs * figsize_per_cell[0]
147
179
  fig_h = rows_on_page * figsize_per_cell[1]
148
180
  fig, axes = plt.subplots(rows_on_page, n_refs, figsize=(fig_w, fig_h), squeeze=False)
149
- fig.suptitle(f"HMM size contours (page {p+1}/{pages})", fontsize=12)
181
+ fig.suptitle(f"HMM size contours (page {p + 1}/{pages})", fontsize=12)
150
182
 
151
183
  # for each panel compute p(length | position)
152
184
  for i_row, sample in enumerate(page_samples):
@@ -160,7 +192,9 @@ def plot_hmm_size_contours(
160
192
  ax.set_title(f"{sample} / {ref}")
161
193
  continue
162
194
 
163
- row_idx = np.nonzero(panel_mask.values if hasattr(panel_mask, "values") else np.asarray(panel_mask))[0]
195
+ row_idx = np.nonzero(
196
+ panel_mask.values if hasattr(panel_mask, "values") else np.asarray(panel_mask)
197
+ )[0]
164
198
  if row_idx.size == 0:
165
199
  ax.text(0.5, 0.5, "no reads", ha="center", va="center")
166
200
  ax.set_title(f"{sample} / {ref}")
@@ -178,7 +212,9 @@ def plot_hmm_size_contours(
178
212
  max_len_here = min(max_len, max_len_local)
179
213
 
180
214
  lengths_range = np.arange(1, max_len_here + 1, dtype=int)
181
- Z = np.zeros((len(lengths_range), n_positions), dtype=float) # rows=length, cols=pos
215
+ Z = np.zeros(
216
+ (len(lengths_range), n_positions), dtype=float
217
+ ) # rows=length, cols=pos
182
218
 
183
219
  # fill Z by efficient bincount across columns
184
220
  for j in range(n_positions):
@@ -222,7 +258,9 @@ def plot_hmm_size_contours(
222
258
  dy = 1.0
223
259
  y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
224
260
 
225
- pcm = ax.pcolormesh(x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax)
261
+ pcm = ax.pcolormesh(
262
+ x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
263
+ )
226
264
  ax.set_title(f"{sample} / {ref}")
227
265
  ax.set_ylabel("length")
228
266
  if i_row == rows_on_page - 1:
@@ -243,9 +281,10 @@ def plot_hmm_size_contours(
243
281
  # saving per page if requested
244
282
  if save_path is not None:
245
283
  import os
284
+
246
285
  os.makedirs(save_path, exist_ok=True)
247
286
  if save_each_page:
248
- fname = f"hmm_size_page_{p+1:03d}.png"
287
+ fname = f"hmm_size_page_{p + 1:03d}.png"
249
288
  out = os.path.join(save_path, fname)
250
289
  fig.savefig(out, dpi=dpi, bbox_inches="tight")
251
290
 
@@ -1,3 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+
1
6
  def plot_volcano_relative_risk(
2
7
  results_dict,
3
8
  save_path=None,
@@ -20,10 +25,10 @@ def plot_volcano_relative_risk(
20
25
  xlim (tuple): Optional x-axis limit.
21
26
  ylim (tuple): Optional y-axis limit.
22
27
  """
23
- import matplotlib.pyplot as plt
24
- import numpy as np
25
28
  import os
26
29
 
30
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
31
+
27
32
  for ref, group_results in results_dict.items():
28
33
  for group_label, (results_df, _) in group_results.items():
29
34
  if results_df.empty:
@@ -31,8 +36,8 @@ def plot_volcano_relative_risk(
31
36
  continue
32
37
 
33
38
  # Split by site type
34
- gpc_df = results_df[results_df['GpC_Site']]
35
- cpg_df = results_df[results_df['CpG_Site']]
39
+ gpc_df = results_df[results_df["GpC_Site"]]
40
+ cpg_df = results_df[results_df["CpG_Site"]]
36
41
 
37
42
  fig, ax = plt.subplots(figsize=(12, 6))
38
43
 
@@ -43,29 +48,29 @@ def plot_volcano_relative_risk(
43
48
 
44
49
  # GpC as circles
45
50
  sc1 = ax.scatter(
46
- gpc_df['Genomic_Position'],
47
- gpc_df['log2_Relative_Risk'],
48
- c=gpc_df['-log10_Adj_P'],
49
- cmap='coolwarm',
50
- edgecolor='k',
51
+ gpc_df["Genomic_Position"],
52
+ gpc_df["log2_Relative_Risk"],
53
+ c=gpc_df["-log10_Adj_P"],
54
+ cmap="coolwarm",
55
+ edgecolor="k",
51
56
  s=40,
52
- marker='o',
53
- label='GpC'
57
+ marker="o",
58
+ label="GpC",
54
59
  )
55
60
 
56
61
  # CpG as stars
57
62
  sc2 = ax.scatter(
58
- cpg_df['Genomic_Position'],
59
- cpg_df['log2_Relative_Risk'],
60
- c=cpg_df['-log10_Adj_P'],
61
- cmap='coolwarm',
62
- edgecolor='k',
63
+ cpg_df["Genomic_Position"],
64
+ cpg_df["log2_Relative_Risk"],
65
+ c=cpg_df["-log10_Adj_P"],
66
+ cmap="coolwarm",
67
+ edgecolor="k",
63
68
  s=60,
64
- marker='*',
65
- label='CpG'
69
+ marker="*",
70
+ label="CpG",
66
71
  )
67
72
 
68
- ax.axhline(y=0, color='gray', linestyle='--')
73
+ ax.axhline(y=0, color="gray", linestyle="--")
69
74
  ax.set_xlabel("Genomic Position")
70
75
  ax.set_ylabel("log2(Relative Risk)")
71
76
  ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
@@ -75,8 +80,8 @@ def plot_volcano_relative_risk(
75
80
  if ylim:
76
81
  ax.set_ylim(ylim)
77
82
 
78
- ax.spines['top'].set_visible(False)
79
- ax.spines['right'].set_visible(False)
83
+ ax.spines["top"].set_visible(False)
84
+ ax.spines["right"].set_visible(False)
80
85
 
81
86
  cbar = plt.colorbar(sc1, ax=ax)
82
87
  cbar.set_label("-log10(Adjusted P-Value)")
@@ -87,13 +92,19 @@ def plot_volcano_relative_risk(
87
92
  # Save if requested
88
93
  if save_path:
89
94
  os.makedirs(save_path, exist_ok=True)
90
- safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
95
+ safe_name = (
96
+ f"{ref}_{group_label}".replace("=", "")
97
+ .replace("__", "_")
98
+ .replace(",", "_")
99
+ .replace(" ", "_")
100
+ )
91
101
  out_file = os.path.join(save_path, f"{safe_name}.png")
92
102
  plt.savefig(out_file, dpi=300)
93
103
  print(f"Saved: {out_file}")
94
104
 
95
105
  plt.show()
96
106
 
107
+
97
108
  def plot_bar_relative_risk(
98
109
  results_dict,
99
110
  sort_by_position=True,
@@ -102,7 +113,7 @@ def plot_bar_relative_risk(
102
113
  save_path=None,
103
114
  highlight_regions=None, # List of (start, end) tuples
104
115
  highlight_color="lightgray",
105
- highlight_alpha=0.3
116
+ highlight_alpha=0.3,
106
117
  ):
107
118
  """
108
119
  Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
@@ -116,10 +127,10 @@ def plot_bar_relative_risk(
116
127
  highlight_color (str): Color of shaded region.
117
128
  highlight_alpha (float): Transparency of shaded region.
118
129
  """
119
- import matplotlib.pyplot as plt
120
- import numpy as np
121
130
  import os
122
131
 
132
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
133
+
123
134
  for ref, group_data in results_dict.items():
124
135
  for group_label, (df, _) in group_data.items():
125
136
  if df.empty:
@@ -127,14 +138,14 @@ def plot_bar_relative_risk(
127
138
  continue
128
139
 
129
140
  df = df.copy()
130
- df['Genomic_Position'] = df['Genomic_Position'].astype(int)
141
+ df["Genomic_Position"] = df["Genomic_Position"].astype(int)
131
142
 
132
143
  if sort_by_position:
133
- df = df.sort_values('Genomic_Position')
144
+ df = df.sort_values("Genomic_Position")
134
145
 
135
- gpc_mask = df['GpC_Site'] & ~df['CpG_Site']
136
- cpg_mask = df['CpG_Site'] & ~df['GpC_Site']
137
- both_mask = df['GpC_Site'] & df['CpG_Site']
146
+ gpc_mask = df["GpC_Site"] & ~df["CpG_Site"]
147
+ cpg_mask = df["CpG_Site"] & ~df["GpC_Site"]
148
+ both_mask = df["GpC_Site"] & df["CpG_Site"]
138
149
 
139
150
  fig, ax = plt.subplots(figsize=(14, 6))
140
151
 
@@ -145,36 +156,36 @@ def plot_bar_relative_risk(
145
156
 
146
157
  # Bar plots
147
158
  ax.bar(
148
- df['Genomic_Position'][gpc_mask],
149
- df['log2_Relative_Risk'][gpc_mask],
159
+ df["Genomic_Position"][gpc_mask],
160
+ df["log2_Relative_Risk"][gpc_mask],
150
161
  width=10,
151
- color='steelblue',
152
- label='GpC Site',
153
- edgecolor='black'
162
+ color="steelblue",
163
+ label="GpC Site",
164
+ edgecolor="black",
154
165
  )
155
166
 
156
167
  ax.bar(
157
- df['Genomic_Position'][cpg_mask],
158
- df['log2_Relative_Risk'][cpg_mask],
168
+ df["Genomic_Position"][cpg_mask],
169
+ df["log2_Relative_Risk"][cpg_mask],
159
170
  width=10,
160
- color='darkorange',
161
- label='CpG Site',
162
- edgecolor='black'
171
+ color="darkorange",
172
+ label="CpG Site",
173
+ edgecolor="black",
163
174
  )
164
175
 
165
176
  if both_mask.any():
166
177
  ax.bar(
167
- df['Genomic_Position'][both_mask],
168
- df['log2_Relative_Risk'][both_mask],
178
+ df["Genomic_Position"][both_mask],
179
+ df["log2_Relative_Risk"][both_mask],
169
180
  width=10,
170
- color='purple',
171
- label='GpC + CpG',
172
- edgecolor='black'
181
+ color="purple",
182
+ label="GpC + CpG",
183
+ edgecolor="black",
173
184
  )
174
185
 
175
- ax.axhline(y=0, color='gray', linestyle='--')
176
- ax.set_xlabel('Genomic Position')
177
- ax.set_ylabel('log2(Relative Risk)')
186
+ ax.axhline(y=0, color="gray", linestyle="--")
187
+ ax.set_xlabel("Genomic Position")
188
+ ax.set_ylabel("log2(Relative Risk)")
178
189
  ax.set_title(f"{ref} — {group_label}")
179
190
  ax.legend()
180
191
 
@@ -183,20 +194,23 @@ def plot_bar_relative_risk(
183
194
  if ylim:
184
195
  ax.set_ylim(ylim)
185
196
 
186
- ax.spines['top'].set_visible(False)
187
- ax.spines['right'].set_visible(False)
197
+ ax.spines["top"].set_visible(False)
198
+ ax.spines["right"].set_visible(False)
188
199
 
189
200
  plt.tight_layout()
190
201
 
191
202
  if save_path:
192
203
  os.makedirs(save_path, exist_ok=True)
193
- safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
204
+ safe_name = (
205
+ f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
206
+ )
194
207
  out_file = os.path.join(save_path, f"{safe_name}.png")
195
208
  plt.savefig(out_file, dpi=300)
196
209
  print(f"📁 Saved: {out_file}")
197
210
 
198
211
  plt.show()
199
212
 
213
+
200
214
  def plot_positionwise_matrix(
201
215
  adata,
202
216
  key="positionwise_result",
@@ -210,35 +224,40 @@ def plot_positionwise_matrix(
210
224
  xtick_step=10,
211
225
  ytick_step=10,
212
226
  save_path=None,
213
- highlight_position=None, # Can be a single int/float or list of them
214
- highlight_axis="row", # "row" or "column"
215
- annotate_points=False # ✅ New option
227
+ highlight_position=None, # Can be a single int/float or list of them
228
+ highlight_axis="row", # "row" or "column"
229
+ annotate_points=False, # ✅ New option
216
230
  ):
217
231
  """
218
232
  Plots positionwise matrices stored in adata.uns[key], with an optional line plot
219
233
  for specified row(s) or column(s), and highlights them on the heatmap.
220
234
  """
221
- import matplotlib.pyplot as plt
222
- import seaborn as sns
235
+ import os
236
+
223
237
  import numpy as np
224
238
  import pandas as pd
225
- import os
239
+
240
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
241
+ sns = require("seaborn", extra="plotting", purpose="position stats plots")
226
242
 
227
243
  def find_closest_index(index, target):
244
+ """Find the index value closest to a target value."""
228
245
  index_vals = pd.to_numeric(index, errors="coerce")
229
246
  target_val = pd.to_numeric([target], errors="coerce")[0]
230
247
  diffs = pd.Series(np.abs(index_vals - target_val), index=index)
231
248
  return diffs.idxmin()
232
249
 
233
250
  # Ensure highlight_position is a list
234
- if highlight_position is not None and not isinstance(highlight_position, (list, tuple, np.ndarray)):
251
+ if highlight_position is not None and not isinstance(
252
+ highlight_position, (list, tuple, np.ndarray)
253
+ ):
235
254
  highlight_position = [highlight_position]
236
255
 
237
256
  for group, mat_df in adata.uns[key].items():
238
257
  mat = mat_df.copy()
239
258
 
240
259
  if log_transform:
241
- with np.errstate(divide='ignore', invalid='ignore'):
260
+ with np.errstate(divide="ignore", invalid="ignore"):
242
261
  if log_base == "log1p":
243
262
  mat = np.log1p(mat)
244
263
  elif log_base == "log2":
@@ -276,7 +295,7 @@ def plot_positionwise_matrix(
276
295
  vmin=vmin,
277
296
  vmax=vmax,
278
297
  cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
279
- ax=heat_ax
298
+ ax=heat_ax,
280
299
  )
281
300
 
282
301
  heat_ax.set_title(f"{key} — {group}", pad=20)
@@ -295,17 +314,27 @@ def plot_positionwise_matrix(
295
314
  series = mat.loc[closest]
296
315
  x_vals = pd.to_numeric(series.index, errors="coerce")
297
316
  idx = mat.index.get_loc(closest)
298
- heat_ax.axhline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
317
+ heat_ax.axhline(
318
+ idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
319
+ )
299
320
  label = f"Row {pos} → {closest}"
300
321
  else:
301
322
  closest = find_closest_index(mat.columns, pos)
302
323
  series = mat[closest]
303
324
  x_vals = pd.to_numeric(series.index, errors="coerce")
304
325
  idx = mat.columns.get_loc(closest)
305
- heat_ax.axvline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
326
+ heat_ax.axvline(
327
+ idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
328
+ )
306
329
  label = f"Col {pos} → {closest}"
307
330
 
308
- line = line_ax.plot(x_vals, series.values, marker='o', label=label, color=colors[i % len(colors)])
331
+ line = line_ax.plot(
332
+ x_vals,
333
+ series.values,
334
+ marker="o",
335
+ label=label,
336
+ color=colors[i % len(colors)],
337
+ )
309
338
 
310
339
  # Annotate each point
311
340
  if annotate_points:
@@ -316,12 +345,18 @@ def plot_positionwise_matrix(
316
345
  xy=(x, y),
317
346
  textcoords="offset points",
318
347
  xytext=(0, 5),
319
- ha='center',
320
- fontsize=8
348
+ ha="center",
349
+ fontsize=8,
321
350
  )
322
351
  except Exception as e:
323
- line_ax.text(0.5, 0.5, f"⚠️ Error plotting {highlight_axis} @ {pos}",
324
- ha='center', va='center', fontsize=10)
352
+ line_ax.text(
353
+ 0.5,
354
+ 0.5,
355
+ f"⚠️ Error plotting {highlight_axis} @ {pos}",
356
+ ha="center",
357
+ va="center",
358
+ fontsize=10,
359
+ )
325
360
  print(f"Error plotting line for {highlight_axis}={pos}: {e}")
326
361
 
327
362
  line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
@@ -342,6 +377,7 @@ def plot_positionwise_matrix(
342
377
 
343
378
  plt.show()
344
379
 
380
+
345
381
  def plot_positionwise_matrix_grid(
346
382
  adata,
347
383
  key,
@@ -356,32 +392,63 @@ def plot_positionwise_matrix_grid(
356
392
  xtick_step=10,
357
393
  ytick_step=10,
358
394
  parallel=False,
359
- max_threads=None
395
+ max_threads=None,
360
396
  ):
361
- import matplotlib.pyplot as plt
362
- import seaborn as sns
397
+ """Plot a grid of positionwise matrices grouped by metadata.
398
+
399
+ Args:
400
+ adata: AnnData containing matrices in ``adata.uns``.
401
+ key: Key for positionwise matrices.
402
+ outer_keys: Keys for outer grouping.
403
+ inner_keys: Keys for inner grouping.
404
+ log_transform: Optional log transform (``log2`` or ``log1p``).
405
+ vmin: Minimum color scale value.
406
+ vmax: Maximum color scale value.
407
+ cmap: Matplotlib colormap.
408
+ save_path: Optional path to save plots.
409
+ figsize: Figure size.
410
+ xtick_step: X-axis tick step.
411
+ ytick_step: Y-axis tick step.
412
+ parallel: Whether to plot in parallel.
413
+ max_threads: Max thread count for parallel plotting.
414
+ """
415
+ import os
416
+
363
417
  import numpy as np
364
418
  import pandas as pd
365
- import os
366
- from matplotlib.gridspec import GridSpec
367
419
  from joblib import Parallel, delayed
368
420
 
421
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
422
+ sns = require("seaborn", extra="plotting", purpose="position stats plots")
423
+ grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="position stats plots")
424
+ GridSpec = grid_spec.GridSpec
425
+
369
426
  matrices = adata.uns[key]
370
427
  group_labels = list(matrices.keys())
371
428
 
372
- parsed_inner = pd.DataFrame([dict(zip(inner_keys, g.split("_")[-len(inner_keys):])) for g in group_labels])
373
- parsed_outer = pd.Series(["_".join(g.split("_")[:-len(inner_keys)]) for g in group_labels], name="outer")
429
+ parsed_inner = pd.DataFrame(
430
+ [dict(zip(inner_keys, g.split("_")[-len(inner_keys) :])) for g in group_labels]
431
+ )
432
+ parsed_outer = pd.Series(
433
+ ["_".join(g.split("_")[: -len(inner_keys)]) for g in group_labels], name="outer"
434
+ )
374
435
  parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
375
436
 
376
437
  def plot_one_grid(outer_label):
377
- selected = parsed[parsed['outer'] == outer_label].copy()
378
- selected["group_str"] = [f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}" for _, row in selected.iterrows()]
438
+ """Plot one grid for a specific outer label."""
439
+ selected = parsed[parsed["outer"] == outer_label].copy()
440
+ selected["group_str"] = [
441
+ f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}"
442
+ for _, row in selected.iterrows()
443
+ ]
379
444
 
380
445
  row_vals = sorted(selected[inner_keys[0]].unique())
381
446
  col_vals = sorted(selected[inner_keys[1]].unique())
382
447
 
383
448
  fig = plt.figure(figsize=figsize)
384
- gs = GridSpec(len(row_vals), len(col_vals) + 1, width_ratios=[1]*len(col_vals) + [0.05], wspace=0.3)
449
+ gs = GridSpec(
450
+ len(row_vals), len(col_vals) + 1, width_ratios=[1] * len(col_vals) + [0.05], wspace=0.3
451
+ )
385
452
  axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
386
453
 
387
454
  local_vmin, local_vmax = vmin, vmax
@@ -397,10 +464,7 @@ def plot_positionwise_matrix_grid(
397
464
  local_vmin = -vmax_auto if vmin is None else vmin
398
465
  local_vmax = vmax_auto if vmax is None else vmax
399
466
 
400
- cbar_label = {
401
- "log2": "log2(Value)",
402
- "log1p": "log1p(Value)"
403
- }.get(log_transform, "Value")
467
+ cbar_label = {"log2": "log2(Value)", "log1p": "log1p(Value)"}.get(log_transform, "Value")
404
468
 
405
469
  cbar_ax = fig.add_subplot(gs[:, -1])
406
470
 
@@ -431,9 +495,11 @@ def plot_positionwise_matrix_grid(
431
495
  vmax=local_vmax,
432
496
  cbar=(i == 0 and j == 0),
433
497
  cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
434
- cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
498
+ cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""},
499
+ )
500
+ ax.set_title(
501
+ f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8
435
502
  )
436
- ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
437
503
 
438
504
  xticks = data.columns.astype(int)
439
505
  yticks = data.index.astype(int)
@@ -448,15 +514,17 @@ def plot_positionwise_matrix_grid(
448
514
  if save_path:
449
515
  os.makedirs(save_path, exist_ok=True)
450
516
  fname = outer_label.replace("_", "").replace("=", "") + ".png"
451
- plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
517
+ plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches="tight")
452
518
  print(f"Saved {fname}")
453
519
 
454
520
  plt.close(fig)
455
521
 
456
522
  if parallel:
457
- Parallel(n_jobs=max_threads)(delayed(plot_one_grid)(outer_label) for outer_label in parsed['outer'].unique())
523
+ Parallel(n_jobs=max_threads)(
524
+ delayed(plot_one_grid)(outer_label) for outer_label in parsed["outer"].unique()
525
+ )
458
526
  else:
459
- for outer_label in parsed['outer'].unique():
527
+ for outer_label in parsed["outer"].unique():
460
528
  plot_one_grid(outer_label)
461
529
 
462
- print("Finished plotting all grids.")
530
+ print("Finished plotting all grids.")