smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,48 +1,72 @@
1
1
  # ------------------------- Utilities -------------------------
2
- import pandas as pd
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING
5
+
3
6
  import numpy as np
7
+ import pandas as pd
8
+ from numpy.fft import rfft, rfftfreq
9
+
10
+ if TYPE_CHECKING:
11
+ from numpy.typing import NDArray
12
+
13
+ from smftools.logging_utils import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+ # optional parallel backend
18
+ try:
19
+ from joblib import Parallel, delayed
20
+
21
+ _have_joblib = True
22
+ except Exception:
23
+ _have_joblib = False
24
+
25
+
26
+ # optionally use scipy for find_peaks (more robust)
27
+ try:
28
+ from scipy.signal import find_peaks
29
+
30
+ _have_scipy = True
31
+ except Exception:
32
+ _have_scipy = False
4
33
 
5
- def random_fill_nans(X):
34
+
35
+ def random_fill_nans(X: NDArray[np.floating]) -> NDArray[np.floating]:
36
+ """Fill NaNs with random values in-place.
37
+
38
+ Args:
39
+ X: Input array containing NaNs.
40
+
41
+ Returns:
42
+ numpy.ndarray: Array with NaNs replaced by random values.
43
+ """
6
44
  nan_mask = np.isnan(X)
7
45
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
8
46
  return X
9
47
 
48
+
10
49
  def binary_autocorrelation_with_spacing(
11
- row,
12
- positions,
13
- max_lag=1000,
14
- assume_sorted=True,
15
- normalize: str = "sum",
16
- return_counts: bool = False
50
+ row: NDArray[np.floating],
51
+ positions: NDArray[np.integer],
52
+ max_lag: int = 1000,
53
+ assume_sorted: bool = True,
54
+ normalize: str = "sum",
55
+ return_counts: bool = False,
17
56
  ):
18
- """
19
- Fast autocorrelation over real genomic spacing.
20
-
21
- Parameters
22
- ----------
23
- row : 1D array (float)
24
- Values per position (NaN = missing). Works for binary or real-valued.
25
- positions : 1D array (int)
26
- Genomic coordinates for each column of `row`.
27
- max_lag : int
28
- Max genomic lag (inclusive).
29
- assume_sorted : bool
30
- If True, assumes `positions` are strictly non-decreasing.
31
- normalize : {"sum", "pearson"}
32
- "sum": autocorr[l] = sum_{pairs at lag l} (xc_i * xc_j) / sum(xc^2)
33
- (fast; comparable across lags and molecules).
34
- "pearson": autocorr[l] = (mean_{pairs at lag l} (xc_i * xc_j)) / (mean(xc^2))
35
- i.e., an estimate of Pearson-like correlation at that lag.
36
- return_counts : bool
37
- If True, return (autocorr, lag_counts). Otherwise just autocorr.
38
-
39
- Returns
40
- -------
41
- autocorr : 1D array, shape (max_lag+1,)
42
- Normalized autocorrelation; autocorr[0] = 1.0.
43
- Lags with no valid pairs are NaN.
44
- (optionally) lag_counts : 1D array, shape (max_lag+1,)
45
- Number of pairs contributing to each lag.
57
+ """Compute autocorrelation over genomic spacing.
58
+
59
+ Args:
60
+ row: Values per position (NaN = missing).
61
+ positions: Genomic coordinates for each column of ``row``.
62
+ max_lag: Max genomic lag (inclusive).
63
+ assume_sorted: Whether ``positions`` are sorted.
64
+ normalize: ``"sum"`` or ``"pearson"`` normalization.
65
+ return_counts: Whether to return lag counts alongside autocorrelation.
66
+
67
+ Returns:
68
+ numpy.ndarray | tuple[numpy.ndarray, numpy.ndarray]: Autocorrelation values and
69
+ optionally counts per lag.
46
70
  """
47
71
 
48
72
  # mask valid entries
@@ -82,12 +106,12 @@ def binary_autocorrelation_with_spacing(
82
106
  j += 1
83
107
  # consider pairs (i, i+1...j-1)
84
108
  if j - i > 1:
85
- diffs = pos[i+1:j] - pos[i] # 1..max_lag
86
- contrib = xc[i] * xc[i+1:j] # contributions for each pair
109
+ diffs = pos[i + 1 : j] - pos[i] # 1..max_lag
110
+ contrib = xc[i] * xc[i + 1 : j] # contributions for each pair
87
111
  # accumulate weighted sums and counts per lag
88
112
  # bincount returns length >= max(diffs)+1; we request minlength
89
- bc_vals = np.bincount(diffs, weights=contrib, minlength=max_lag+1)[:max_lag+1]
90
- bc_counts = np.bincount(diffs, minlength=max_lag+1)[:max_lag+1]
113
+ bc_vals = np.bincount(diffs, weights=contrib, minlength=max_lag + 1)[: max_lag + 1]
114
+ bc_counts = np.bincount(diffs, minlength=max_lag + 1)[: max_lag + 1]
91
115
  lag_sums += bc_vals
92
116
  lag_counts += bc_counts
93
117
 
@@ -113,20 +137,17 @@ def binary_autocorrelation_with_spacing(
113
137
  return autocorr.astype(np.float32, copy=False)
114
138
 
115
139
 
116
- from numpy.fft import rfft, rfftfreq
117
-
118
- # optionally use scipy for find_peaks (more robust)
119
- try:
120
- from scipy.signal import find_peaks
121
- _have_scipy = True
122
- except Exception:
123
- _have_scipy = False
124
-
125
140
  # ---------- helpers ----------
126
141
  def weighted_mean_autocorr(ac_matrix, counts_matrix, min_count=20):
127
- """
128
- Weighted mean across molecules: sum(ac * counts) / sum(counts) per lag.
129
- Mask lags with total counts < min_count (set NaN).
142
+ """Compute weighted mean autocorrelation per lag.
143
+
144
+ Args:
145
+ ac_matrix: Autocorrelation matrix per molecule.
146
+ counts_matrix: Pair counts per lag.
147
+ min_count: Minimum total count required to keep a lag.
148
+
149
+ Returns:
150
+ tuple[numpy.ndarray, numpy.ndarray]: Mean autocorrelation and total counts.
130
151
  """
131
152
  counts_total = counts_matrix.sum(axis=0)
132
153
  # replace NaNs in ac_matrix with 0 for weighted sum
@@ -138,7 +159,22 @@ def weighted_mean_autocorr(ac_matrix, counts_matrix, min_count=20):
138
159
  mean_ac[counts_total < min_count] = np.nan
139
160
  return mean_ac, counts_total
140
161
 
141
- def psd_from_autocorr(mean_ac, lags, pad_factor=4):
162
+
163
+ def psd_from_autocorr(
164
+ mean_ac: NDArray[np.floating],
165
+ lags: NDArray[np.floating],
166
+ pad_factor: int = 4,
167
+ ) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
168
+ """Compute a power spectral density from autocorrelation.
169
+
170
+ Args:
171
+ mean_ac: Mean autocorrelation values.
172
+ lags: Lag values in base pairs.
173
+ pad_factor: Padding factor for FFT resolution.
174
+
175
+ Returns:
176
+ tuple[numpy.ndarray, numpy.ndarray]: Frequencies and power values.
177
+ """
142
178
  n = len(mean_ac)
143
179
  pad_n = int(max(2**10, pad_factor * n)) # pad to at least some min to stabilize FFT res
144
180
  ac_padded = np.zeros(pad_n, dtype=np.float64)
@@ -149,7 +185,24 @@ def psd_from_autocorr(mean_ac, lags, pad_factor=4):
149
185
  freqs = rfftfreq(pad_n, d=df)
150
186
  return freqs, power
151
187
 
152
- def find_peak_in_nrl_band(freqs, power, nrl_search_bp=(120,260), prominence_frac=0.05):
188
+
189
+ def find_peak_in_nrl_band(
190
+ freqs: NDArray[np.floating],
191
+ power: NDArray[np.floating],
192
+ nrl_search_bp: tuple[int, int] = (120, 260),
193
+ prominence_frac: float = 0.05,
194
+ ) -> tuple[float | None, int | None]:
195
+ """Find the peak frequency in the nucleosome repeat length band.
196
+
197
+ Args:
198
+ freqs: Frequency bins.
199
+ power: Power values.
200
+ nrl_search_bp: Search band in base pairs.
201
+ prominence_frac: Fraction of peak power for prominence.
202
+
203
+ Returns:
204
+ tuple[float | None, int | None]: Peak frequency and index, or ``(None, None)``.
205
+ """
153
206
  fmin = 1.0 / nrl_search_bp[1]
154
207
  fmax = 1.0 / nrl_search_bp[0]
155
208
  band_mask = (freqs >= fmin) & (freqs <= fmax)
@@ -170,7 +223,22 @@ def find_peak_in_nrl_band(freqs, power, nrl_search_bp=(120,260), prominence_frac
170
223
  idx = band_indices[rel]
171
224
  return freqs[idx], idx
172
225
 
173
- def fwhm_freq_to_bp(freqs, power, peak_idx):
226
+
227
+ def fwhm_freq_to_bp(
228
+ freqs: NDArray[np.floating],
229
+ power: NDArray[np.floating],
230
+ peak_idx: int,
231
+ ) -> tuple[float, float, float]:
232
+ """Estimate FWHM in base pairs for a spectral peak.
233
+
234
+ Args:
235
+ freqs: Frequency bins.
236
+ power: Power values.
237
+ peak_idx: Index of the peak.
238
+
239
+ Returns:
240
+ tuple[float, float, float]: FWHM in bp and left/right frequencies.
241
+ """
174
242
  # find half power
175
243
  pk = power[peak_idx]
176
244
  half = pk / 2.0
@@ -182,39 +250,71 @@ def fwhm_freq_to_bp(freqs, power, peak_idx):
182
250
  if left == peak_idx:
183
251
  left_f = freqs[peak_idx]
184
252
  else:
185
- x0, x1 = freqs[left], freqs[left+1]
186
- y0, y1 = power[left], power[left+1]
187
- left_f = x0 if y1 == y0 else x0 + (half - y0)*(x1-x0)/(y1-y0)
253
+ x0, x1 = freqs[left], freqs[left + 1]
254
+ y0, y1 = power[left], power[left + 1]
255
+ left_f = x0 if y1 == y0 else x0 + (half - y0) * (x1 - x0) / (y1 - y0)
188
256
  # move right
189
257
  right = peak_idx
190
- while right < len(power)-1 and power[right] > half:
258
+ while right < len(power) - 1 and power[right] > half:
191
259
  right += 1
192
260
  if right == peak_idx:
193
261
  right_f = freqs[peak_idx]
194
262
  else:
195
- x0, x1 = freqs[right-1], freqs[right]
196
- y0, y1 = power[right-1], power[right]
197
- right_f = x1 if y1 == y0 else x0 + (half - y0)*(x1-x0)/(y1-y0)
263
+ x0, x1 = freqs[right - 1], freqs[right]
264
+ y0, y1 = power[right - 1], power[right]
265
+ right_f = x1 if y1 == y0 else x0 + (half - y0) * (x1 - x0) / (y1 - y0)
198
266
  # convert to bp approximating delta_NRL = |1/left_f - 1/right_f|
199
267
  left_NRL = 1.0 / right_f if right_f > 0 else np.nan
200
268
  right_NRL = 1.0 / left_f if left_f > 0 else np.nan
201
269
  fwhm_bp = abs(left_NRL - right_NRL)
202
270
  return fwhm_bp, left_f, right_f
203
271
 
204
- def estimate_snr(power, peak_idx, exclude_bins=5):
272
+
273
+ def estimate_snr(
274
+ power: NDArray[np.floating],
275
+ peak_idx: int,
276
+ exclude_bins: int = 5,
277
+ ) -> tuple[float, float, float]:
278
+ """Estimate signal-to-noise ratio around a spectral peak.
279
+
280
+ Args:
281
+ power: Power values.
282
+ peak_idx: Index of the peak.
283
+ exclude_bins: Bins to exclude around the peak when estimating background.
284
+
285
+ Returns:
286
+ tuple[float, float, float]: SNR, peak power, and background median.
287
+ """
205
288
  pk = power[peak_idx]
206
289
  mask = np.ones_like(power, dtype=bool)
207
- lo = max(0, peak_idx-exclude_bins)
208
- hi = min(len(power), peak_idx+exclude_bins+1)
290
+ lo = max(0, peak_idx - exclude_bins)
291
+ hi = min(len(power), peak_idx + exclude_bins + 1)
209
292
  mask[lo:hi] = False
210
293
  bg = power[mask]
211
294
  bg_med = np.median(bg) if bg.size else np.median(power)
212
295
  return pk / (bg_med if bg_med > 0 else np.finfo(float).eps), pk, bg_med
213
296
 
214
- def sample_autocorr_at_harmonics(mean_ac, lags, nrl_bp, max_harmonics=6):
297
+
298
+ def sample_autocorr_at_harmonics(
299
+ mean_ac: NDArray[np.floating],
300
+ lags: NDArray[np.floating],
301
+ nrl_bp: float,
302
+ max_harmonics: int = 6,
303
+ ) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
304
+ """Sample autocorrelation heights at NRL harmonics.
305
+
306
+ Args:
307
+ mean_ac: Mean autocorrelation values.
308
+ lags: Lag values in base pairs.
309
+ nrl_bp: NRL in base pairs.
310
+ max_harmonics: Maximum harmonics to sample.
311
+
312
+ Returns:
313
+ tuple[numpy.ndarray, numpy.ndarray]: Sampled lags and heights.
314
+ """
215
315
  sample_lags = []
216
316
  heights = []
217
- for m in range(1, max_harmonics+1):
317
+ for m in range(1, max_harmonics + 1):
218
318
  target = m * nrl_bp
219
319
  # stop if beyond observed lag range
220
320
  if target > lags[-1]:
@@ -227,7 +327,22 @@ def sample_autocorr_at_harmonics(mean_ac, lags, nrl_bp, max_harmonics=6):
227
327
  heights.append(h)
228
328
  return np.array(sample_lags), np.array(heights)
229
329
 
230
- def fit_exponential_envelope(sample_lags, heights, counts=None):
330
+
331
+ def fit_exponential_envelope(
332
+ sample_lags: NDArray[np.floating],
333
+ heights: NDArray[np.floating],
334
+ counts: NDArray[np.floating] | None = None,
335
+ ) -> tuple[float, float, float, float]:
336
+ """Fit an exponential envelope to sampled autocorrelation peaks.
337
+
338
+ Args:
339
+ sample_lags: Sampled lag values.
340
+ heights: Sampled autocorrelation heights.
341
+ counts: Optional weights per sample.
342
+
343
+ Returns:
344
+ tuple[float, float, float, float]: ``(xi, A, slope, r2)``.
345
+ """
231
346
  # heights ~ A * exp(-lag / xi)
232
347
  mask = (heights > 0) & np.isfinite(heights)
233
348
  if mask.sum() < 2:
@@ -238,7 +353,7 @@ def fit_exponential_envelope(sample_lags, heights, counts=None):
238
353
  w = np.ones_like(y)
239
354
  else:
240
355
  w = np.asarray(counts[mask], dtype=float)
241
- w = w / (np.max(w) if np.max(w)>0 else 1.0)
356
+ w = w / (np.max(w) if np.max(w) > 0 else 1.0)
242
357
  # weighted linear regression y = b0 + b1 * x
243
358
  X = np.vstack([np.ones_like(x), x]).T
244
359
  W = np.diag(w)
@@ -253,122 +368,147 @@ def fit_exponential_envelope(sample_lags, heights, counts=None):
253
368
  xi = -1.0 / b1 if b1 < 0 else np.nan
254
369
  # R^2
255
370
  y_pred = X.dot(b)
256
- ss_res = np.sum(w * (y - y_pred)**2)
257
- ss_tot = np.sum(w * (y - np.average(y, weights=w))**2)
258
- r2 = 1.0 - ss_res/ss_tot if ss_tot != 0 else np.nan
371
+ ss_res = np.sum(w * (y - y_pred) ** 2)
372
+ ss_tot = np.sum(w * (y - np.average(y, weights=w)) ** 2)
373
+ r2 = 1.0 - ss_res / ss_tot if ss_tot != 0 else np.nan
259
374
  return xi, A, b1, r2
260
375
 
376
+
261
377
  # ---------- main analysis per site_type ----------
262
- def analyze_autocorr_matrix(autocorr_matrix, counts_matrix, lags,
263
- nrl_search_bp=(120,260), pad_factor=4,
264
- min_count=20, max_harmonics=6):
265
- """
266
- Return dict: nrl_bp, peak_power, fwhm_bp, snr, xi, envelope points, freqs, power, mean_ac
378
+ def analyze_autocorr_matrix(
379
+ autocorr_matrix: NDArray[np.floating],
380
+ counts_matrix: NDArray[np.integer],
381
+ lags: NDArray[np.floating],
382
+ nrl_search_bp: tuple[int, int] = (120, 260),
383
+ pad_factor: int = 4,
384
+ min_count: int = 20,
385
+ max_harmonics: int = 6,
386
+ ):
387
+ """Analyze autocorrelation matrix and extract periodicity metrics.
388
+
389
+ Args:
390
+ autocorr_matrix: Autocorrelation values per molecule.
391
+ counts_matrix: Pair counts per lag.
392
+ lags: Lag values in base pairs.
393
+ nrl_search_bp: NRL search band in base pairs.
394
+ pad_factor: Padding factor for FFT.
395
+ min_count: Minimum total count to retain a lag.
396
+ max_harmonics: Maximum harmonics to sample.
397
+
398
+ Returns:
399
+ dict: Metrics including NRL, SNR, and PSD summaries.
267
400
  """
268
- mean_ac, counts_total = weighted_mean_autocorr(autocorr_matrix, counts_matrix, min_count=min_count)
401
+ mean_ac, counts_total = weighted_mean_autocorr(
402
+ autocorr_matrix, counts_matrix, min_count=min_count
403
+ )
269
404
  freqs, power = psd_from_autocorr(mean_ac, lags, pad_factor=pad_factor)
270
405
  f0, peak_idx = find_peak_in_nrl_band(freqs, power, nrl_search_bp=nrl_search_bp)
271
406
  if f0 is None:
272
- return {"error":"no_peak_found", "mean_ac":mean_ac, "counts":counts_total}
407
+ return {"error": "no_peak_found", "mean_ac": mean_ac, "counts": counts_total}
273
408
  nrl_bp = 1.0 / f0
274
409
  fwhm_bp, left_f, right_f = fwhm_freq_to_bp(freqs, power, peak_idx)
275
410
  snr, peak_power, bg = estimate_snr(power, peak_idx)
276
- sample_lags, heights = sample_autocorr_at_harmonics(mean_ac, lags, nrl_bp, max_harmonics=max_harmonics)
277
- xi, A, slope, r2 = fit_exponential_envelope(sample_lags, heights) if heights.size else (np.nan,)*4
411
+ sample_lags, heights = sample_autocorr_at_harmonics(
412
+ mean_ac, lags, nrl_bp, max_harmonics=max_harmonics
413
+ )
414
+ xi, A, slope, r2 = (
415
+ fit_exponential_envelope(sample_lags, heights) if heights.size else (np.nan,) * 4
416
+ )
278
417
 
279
418
  return dict(
280
- nrl_bp = nrl_bp,
281
- f0 = f0,
282
- peak_power = peak_power,
283
- fwhm_bp = fwhm_bp,
284
- snr = snr,
285
- bg_median = bg,
286
- envelope_sample_lags = sample_lags,
287
- envelope_heights = heights,
288
- xi = xi,
289
- xi_A = A,
290
- xi_slope = slope,
291
- xi_r2 = r2,
292
- freqs = freqs,
293
- power = power,
294
- mean_ac = mean_ac,
295
- counts = counts_total
419
+ nrl_bp=nrl_bp,
420
+ f0=f0,
421
+ peak_power=peak_power,
422
+ fwhm_bp=fwhm_bp,
423
+ snr=snr,
424
+ bg_median=bg,
425
+ envelope_sample_lags=sample_lags,
426
+ envelope_heights=heights,
427
+ xi=xi,
428
+ xi_A=A,
429
+ xi_slope=slope,
430
+ xi_r2=r2,
431
+ freqs=freqs,
432
+ power=power,
433
+ mean_ac=mean_ac,
434
+ counts=counts_total,
296
435
  )
297
436
 
437
+
298
438
  # ---------- bootstrap wrapper ----------
299
- def bootstrap_periodicity(autocorr_matrix, counts_matrix, lags, n_boot=200, **kwargs):
439
+ def bootstrap_periodicity(
440
+ autocorr_matrix: NDArray[np.floating],
441
+ counts_matrix: NDArray[np.integer],
442
+ lags: NDArray[np.floating],
443
+ n_boot: int = 200,
444
+ **kwargs,
445
+ ) -> dict:
446
+ """Bootstrap periodicity metrics from autocorrelation matrices.
447
+
448
+ Args:
449
+ autocorr_matrix: Autocorrelation matrix per molecule.
450
+ counts_matrix: Pair counts per lag.
451
+ lags: Lag values in base pairs.
452
+ n_boot: Number of bootstrap samples.
453
+ **kwargs: Additional arguments for ``analyze_autocorr_matrix``.
454
+
455
+ Returns:
456
+ dict: Bootstrapped metric arrays and per-iteration metrics.
457
+ """
300
458
  rng = np.random.default_rng()
301
459
  metrics = []
302
460
  n = autocorr_matrix.shape[0]
303
461
  for _ in range(n_boot):
304
462
  sample_idx = rng.integers(0, n, size=n)
305
- res = analyze_autocorr_matrix(autocorr_matrix[sample_idx], counts_matrix[sample_idx], lags, **kwargs)
463
+ res = analyze_autocorr_matrix(
464
+ autocorr_matrix[sample_idx], counts_matrix[sample_idx], lags, **kwargs
465
+ )
306
466
  metrics.append(res)
307
467
  # extract key fields robustly
308
468
  nrls = np.array([m.get("nrl_bp", np.nan) for m in metrics])
309
- xis = np.array([m.get("xi", np.nan) for m in metrics])
310
- return {"nrl_boot":nrls, "xi_boot":xis, "metrics":metrics}
311
-
469
+ xis = np.array([m.get("xi", np.nan) for m in metrics])
470
+ return {"nrl_boot": nrls, "xi_boot": xis, "metrics": metrics}
312
471
 
313
- # optional parallel backend
314
- try:
315
- from joblib import Parallel, delayed
316
- _have_joblib = True
317
- except Exception:
318
- _have_joblib = False
319
472
 
320
473
  def rolling_autocorr_metrics(
321
- X,
322
- positions,
323
- site_label: str = None,
474
+ X: NDArray[np.floating],
475
+ positions: NDArray[np.integer],
476
+ site_label: str | None = None,
324
477
  window_size: int = 2000,
325
478
  step: int = 500,
326
479
  max_lag: int = 800,
327
480
  min_molecules_per_window: int = 10,
328
- nrl_search_bp: tuple = (120, 260),
481
+ nrl_search_bp: tuple[int, int] = (120, 260),
329
482
  pad_factor: int = 4,
330
483
  min_count_for_mean: int = 20,
331
484
  max_harmonics: int = 6,
332
485
  n_jobs: int = 1,
333
486
  verbose: bool = False,
334
487
  return_window_results: bool = False,
335
- fixed_nrl_bp: float = None,
336
-
488
+ fixed_nrl_bp: float | None = None,
337
489
  ):
338
- """
339
- Slide a genomic window across `positions` and compute periodicity metrics per window.
340
-
341
- Parameters
342
- ----------
343
- X : array-like or sparse, shape (n_molecules, n_positions)
344
- Binary site matrix for a group (sample × reference × site_type).
345
- positions : 1D array-like of ints
346
- Genomic coordinates for columns of X (same length as X.shape[1]).
347
- site_label : optional str
348
- Label for the site type (used in returned dicts/df).
349
- window_size : int
350
- Window width in bp.
351
- step : int
352
- Slide step in bp.
353
- max_lag : int
354
- Max lag (bp) to compute autocorr out to.
355
- min_molecules_per_window : int
356
- Minimum molecules required to compute metrics for a window; otherwise metrics = NaN.
357
- nrl_search_bp, pad_factor, min_count_for_mean, max_harmonics : forwarded to analyze_autocorr_matrix
358
- n_jobs : int
359
- Number of parallel jobs (uses joblib if available).
360
- verbose : bool
361
- Print progress messages.
362
- return_window_results : bool
363
- If True, return also the per-window raw `analyze_autocorr_matrix` outputs.
364
-
365
- Returns
366
- -------
367
- df : pandas.DataFrame
368
- One row per window with columns:
369
- ['site', 'window_start', 'window_end', 'center', 'n_molecules',
370
- 'nrl_bp', 'snr', 'peak_power', 'fwhm_bp', 'xi', 'xi_A', 'xi_r2']
371
- (optionally) window_results : list of dicts (same order as df rows) when return_window_results=True
490
+ """Slide a genomic window across positions and compute periodicity metrics.
491
+
492
+ Args:
493
+ X: Binary site matrix for a group (sample × reference × site_type).
494
+ positions: Genomic coordinates for columns of ``X``.
495
+ site_label: Label for the site type.
496
+ window_size: Window width in bp.
497
+ step: Slide step in bp.
498
+ max_lag: Max lag (bp) to compute autocorr out to.
499
+ min_molecules_per_window: Minimum molecules required per window.
500
+ nrl_search_bp: NRL search band in base pairs.
501
+ pad_factor: Padding factor for FFT.
502
+ min_count_for_mean: Minimum count for mean autocorrelation.
503
+ max_harmonics: Maximum harmonics to sample.
504
+ n_jobs: Number of parallel jobs (joblib if available).
505
+ verbose: Whether to log progress.
506
+ return_window_results: Whether to return per-window analyzer outputs.
507
+ fixed_nrl_bp: If provided, use a fixed NRL in bp for analysis.
508
+
509
+ Returns:
510
+ pandas.DataFrame | tuple[pandas.DataFrame, list[dict]]: Window-level metrics,
511
+ with optional raw analyzer outputs.
372
512
  """
373
513
 
374
514
  # normalize inputs
@@ -386,10 +526,16 @@ def rolling_autocorr_metrics(
386
526
  window_starts = list(range(start, end - window_size + 1, step))
387
527
 
388
528
  if verbose:
389
- print(f"Rolling windows: {len(window_starts)} windows, window_size={window_size}, step={step}")
529
+ logger.info(
530
+ "Rolling windows: %s windows, window_size=%s, step=%s",
531
+ len(window_starts),
532
+ window_size,
533
+ step,
534
+ )
390
535
 
391
536
  # helper to extract row to dense 1D np array (supports sparse rows)
392
537
  def _row_to_arr(row):
538
+ """Convert a matrix row (dense or sparse) to a 1D NumPy array."""
393
539
  # handle scipy sparse row
394
540
  try:
395
541
  import scipy.sparse as sp
@@ -402,6 +548,7 @@ def rolling_autocorr_metrics(
402
548
 
403
549
  # function to process one window
404
550
  def _process_window(ws):
551
+ """Compute rolling-window autocorrelation metrics for a window start."""
405
552
  we = ws + window_size
406
553
  mask_pos = (pos >= ws) & (pos < we)
407
554
  if mask_pos.sum() < 2:
@@ -428,7 +575,9 @@ def rolling_autocorr_metrics(
428
575
  continue
429
576
  # compute autocorr on the windowed template; positions are pos[mask_pos]
430
577
  try:
431
- ac, cnts = binary_autocorrelation_with_spacing(subrow, pos[mask_pos], max_lag=max_lag, assume_sorted=True, return_counts=True)
578
+ ac, cnts = binary_autocorrelation_with_spacing(
579
+ subrow, pos[mask_pos], max_lag=max_lag, assume_sorted=True, return_counts=True
580
+ )
432
581
  except Exception:
433
582
  # if autocorr fails for this row, skip it
434
583
  continue
@@ -460,7 +609,9 @@ def rolling_autocorr_metrics(
460
609
 
461
610
  # If a fixed global NRL is provided, compute metrics around that frequency
462
611
  if fixed_nrl_bp is not None:
463
- freqs, power = psd_from_autocorr(mean_ac, np.arange(mean_ac.size), pad_factor=pad_factor)
612
+ freqs, power = psd_from_autocorr(
613
+ mean_ac, np.arange(mean_ac.size), pad_factor=pad_factor
614
+ )
464
615
  # locate nearest freq bin to target_freq
465
616
  target_f = 1.0 / float(fixed_nrl_bp)
466
617
  # mask valid freqs
@@ -477,28 +628,44 @@ def rolling_autocorr_metrics(
477
628
  snr_val, _, bg = estimate_snr(power, peak_idx, exclude_bins=3)
478
629
  # sample harmonics from mean_ac at integer-lag positions using fixed_nrl_bp
479
630
  # note: lags array is integer 0..(mean_ac.size-1)
480
- sample_lags, heights = sample_autocorr_at_harmonics(mean_ac, np.arange(mean_ac.size), fixed_nrl_bp, max_harmonics=max_harmonics)
481
- xi, A, slope, r2 = fit_exponential_envelope(sample_lags, heights) if heights.size else (np.nan, np.nan, np.nan, np.nan)
631
+ sample_lags, heights = sample_autocorr_at_harmonics(
632
+ mean_ac, np.arange(mean_ac.size), fixed_nrl_bp, max_harmonics=max_harmonics
633
+ )
634
+ xi, A, slope, r2 = (
635
+ fit_exponential_envelope(sample_lags, heights)
636
+ if heights.size
637
+ else (np.nan, np.nan, np.nan, np.nan)
638
+ )
482
639
  res = dict(
483
640
  nrl_bp=float(fixed_nrl_bp),
484
641
  f0=float(target_f),
485
642
  peak_power=peak_power,
486
- fwhm_bp=np.nan, # not robustly defined when using fixed freq (skip or compute small-band FWHM)
643
+ fwhm_bp=np.nan, # not robustly defined when using fixed freq (skip or compute small-band FWHM)
487
644
  snr=float(snr_val),
488
645
  bg_median=float(bg) if np.isfinite(bg) else np.nan,
489
646
  envelope_sample_lags=sample_lags,
490
647
  envelope_heights=heights,
491
- xi=xi, xi_A=A, xi_slope=slope, xi_r2=r2,
492
- freqs=freqs, power=power, mean_ac=mean_ac, counts=counts_total
648
+ xi=xi,
649
+ xi_A=A,
650
+ xi_slope=slope,
651
+ xi_r2=r2,
652
+ freqs=freqs,
653
+ power=power,
654
+ mean_ac=mean_ac,
655
+ counts=counts_total,
493
656
  )
494
657
  else:
495
658
  # existing behavior: call analyzer_fn
496
659
  try:
497
- res = analyze_autocorr_matrix(ac_mat, cnt_mat, np.arange(mean_ac.size),
498
- nrl_search_bp=nrl_search_bp,
499
- pad_factor=pad_factor,
500
- min_count=min_count_for_mean,
501
- max_harmonics=max_harmonics)
660
+ res = analyze_autocorr_matrix(
661
+ ac_mat,
662
+ cnt_mat,
663
+ np.arange(mean_ac.size),
664
+ nrl_search_bp=nrl_search_bp,
665
+ pad_factor=pad_factor,
666
+ min_count=min_count_for_mean,
667
+ max_harmonics=max_harmonics,
668
+ )
502
669
  except Exception as e:
503
670
  res = {"error": str(e)}
504
671
 
@@ -524,39 +691,45 @@ def rolling_autocorr_metrics(
524
691
  metrics = r["metrics"]
525
692
  window_results.append(metrics)
526
693
  if metrics is None or ("error" in metrics and metrics.get("error") == "no_peak_found"):
527
- rows_out.append({
528
- "site": r["site"],
529
- "window_start": r["window_start"],
530
- "window_end": r["window_end"],
531
- "center": r["center"],
532
- "n_molecules": r["n_molecules"],
533
- "nrl_bp": np.nan,
534
- "snr": np.nan,
535
- "peak_power": np.nan,
536
- "fwhm_bp": np.nan,
537
- "xi": np.nan,
538
- "xi_A": np.nan,
539
- "xi_r2": np.nan,
540
- "analyzer_error": (metrics.get("error") if isinstance(metrics, dict) else "no_metrics"),
541
- })
694
+ rows_out.append(
695
+ {
696
+ "site": r["site"],
697
+ "window_start": r["window_start"],
698
+ "window_end": r["window_end"],
699
+ "center": r["center"],
700
+ "n_molecules": r["n_molecules"],
701
+ "nrl_bp": np.nan,
702
+ "snr": np.nan,
703
+ "peak_power": np.nan,
704
+ "fwhm_bp": np.nan,
705
+ "xi": np.nan,
706
+ "xi_A": np.nan,
707
+ "xi_r2": np.nan,
708
+ "analyzer_error": (
709
+ metrics.get("error") if isinstance(metrics, dict) else "no_metrics"
710
+ ),
711
+ }
712
+ )
542
713
  else:
543
- rows_out.append({
544
- "site": r["site"],
545
- "window_start": r["window_start"],
546
- "window_end": r["window_end"],
547
- "center": r["center"],
548
- "n_molecules": r["n_molecules"],
549
- "nrl_bp": float(metrics.get("nrl_bp", np.nan)),
550
- "snr": float(metrics.get("snr", np.nan)),
551
- "peak_power": float(metrics.get("peak_power", np.nan)),
552
- "fwhm_bp": float(metrics.get("fwhm_bp", np.nan)),
553
- "xi": float(metrics.get("xi", np.nan)),
554
- "xi_A": float(metrics.get("xi_A", np.nan)),
555
- "xi_r2": float(metrics.get("xi_r2", np.nan)),
556
- "analyzer_error": None,
557
- })
714
+ rows_out.append(
715
+ {
716
+ "site": r["site"],
717
+ "window_start": r["window_start"],
718
+ "window_end": r["window_end"],
719
+ "center": r["center"],
720
+ "n_molecules": r["n_molecules"],
721
+ "nrl_bp": float(metrics.get("nrl_bp", np.nan)),
722
+ "snr": float(metrics.get("snr", np.nan)),
723
+ "peak_power": float(metrics.get("peak_power", np.nan)),
724
+ "fwhm_bp": float(metrics.get("fwhm_bp", np.nan)),
725
+ "xi": float(metrics.get("xi", np.nan)),
726
+ "xi_A": float(metrics.get("xi_A", np.nan)),
727
+ "xi_r2": float(metrics.get("xi_r2", np.nan)),
728
+ "analyzer_error": None,
729
+ }
730
+ )
558
731
 
559
732
  df = pd.DataFrame(rows_out)
560
733
  if return_window_results:
561
734
  return df, window_results
562
- return df
735
+ return df