smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,609 @@
1
+ from typing import Optional
2
+
3
+ def plot_spatial_autocorr_grid(
4
+ adata,
5
+ out_dir: str,
6
+ site_types=("GpC", "CpG", "any_C"),
7
+ sample_col: str = "Sample",
8
+ reference_col: str = "Reference_strand",
9
+ window: int = 25,
10
+ rows_per_fig: int = 6,
11
+ dpi: int = 160,
12
+ filename_prefix: str = "autocorr_grid",
13
+ include_combined_column: bool = True,
14
+ references: Optional[list] = None,
15
+ annotate_periodicity: bool = True,
16
+ counts_key_suffix: str = "_counts",
17
+ # plotting thresholds
18
+ plot_min_count: int = 10,
19
+ ):
20
+ """
21
+ Plot a grid of mean spatial autocorrelations per sample × (site_type × reference).
22
+ Expects preprocessing to have created:
23
+ - adata.obsm[f"{site}_spatial_autocorr"] -> (n_molecules, n_lags) float32
24
+ - adata.obsm[f"{site}_spatial_autocorr_counts"] -> (n_molecules, n_lags) int32 (optional)
25
+ - adata.uns[f"{site}_spatial_autocorr_lags"] -> 1D lags array
26
+ - adata.uns[f"{site}_spatial_periodicity_metrics_by_group"] -> dict keyed by (sample, ref)
27
+ If per-group metrics are missing and `analyze_autocorr_matrix` is importable, the function will
28
+ fall back to running the analyzer for that group (slow) and cache the result into adata.uns.
29
+ """
30
+ import os
31
+ import numpy as np
32
+ import pandas as pd
33
+ import matplotlib.pyplot as plt
34
+ import warnings
35
+
36
+ # Try importing analyzer (used only as fallback)
37
+ try:
38
+ from ..tools.spatial_autocorrelation import analyze_autocorr_matrix # prefer packaged analyzer
39
+ except Exception:
40
+ analyze_autocorr_matrix = globals().get("analyze_autocorr_matrix", None)
41
+
42
+ os.makedirs(out_dir, exist_ok=True)
43
+ site_types = list(site_types)
44
+
45
+ # small rolling average helper for smoother visualization
46
+ def _rolling_1d(arr: np.ndarray, win: int) -> np.ndarray:
47
+ if win <= 1:
48
+ return arr
49
+ valid = np.isfinite(arr).astype(float)
50
+ arr_z = np.nan_to_num(arr, nan=0.0)
51
+ k = np.ones(win, dtype=float)
52
+ num = np.convolve(arr_z, k, mode="same")
53
+ den = np.convolve(valid, k, mode="same")
54
+ with np.errstate(invalid="ignore", divide="ignore"):
55
+ out = num / den
56
+ out[den == 0] = np.nan
57
+ return out
58
+
59
+ # group summary extractor: returns (lags, mean_curve_smoothed, std_curve_smoothed, counts_block_or_None)
60
+ def _compute_group_summary_for_mask(site: str, mask: np.ndarray):
61
+ obsm_key = f"{site}_spatial_autocorr"
62
+ lags_key = f"{site}_spatial_autocorr_lags"
63
+ counts_key = f"{site}_spatial_autocorr{counts_key_suffix}"
64
+ if obsm_key not in adata.obsm or lags_key not in adata.uns:
65
+ return None, None, None, None
66
+ mat = np.asarray(adata.obsm[obsm_key])
67
+ if mat.size == 0:
68
+ return None, None, None, None
69
+ sel = mat[mask, :]
70
+ if sel.size == 0:
71
+ return None, None, None, None
72
+ mean_per_lag = np.nanmean(sel, axis=0)
73
+ std_per_lag = np.nanstd(sel, axis=0, ddof=1)
74
+ counts = None
75
+ if counts_key in adata.obsm:
76
+ counts_mat = np.asarray(adata.obsm[counts_key])
77
+ counts = counts_mat[mask, :].astype(int)
78
+ return np.asarray(adata.uns[lags_key]), _rolling_1d(mean_per_lag, window), _rolling_1d(std_per_lag, window), counts
79
+
80
+ # samples meta
81
+ if sample_col not in adata.obs:
82
+ raise KeyError(f"sample_col '{sample_col}' not present in adata.obs")
83
+ samples = adata.obs[sample_col]
84
+ if not pd.api.types.is_categorical_dtype(samples):
85
+ samples = samples.astype("category")
86
+ sample_levels = list(samples.cat.categories)
87
+
88
+ # references meta
89
+ if reference_col not in adata.obs:
90
+ raise KeyError(f"reference_col '{reference_col}' not present in adata.obs")
91
+ if references is None:
92
+ refs_series = adata.obs[reference_col]
93
+ if not pd.api.types.is_categorical_dtype(refs_series):
94
+ refs_series = refs_series.astype("category")
95
+ references = list(refs_series.cat.categories)
96
+ references = list(references)
97
+
98
+ # build column metadata
99
+ group_column_meta = []
100
+ for site in site_types:
101
+ cols = []
102
+ if include_combined_column:
103
+ cols.append(("all", None))
104
+ for r in references:
105
+ cols.append(("ref", r))
106
+ group_column_meta.append((site, cols))
107
+
108
+ ncols = sum(len(cols) for _, cols in group_column_meta)
109
+ saved_pages = []
110
+ # metrics_cache for fallback-computed entries (persisted at end)
111
+ metrics_cache = {site: {} for site in site_types}
112
+
113
+ # Iterate pages
114
+ for start_idx in range(0, len(sample_levels), rows_per_fig):
115
+ chunk = sample_levels[start_idx : start_idx + rows_per_fig]
116
+ nrows = len(chunk)
117
+
118
+ fig, axes = plt.subplots(
119
+ nrows=nrows, ncols=ncols,
120
+ figsize=(4.2 * ncols, 2.4 * nrows),
121
+ dpi=dpi,
122
+ squeeze=False,
123
+ )
124
+
125
+ col_idx = 0
126
+ # per-site prefetching (avoid repeated conversion)
127
+ for site, cols in group_column_meta:
128
+ obsm_key = f"{site}_spatial_autocorr"
129
+ counts_key = f"{site}_spatial_autocorr{counts_key_suffix}"
130
+ lags_key = f"{site}_spatial_autocorr_lags"
131
+ ac_full = np.asarray(adata.obsm[obsm_key]) if obsm_key in adata.obsm else None
132
+ counts_full = np.asarray(adata.obsm[counts_key]) if counts_key in adata.obsm else None
133
+ lags = np.asarray(adata.uns[lags_key]) if lags_key in adata.uns else None
134
+
135
+ # metrics_by_group may already exist (precomputed)
136
+ metrics_by_group_key = f"{site}_spatial_periodicity_metrics_by_group"
137
+ metrics_by_group_precomp = adata.uns.get(metrics_by_group_key, None)
138
+
139
+ for col_kind, col_val in cols:
140
+ for r, sample_name in enumerate(chunk):
141
+ ax = axes[r, col_idx]
142
+
143
+ # compute mask
144
+ sample_mask = (adata.obs[sample_col].values == sample_name)
145
+ if col_kind == "ref":
146
+ ref_mask = (adata.obs[reference_col].values == col_val)
147
+ mask = sample_mask & ref_mask
148
+ else:
149
+ mask = sample_mask
150
+
151
+ # count molecules
152
+ n_reads_grp = int(mask.sum())
153
+
154
+ # group summary (mean/std and counts_block)
155
+ lags_local, mean_curve, std_curve, counts_block = _compute_group_summary_for_mask(site, mask)
156
+
157
+ # plot title for top row
158
+ if r == 0:
159
+ title = f"{site} (all refs)" if col_kind == "all" else f"{site} [{col_val}]"
160
+ ax.set_title(title, fontsize=9)
161
+
162
+ # handle no-data
163
+ if lags_local is None:
164
+ ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=8)
165
+ ax.set_xlim(0, 1)
166
+ ax.set_xlabel("Lag (bp)", fontsize=7)
167
+ ax.tick_params(axis='both', which='major', labelsize=6)
168
+ ax.grid(True, alpha=0.22)
169
+ #col_idx += 1
170
+ continue
171
+
172
+ # mask low-support lags if counts available
173
+ mean_plot = mean_curve.copy()
174
+ if counts_block is not None:
175
+ # counts_block shape: (n_molecules_in_group, n_lags)
176
+ support = counts_block.sum(axis=0)
177
+ low_support = support < plot_min_count
178
+ high_support = ~low_support
179
+
180
+ # smooth the original mean once for plotting context
181
+ mean_curve_smooth = _rolling_1d(mean_curve, window)
182
+
183
+ # mask the smoothed mean to only show high-support points as the main trace
184
+ mean_plot = mean_curve_smooth.copy()
185
+ mean_plot[low_support] = np.nan
186
+
187
+ # plot a faint grey line for the low-support regions (context only)
188
+ if low_support.any():
189
+ ax.plot(lags_local[low_support], mean_curve_smooth[low_support], color="0.85", lw=0.6, label="_nolegend_")
190
+
191
+ # plot mean (high-support only) and +/- std (std is computed from all molecules)
192
+ ax.plot(lags_local, mean_plot, lw=1.1)
193
+ upper = mean_curve + std_curve
194
+ lower = mean_curve - std_curve
195
+ ax.fill_between(lags_local, lower, upper, alpha=0.18)
196
+
197
+ # ---------- use precomputed metrics if present, otherwise fallback ----------
198
+ group_key = (sample_name, None if col_kind == "all" else col_val)
199
+ res = None
200
+ if metrics_by_group_precomp is not None:
201
+ # metrics_by_group_precomp can be dict-like
202
+ res = metrics_by_group_precomp.get(group_key, None)
203
+
204
+ if res is None and annotate_periodicity and (analyze_autocorr_matrix is not None) and (ac_full is not None):
205
+ # fallback: run analyzer on the subset (warn + cache)
206
+ ac_sel = ac_full[mask, :]
207
+ cnt_sel = counts_full[mask, :] if counts_full is not None else None
208
+ if ac_sel.size:
209
+ warnings.warn(f"Precomputed periodicity metrics for {site} {group_key} not found — running analyzer as fallback (slow).")
210
+ try:
211
+ res = analyze_autocorr_matrix(
212
+ ac_sel,
213
+ cnt_sel if cnt_sel is not None else np.zeros_like(ac_sel, dtype=int),
214
+ lags_local,
215
+ nrl_search_bp=(120, 260),
216
+ pad_factor=4,
217
+ min_count=plot_min_count,
218
+ max_harmonics=6,
219
+ )
220
+ except Exception as e:
221
+ res = {"error": str(e)}
222
+ # cache into adata.uns for future plotting runs
223
+ if metrics_by_group_precomp is None:
224
+ adata.uns[metrics_by_group_key] = {}
225
+ metrics_by_group_precomp = adata.uns[metrics_by_group_key]
226
+ metrics_by_group_precomp[group_key] = res
227
+ # also record in local metrics_cache for persistence at the end
228
+ metrics_cache[site][group_key] = res
229
+
230
+ # overlay periodicity annotations if available and valid
231
+ if annotate_periodicity and (res is not None) and ("error" not in res):
232
+ # safe array conversion
233
+ sample_lags = np.asarray(res.get("envelope_sample_lags", np.array([])))
234
+ envelope_heights = np.asarray(res.get("envelope_heights", np.array([])))
235
+ nrl = res.get("nrl_bp", None)
236
+ xi_val = res.get("xi", None)
237
+ snr = res.get("snr", None)
238
+ fwhm_bp = res.get("fwhm_bp", None)
239
+
240
+ # vertical NRL line & harmonics (safe check)
241
+ if (nrl is not None) and np.isfinite(nrl):
242
+ ax.axvline(float(nrl), color="C3", linestyle="--", linewidth=1.0, alpha=0.9)
243
+ for m in range(2, 5):
244
+ ax.axvline(float(nrl) * m, color="C3", linestyle=":", linewidth=0.7, alpha=0.6)
245
+
246
+ # envelope points + fitted exponential
247
+ if sample_lags.size:
248
+ ax.scatter(sample_lags, envelope_heights, s=18, color="C2")
249
+ if (xi_val is not None) and np.isfinite(xi_val) and np.isfinite(res.get("xi_A", np.nan)):
250
+ A = float(res.get("xi_A", np.nan))
251
+ xi_val = float(xi_val)
252
+ env_x = np.linspace(np.min(sample_lags), np.max(sample_lags), 200)
253
+ env_y = A * np.exp(-env_x / xi_val)
254
+ ax.plot(env_x, env_y, linestyle="--", color="C2", linewidth=1.0, alpha=0.9)
255
+
256
+ # inset PSD plotted vs NRL (linear x-axis)
257
+ freqs = res.get("freqs", None)
258
+ power = res.get("power", None)
259
+ peak_f = res.get("f0", None)
260
+ if freqs is not None and power is not None:
261
+ freqs = np.asarray(freqs)
262
+ power = np.asarray(power)
263
+ valid = (freqs > 0) & np.isfinite(freqs) & np.isfinite(power)
264
+ if valid.any():
265
+ inset = ax.inset_axes([0.62, 0.58, 0.36, 0.37])
266
+ nrl_vals = 1.0 / freqs[valid] # convert freq -> NRL (bp)
267
+ inset.plot(nrl_vals, power[valid], lw=0.7)
268
+ if peak_f is not None and peak_f > 0:
269
+ inset.axvline(1.0 / float(peak_f), color="C3", linestyle="--", linewidth=0.8)
270
+ # choose a reasonable linear x-limits (prefer typical NRL range but fallback to data)
271
+ default_xlim = (60, 400)
272
+ data_xlim = (float(np.nanmin(nrl_vals)), 600)
273
+ # pick intersection/covering range
274
+ left = min(default_xlim[0], data_xlim[0])
275
+ right = max(default_xlim[1], data_xlim[1])
276
+ inset.set_xlim(left, right)
277
+ inset.set_xlabel("NRL (bp)", fontsize=6)
278
+ inset.set_ylabel("power", fontsize=6)
279
+ inset.tick_params(labelsize=6)
280
+ if (snr is not None) and np.isfinite(snr):
281
+ inset.text(0.95, 0.95, f"SNR={float(snr):.1f}", transform=inset.transAxes,
282
+ ha="right", va="top", fontsize=6, bbox=dict(facecolor="white", alpha=0.6, edgecolor="none"))
283
+
284
+ # set x-limits based on finite lags
285
+ finite_mask = np.isfinite(lags_local)
286
+ if finite_mask.any():
287
+ ax.set_xlim(float(np.nanmin(lags_local[finite_mask])), float(np.nanmax(lags_local[finite_mask])))
288
+
289
+ # small cosmetics
290
+ ax.set_xlabel("Lag (bp)", fontsize=7)
291
+ ax.tick_params(axis='both', which='major', labelsize=6)
292
+ ax.grid(True, alpha=0.22)
293
+
294
+ col_idx += 1
295
+
296
+ # layout and left-hand sample labels
297
+ fig.tight_layout(rect=[0.06, 0, 1, 0.97])
298
+ for r, sample_name in enumerate(chunk):
299
+ first_ax = axes[r, 0]
300
+ pos = first_ax.get_position()
301
+ ycenter = pos.y0 + pos.height / 2.0
302
+ n_reads_grp = int((adata.obs[sample_col].values == sample_name).sum())
303
+ label = f"{sample_name}\n(n={n_reads_grp})"
304
+ fig.text(0.02, ycenter, label, va='center', ha='left', rotation='vertical', fontsize=9)
305
+
306
+ fig.suptitle("Spatial autocorrelation by sample × (site_type × reference)", y=0.995, fontsize=11)
307
+
308
+ page_idx = start_idx // rows_per_fig + 1
309
+ out_png = os.path.join(out_dir, f"{filename_prefix}_page{page_idx}.png")
310
+ plt.savefig(out_png, bbox_inches="tight")
311
+ plt.close(fig)
312
+ saved_pages.append(out_png)
313
+
314
+ # persist any metrics we computed via fallback into adata.uns
315
+ for site, d in metrics_cache.items():
316
+ if d:
317
+ adata.uns[f"{site}_spatial_periodicity_metrics_by_group"] = d
318
+
319
+ # ---------------------------
320
+ # Write combined CSV + per-sample/ref CSVs
321
+ # ---------------------------
322
+ csv_dir = os.path.join(out_dir, "periodicity_csvs")
323
+ os.makedirs(csv_dir, exist_ok=True)
324
+
325
+ # include combined ('all') as a reference group for convenience
326
+ ref_values = list(references) + ["all"]
327
+
328
+ combined_rows = []
329
+
330
+ for sample_name in sample_levels:
331
+ for ref in ref_values:
332
+ rows = []
333
+ for site in site_types:
334
+ key = (sample_name, None) if ref == "all" else (sample_name, ref)
335
+ metrics_by_group_key = f"{site}_spatial_periodicity_metrics_by_group"
336
+ group_dict = adata.uns.get(metrics_by_group_key, None)
337
+ entry = None
338
+ if group_dict is not None:
339
+ entry = group_dict.get(key, None)
340
+
341
+ def to_list(x):
342
+ """
343
+ Normalize x to a Python list:
344
+ - None -> []
345
+ - list/tuple -> list(x)
346
+ - numpy array -> arr.tolist()
347
+ - scalar -> [scalar]
348
+ - string -> [string] (preserve)
349
+ """
350
+ if x is None:
351
+ return []
352
+ if isinstance(x, (list, tuple)):
353
+ return list(x)
354
+ # treat strings separately to avoid splitting into characters
355
+ if isinstance(x, str):
356
+ return [x]
357
+ try:
358
+ arr = np.asarray(x)
359
+ except Exception:
360
+ return [x]
361
+ # numpy scalars -> 0-dim arrays
362
+ if arr.ndim == 0:
363
+ return [arr.item()]
364
+ # convert to python list
365
+ return arr.tolist()
366
+
367
+ def _safe_float(x):
368
+ try:
369
+ return float(x)
370
+ except Exception:
371
+ return float("nan")
372
+
373
+ # --- inside your combined CSV loop, replace the envelope handling with this ---
374
+ env_lags_raw = entry.get("envelope_sample_lags", []) if entry is not None else []
375
+ env_heights_raw = entry.get("envelope_heights", []) if entry is not None else []
376
+
377
+ env_lags_list = to_list(env_lags_raw)
378
+ env_heights_list = to_list(env_heights_raw)
379
+
380
+ row = {
381
+ "site": site,
382
+ "sample": sample_name,
383
+ "reference": ref,
384
+ "nrl_bp": _safe_float(entry.get("nrl_bp", float("nan"))) if entry is not None else float("nan"),
385
+ "snr": _safe_float(entry.get("snr", float("nan"))) if entry is not None else float("nan"),
386
+ "fwhm_bp": _safe_float(entry.get("fwhm_bp", float("nan"))) if entry is not None else float("nan"),
387
+ "xi": _safe_float(entry.get("xi", float("nan"))) if entry is not None else float("nan"),
388
+ "xi_A": _safe_float(entry.get("xi_A", float("nan"))) if entry is not None else float("nan"),
389
+ "xi_r2": _safe_float(entry.get("xi_r2", float("nan"))) if entry is not None else float("nan"),
390
+ "envelope_sample_lags": ";".join(map(str, env_lags_list)) if len(env_lags_list) else "",
391
+ "envelope_heights": ";".join(map(str, env_heights_list)) if len(env_heights_list) else "",
392
+ "analyzer_error": entry.get("error", entry.get("analyzer_error", None)) if entry is not None else "no_metrics",
393
+ }
394
+ rows.append(row)
395
+ combined_rows.append(row)
396
+
397
+ # write per-(sample,ref) CSV
398
+ df_group = pd.DataFrame(rows)
399
+ safe_sample = str(sample_name).replace(os.sep, "_")
400
+ safe_ref = str(ref).replace(os.sep, "_")
401
+ out_csv = os.path.join(csv_dir, f"{safe_sample}__{safe_ref}__periodicity_metrics.csv")
402
+ try:
403
+ df_group.to_csv(out_csv, index=False)
404
+ except Exception as e:
405
+ # don't fail the whole pipeline for a single write error; log and continue
406
+ import warnings
407
+ warnings.warn(f"Failed to write {out_csv}: {e}")
408
+
409
+ # write the single combined CSV (one row per sample x ref x site)
410
+ combined_df = pd.DataFrame(combined_rows)
411
+ combined_out = os.path.join(out_dir, "periodicity_metrics_combined.csv")
412
+ try:
413
+ combined_df.to_csv(combined_out, index=False)
414
+ except Exception as e:
415
+ import warnings
416
+ warnings.warn(f"Failed to write combined CSV {combined_out}: {e}")
417
+
418
+ return saved_pages
419
+
420
+ def plot_rolling_metrics(df, out_png=None, title=None, figsize=(10, 3.5), dpi=160, show=False):
421
+ """
422
+ Plot NRL and SNR vs window center from the dataframe returned by rolling_autocorr_metrics.
423
+ If out_png is None, returns the matplotlib Figure object; otherwise saves PNG and returns path.
424
+ """
425
+ import matplotlib.pyplot as plt
426
+ # sort by center
427
+ df2 = df.sort_values("center")
428
+ x = df2["center"].values
429
+ fig, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize, dpi=dpi, sharex=True)
430
+
431
+ axes[0].plot(x, df2["nrl_bp"].values, marker="o", lw=1)
432
+ axes[0].set_xlabel("Window center (bp)")
433
+ axes[0].set_ylabel("NRL (bp)")
434
+ axes[0].grid(True, alpha=0.2)
435
+
436
+ axes[1].plot(x, df2["snr"].values, marker="o", lw=1, color="C3")
437
+ axes[1].set_xlabel("Window center (bp)")
438
+ axes[1].set_ylabel("SNR")
439
+ axes[1].grid(True, alpha=0.2)
440
+
441
+ if title:
442
+ fig.suptitle(title, y=1.02)
443
+
444
+ fig.tight_layout()
445
+
446
+ if out_png:
447
+ fig.savefig(out_png, bbox_inches="tight")
448
+ if not show:
449
+ import matplotlib
450
+ matplotlib.pyplot.close(fig)
451
+ return out_png
452
+ if not show:
453
+ import matplotlib
454
+ matplotlib.pyplot.close(fig)
455
+ return fig
456
+
457
+ import numpy as np
458
+ import pandas as pd
459
+
460
+ def plot_rolling_grid(
461
+ rolling_dict,
462
+ out_dir,
463
+ site,
464
+ metrics=("nrl_bp", "snr", "xi"),
465
+ sample_order=None,
466
+ reference_order=None,
467
+ rows_per_page: int = 6,
468
+ cols_per_page: int = None,
469
+ dpi: int = 160,
470
+ figsize_per_panel=(3.5, 2.2),
471
+ per_metric_ylim: dict = None,
472
+ filename_prefix: str = "rolling_grid",
473
+ metric_display_names: dict = None,
474
+ ):
475
+ """
476
+ Plot rolling metrics in a grid, creating a separate paginated page-set for each metric.
477
+
478
+ Parameters
479
+ ----------
480
+ rolling_dict : dict
481
+ mapping (sample, ref) -> DataFrame (must contain 'center' and metric columns).
482
+ Keys may use `None` for combined/"all" reference.
483
+ out_dir : str
484
+ site : str
485
+ metrics : sequence[str]
486
+ list of metric column names to plot. One page-set per metric will be written.
487
+ sample_order, reference_order : optional lists for ordering (values as in keys)
488
+ rows_per_page : int
489
+ number of sample rows per page.
490
+ cols_per_page : int | None
491
+ number of columns per page (defaults to number of unique refs).
492
+ figsize_per_panel : (w,h) for each subplot panel.
493
+ per_metric_ylim : dict or None
494
+ optional mapping metric -> (ymin,ymax) to force consistent y-limits for that metric.
495
+ If absent, y-limits are autoscaled per page.
496
+ filename_prefix : str
497
+ metric_display_names : dict or None
498
+ optional mapping metric -> friendly label for y-axis/title.
499
+
500
+ Returns
501
+ -------
502
+ pages_by_metric : dict mapping metric -> [out_png_paths]
503
+ """
504
+ import os
505
+ import math
506
+ import matplotlib.pyplot as plt
507
+ import numpy as np
508
+ import pandas as pd
509
+
510
+ if per_metric_ylim is None:
511
+ per_metric_ylim = {}
512
+ if metric_display_names is None:
513
+ metric_display_names = {}
514
+
515
+ os.makedirs(out_dir, exist_ok=True)
516
+
517
+ keys = list(rolling_dict.keys())
518
+ if not keys:
519
+ raise ValueError("rolling_dict is empty")
520
+
521
+ # normalize reference labels and keep mapping to original
522
+ label_to_orig = {}
523
+ for (_sample, ref) in keys:
524
+ label = "all" if (ref is None) else str(ref)
525
+ if label not in label_to_orig:
526
+ label_to_orig[label] = ref
527
+
528
+ # sample ordering
529
+ all_samples = sorted({k[0] for k in keys}, key=lambda x: str(x))
530
+ sample_list = [s for s in (sample_order or all_samples) if s in all_samples]
531
+
532
+ # reference labels ordering
533
+ default_ref_labels = sorted(label_to_orig.keys(), key=lambda s: s)
534
+ if reference_order is not None:
535
+ ref_labels = [("all" if r is None else str(r)) for r in reference_order if (("all" if r is None else str(r)) in label_to_orig)]
536
+ else:
537
+ ref_labels = default_ref_labels
538
+
539
+ ncols_total = len(ref_labels)
540
+ if cols_per_page is None:
541
+ cols_per_page = ncols_total
542
+
543
+ pages_by_metric = {}
544
+
545
+ # for each metric produce pages
546
+ for metric in metrics:
547
+ saved_pages = []
548
+ display_name = metric_display_names.get(metric, metric)
549
+
550
+ # paginate samples
551
+ for start in range(0, len(sample_list), rows_per_page):
552
+ page_samples = sample_list[start : start + rows_per_page]
553
+ nrows = len(page_samples)
554
+
555
+ fig, axes = plt.subplots(
556
+ nrows=nrows, ncols=cols_per_page,
557
+ figsize=(figsize_per_panel[0] * cols_per_page, figsize_per_panel[1] * nrows),
558
+ dpi=dpi, squeeze=False
559
+ )
560
+
561
+ for i, sample in enumerate(page_samples):
562
+ for j in range(cols_per_page):
563
+ ax = axes[i, j]
564
+ if j >= len(ref_labels):
565
+ ax.axis("off")
566
+ continue
567
+
568
+ label = ref_labels[j]
569
+ orig_ref = label_to_orig.get(label, None)
570
+ key = (sample, orig_ref)
571
+ df = rolling_dict.get(key, None)
572
+
573
+ ax.set_title(f"{sample} | {label}", fontsize=8)
574
+
575
+ if df is None or df.empty or (metric not in df.columns):
576
+ ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=8)
577
+ ax.set_xticks([])
578
+ ax.set_yticks([])
579
+ continue
580
+
581
+ df2 = df.sort_values("center")
582
+ x = df2["center"].values
583
+ y = df2[metric].values
584
+
585
+ ax.plot(x, y, lw=1, marker="o")
586
+ ax.set_xlabel("center (bp)", fontsize=7)
587
+ ax.set_ylabel(display_name, fontsize=7)
588
+ ax.grid(True, alpha=0.18)
589
+
590
+ # apply per-metric y-lim if provided
591
+ if metric in per_metric_ylim:
592
+ yl = per_metric_ylim[metric]
593
+ try:
594
+ ax.set_ylim(float(yl[0]), float(yl[1]))
595
+ except Exception:
596
+ pass
597
+
598
+ fig.suptitle(f"{site} — {display_name}", fontsize=10)
599
+ fig.tight_layout(rect=[0.03, 0.03, 1, 0.96])
600
+
601
+ page_idx = start // rows_per_page + 1
602
+ out_png = os.path.join(out_dir, f"{filename_prefix}_{site}_{metric}_page{page_idx}.png")
603
+ fig.savefig(out_png, bbox_inches="tight")
604
+ plt.close(fig)
605
+ saved_pages.append(out_png)
606
+
607
+ pages_by_metric[metric] = saved_pages
608
+
609
+ return pages_by_metric