smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -67,4 +67,118 @@ def calculate_row_entropy(
67
67
  row_indices.extend(subset.obs_names.tolist())
68
68
 
69
69
  entropy_key = f"{output_key}_entropy"
70
- adata.obs.loc[row_indices, entropy_key] = entropy_values
70
+ adata.obs.loc[row_indices, entropy_key] = entropy_values
71
+
72
+ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sorted=True):
73
+ """
74
+ Fast autocorrelation over real genomic spacing.
75
+ Uses a sliding window + bincount to aggregate per-lag products.
76
+
77
+ Parameters
78
+ ----------
79
+ row : 1D array (float)
80
+ Values per position (NaN = missing). Works for binary or real-valued.
81
+ positions : 1D array (int)
82
+ Genomic coordinates for each column of `row`.
83
+ max_lag : int
84
+ Max genomic lag (inclusive).
85
+ assume_sorted : bool
86
+ If True, assumes `positions` are strictly non-decreasing.
87
+
88
+ Returns
89
+ -------
90
+ autocorr : 1D array, shape (max_lag+1,)
91
+ Normalized autocorrelation; autocorr[0] = 1.0.
92
+ Lags with no valid pairs are NaN.
93
+ """
94
+ import numpy as np
95
+
96
+ # mask valid entries
97
+ valid = ~np.isnan(row)
98
+ if valid.sum() < 2:
99
+ return np.full(max_lag + 1, np.nan, dtype=np.float32)
100
+
101
+ x = row[valid].astype(np.float64, copy=False)
102
+ pos = positions[valid].astype(np.int64, copy=False)
103
+
104
+ # sort by position if needed
105
+ if not assume_sorted:
106
+ order = np.argsort(pos, kind="mergesort")
107
+ pos = pos[order]
108
+ x = x[order]
109
+
110
+ n = x.size
111
+ x_mean = x.mean()
112
+ xc = x - x_mean
113
+ var = np.sum(xc * xc)
114
+ if var == 0.0:
115
+ return np.full(max_lag + 1, np.nan, dtype=np.float32)
116
+
117
+ lag_sums = np.zeros(max_lag + 1, dtype=np.float64)
118
+ lag_counts = np.zeros(max_lag + 1, dtype=np.int64)
119
+
120
+ # sliding window upper pointer
121
+ j = 1
122
+ for i in range(n - 1):
123
+ # advance j to include all positions within max_lag
124
+ while j < n and pos[j] - pos[i] <= max_lag:
125
+ j += 1
126
+ # consider pairs (i, i+1...j-1)
127
+ if j - i > 1:
128
+ diffs = pos[i+1:j] - pos[i] # 1..max_lag
129
+ contrib = xc[i] * xc[i+1:j] # contributions for each pair
130
+ # accumulate weighted sums and counts per lag
131
+ lag_sums[:max_lag+1] += np.bincount(diffs, weights=contrib,
132
+ minlength=max_lag+1)[:max_lag+1]
133
+ lag_counts[:max_lag+1] += np.bincount(diffs,
134
+ minlength=max_lag+1)[:max_lag+1]
135
+
136
+ autocorr = np.full(max_lag + 1, np.nan, dtype=np.float64)
137
+ nz = lag_counts > 0
138
+ autocorr[nz] = lag_sums[nz] / var
139
+ autocorr[0] = 1.0 # by definition
140
+
141
+ return autocorr.astype(np.float32, copy=False)
142
+
143
+ # def binary_autocorrelation_with_spacing(row, positions, max_lag=1000):
144
+ # """
145
+ # Compute autocorrelation within a read using real genomic spacing from `positions`.
146
+ # Only valid (non-NaN) positions are considered.
147
+ # Output is indexed by genomic lag (up to max_lag).
148
+ # """
149
+ # from collections import defaultdict
150
+ # import numpy as np
151
+ # # Get valid positions and values
152
+ # valid_mask = ~np.isnan(row)
153
+ # x = row[valid_mask]
154
+ # pos = positions[valid_mask]
155
+ # n = len(x)
156
+
157
+ # if n < 2:
158
+ # return np.full(max_lag + 1, np.nan)
159
+
160
+ # x_mean = x.mean()
161
+ # var = np.sum((x - x_mean)**2)
162
+ # if var == 0:
163
+ # return np.full(max_lag + 1, np.nan)
164
+
165
+ # # Collect values by lag
166
+ # lag_sums = defaultdict(float)
167
+ # lag_counts = defaultdict(int)
168
+
169
+ # for i in range(n):
170
+ # for j in range(i + 1, n):
171
+ # lag = abs(pos[j] - pos[i])
172
+ # if lag > max_lag:
173
+ # continue
174
+ # product = (x[i] - x_mean) * (x[j] - x_mean)
175
+ # lag_sums[lag] += product
176
+ # lag_counts[lag] += 1
177
+
178
+ # # Normalize to get autocorrelation
179
+ # autocorr = np.full(max_lag + 1, np.nan)
180
+ # for lag in range(max_lag + 1):
181
+ # if lag_counts[lag] > 0:
182
+ # autocorr[lag] = lag_sums[lag] / var
183
+
184
+ # return autocorr
@@ -0,0 +1,562 @@
1
+ # ------------------------- Utilities -------------------------
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ def random_fill_nans(X):
6
+ nan_mask = np.isnan(X)
7
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
8
+ return X
9
+
10
+ def binary_autocorrelation_with_spacing(
11
+ row,
12
+ positions,
13
+ max_lag=1000,
14
+ assume_sorted=True,
15
+ normalize: str = "sum",
16
+ return_counts: bool = False
17
+ ):
18
+ """
19
+ Fast autocorrelation over real genomic spacing.
20
+
21
+ Parameters
22
+ ----------
23
+ row : 1D array (float)
24
+ Values per position (NaN = missing). Works for binary or real-valued.
25
+ positions : 1D array (int)
26
+ Genomic coordinates for each column of `row`.
27
+ max_lag : int
28
+ Max genomic lag (inclusive).
29
+ assume_sorted : bool
30
+ If True, assumes `positions` are strictly non-decreasing.
31
+ normalize : {"sum", "pearson"}
32
+ "sum": autocorr[l] = sum_{pairs at lag l} (xc_i * xc_j) / sum(xc^2)
33
+ (fast; comparable across lags and molecules).
34
+ "pearson": autocorr[l] = (mean_{pairs at lag l} (xc_i * xc_j)) / (mean(xc^2))
35
+ i.e., an estimate of Pearson-like correlation at that lag.
36
+ return_counts : bool
37
+ If True, return (autocorr, lag_counts). Otherwise just autocorr.
38
+
39
+ Returns
40
+ -------
41
+ autocorr : 1D array, shape (max_lag+1,)
42
+ Normalized autocorrelation; autocorr[0] = 1.0.
43
+ Lags with no valid pairs are NaN.
44
+ (optionally) lag_counts : 1D array, shape (max_lag+1,)
45
+ Number of pairs contributing to each lag.
46
+ """
47
+
48
+ # mask valid entries
49
+ valid = ~np.isnan(row)
50
+ if valid.sum() < 2:
51
+ out = np.full(max_lag + 1, np.nan, dtype=np.float32)
52
+ return (out, np.zeros_like(out, dtype=int)) if return_counts else out
53
+
54
+ x = row[valid].astype(np.float64, copy=False)
55
+ pos = positions[valid].astype(np.int64, copy=False)
56
+
57
+ # sort by position if needed
58
+ if not assume_sorted:
59
+ order = np.argsort(pos, kind="mergesort")
60
+ pos = pos[order]
61
+ x = x[order]
62
+
63
+ n = x.size
64
+ x_mean = x.mean()
65
+ xc = x - x_mean
66
+ sum_xc2 = np.sum(xc * xc)
67
+ if sum_xc2 == 0.0:
68
+ out = np.full(max_lag + 1, np.nan, dtype=np.float32)
69
+ return (out, np.zeros_like(out, dtype=int)) if return_counts else out
70
+
71
+ lag_sums = np.zeros(max_lag + 1, dtype=np.float64)
72
+ lag_counts = np.zeros(max_lag + 1, dtype=np.int64)
73
+
74
+ # sliding window upper pointer
75
+ j = 1
76
+ for i in range(n - 1):
77
+ # ensure j starts at least i+1 (important correctness)
78
+ if j <= i:
79
+ j = i + 1
80
+ # advance j to include all positions within max_lag
81
+ while j < n and pos[j] - pos[i] <= max_lag:
82
+ j += 1
83
+ # consider pairs (i, i+1...j-1)
84
+ if j - i > 1:
85
+ diffs = pos[i+1:j] - pos[i] # 1..max_lag
86
+ contrib = xc[i] * xc[i+1:j] # contributions for each pair
87
+ # accumulate weighted sums and counts per lag
88
+ # bincount returns length >= max(diffs)+1; we request minlength
89
+ bc_vals = np.bincount(diffs, weights=contrib, minlength=max_lag+1)[:max_lag+1]
90
+ bc_counts = np.bincount(diffs, minlength=max_lag+1)[:max_lag+1]
91
+ lag_sums += bc_vals
92
+ lag_counts += bc_counts
93
+
94
+ autocorr = np.full(max_lag + 1, np.nan, dtype=np.float64)
95
+ nz = lag_counts > 0
96
+
97
+ if normalize == "sum":
98
+ # matches your original: sum_pairs / sum_xc2
99
+ autocorr[nz] = lag_sums[nz] / sum_xc2
100
+ elif normalize == "pearson":
101
+ # (mean of pairwise products) / (mean(xc^2)) -> more like correlation coeff
102
+ mean_pair = lag_sums[nz] / lag_counts[nz]
103
+ mean_var = sum_xc2 / n
104
+ autocorr[nz] = mean_pair / mean_var
105
+ else:
106
+ raise ValueError("normalize must be 'sum' or 'pearson'")
107
+
108
+ # define lag 0 as exactly 1.0 (by definition)
109
+ autocorr[0] = 1.0
110
+
111
+ if return_counts:
112
+ return autocorr.astype(np.float32, copy=False), lag_counts
113
+ return autocorr.astype(np.float32, copy=False)
114
+
115
+
116
+ from numpy.fft import rfft, rfftfreq
117
+
118
+ # optionally use scipy for find_peaks (more robust)
119
+ try:
120
+ from scipy.signal import find_peaks
121
+ _have_scipy = True
122
+ except Exception:
123
+ _have_scipy = False
124
+
125
+ # ---------- helpers ----------
126
+ def weighted_mean_autocorr(ac_matrix, counts_matrix, min_count=20):
127
+ """
128
+ Weighted mean across molecules: sum(ac * counts) / sum(counts) per lag.
129
+ Mask lags with total counts < min_count (set NaN).
130
+ """
131
+ counts_total = counts_matrix.sum(axis=0)
132
+ # replace NaNs in ac_matrix with 0 for weighted sum
133
+ filled = np.where(np.isfinite(ac_matrix), ac_matrix, 0.0)
134
+ s = (filled * counts_matrix).sum(axis=0)
135
+ with np.errstate(invalid="ignore", divide="ignore"):
136
+ mean_ac = np.where(counts_total > 0, s / counts_total, np.nan)
137
+ # mask low support
138
+ mean_ac[counts_total < min_count] = np.nan
139
+ return mean_ac, counts_total
140
+
141
+ def psd_from_autocorr(mean_ac, lags, pad_factor=4):
142
+ n = len(mean_ac)
143
+ pad_n = int(max(2**10, pad_factor * n)) # pad to at least some min to stabilize FFT res
144
+ ac_padded = np.zeros(pad_n, dtype=np.float64)
145
+ ac_padded[:n] = np.where(np.isfinite(mean_ac), mean_ac, 0.0)
146
+ A = rfft(ac_padded)
147
+ power = np.abs(A) ** 2
148
+ df = (lags[1] - lags[0]) if len(lags) > 1 else 1.0
149
+ freqs = rfftfreq(pad_n, d=df)
150
+ return freqs, power
151
+
152
+ def find_peak_in_nrl_band(freqs, power, nrl_search_bp=(120,260), prominence_frac=0.05):
153
+ fmin = 1.0 / nrl_search_bp[1]
154
+ fmax = 1.0 / nrl_search_bp[0]
155
+ band_mask = (freqs >= fmin) & (freqs <= fmax)
156
+ if not np.any(band_mask):
157
+ return None, None
158
+ freqs_band = freqs[band_mask]
159
+ power_band = power[band_mask]
160
+ if _have_scipy:
161
+ prom = max(np.max(power_band) * prominence_frac, 1e-12)
162
+ peaks, props = find_peaks(power_band, prominence=prom)
163
+ if peaks.size:
164
+ rel = peaks[np.argmax(power_band[peaks])]
165
+ else:
166
+ rel = int(np.argmax(power_band))
167
+ else:
168
+ rel = int(np.argmax(power_band))
169
+ band_indices = np.nonzero(band_mask)[0]
170
+ idx = band_indices[rel]
171
+ return freqs[idx], idx
172
+
173
+ def fwhm_freq_to_bp(freqs, power, peak_idx):
174
+ # find half power
175
+ pk = power[peak_idx]
176
+ half = pk / 2.0
177
+ # move left
178
+ left = peak_idx
179
+ while left > 0 and power[left] > half:
180
+ left -= 1
181
+ # left interpolation
182
+ if left == peak_idx:
183
+ left_f = freqs[peak_idx]
184
+ else:
185
+ x0, x1 = freqs[left], freqs[left+1]
186
+ y0, y1 = power[left], power[left+1]
187
+ left_f = x0 if y1 == y0 else x0 + (half - y0)*(x1-x0)/(y1-y0)
188
+ # move right
189
+ right = peak_idx
190
+ while right < len(power)-1 and power[right] > half:
191
+ right += 1
192
+ if right == peak_idx:
193
+ right_f = freqs[peak_idx]
194
+ else:
195
+ x0, x1 = freqs[right-1], freqs[right]
196
+ y0, y1 = power[right-1], power[right]
197
+ right_f = x1 if y1 == y0 else x0 + (half - y0)*(x1-x0)/(y1-y0)
198
+ # convert to bp approximating delta_NRL = |1/left_f - 1/right_f|
199
+ left_NRL = 1.0 / right_f if right_f > 0 else np.nan
200
+ right_NRL = 1.0 / left_f if left_f > 0 else np.nan
201
+ fwhm_bp = abs(left_NRL - right_NRL)
202
+ return fwhm_bp, left_f, right_f
203
+
204
+ def estimate_snr(power, peak_idx, exclude_bins=5):
205
+ pk = power[peak_idx]
206
+ mask = np.ones_like(power, dtype=bool)
207
+ lo = max(0, peak_idx-exclude_bins)
208
+ hi = min(len(power), peak_idx+exclude_bins+1)
209
+ mask[lo:hi] = False
210
+ bg = power[mask]
211
+ bg_med = np.median(bg) if bg.size else np.median(power)
212
+ return pk / (bg_med if bg_med > 0 else np.finfo(float).eps), pk, bg_med
213
+
214
+ def sample_autocorr_at_harmonics(mean_ac, lags, nrl_bp, max_harmonics=6):
215
+ sample_lags = []
216
+ heights = []
217
+ for m in range(1, max_harmonics+1):
218
+ target = m * nrl_bp
219
+ # stop if beyond observed lag range
220
+ if target > lags[-1]:
221
+ break
222
+ idx = np.argmin(np.abs(lags - target))
223
+ h = mean_ac[idx]
224
+ if not np.isfinite(h):
225
+ break
226
+ sample_lags.append(lags[idx])
227
+ heights.append(h)
228
+ return np.array(sample_lags), np.array(heights)
229
+
230
+ def fit_exponential_envelope(sample_lags, heights, counts=None):
231
+ # heights ~ A * exp(-lag / xi)
232
+ mask = (heights > 0) & np.isfinite(heights)
233
+ if mask.sum() < 2:
234
+ return np.nan, np.nan, np.nan, np.nan
235
+ x = sample_lags[mask].astype(float)
236
+ y = np.log(heights[mask].astype(float))
237
+ if counts is None:
238
+ w = np.ones_like(y)
239
+ else:
240
+ w = np.asarray(counts[mask], dtype=float)
241
+ w = w / (np.max(w) if np.max(w)>0 else 1.0)
242
+ # weighted linear regression y = b0 + b1 * x
243
+ X = np.vstack([np.ones_like(x), x]).T
244
+ W = np.diag(w)
245
+ XtWX = X.T.dot(W).dot(X)
246
+ XtWy = X.T.dot(W).dot(y)
247
+ try:
248
+ b = np.linalg.solve(XtWX, XtWy)
249
+ except np.linalg.LinAlgError:
250
+ return np.nan, np.nan, np.nan, np.nan
251
+ b0, b1 = b
252
+ A = np.exp(b0)
253
+ xi = -1.0 / b1 if b1 < 0 else np.nan
254
+ # R^2
255
+ y_pred = X.dot(b)
256
+ ss_res = np.sum(w * (y - y_pred)**2)
257
+ ss_tot = np.sum(w * (y - np.average(y, weights=w))**2)
258
+ r2 = 1.0 - ss_res/ss_tot if ss_tot != 0 else np.nan
259
+ return xi, A, b1, r2
260
+
261
+ # ---------- main analysis per site_type ----------
262
+ def analyze_autocorr_matrix(autocorr_matrix, counts_matrix, lags,
263
+ nrl_search_bp=(120,260), pad_factor=4,
264
+ min_count=20, max_harmonics=6):
265
+ """
266
+ Return dict: nrl_bp, peak_power, fwhm_bp, snr, xi, envelope points, freqs, power, mean_ac
267
+ """
268
+ mean_ac, counts_total = weighted_mean_autocorr(autocorr_matrix, counts_matrix, min_count=min_count)
269
+ freqs, power = psd_from_autocorr(mean_ac, lags, pad_factor=pad_factor)
270
+ f0, peak_idx = find_peak_in_nrl_band(freqs, power, nrl_search_bp=nrl_search_bp)
271
+ if f0 is None:
272
+ return {"error":"no_peak_found", "mean_ac":mean_ac, "counts":counts_total}
273
+ nrl_bp = 1.0 / f0
274
+ fwhm_bp, left_f, right_f = fwhm_freq_to_bp(freqs, power, peak_idx)
275
+ snr, peak_power, bg = estimate_snr(power, peak_idx)
276
+ sample_lags, heights = sample_autocorr_at_harmonics(mean_ac, lags, nrl_bp, max_harmonics=max_harmonics)
277
+ xi, A, slope, r2 = fit_exponential_envelope(sample_lags, heights) if heights.size else (np.nan,)*4
278
+
279
+ return dict(
280
+ nrl_bp = nrl_bp,
281
+ f0 = f0,
282
+ peak_power = peak_power,
283
+ fwhm_bp = fwhm_bp,
284
+ snr = snr,
285
+ bg_median = bg,
286
+ envelope_sample_lags = sample_lags,
287
+ envelope_heights = heights,
288
+ xi = xi,
289
+ xi_A = A,
290
+ xi_slope = slope,
291
+ xi_r2 = r2,
292
+ freqs = freqs,
293
+ power = power,
294
+ mean_ac = mean_ac,
295
+ counts = counts_total
296
+ )
297
+
298
+ # ---------- bootstrap wrapper ----------
299
+ def bootstrap_periodicity(autocorr_matrix, counts_matrix, lags, n_boot=200, **kwargs):
300
+ rng = np.random.default_rng()
301
+ metrics = []
302
+ n = autocorr_matrix.shape[0]
303
+ for _ in range(n_boot):
304
+ sample_idx = rng.integers(0, n, size=n)
305
+ res = analyze_autocorr_matrix(autocorr_matrix[sample_idx], counts_matrix[sample_idx], lags, **kwargs)
306
+ metrics.append(res)
307
+ # extract key fields robustly
308
+ nrls = np.array([m.get("nrl_bp", np.nan) for m in metrics])
309
+ xis = np.array([m.get("xi", np.nan) for m in metrics])
310
+ return {"nrl_boot":nrls, "xi_boot":xis, "metrics":metrics}
311
+
312
+
313
+ # optional parallel backend
314
+ try:
315
+ from joblib import Parallel, delayed
316
+ _have_joblib = True
317
+ except Exception:
318
+ _have_joblib = False
319
+
320
+ def rolling_autocorr_metrics(
321
+ X,
322
+ positions,
323
+ site_label: str = None,
324
+ window_size: int = 2000,
325
+ step: int = 500,
326
+ max_lag: int = 800,
327
+ min_molecules_per_window: int = 10,
328
+ nrl_search_bp: tuple = (120, 260),
329
+ pad_factor: int = 4,
330
+ min_count_for_mean: int = 20,
331
+ max_harmonics: int = 6,
332
+ n_jobs: int = 1,
333
+ verbose: bool = False,
334
+ return_window_results: bool = False,
335
+ fixed_nrl_bp: float = None,
336
+
337
+ ):
338
+ """
339
+ Slide a genomic window across `positions` and compute periodicity metrics per window.
340
+
341
+ Parameters
342
+ ----------
343
+ X : array-like or sparse, shape (n_molecules, n_positions)
344
+ Binary site matrix for a group (sample × reference × site_type).
345
+ positions : 1D array-like of ints
346
+ Genomic coordinates for columns of X (same length as X.shape[1]).
347
+ site_label : optional str
348
+ Label for the site type (used in returned dicts/df).
349
+ window_size : int
350
+ Window width in bp.
351
+ step : int
352
+ Slide step in bp.
353
+ max_lag : int
354
+ Max lag (bp) to compute autocorr out to.
355
+ min_molecules_per_window : int
356
+ Minimum molecules required to compute metrics for a window; otherwise metrics = NaN.
357
+ nrl_search_bp, pad_factor, min_count_for_mean, max_harmonics : forwarded to analyze_autocorr_matrix
358
+ n_jobs : int
359
+ Number of parallel jobs (uses joblib if available).
360
+ verbose : bool
361
+ Print progress messages.
362
+ return_window_results : bool
363
+ If True, return also the per-window raw `analyze_autocorr_matrix` outputs.
364
+
365
+ Returns
366
+ -------
367
+ df : pandas.DataFrame
368
+ One row per window with columns:
369
+ ['site', 'window_start', 'window_end', 'center', 'n_molecules',
370
+ 'nrl_bp', 'snr', 'peak_power', 'fwhm_bp', 'xi', 'xi_A', 'xi_r2']
371
+ (optionally) window_results : list of dicts (same order as df rows) when return_window_results=True
372
+ """
373
+
374
+ # normalize inputs
375
+ pos = np.asarray(positions, dtype=np.int64)
376
+ n_positions = pos.size
377
+ X_arr = X # could be sparse; will be handled per-row
378
+
379
+ start = int(pos.min())
380
+ end = int(pos.max())
381
+
382
+ # generate window starts; ensure at least one window
383
+ if end - start + 1 <= window_size:
384
+ window_starts = [start]
385
+ else:
386
+ window_starts = list(range(start, end - window_size + 1, step))
387
+
388
+ if verbose:
389
+ print(f"Rolling windows: {len(window_starts)} windows, window_size={window_size}, step={step}")
390
+
391
+ # helper to extract row to dense 1D np array (supports sparse rows)
392
+ def _row_to_arr(row):
393
+ # handle scipy sparse row
394
+ try:
395
+ import scipy.sparse as sp
396
+ except Exception:
397
+ sp = None
398
+ if sp is not None and sp.issparse(row):
399
+ return row.toarray().ravel()
400
+ else:
401
+ return np.asarray(row).ravel()
402
+
403
+ # function to process one window
404
+ def _process_window(ws):
405
+ we = ws + window_size
406
+ mask_pos = (pos >= ws) & (pos < we)
407
+ if mask_pos.sum() < 2:
408
+ return dict(
409
+ site=site_label,
410
+ window_start=ws,
411
+ window_end=we,
412
+ center=(ws + we) / 2.0,
413
+ n_molecules=0,
414
+ metrics=None,
415
+ )
416
+
417
+ rows_ac = []
418
+ rows_cnt = []
419
+
420
+ # iterate molecules (rows) and compute autocorr on sub-positions
421
+ n_mol = X_arr.shape[0]
422
+ for i in range(n_mol):
423
+ # safe row extraction for dense or sparse matrix
424
+ row_i = _row_to_arr(X_arr[i])
425
+ subrow = row_i[mask_pos]
426
+ # skip entirely-NaN rows (shouldn't happen with binaries) or empty
427
+ if subrow.size == 0:
428
+ continue
429
+ # compute autocorr on the windowed template; positions are pos[mask_pos]
430
+ try:
431
+ ac, cnts = binary_autocorrelation_with_spacing(subrow, pos[mask_pos], max_lag=max_lag, assume_sorted=True, return_counts=True)
432
+ except Exception:
433
+ # if autocorr fails for this row, skip it
434
+ continue
435
+ rows_ac.append(ac)
436
+ rows_cnt.append(cnts)
437
+
438
+ n_used = len(rows_ac)
439
+ if n_used < min_molecules_per_window:
440
+ return dict(
441
+ site=site_label,
442
+ window_start=ws,
443
+ window_end=we,
444
+ center=(ws + we) / 2.0,
445
+ n_molecules=n_used,
446
+ metrics=None,
447
+ )
448
+
449
+ ac_mat = np.asarray(rows_ac, dtype=np.float64)
450
+ cnt_mat = np.asarray(rows_cnt, dtype=np.int64)
451
+
452
+ # analyze per-window matrix using your analyzer
453
+ # compute weighted mean_ac (same as analyze_autocorr_matrix does earlier)
454
+ counts_total = cnt_mat.sum(axis=0)
455
+ filled = np.where(np.isfinite(ac_mat), ac_mat, 0.0)
456
+ s = (filled * cnt_mat).sum(axis=0)
457
+ with np.errstate(invalid="ignore", divide="ignore"):
458
+ mean_ac = np.where(counts_total > 0, s / counts_total, np.nan)
459
+ mean_ac[counts_total < min_count_for_mean] = np.nan
460
+
461
+ # If a fixed global NRL is provided, compute metrics around that frequency
462
+ if fixed_nrl_bp is not None:
463
+ freqs, power = psd_from_autocorr(mean_ac, np.arange(mean_ac.size), pad_factor=pad_factor)
464
+ # locate nearest freq bin to target_freq
465
+ target_f = 1.0 / float(fixed_nrl_bp)
466
+ # mask valid freqs
467
+ valid_mask = np.isfinite(freqs) & np.isfinite(power)
468
+ if not np.any(valid_mask):
469
+ res = {"error": "no_power"}
470
+ else:
471
+ # find index closest to target_f within valid_mask
472
+ idx_all = np.arange(len(freqs))
473
+ valid_idx = idx_all[valid_mask]
474
+ idx_closest = valid_idx[np.argmin(np.abs(freqs[valid_mask] - target_f))]
475
+ peak_idx = int(idx_closest)
476
+ peak_power = float(power[peak_idx])
477
+ snr_val, _, bg = estimate_snr(power, peak_idx, exclude_bins=3)
478
+ # sample harmonics from mean_ac at integer-lag positions using fixed_nrl_bp
479
+ # note: lags array is integer 0..(mean_ac.size-1)
480
+ sample_lags, heights = sample_autocorr_at_harmonics(mean_ac, np.arange(mean_ac.size), fixed_nrl_bp, max_harmonics=max_harmonics)
481
+ xi, A, slope, r2 = fit_exponential_envelope(sample_lags, heights) if heights.size else (np.nan, np.nan, np.nan, np.nan)
482
+ res = dict(
483
+ nrl_bp=float(fixed_nrl_bp),
484
+ f0=float(target_f),
485
+ peak_power=peak_power,
486
+ fwhm_bp=np.nan, # not robustly defined when using fixed freq (skip or compute small-band FWHM)
487
+ snr=float(snr_val),
488
+ bg_median=float(bg) if np.isfinite(bg) else np.nan,
489
+ envelope_sample_lags=sample_lags,
490
+ envelope_heights=heights,
491
+ xi=xi, xi_A=A, xi_slope=slope, xi_r2=r2,
492
+ freqs=freqs, power=power, mean_ac=mean_ac, counts=counts_total
493
+ )
494
+ else:
495
+ # existing behavior: call analyzer_fn
496
+ try:
497
+ res = analyze_autocorr_matrix(ac_mat, cnt_mat, np.arange(mean_ac.size),
498
+ nrl_search_bp=nrl_search_bp,
499
+ pad_factor=pad_factor,
500
+ min_count=min_count_for_mean,
501
+ max_harmonics=max_harmonics)
502
+ except Exception as e:
503
+ res = {"error": str(e)}
504
+
505
+ return dict(
506
+ site=site_label,
507
+ window_start=ws,
508
+ window_end=we,
509
+ center=(ws + we) / 2.0,
510
+ n_molecules=n_used,
511
+ metrics=res,
512
+ )
513
+
514
+ # choose mapping (parallel if available)
515
+ if _have_joblib and (n_jobs is not None) and (n_jobs != 1):
516
+ results = Parallel(n_jobs=n_jobs)(delayed(_process_window)(ws) for ws in window_starts)
517
+ else:
518
+ results = [_process_window(ws) for ws in window_starts]
519
+
520
+ # build dataframe rows
521
+ rows_out = []
522
+ window_results = []
523
+ for r in results:
524
+ metrics = r["metrics"]
525
+ window_results.append(metrics)
526
+ if metrics is None or ("error" in metrics and metrics.get("error") == "no_peak_found"):
527
+ rows_out.append({
528
+ "site": r["site"],
529
+ "window_start": r["window_start"],
530
+ "window_end": r["window_end"],
531
+ "center": r["center"],
532
+ "n_molecules": r["n_molecules"],
533
+ "nrl_bp": np.nan,
534
+ "snr": np.nan,
535
+ "peak_power": np.nan,
536
+ "fwhm_bp": np.nan,
537
+ "xi": np.nan,
538
+ "xi_A": np.nan,
539
+ "xi_r2": np.nan,
540
+ "analyzer_error": (metrics.get("error") if isinstance(metrics, dict) else "no_metrics"),
541
+ })
542
+ else:
543
+ rows_out.append({
544
+ "site": r["site"],
545
+ "window_start": r["window_start"],
546
+ "window_end": r["window_end"],
547
+ "center": r["center"],
548
+ "n_molecules": r["n_molecules"],
549
+ "nrl_bp": float(metrics.get("nrl_bp", np.nan)),
550
+ "snr": float(metrics.get("snr", np.nan)),
551
+ "peak_power": float(metrics.get("peak_power", np.nan)),
552
+ "fwhm_bp": float(metrics.get("fwhm_bp", np.nan)),
553
+ "xi": float(metrics.get("xi", np.nan)),
554
+ "xi_A": float(metrics.get("xi_A", np.nan)),
555
+ "xi_r2": float(metrics.get("xi_r2", np.nan)),
556
+ "analyzer_error": None,
557
+ })
558
+
559
+ df = pd.DataFrame(rows_out)
560
+ if return_window_results:
561
+ return df, window_results
562
+ return df