smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Optional
2
4
 
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from smftools.optional_imports import require
9
+
10
+
3
11
  def plot_spatial_autocorr_grid(
4
12
  adata,
5
13
  out_dir: str,
@@ -14,6 +22,7 @@ def plot_spatial_autocorr_grid(
14
22
  references: Optional[list] = None,
15
23
  annotate_periodicity: bool = True,
16
24
  counts_key_suffix: str = "_counts",
25
+ normalization_method: str = "pearson",
17
26
  # plotting thresholds
18
27
  plot_min_count: int = 10,
19
28
  ):
@@ -28,14 +37,15 @@ def plot_spatial_autocorr_grid(
28
37
  fall back to running the analyzer for that group (slow) and cache the result into adata.uns.
29
38
  """
30
39
  import os
31
- import numpy as np
32
- import pandas as pd
33
- import matplotlib.pyplot as plt
34
40
  import warnings
35
41
 
42
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="autocorrelation plots")
43
+
36
44
  # Try importing analyzer (used only as fallback)
37
45
  try:
38
- from ..tools.spatial_autocorrelation import analyze_autocorr_matrix # prefer packaged analyzer
46
+ from ..tools.spatial_autocorrelation import (
47
+ analyze_autocorr_matrix,
48
+ ) # prefer packaged analyzer
39
49
  except Exception:
40
50
  analyze_autocorr_matrix = globals().get("analyze_autocorr_matrix", None)
41
51
 
@@ -44,6 +54,7 @@ def plot_spatial_autocorr_grid(
44
54
 
45
55
  # small rolling average helper for smoother visualization
46
56
  def _rolling_1d(arr: np.ndarray, win: int) -> np.ndarray:
57
+ """Compute a rolling mean with NaN-aware normalization."""
47
58
  if win <= 1:
48
59
  return arr
49
60
  valid = np.isfinite(arr).astype(float)
@@ -58,6 +69,7 @@ def plot_spatial_autocorr_grid(
58
69
 
59
70
  # group summary extractor: returns (lags, mean_curve_smoothed, std_curve_smoothed, counts_block_or_None)
60
71
  def _compute_group_summary_for_mask(site: str, mask: np.ndarray):
72
+ """Extract summary curves for a site and mask."""
61
73
  obsm_key = f"{site}_spatial_autocorr"
62
74
  lags_key = f"{site}_spatial_autocorr_lags"
63
75
  counts_key = f"{site}_spatial_autocorr{counts_key_suffix}"
@@ -75,7 +87,12 @@ def plot_spatial_autocorr_grid(
75
87
  if counts_key in adata.obsm:
76
88
  counts_mat = np.asarray(adata.obsm[counts_key])
77
89
  counts = counts_mat[mask, :].astype(int)
78
- return np.asarray(adata.uns[lags_key]), _rolling_1d(mean_per_lag, window), _rolling_1d(std_per_lag, window), counts
90
+ return (
91
+ np.asarray(adata.uns[lags_key]),
92
+ _rolling_1d(mean_per_lag, window),
93
+ _rolling_1d(std_per_lag, window),
94
+ counts,
95
+ )
79
96
 
80
97
  # samples meta
81
98
  if sample_col not in adata.obs:
@@ -116,7 +133,8 @@ def plot_spatial_autocorr_grid(
116
133
  nrows = len(chunk)
117
134
 
118
135
  fig, axes = plt.subplots(
119
- nrows=nrows, ncols=ncols,
136
+ nrows=nrows,
137
+ ncols=ncols,
120
138
  figsize=(4.2 * ncols, 2.4 * nrows),
121
139
  dpi=dpi,
122
140
  squeeze=False,
@@ -141,9 +159,9 @@ def plot_spatial_autocorr_grid(
141
159
  ax = axes[r, col_idx]
142
160
 
143
161
  # compute mask
144
- sample_mask = (adata.obs[sample_col].values == sample_name)
162
+ sample_mask = adata.obs[sample_col].values == sample_name
145
163
  if col_kind == "ref":
146
- ref_mask = (adata.obs[reference_col].values == col_val)
164
+ ref_mask = adata.obs[reference_col].values == col_val
147
165
  mask = sample_mask & ref_mask
148
166
  else:
149
167
  mask = sample_mask
@@ -152,7 +170,9 @@ def plot_spatial_autocorr_grid(
152
170
  n_reads_grp = int(mask.sum())
153
171
 
154
172
  # group summary (mean/std and counts_block)
155
- lags_local, mean_curve, std_curve, counts_block = _compute_group_summary_for_mask(site, mask)
173
+ lags_local, mean_curve, std_curve, counts_block = (
174
+ _compute_group_summary_for_mask(site, mask)
175
+ )
156
176
 
157
177
  # plot title for top row
158
178
  if r == 0:
@@ -164,9 +184,12 @@ def plot_spatial_autocorr_grid(
164
184
  ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=8)
165
185
  ax.set_xlim(0, 1)
166
186
  ax.set_xlabel("Lag (bp)", fontsize=7)
167
- ax.tick_params(axis='both', which='major', labelsize=6)
187
+ ax.set_ylabel(
188
+ f"Autocorrelation {normalization_method} normalized", fontsize=7
189
+ )
190
+ ax.tick_params(axis="both", which="major", labelsize=6)
168
191
  ax.grid(True, alpha=0.22)
169
- #col_idx += 1
192
+ # col_idx += 1
170
193
  continue
171
194
 
172
195
  # mask low-support lags if counts available
@@ -186,7 +209,13 @@ def plot_spatial_autocorr_grid(
186
209
 
187
210
  # plot a faint grey line for the low-support regions (context only)
188
211
  if low_support.any():
189
- ax.plot(lags_local[low_support], mean_curve_smooth[low_support], color="0.85", lw=0.6, label="_nolegend_")
212
+ ax.plot(
213
+ lags_local[low_support],
214
+ mean_curve_smooth[low_support],
215
+ color="0.85",
216
+ lw=0.6,
217
+ label="_nolegend_",
218
+ )
190
219
 
191
220
  # plot mean (high-support only) and +/- std (std is computed from all molecules)
192
221
  ax.plot(lags_local, mean_plot, lw=1.1)
@@ -201,16 +230,25 @@ def plot_spatial_autocorr_grid(
201
230
  # metrics_by_group_precomp can be dict-like
202
231
  res = metrics_by_group_precomp.get(group_key, None)
203
232
 
204
- if res is None and annotate_periodicity and (analyze_autocorr_matrix is not None) and (ac_full is not None):
233
+ if (
234
+ res is None
235
+ and annotate_periodicity
236
+ and (analyze_autocorr_matrix is not None)
237
+ and (ac_full is not None)
238
+ ):
205
239
  # fallback: run analyzer on the subset (warn + cache)
206
240
  ac_sel = ac_full[mask, :]
207
241
  cnt_sel = counts_full[mask, :] if counts_full is not None else None
208
242
  if ac_sel.size:
209
- warnings.warn(f"Precomputed periodicity metrics for {site} {group_key} not found — running analyzer as fallback (slow).")
243
+ warnings.warn(
244
+ f"Precomputed periodicity metrics for {site} {group_key} not found — running analyzer as fallback (slow)."
245
+ )
210
246
  try:
211
247
  res = analyze_autocorr_matrix(
212
248
  ac_sel,
213
- cnt_sel if cnt_sel is not None else np.zeros_like(ac_sel, dtype=int),
249
+ cnt_sel
250
+ if cnt_sel is not None
251
+ else np.zeros_like(ac_sel, dtype=int),
214
252
  lags_local,
215
253
  nrl_search_bp=(120, 260),
216
254
  pad_factor=4,
@@ -239,19 +277,38 @@ def plot_spatial_autocorr_grid(
239
277
 
240
278
  # vertical NRL line & harmonics (safe check)
241
279
  if (nrl is not None) and np.isfinite(nrl):
242
- ax.axvline(float(nrl), color="C3", linestyle="--", linewidth=1.0, alpha=0.9)
280
+ ax.axvline(
281
+ float(nrl), color="C3", linestyle="--", linewidth=1.0, alpha=0.9
282
+ )
243
283
  for m in range(2, 5):
244
- ax.axvline(float(nrl) * m, color="C3", linestyle=":", linewidth=0.7, alpha=0.6)
284
+ ax.axvline(
285
+ float(nrl) * m,
286
+ color="C3",
287
+ linestyle=":",
288
+ linewidth=0.7,
289
+ alpha=0.6,
290
+ )
245
291
 
246
292
  # envelope points + fitted exponential
247
293
  if sample_lags.size:
248
294
  ax.scatter(sample_lags, envelope_heights, s=18, color="C2")
249
- if (xi_val is not None) and np.isfinite(xi_val) and np.isfinite(res.get("xi_A", np.nan)):
295
+ if (
296
+ (xi_val is not None)
297
+ and np.isfinite(xi_val)
298
+ and np.isfinite(res.get("xi_A", np.nan))
299
+ ):
250
300
  A = float(res.get("xi_A", np.nan))
251
301
  xi_val = float(xi_val)
252
302
  env_x = np.linspace(np.min(sample_lags), np.max(sample_lags), 200)
253
303
  env_y = A * np.exp(-env_x / xi_val)
254
- ax.plot(env_x, env_y, linestyle="--", color="C2", linewidth=1.0, alpha=0.9)
304
+ ax.plot(
305
+ env_x,
306
+ env_y,
307
+ linestyle="--",
308
+ color="C2",
309
+ linewidth=1.0,
310
+ alpha=0.9,
311
+ )
255
312
 
256
313
  # inset PSD plotted vs NRL (linear x-axis)
257
314
  freqs = res.get("freqs", None)
@@ -266,7 +323,12 @@ def plot_spatial_autocorr_grid(
266
323
  nrl_vals = 1.0 / freqs[valid] # convert freq -> NRL (bp)
267
324
  inset.plot(nrl_vals, power[valid], lw=0.7)
268
325
  if peak_f is not None and peak_f > 0:
269
- inset.axvline(1.0 / float(peak_f), color="C3", linestyle="--", linewidth=0.8)
326
+ inset.axvline(
327
+ 1.0 / float(peak_f),
328
+ color="C3",
329
+ linestyle="--",
330
+ linewidth=0.8,
331
+ )
270
332
  # choose a reasonable linear x-limits (prefer typical NRL range but fallback to data)
271
333
  default_xlim = (60, 400)
272
334
  data_xlim = (float(np.nanmin(nrl_vals)), 600)
@@ -278,17 +340,29 @@ def plot_spatial_autocorr_grid(
278
340
  inset.set_ylabel("power", fontsize=6)
279
341
  inset.tick_params(labelsize=6)
280
342
  if (snr is not None) and np.isfinite(snr):
281
- inset.text(0.95, 0.95, f"SNR={float(snr):.1f}", transform=inset.transAxes,
282
- ha="right", va="top", fontsize=6, bbox=dict(facecolor="white", alpha=0.6, edgecolor="none"))
343
+ inset.text(
344
+ 0.95,
345
+ 0.95,
346
+ f"SNR={float(snr):.1f}",
347
+ transform=inset.transAxes,
348
+ ha="right",
349
+ va="top",
350
+ fontsize=6,
351
+ bbox=dict(facecolor="white", alpha=0.6, edgecolor="none"),
352
+ )
283
353
 
284
354
  # set x-limits based on finite lags
285
355
  finite_mask = np.isfinite(lags_local)
286
356
  if finite_mask.any():
287
- ax.set_xlim(float(np.nanmin(lags_local[finite_mask])), float(np.nanmax(lags_local[finite_mask])))
357
+ ax.set_xlim(
358
+ float(np.nanmin(lags_local[finite_mask])),
359
+ float(np.nanmax(lags_local[finite_mask])),
360
+ )
288
361
 
289
362
  # small cosmetics
290
363
  ax.set_xlabel("Lag (bp)", fontsize=7)
291
- ax.tick_params(axis='both', which='major', labelsize=6)
364
+ ax.set_ylabel(f"Autocorrelation {normalization_method} normalized", fontsize=7)
365
+ ax.tick_params(axis="both", which="major", labelsize=6)
292
366
  ax.grid(True, alpha=0.22)
293
367
 
294
368
  col_idx += 1
@@ -301,9 +375,13 @@ def plot_spatial_autocorr_grid(
301
375
  ycenter = pos.y0 + pos.height / 2.0
302
376
  n_reads_grp = int((adata.obs[sample_col].values == sample_name).sum())
303
377
  label = f"{sample_name}\n(n={n_reads_grp})"
304
- fig.text(0.02, ycenter, label, va='center', ha='left', rotation='vertical', fontsize=9)
378
+ fig.text(0.02, ycenter, label, va="center", ha="left", rotation="vertical", fontsize=9)
305
379
 
306
- fig.suptitle("Spatial autocorrelation by sample × (site_type × reference)", y=0.995, fontsize=11)
380
+ fig.suptitle(
381
+ f"Spatial autocorrelation ({normalization_method}) by sample × (site_type × reference)",
382
+ y=0.995,
383
+ fontsize=11,
384
+ )
307
385
 
308
386
  page_idx = start_idx // rows_per_fig + 1
309
387
  out_png = os.path.join(out_dir, f"{filename_prefix}_page{page_idx}.png")
@@ -365,6 +443,7 @@ def plot_spatial_autocorr_grid(
365
443
  return arr.tolist()
366
444
 
367
445
  def _safe_float(x):
446
+ """Coerce a value to float, returning NaN on failure."""
368
447
  try:
369
448
  return float(x)
370
449
  except Exception:
@@ -381,15 +460,33 @@ def plot_spatial_autocorr_grid(
381
460
  "site": site,
382
461
  "sample": sample_name,
383
462
  "reference": ref,
384
- "nrl_bp": _safe_float(entry.get("nrl_bp", float("nan"))) if entry is not None else float("nan"),
385
- "snr": _safe_float(entry.get("snr", float("nan"))) if entry is not None else float("nan"),
386
- "fwhm_bp": _safe_float(entry.get("fwhm_bp", float("nan"))) if entry is not None else float("nan"),
387
- "xi": _safe_float(entry.get("xi", float("nan"))) if entry is not None else float("nan"),
388
- "xi_A": _safe_float(entry.get("xi_A", float("nan"))) if entry is not None else float("nan"),
389
- "xi_r2": _safe_float(entry.get("xi_r2", float("nan"))) if entry is not None else float("nan"),
390
- "envelope_sample_lags": ";".join(map(str, env_lags_list)) if len(env_lags_list) else "",
391
- "envelope_heights": ";".join(map(str, env_heights_list)) if len(env_heights_list) else "",
392
- "analyzer_error": entry.get("error", entry.get("analyzer_error", None)) if entry is not None else "no_metrics",
463
+ "nrl_bp": _safe_float(entry.get("nrl_bp", float("nan")))
464
+ if entry is not None
465
+ else float("nan"),
466
+ "snr": _safe_float(entry.get("snr", float("nan")))
467
+ if entry is not None
468
+ else float("nan"),
469
+ "fwhm_bp": _safe_float(entry.get("fwhm_bp", float("nan")))
470
+ if entry is not None
471
+ else float("nan"),
472
+ "xi": _safe_float(entry.get("xi", float("nan")))
473
+ if entry is not None
474
+ else float("nan"),
475
+ "xi_A": _safe_float(entry.get("xi_A", float("nan")))
476
+ if entry is not None
477
+ else float("nan"),
478
+ "xi_r2": _safe_float(entry.get("xi_r2", float("nan")))
479
+ if entry is not None
480
+ else float("nan"),
481
+ "envelope_sample_lags": ";".join(map(str, env_lags_list))
482
+ if len(env_lags_list)
483
+ else "",
484
+ "envelope_heights": ";".join(map(str, env_heights_list))
485
+ if len(env_heights_list)
486
+ else "",
487
+ "analyzer_error": entry.get("error", entry.get("analyzer_error", None))
488
+ if entry is not None
489
+ else "no_metrics",
393
490
  }
394
491
  rows.append(row)
395
492
  combined_rows.append(row)
@@ -404,6 +501,7 @@ def plot_spatial_autocorr_grid(
404
501
  except Exception as e:
405
502
  # don't fail the whole pipeline for a single write error; log and continue
406
503
  import warnings
504
+
407
505
  warnings.warn(f"Failed to write {out_csv}: {e}")
408
506
 
409
507
  # write the single combined CSV (one row per sample x ref x site)
@@ -413,16 +511,19 @@ def plot_spatial_autocorr_grid(
413
511
  combined_df.to_csv(combined_out, index=False)
414
512
  except Exception as e:
415
513
  import warnings
514
+
416
515
  warnings.warn(f"Failed to write combined CSV {combined_out}: {e}")
417
516
 
418
517
  return saved_pages
419
518
 
519
+
420
520
  def plot_rolling_metrics(df, out_png=None, title=None, figsize=(10, 3.5), dpi=160, show=False):
421
521
  """
422
522
  Plot NRL and SNR vs window center from the dataframe returned by rolling_autocorr_metrics.
423
523
  If out_png is None, returns the matplotlib Figure object; otherwise saves PNG and returns path.
424
524
  """
425
- import matplotlib.pyplot as plt
525
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="autocorrelation plots")
526
+
426
527
  # sort by center
427
528
  df2 = df.sort_values("center")
428
529
  x = df2["center"].values
@@ -446,16 +547,16 @@ def plot_rolling_metrics(df, out_png=None, title=None, figsize=(10, 3.5), dpi=16
446
547
  if out_png:
447
548
  fig.savefig(out_png, bbox_inches="tight")
448
549
  if not show:
449
- import matplotlib
550
+ matplotlib = require("matplotlib", extra="plotting", purpose="autocorrelation plots")
551
+
450
552
  matplotlib.pyplot.close(fig)
451
553
  return out_png
452
554
  if not show:
453
- import matplotlib
555
+ matplotlib = require("matplotlib", extra="plotting", purpose="autocorrelation plots")
556
+
454
557
  matplotlib.pyplot.close(fig)
455
558
  return fig
456
559
 
457
- import numpy as np
458
- import pandas as pd
459
560
 
460
561
  def plot_rolling_grid(
461
562
  rolling_dict,
@@ -502,10 +603,8 @@ def plot_rolling_grid(
502
603
  pages_by_metric : dict mapping metric -> [out_png_paths]
503
604
  """
504
605
  import os
505
- import math
506
- import matplotlib.pyplot as plt
507
- import numpy as np
508
- import pandas as pd
606
+
607
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="autocorrelation plots")
509
608
 
510
609
  if per_metric_ylim is None:
511
610
  per_metric_ylim = {}
@@ -520,7 +619,7 @@ def plot_rolling_grid(
520
619
 
521
620
  # normalize reference labels and keep mapping to original
522
621
  label_to_orig = {}
523
- for (_sample, ref) in keys:
622
+ for _sample, ref in keys:
524
623
  label = "all" if (ref is None) else str(ref)
525
624
  if label not in label_to_orig:
526
625
  label_to_orig[label] = ref
@@ -532,7 +631,11 @@ def plot_rolling_grid(
532
631
  # reference labels ordering
533
632
  default_ref_labels = sorted(label_to_orig.keys(), key=lambda s: s)
534
633
  if reference_order is not None:
535
- ref_labels = [("all" if r is None else str(r)) for r in reference_order if (("all" if r is None else str(r)) in label_to_orig)]
634
+ ref_labels = [
635
+ ("all" if r is None else str(r))
636
+ for r in reference_order
637
+ if (("all" if r is None else str(r)) in label_to_orig)
638
+ ]
536
639
  else:
537
640
  ref_labels = default_ref_labels
538
641
 
@@ -553,9 +656,11 @@ def plot_rolling_grid(
553
656
  nrows = len(page_samples)
554
657
 
555
658
  fig, axes = plt.subplots(
556
- nrows=nrows, ncols=cols_per_page,
659
+ nrows=nrows,
660
+ ncols=cols_per_page,
557
661
  figsize=(figsize_per_panel[0] * cols_per_page, figsize_per_panel[1] * nrows),
558
- dpi=dpi, squeeze=False
662
+ dpi=dpi,
663
+ squeeze=False,
559
664
  )
560
665
 
561
666
  for i, sample in enumerate(page_samples):