smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,11 @@
1
1
  import math
2
- from typing import List, Optional, Tuple, Union
3
- import numpy as np
2
+ from typing import Optional, Tuple, Union
3
+
4
4
  import matplotlib.pyplot as plt
5
+ import numpy as np
5
6
  from matplotlib.backends.backend_pdf import PdfPages
6
7
 
8
+
7
9
  def plot_hmm_size_contours(
8
10
  adata,
9
11
  length_layer: str,
@@ -36,32 +38,41 @@ def plot_hmm_size_contours(
36
38
 
37
39
  Other args are the same as prior function.
38
40
  """
41
+
39
42
  # --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
40
43
  def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
44
+ """Build a normalized 1D Gaussian kernel."""
41
45
  if sigma <= 0 or sigma is None:
42
46
  return np.array([1.0], dtype=float)
43
47
  # choose kernel size = odd ~ 6*sigma (covers +/-3 sigma)
44
48
  radius = max(1, int(math.ceil(3.0 * float(sigma))))
45
49
  xs = np.arange(-radius, radius + 1, dtype=float)
46
- k = np.exp(-(xs ** 2) / (2.0 * sigma ** 2))
50
+ k = np.exp(-(xs**2) / (2.0 * sigma**2))
47
51
  k_sum = k.sum()
48
52
  if k_sum <= eps:
49
53
  k = np.array([1.0], dtype=float)
50
54
  k_sum = 1.0
51
55
  return k / k_sum
52
56
 
53
- def _smooth_with_numpy_separable(Z: np.ndarray, sigma_len: float, sigma_pos: float) -> np.ndarray:
57
+ def _smooth_with_numpy_separable(
58
+ Z: np.ndarray, sigma_len: float, sigma_pos: float
59
+ ) -> np.ndarray:
60
+ """Apply separable Gaussian smoothing with NumPy."""
54
61
  # Z shape: (n_lengths, n_positions)
55
62
  out = Z.copy()
56
63
  # smooth along length axis (axis=0)
57
64
  if sigma_len and sigma_len > 0:
58
65
  k_len = _gaussian_1d_kernel(sigma_len)
59
66
  # convolve each column
60
- out = np.apply_along_axis(lambda col: np.convolve(col, k_len, mode="same"), axis=0, arr=out)
67
+ out = np.apply_along_axis(
68
+ lambda col: np.convolve(col, k_len, mode="same"), axis=0, arr=out
69
+ )
61
70
  # smooth along position axis (axis=1)
62
71
  if sigma_pos and sigma_pos > 0:
63
72
  k_pos = _gaussian_1d_kernel(sigma_pos)
64
- out = np.apply_along_axis(lambda row: np.convolve(row, k_pos, mode="same"), axis=1, arr=out)
73
+ out = np.apply_along_axis(
74
+ lambda row: np.convolve(row, k_pos, mode="same"), axis=1, arr=out
75
+ )
65
76
  return out
66
77
 
67
78
  # prefer scipy.ndimage if available (faster and better boundary handling)
@@ -69,11 +80,13 @@ def plot_hmm_size_contours(
69
80
  if use_scipy_if_available:
70
81
  try:
71
82
  from scipy.ndimage import gaussian_filter as _scipy_gaussian_filter
83
+
72
84
  _have_scipy = True
73
85
  except Exception:
74
86
  _have_scipy = False
75
87
 
76
88
  def _smooth_Z(Z: np.ndarray, sigma_len: float, sigma_pos: float) -> np.ndarray:
89
+ """Smooth a matrix using scipy if available or NumPy fallback."""
77
90
  if (sigma_len is None or sigma_len == 0) and (sigma_pos is None or sigma_pos == 0):
78
91
  return Z
79
92
  if _have_scipy:
@@ -84,8 +97,16 @@ def plot_hmm_size_contours(
84
97
  return _smooth_with_numpy_separable(Z, float(sigma_len or 0.0), float(sigma_pos or 0.0))
85
98
 
86
99
  # --- gather unique ordered labels ---
87
- samples = list(adata.obs[sample_col].cat.categories) if getattr(adata.obs[sample_col], "dtype", None) == "category" else list(pd.Categorical(adata.obs[sample_col]).categories)
88
- refs = list(adata.obs[ref_obs_col].cat.categories) if getattr(adata.obs[ref_obs_col], "dtype", None) == "category" else list(pd.Categorical(adata.obs[ref_obs_col]).categories)
100
+ samples = (
101
+ list(adata.obs[sample_col].cat.categories)
102
+ if getattr(adata.obs[sample_col], "dtype", None) == "category"
103
+ else list(pd.Categorical(adata.obs[sample_col]).categories)
104
+ )
105
+ refs = (
106
+ list(adata.obs[ref_obs_col].cat.categories)
107
+ if getattr(adata.obs[ref_obs_col], "dtype", None) == "category"
108
+ else list(pd.Categorical(adata.obs[ref_obs_col]).categories)
109
+ )
89
110
 
90
111
  n_samples = len(samples)
91
112
  n_refs = len(refs)
@@ -102,6 +123,7 @@ def plot_hmm_size_contours(
102
123
 
103
124
  # helper to get dense layer array for subset
104
125
  def _get_layer_array(layer):
126
+ """Convert a layer to a dense NumPy array."""
105
127
  arr = layer
106
128
  # sparse -> toarray
107
129
  if hasattr(arr, "toarray"):
@@ -146,7 +168,7 @@ def plot_hmm_size_contours(
146
168
  fig_w = n_refs * figsize_per_cell[0]
147
169
  fig_h = rows_on_page * figsize_per_cell[1]
148
170
  fig, axes = plt.subplots(rows_on_page, n_refs, figsize=(fig_w, fig_h), squeeze=False)
149
- fig.suptitle(f"HMM size contours (page {p+1}/{pages})", fontsize=12)
171
+ fig.suptitle(f"HMM size contours (page {p + 1}/{pages})", fontsize=12)
150
172
 
151
173
  # for each panel compute p(length | position)
152
174
  for i_row, sample in enumerate(page_samples):
@@ -160,7 +182,9 @@ def plot_hmm_size_contours(
160
182
  ax.set_title(f"{sample} / {ref}")
161
183
  continue
162
184
 
163
- row_idx = np.nonzero(panel_mask.values if hasattr(panel_mask, "values") else np.asarray(panel_mask))[0]
185
+ row_idx = np.nonzero(
186
+ panel_mask.values if hasattr(panel_mask, "values") else np.asarray(panel_mask)
187
+ )[0]
164
188
  if row_idx.size == 0:
165
189
  ax.text(0.5, 0.5, "no reads", ha="center", va="center")
166
190
  ax.set_title(f"{sample} / {ref}")
@@ -178,7 +202,9 @@ def plot_hmm_size_contours(
178
202
  max_len_here = min(max_len, max_len_local)
179
203
 
180
204
  lengths_range = np.arange(1, max_len_here + 1, dtype=int)
181
- Z = np.zeros((len(lengths_range), n_positions), dtype=float) # rows=length, cols=pos
205
+ Z = np.zeros(
206
+ (len(lengths_range), n_positions), dtype=float
207
+ ) # rows=length, cols=pos
182
208
 
183
209
  # fill Z by efficient bincount across columns
184
210
  for j in range(n_positions):
@@ -222,7 +248,9 @@ def plot_hmm_size_contours(
222
248
  dy = 1.0
223
249
  y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
224
250
 
225
- pcm = ax.pcolormesh(x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax)
251
+ pcm = ax.pcolormesh(
252
+ x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
253
+ )
226
254
  ax.set_title(f"{sample} / {ref}")
227
255
  ax.set_ylabel("length")
228
256
  if i_row == rows_on_page - 1:
@@ -243,9 +271,10 @@ def plot_hmm_size_contours(
243
271
  # saving per page if requested
244
272
  if save_path is not None:
245
273
  import os
274
+
246
275
  os.makedirs(save_path, exist_ok=True)
247
276
  if save_each_page:
248
- fname = f"hmm_size_page_{p+1:03d}.png"
277
+ fname = f"hmm_size_page_{p + 1:03d}.png"
249
278
  out = os.path.join(save_path, fname)
250
279
  fig.savefig(out, dpi=dpi, bbox_inches="tight")
251
280
 
@@ -20,10 +20,10 @@ def plot_volcano_relative_risk(
20
20
  xlim (tuple): Optional x-axis limit.
21
21
  ylim (tuple): Optional y-axis limit.
22
22
  """
23
- import matplotlib.pyplot as plt
24
- import numpy as np
25
23
  import os
26
24
 
25
+ import matplotlib.pyplot as plt
26
+
27
27
  for ref, group_results in results_dict.items():
28
28
  for group_label, (results_df, _) in group_results.items():
29
29
  if results_df.empty:
@@ -31,8 +31,8 @@ def plot_volcano_relative_risk(
31
31
  continue
32
32
 
33
33
  # Split by site type
34
- gpc_df = results_df[results_df['GpC_Site']]
35
- cpg_df = results_df[results_df['CpG_Site']]
34
+ gpc_df = results_df[results_df["GpC_Site"]]
35
+ cpg_df = results_df[results_df["CpG_Site"]]
36
36
 
37
37
  fig, ax = plt.subplots(figsize=(12, 6))
38
38
 
@@ -43,29 +43,29 @@ def plot_volcano_relative_risk(
43
43
 
44
44
  # GpC as circles
45
45
  sc1 = ax.scatter(
46
- gpc_df['Genomic_Position'],
47
- gpc_df['log2_Relative_Risk'],
48
- c=gpc_df['-log10_Adj_P'],
49
- cmap='coolwarm',
50
- edgecolor='k',
46
+ gpc_df["Genomic_Position"],
47
+ gpc_df["log2_Relative_Risk"],
48
+ c=gpc_df["-log10_Adj_P"],
49
+ cmap="coolwarm",
50
+ edgecolor="k",
51
51
  s=40,
52
- marker='o',
53
- label='GpC'
52
+ marker="o",
53
+ label="GpC",
54
54
  )
55
55
 
56
56
  # CpG as stars
57
57
  sc2 = ax.scatter(
58
- cpg_df['Genomic_Position'],
59
- cpg_df['log2_Relative_Risk'],
60
- c=cpg_df['-log10_Adj_P'],
61
- cmap='coolwarm',
62
- edgecolor='k',
58
+ cpg_df["Genomic_Position"],
59
+ cpg_df["log2_Relative_Risk"],
60
+ c=cpg_df["-log10_Adj_P"],
61
+ cmap="coolwarm",
62
+ edgecolor="k",
63
63
  s=60,
64
- marker='*',
65
- label='CpG'
64
+ marker="*",
65
+ label="CpG",
66
66
  )
67
67
 
68
- ax.axhline(y=0, color='gray', linestyle='--')
68
+ ax.axhline(y=0, color="gray", linestyle="--")
69
69
  ax.set_xlabel("Genomic Position")
70
70
  ax.set_ylabel("log2(Relative Risk)")
71
71
  ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
@@ -75,8 +75,8 @@ def plot_volcano_relative_risk(
75
75
  if ylim:
76
76
  ax.set_ylim(ylim)
77
77
 
78
- ax.spines['top'].set_visible(False)
79
- ax.spines['right'].set_visible(False)
78
+ ax.spines["top"].set_visible(False)
79
+ ax.spines["right"].set_visible(False)
80
80
 
81
81
  cbar = plt.colorbar(sc1, ax=ax)
82
82
  cbar.set_label("-log10(Adjusted P-Value)")
@@ -87,13 +87,19 @@ def plot_volcano_relative_risk(
87
87
  # Save if requested
88
88
  if save_path:
89
89
  os.makedirs(save_path, exist_ok=True)
90
- safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
90
+ safe_name = (
91
+ f"{ref}_{group_label}".replace("=", "")
92
+ .replace("__", "_")
93
+ .replace(",", "_")
94
+ .replace(" ", "_")
95
+ )
91
96
  out_file = os.path.join(save_path, f"{safe_name}.png")
92
97
  plt.savefig(out_file, dpi=300)
93
- print(f"📁 Saved: {out_file}")
98
+ print(f"Saved: {out_file}")
94
99
 
95
100
  plt.show()
96
101
 
102
+
97
103
  def plot_bar_relative_risk(
98
104
  results_dict,
99
105
  sort_by_position=True,
@@ -102,7 +108,7 @@ def plot_bar_relative_risk(
102
108
  save_path=None,
103
109
  highlight_regions=None, # List of (start, end) tuples
104
110
  highlight_color="lightgray",
105
- highlight_alpha=0.3
111
+ highlight_alpha=0.3,
106
112
  ):
107
113
  """
108
114
  Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
@@ -116,10 +122,10 @@ def plot_bar_relative_risk(
116
122
  highlight_color (str): Color of shaded region.
117
123
  highlight_alpha (float): Transparency of shaded region.
118
124
  """
119
- import matplotlib.pyplot as plt
120
- import numpy as np
121
125
  import os
122
126
 
127
+ import matplotlib.pyplot as plt
128
+
123
129
  for ref, group_data in results_dict.items():
124
130
  for group_label, (df, _) in group_data.items():
125
131
  if df.empty:
@@ -127,14 +133,14 @@ def plot_bar_relative_risk(
127
133
  continue
128
134
 
129
135
  df = df.copy()
130
- df['Genomic_Position'] = df['Genomic_Position'].astype(int)
136
+ df["Genomic_Position"] = df["Genomic_Position"].astype(int)
131
137
 
132
138
  if sort_by_position:
133
- df = df.sort_values('Genomic_Position')
139
+ df = df.sort_values("Genomic_Position")
134
140
 
135
- gpc_mask = df['GpC_Site'] & ~df['CpG_Site']
136
- cpg_mask = df['CpG_Site'] & ~df['GpC_Site']
137
- both_mask = df['GpC_Site'] & df['CpG_Site']
141
+ gpc_mask = df["GpC_Site"] & ~df["CpG_Site"]
142
+ cpg_mask = df["CpG_Site"] & ~df["GpC_Site"]
143
+ both_mask = df["GpC_Site"] & df["CpG_Site"]
138
144
 
139
145
  fig, ax = plt.subplots(figsize=(14, 6))
140
146
 
@@ -145,36 +151,36 @@ def plot_bar_relative_risk(
145
151
 
146
152
  # Bar plots
147
153
  ax.bar(
148
- df['Genomic_Position'][gpc_mask],
149
- df['log2_Relative_Risk'][gpc_mask],
154
+ df["Genomic_Position"][gpc_mask],
155
+ df["log2_Relative_Risk"][gpc_mask],
150
156
  width=10,
151
- color='steelblue',
152
- label='GpC Site',
153
- edgecolor='black'
157
+ color="steelblue",
158
+ label="GpC Site",
159
+ edgecolor="black",
154
160
  )
155
161
 
156
162
  ax.bar(
157
- df['Genomic_Position'][cpg_mask],
158
- df['log2_Relative_Risk'][cpg_mask],
163
+ df["Genomic_Position"][cpg_mask],
164
+ df["log2_Relative_Risk"][cpg_mask],
159
165
  width=10,
160
- color='darkorange',
161
- label='CpG Site',
162
- edgecolor='black'
166
+ color="darkorange",
167
+ label="CpG Site",
168
+ edgecolor="black",
163
169
  )
164
170
 
165
171
  if both_mask.any():
166
172
  ax.bar(
167
- df['Genomic_Position'][both_mask],
168
- df['log2_Relative_Risk'][both_mask],
173
+ df["Genomic_Position"][both_mask],
174
+ df["log2_Relative_Risk"][both_mask],
169
175
  width=10,
170
- color='purple',
171
- label='GpC + CpG',
172
- edgecolor='black'
176
+ color="purple",
177
+ label="GpC + CpG",
178
+ edgecolor="black",
173
179
  )
174
180
 
175
- ax.axhline(y=0, color='gray', linestyle='--')
176
- ax.set_xlabel('Genomic Position')
177
- ax.set_ylabel('log2(Relative Risk)')
181
+ ax.axhline(y=0, color="gray", linestyle="--")
182
+ ax.set_xlabel("Genomic Position")
183
+ ax.set_ylabel("log2(Relative Risk)")
178
184
  ax.set_title(f"{ref} — {group_label}")
179
185
  ax.legend()
180
186
 
@@ -183,20 +189,23 @@ def plot_bar_relative_risk(
183
189
  if ylim:
184
190
  ax.set_ylim(ylim)
185
191
 
186
- ax.spines['top'].set_visible(False)
187
- ax.spines['right'].set_visible(False)
192
+ ax.spines["top"].set_visible(False)
193
+ ax.spines["right"].set_visible(False)
188
194
 
189
195
  plt.tight_layout()
190
196
 
191
197
  if save_path:
192
198
  os.makedirs(save_path, exist_ok=True)
193
- safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
199
+ safe_name = (
200
+ f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
201
+ )
194
202
  out_file = os.path.join(save_path, f"{safe_name}.png")
195
203
  plt.savefig(out_file, dpi=300)
196
204
  print(f"📁 Saved: {out_file}")
197
205
 
198
206
  plt.show()
199
207
 
208
+
200
209
  def plot_positionwise_matrix(
201
210
  adata,
202
211
  key="positionwise_result",
@@ -210,35 +219,39 @@ def plot_positionwise_matrix(
210
219
  xtick_step=10,
211
220
  ytick_step=10,
212
221
  save_path=None,
213
- highlight_position=None, # Can be a single int/float or list of them
214
- highlight_axis="row", # "row" or "column"
215
- annotate_points=False # ✅ New option
222
+ highlight_position=None, # Can be a single int/float or list of them
223
+ highlight_axis="row", # "row" or "column"
224
+ annotate_points=False, # ✅ New option
216
225
  ):
217
226
  """
218
227
  Plots positionwise matrices stored in adata.uns[key], with an optional line plot
219
228
  for specified row(s) or column(s), and highlights them on the heatmap.
220
229
  """
230
+ import os
231
+
221
232
  import matplotlib.pyplot as plt
222
- import seaborn as sns
223
233
  import numpy as np
224
234
  import pandas as pd
225
- import os
235
+ import seaborn as sns
226
236
 
227
237
  def find_closest_index(index, target):
238
+ """Find the index value closest to a target value."""
228
239
  index_vals = pd.to_numeric(index, errors="coerce")
229
240
  target_val = pd.to_numeric([target], errors="coerce")[0]
230
241
  diffs = pd.Series(np.abs(index_vals - target_val), index=index)
231
242
  return diffs.idxmin()
232
243
 
233
244
  # Ensure highlight_position is a list
234
- if highlight_position is not None and not isinstance(highlight_position, (list, tuple, np.ndarray)):
245
+ if highlight_position is not None and not isinstance(
246
+ highlight_position, (list, tuple, np.ndarray)
247
+ ):
235
248
  highlight_position = [highlight_position]
236
249
 
237
250
  for group, mat_df in adata.uns[key].items():
238
251
  mat = mat_df.copy()
239
252
 
240
253
  if log_transform:
241
- with np.errstate(divide='ignore', invalid='ignore'):
254
+ with np.errstate(divide="ignore", invalid="ignore"):
242
255
  if log_base == "log1p":
243
256
  mat = np.log1p(mat)
244
257
  elif log_base == "log2":
@@ -276,7 +289,7 @@ def plot_positionwise_matrix(
276
289
  vmin=vmin,
277
290
  vmax=vmax,
278
291
  cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
279
- ax=heat_ax
292
+ ax=heat_ax,
280
293
  )
281
294
 
282
295
  heat_ax.set_title(f"{key} — {group}", pad=20)
@@ -295,17 +308,27 @@ def plot_positionwise_matrix(
295
308
  series = mat.loc[closest]
296
309
  x_vals = pd.to_numeric(series.index, errors="coerce")
297
310
  idx = mat.index.get_loc(closest)
298
- heat_ax.axhline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
311
+ heat_ax.axhline(
312
+ idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
313
+ )
299
314
  label = f"Row {pos} → {closest}"
300
315
  else:
301
316
  closest = find_closest_index(mat.columns, pos)
302
317
  series = mat[closest]
303
318
  x_vals = pd.to_numeric(series.index, errors="coerce")
304
319
  idx = mat.columns.get_loc(closest)
305
- heat_ax.axvline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
320
+ heat_ax.axvline(
321
+ idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
322
+ )
306
323
  label = f"Col {pos} → {closest}"
307
324
 
308
- line = line_ax.plot(x_vals, series.values, marker='o', label=label, color=colors[i % len(colors)])
325
+ line = line_ax.plot(
326
+ x_vals,
327
+ series.values,
328
+ marker="o",
329
+ label=label,
330
+ color=colors[i % len(colors)],
331
+ )
309
332
 
310
333
  # Annotate each point
311
334
  if annotate_points:
@@ -316,12 +339,18 @@ def plot_positionwise_matrix(
316
339
  xy=(x, y),
317
340
  textcoords="offset points",
318
341
  xytext=(0, 5),
319
- ha='center',
320
- fontsize=8
342
+ ha="center",
343
+ fontsize=8,
321
344
  )
322
345
  except Exception as e:
323
- line_ax.text(0.5, 0.5, f"⚠️ Error plotting {highlight_axis} @ {pos}",
324
- ha='center', va='center', fontsize=10)
346
+ line_ax.text(
347
+ 0.5,
348
+ 0.5,
349
+ f"⚠️ Error plotting {highlight_axis} @ {pos}",
350
+ ha="center",
351
+ va="center",
352
+ fontsize=10,
353
+ )
325
354
  print(f"Error plotting line for {highlight_axis}={pos}: {e}")
326
355
 
327
356
  line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
@@ -342,6 +371,7 @@ def plot_positionwise_matrix(
342
371
 
343
372
  plt.show()
344
373
 
374
+
345
375
  def plot_positionwise_matrix_grid(
346
376
  adata,
347
377
  key,
@@ -356,32 +386,61 @@ def plot_positionwise_matrix_grid(
356
386
  xtick_step=10,
357
387
  ytick_step=10,
358
388
  parallel=False,
359
- max_threads=None
389
+ max_threads=None,
360
390
  ):
391
+ """Plot a grid of positionwise matrices grouped by metadata.
392
+
393
+ Args:
394
+ adata: AnnData containing matrices in ``adata.uns``.
395
+ key: Key for positionwise matrices.
396
+ outer_keys: Keys for outer grouping.
397
+ inner_keys: Keys for inner grouping.
398
+ log_transform: Optional log transform (``log2`` or ``log1p``).
399
+ vmin: Minimum color scale value.
400
+ vmax: Maximum color scale value.
401
+ cmap: Matplotlib colormap.
402
+ save_path: Optional path to save plots.
403
+ figsize: Figure size.
404
+ xtick_step: X-axis tick step.
405
+ ytick_step: Y-axis tick step.
406
+ parallel: Whether to plot in parallel.
407
+ max_threads: Max thread count for parallel plotting.
408
+ """
409
+ import os
410
+
361
411
  import matplotlib.pyplot as plt
362
- import seaborn as sns
363
412
  import numpy as np
364
413
  import pandas as pd
365
- import os
366
- from matplotlib.gridspec import GridSpec
414
+ import seaborn as sns
367
415
  from joblib import Parallel, delayed
416
+ from matplotlib.gridspec import GridSpec
368
417
 
369
418
  matrices = adata.uns[key]
370
419
  group_labels = list(matrices.keys())
371
420
 
372
- parsed_inner = pd.DataFrame([dict(zip(inner_keys, g.split("_")[-len(inner_keys):])) for g in group_labels])
373
- parsed_outer = pd.Series(["_".join(g.split("_")[:-len(inner_keys)]) for g in group_labels], name="outer")
421
+ parsed_inner = pd.DataFrame(
422
+ [dict(zip(inner_keys, g.split("_")[-len(inner_keys) :])) for g in group_labels]
423
+ )
424
+ parsed_outer = pd.Series(
425
+ ["_".join(g.split("_")[: -len(inner_keys)]) for g in group_labels], name="outer"
426
+ )
374
427
  parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
375
428
 
376
429
  def plot_one_grid(outer_label):
377
- selected = parsed[parsed['outer'] == outer_label].copy()
378
- selected["group_str"] = [f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}" for _, row in selected.iterrows()]
430
+ """Plot one grid for a specific outer label."""
431
+ selected = parsed[parsed["outer"] == outer_label].copy()
432
+ selected["group_str"] = [
433
+ f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}"
434
+ for _, row in selected.iterrows()
435
+ ]
379
436
 
380
437
  row_vals = sorted(selected[inner_keys[0]].unique())
381
438
  col_vals = sorted(selected[inner_keys[1]].unique())
382
439
 
383
440
  fig = plt.figure(figsize=figsize)
384
- gs = GridSpec(len(row_vals), len(col_vals) + 1, width_ratios=[1]*len(col_vals) + [0.05], wspace=0.3)
441
+ gs = GridSpec(
442
+ len(row_vals), len(col_vals) + 1, width_ratios=[1] * len(col_vals) + [0.05], wspace=0.3
443
+ )
385
444
  axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
386
445
 
387
446
  local_vmin, local_vmax = vmin, vmax
@@ -397,10 +456,7 @@ def plot_positionwise_matrix_grid(
397
456
  local_vmin = -vmax_auto if vmin is None else vmin
398
457
  local_vmax = vmax_auto if vmax is None else vmax
399
458
 
400
- cbar_label = {
401
- "log2": "log2(Value)",
402
- "log1p": "log1p(Value)"
403
- }.get(log_transform, "Value")
459
+ cbar_label = {"log2": "log2(Value)", "log1p": "log1p(Value)"}.get(log_transform, "Value")
404
460
 
405
461
  cbar_ax = fig.add_subplot(gs[:, -1])
406
462
 
@@ -431,9 +487,11 @@ def plot_positionwise_matrix_grid(
431
487
  vmax=local_vmax,
432
488
  cbar=(i == 0 and j == 0),
433
489
  cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
434
- cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
490
+ cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""},
491
+ )
492
+ ax.set_title(
493
+ f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8
435
494
  )
436
- ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
437
495
 
438
496
  xticks = data.columns.astype(int)
439
497
  yticks = data.index.astype(int)
@@ -448,15 +506,17 @@ def plot_positionwise_matrix_grid(
448
506
  if save_path:
449
507
  os.makedirs(save_path, exist_ok=True)
450
508
  fname = outer_label.replace("_", "").replace("=", "") + ".png"
451
- plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
452
- print(f"Saved {fname}")
509
+ plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches="tight")
510
+ print(f"Saved {fname}")
453
511
 
454
512
  plt.close(fig)
455
513
 
456
514
  if parallel:
457
- Parallel(n_jobs=max_threads)(delayed(plot_one_grid)(outer_label) for outer_label in parsed['outer'].unique())
515
+ Parallel(n_jobs=max_threads)(
516
+ delayed(plot_one_grid)(outer_label) for outer_label in parsed["outer"].unique()
517
+ )
458
518
  else:
459
- for outer_label in parsed['outer'].unique():
519
+ for outer_label in parsed["outer"].unique():
460
520
  plot_one_grid(outer_label)
461
521
 
462
- print("Finished plotting all grids.")
522
+ print("Finished plotting all grids.")