smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,601 @@
1
+ # ------------------------- Utilities -------------------------
2
+ def random_fill_nans(X):
3
+ import numpy as np
4
+ nan_mask = np.isnan(X)
5
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
+ return X
7
+
8
+ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
9
+ """
10
+ Perform Bayesian-style methylation vs activity analysis independently within each group.
11
+
12
+ Parameters:
13
+ adata (AnnData): Annotated data matrix.
14
+ sites (list of str): List of site keys (e.g., ['GpC_site', 'CpG_site']).
15
+ alpha (float): FDR threshold for significance.
16
+ groupby (str or list of str): Column(s) in adata.obs to group by.
17
+
18
+ Returns:
19
+ results_dict (dict): Dictionary with structure:
20
+ results_dict[ref][group_label] = (results_df, sig_df)
21
+ """
22
+ import numpy as np
23
+ import pandas as pd
24
+ from scipy.stats import fisher_exact
25
+ from statsmodels.stats.multitest import multipletests
26
+
27
+ def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
28
+ p_adj = multipletests(p_values, method='fdr_bh')[1] if p_values else []
29
+
30
+ genomic_positions = np.array(site_subset.var_names)[positions_list]
31
+ is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
32
+ is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
33
+
34
+ results_df = pd.DataFrame({
35
+ 'Feature_Index': positions_list,
36
+ 'Genomic_Position': genomic_positions.astype(int),
37
+ 'Relative_Risk': relative_risks,
38
+ 'Adjusted_P_Value': p_adj,
39
+ 'GpC_Site': is_gpc_site,
40
+ 'CpG_Site': is_cpg_site
41
+ })
42
+
43
+ results_df['log2_Relative_Risk'] = np.log2(results_df['Relative_Risk'].replace(0, 1e-300))
44
+ results_df['-log10_Adj_P'] = -np.log10(results_df['Adjusted_P_Value'].replace(0, 1e-300))
45
+ sig_df = results_df[results_df['Adjusted_P_Value'] < alpha]
46
+ return results_df, sig_df
47
+
48
+ results_dict = {}
49
+
50
+ for ref in adata.obs['Reference_strand'].unique():
51
+ ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
52
+ if ref_subset.shape[0] == 0:
53
+ continue
54
+
55
+ # Normalize groupby to list
56
+ if groupby is not None:
57
+ if isinstance(groupby, str):
58
+ groupby = [groupby]
59
+ def format_group_label(row):
60
+ return ",".join([f"{col}={row[col]}" for col in groupby])
61
+
62
+ combined_label = '__'.join(groupby)
63
+ ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
64
+ groups = ref_subset.obs[combined_label].unique()
65
+ else:
66
+ combined_label = None
67
+ groups = ['all']
68
+
69
+ results_dict[ref] = {}
70
+
71
+ for group in groups:
72
+ if group == 'all':
73
+ group_subset = ref_subset
74
+ else:
75
+ group_subset = ref_subset[ref_subset.obs[combined_label] == group]
76
+
77
+ if group_subset.shape[0] == 0:
78
+ continue
79
+
80
+ # Build site mask
81
+ site_mask = np.zeros(group_subset.shape[1], dtype=bool)
82
+ for site in sites:
83
+ site_mask |= group_subset.var[f"{ref}_{site}"]
84
+ site_subset = group_subset[:, site_mask].copy()
85
+
86
+ # Matrix and labels
87
+ X = random_fill_nans(site_subset.X.copy())
88
+ y = site_subset.obs['activity_status'].map({'Active': 1, 'Silent': 0}).values
89
+ P_active = np.mean(y)
90
+
91
+ # Analysis
92
+ positions_list, relative_risks, p_values = [], [], []
93
+ for pos in range(X.shape[1]):
94
+ methylation_state = (X[:, pos] > 0).astype(int)
95
+ table = pd.crosstab(methylation_state, y)
96
+ if table.shape != (2, 2):
97
+ continue
98
+
99
+ P_methylated = np.mean(methylation_state)
100
+ P_methylated_given_active = np.mean(methylation_state[y == 1])
101
+ P_methylated_given_inactive = np.mean(methylation_state[y == 0])
102
+
103
+ if P_methylated_given_inactive == 0 or P_methylated in [0, 1]:
104
+ continue
105
+
106
+ P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
107
+ P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (1 - P_methylated)
108
+ RR = P_active_given_methylated / P_active_given_unmethylated
109
+
110
+ _, p_value = fisher_exact(table)
111
+ positions_list.append(pos)
112
+ relative_risks.append(RR)
113
+ p_values.append(p_value)
114
+
115
+ results_df, sig_df = compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values)
116
+ results_dict[ref][group] = (results_df, sig_df)
117
+
118
+ return results_dict
119
+
120
+ import copy
121
+ import warnings
122
+ from typing import Dict, Any, List, Optional, Tuple, Union
123
+
124
+ import numpy as np
125
+ import pandas as pd
126
+ import matplotlib.pyplot as plt
127
+
128
+ # optional imports
129
+ try:
130
+ from joblib import Parallel, delayed
131
+ JOBLIB_AVAILABLE = True
132
+ except Exception:
133
+ JOBLIB_AVAILABLE = False
134
+
135
+ try:
136
+ from scipy.stats import chi2_contingency
137
+ SCIPY_STATS_AVAILABLE = True
138
+ except Exception:
139
+ SCIPY_STATS_AVAILABLE = False
140
+
141
+ # -----------------------------
142
+ # Compute positionwise statistic (multi-method + simple site_types)
143
+ # -----------------------------
144
+ import numpy as np
145
+ import pandas as pd
146
+ from typing import List, Optional, Sequence, Dict, Any, Tuple
147
+ from contextlib import contextmanager
148
+ from joblib import Parallel, delayed, cpu_count
149
+ import joblib
150
+ from tqdm import tqdm
151
+ from scipy.stats import chi2_contingency
152
+ import warnings
153
+ import matplotlib.pyplot as plt
154
+ from itertools import cycle
155
+ import os
156
+ import warnings
157
+
158
+
159
+ # ---------------------------
160
+ # joblib <-> tqdm integration
161
+ # ---------------------------
162
+ @contextmanager
163
+ def tqdm_joblib(tqdm_object: tqdm):
164
+ """Context manager to patch joblib to update a tqdm progress bar."""
165
+ old = joblib.parallel.BatchCompletionCallBack
166
+
167
+ class TqdmBatchCompletionCallback(old): # type: ignore
168
+ def __call__(self, *args, **kwargs):
169
+ try:
170
+ tqdm_object.update(n=self.batch_size)
171
+ except Exception:
172
+ tqdm_object.update(1)
173
+ return super().__call__(*args, **kwargs)
174
+
175
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
176
+ try:
177
+ yield tqdm_object
178
+ finally:
179
+ joblib.parallel.BatchCompletionCallBack = old
180
+
181
+
182
+ # ---------------------------
183
+ # row workers (upper-triangle only)
184
+ # ---------------------------
185
+ def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
186
+ n_pos = X_bin.shape[1]
187
+ row = np.full((n_pos,), np.nan, dtype=float)
188
+ xi = X_bin[:, i]
189
+ for j in range(i, n_pos):
190
+ xj = X_bin[:, j]
191
+ mask = (~np.isnan(xi)) & (~np.isnan(xj))
192
+ if int(mask.sum()) < int(min_count_for_pairwise):
193
+ continue
194
+ try:
195
+ table = pd.crosstab(xi[mask], xj[mask])
196
+ if table.shape != (2, 2):
197
+ continue
198
+ chi2, _, _, _ = chi2_contingency(table, correction=False)
199
+ row[j] = float(chi2)
200
+ except Exception:
201
+ row[j] = np.nan
202
+ return (i, row)
203
+
204
+
205
+ def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
206
+ n_pos = X_bin.shape[1]
207
+ row = np.full((n_pos,), np.nan, dtype=float)
208
+ xi = X_bin[:, i]
209
+ for j in range(i, n_pos):
210
+ xj = X_bin[:, j]
211
+ mask = (~np.isnan(xi)) & (~np.isnan(xj))
212
+ if int(mask.sum()) < int(min_count_for_pairwise):
213
+ continue
214
+ a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
215
+ b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
216
+ c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
217
+ d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
218
+ try:
219
+ if (a + b) > 0 and (c + d) > 0 and (c > 0):
220
+ p1 = a / float(a + b)
221
+ p2 = c / float(c + d)
222
+ row[j] = float(p1 / p2) if p2 > 0 else np.nan
223
+ else:
224
+ row[j] = np.nan
225
+ except Exception:
226
+ row[j] = np.nan
227
+ return (i, row)
228
+
229
+ def compute_positionwise_statistics(
230
+ adata,
231
+ layer: str,
232
+ methods: Sequence[str] = ("pearson",),
233
+ sample_col: str = "Barcode",
234
+ ref_col: str = "Reference_strand",
235
+ site_types: Optional[Sequence[str]] = None,
236
+ encoding: str = "signed",
237
+ output_key: str = "positionwise_result",
238
+ min_count_for_pairwise: int = 10,
239
+ max_threads: Optional[int] = None,
240
+ reverse_indices_on_store: bool = False,
241
+ ):
242
+ """
243
+ Compute per-(sample,ref) positionwise matrices for methods in `methods`.
244
+
245
+ Results stored at:
246
+ adata.uns[output_key][method][ (sample, ref) ] = DataFrame
247
+ adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
248
+ """
249
+ if isinstance(methods, str):
250
+ methods = [methods]
251
+ methods = [m.lower() for m in methods]
252
+
253
+ # prepare containers
254
+ adata.uns[output_key] = {m: {} for m in methods}
255
+ adata.uns[output_key + "_n"] = {m: {} for m in methods}
256
+
257
+ # workers
258
+ if max_threads is None or max_threads <= 0:
259
+ n_jobs = max(1, cpu_count() or 1)
260
+ else:
261
+ n_jobs = max(1, int(max_threads))
262
+
263
+ # samples / refs
264
+ sseries = adata.obs[sample_col]
265
+ if not pd.api.types.is_categorical_dtype(sseries):
266
+ sseries = sseries.astype("category")
267
+ samples = list(sseries.cat.categories)
268
+
269
+ rseries = adata.obs[ref_col]
270
+ if not pd.api.types.is_categorical_dtype(rseries):
271
+ rseries = rseries.astype("category")
272
+ references = list(rseries.cat.categories)
273
+
274
+ total_tasks = len(samples) * len(references)
275
+ pbar_outer = tqdm(total=total_tasks, desc="positionwise (sample x ref)", unit="cell")
276
+
277
+ for sample in samples:
278
+ for ref in references:
279
+ label = (sample, ref)
280
+ try:
281
+ mask = (adata.obs[sample_col] == sample) & (adata.obs[ref_col] == ref)
282
+ subset = adata[mask]
283
+ n_reads = subset.shape[0]
284
+
285
+ # nothing to do -> store empty placeholders
286
+ if n_reads == 0:
287
+ for m in methods:
288
+ adata.uns[output_key][m][label] = pd.DataFrame()
289
+ adata.uns[output_key + "_n"][m][label] = 0
290
+ pbar_outer.update(1)
291
+ continue
292
+
293
+ # select var columns based on site_types and reference
294
+ if site_types:
295
+ col_mask = np.zeros(subset.shape[1], dtype=bool)
296
+ for st in site_types:
297
+ colname = f"{ref}_{st}"
298
+ if colname in subset.var.columns:
299
+ col_mask |= np.asarray(subset.var[colname].values, dtype=bool)
300
+ else:
301
+ # if mask not present, warn once (but keep searching)
302
+ # user may pass generic site types
303
+ pass
304
+ if not col_mask.any():
305
+ selected_var_idx = np.arange(subset.shape[1])
306
+ else:
307
+ selected_var_idx = np.nonzero(col_mask)[0]
308
+ else:
309
+ selected_var_idx = np.arange(subset.shape[1])
310
+
311
+ if selected_var_idx.size == 0:
312
+ for m in methods:
313
+ adata.uns[output_key][m][label] = pd.DataFrame()
314
+ adata.uns[output_key + "_n"][m][label] = int(n_reads)
315
+ pbar_outer.update(1)
316
+ continue
317
+
318
+ # extract matrix
319
+ if (layer in subset.layers) and (subset.layers[layer] is not None):
320
+ X = subset.layers[layer]
321
+ else:
322
+ X = subset.X
323
+ X = np.asarray(X, dtype=float)
324
+ X = X[:, selected_var_idx] # (n_reads, n_pos)
325
+
326
+ # binary encoding
327
+ if encoding == "signed":
328
+ X_bin = np.where(X == 1, 1.0, np.where(X == -1, 0.0, np.nan))
329
+ else:
330
+ X_bin = np.where(X == 1, 1.0, np.where(X == 0, 0.0, np.nan))
331
+
332
+ n_pos = X_bin.shape[1]
333
+ if n_pos == 0:
334
+ for m in methods:
335
+ adata.uns[output_key][m][label] = pd.DataFrame()
336
+ adata.uns[output_key + "_n"][m][label] = int(n_reads)
337
+ pbar_outer.update(1)
338
+ continue
339
+
340
+ var_names = list(subset.var_names[selected_var_idx])
341
+
342
+ # compute per-method
343
+ for method in methods:
344
+ m = method.lower()
345
+ if m == "pearson":
346
+ # pairwise Pearson with column demean (nan-aware approximation)
347
+ with np.errstate(invalid="ignore"):
348
+ col_mean = np.nanmean(X_bin, axis=0)
349
+ Xc = X_bin - col_mean # nan preserved
350
+ Xc0 = np.nan_to_num(Xc, nan=0.0)
351
+ cov = Xc0.T @ Xc0
352
+ denom = (np.sqrt((Xc0**2).sum(axis=0))[:, None] * np.sqrt((Xc0**2).sum(axis=0))[None, :])
353
+ with np.errstate(divide="ignore", invalid="ignore"):
354
+ mat = np.where(denom != 0.0, cov / denom, np.nan)
355
+ elif m == "binary_covariance":
356
+ binary = (X_bin == 1).astype(float)
357
+ valid = (~np.isnan(X_bin)).astype(float)
358
+ with np.errstate(divide="ignore", invalid="ignore"):
359
+ numerator = binary.T @ binary
360
+ denominator = valid.T @ valid
361
+ mat = np.true_divide(numerator, denominator)
362
+ mat[~np.isfinite(mat)] = 0.0
363
+ elif m in ("chi_squared", "relative_risk"):
364
+ if m == "chi_squared":
365
+ worker = _chi2_row_job
366
+ else:
367
+ worker = _relative_risk_row_job
368
+ out = np.full((n_pos, n_pos), np.nan, dtype=float)
369
+ tasks = (delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos))
370
+ pbar_rows = tqdm(total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False)
371
+ with tqdm_joblib(pbar_rows):
372
+ results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
373
+ pbar_rows.close()
374
+ for i, row in results:
375
+ out[int(i), :] = row
376
+ iu = np.triu_indices(n_pos, k=1)
377
+ out[iu[1], iu[0]] = out[iu]
378
+ mat = out
379
+ else:
380
+ raise ValueError(f"Unsupported method: {method}")
381
+
382
+ # optionally reverse order at store-time
383
+ if reverse_indices_on_store:
384
+ mat_store = np.flip(np.flip(mat, axis=0), axis=1)
385
+ idx_names = var_names[::-1]
386
+ else:
387
+ mat_store = mat
388
+ idx_names = var_names
389
+
390
+ # make dataframe with labels
391
+ df = pd.DataFrame(mat_store, index=idx_names, columns=idx_names)
392
+
393
+ adata.uns[output_key][m][label] = df
394
+ adata.uns[output_key + "_n"][m][label] = int(n_reads)
395
+
396
+ except Exception as exc:
397
+ warnings.warn(f"Failed computing positionwise for {sample}__{ref}: {exc}")
398
+ finally:
399
+ pbar_outer.update(1)
400
+
401
+ pbar_outer.close()
402
+ return None
403
+
404
+
405
+ # ---------------------------
406
+ # Plotting function
407
+ # ---------------------------
408
+
409
+ def plot_positionwise_matrices(
410
+ adata,
411
+ methods: List[str],
412
+ cmaps: Optional[List[str]] = None,
413
+ sample_col: str = "Barcode",
414
+ ref_col: str = "Reference_strand",
415
+ output_dir: Optional[str] = None,
416
+ vmin: Optional[Dict[str, float]] = None,
417
+ vmax: Optional[Dict[str, float]] = None,
418
+ figsize_per_cell: Tuple[float, float] = (3.5, 3.5),
419
+ dpi: int = 160,
420
+ cbar_shrink: float = 0.9,
421
+ output_key: str = "positionwise_result",
422
+ show_colorbar: bool = True,
423
+ flip_display_axes: bool = False,
424
+ rows_per_page: int = 6,
425
+ sample_label_rotation: float = 90.0,
426
+ ):
427
+ """
428
+ Plot grids of matrices for each method with pagination and rotated sample-row labels.
429
+
430
+ New args:
431
+ - rows_per_page: how many sample rows per page/figure (pagination)
432
+ - sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
433
+ Returns:
434
+ dict mapping method -> list of saved filenames (empty list if figures were shown).
435
+ """
436
+ if isinstance(methods, str):
437
+ methods = [methods]
438
+ if cmaps is None:
439
+ cmaps = ["viridis"] * len(methods)
440
+ cmap_cycle = cycle(cmaps)
441
+
442
+ # canonicalize sample/ref order
443
+ sseries = adata.obs[sample_col]
444
+ if not pd.api.types.is_categorical_dtype(sseries):
445
+ sseries = sseries.astype("category")
446
+ samples = list(sseries.cat.categories)
447
+
448
+ rseries = adata.obs[ref_col]
449
+ if not pd.api.types.is_categorical_dtype(rseries):
450
+ rseries = rseries.astype("category")
451
+ references = list(rseries.cat.categories)
452
+
453
+ # ensure directories
454
+ if output_dir:
455
+ os.makedirs(output_dir, exist_ok=True)
456
+
457
+ saved_files_by_method = {}
458
+
459
+ def _get_df_from_store(store, sample, ref):
460
+ """
461
+ try multiple key formats: (sample, ref) tuple, 'sample__ref' string,
462
+ or str(sample)+'__'+str(ref). Return None if not found.
463
+ """
464
+ if store is None:
465
+ return None
466
+ # try tuple key
467
+ key_t = (sample, ref)
468
+ if key_t in store:
469
+ return store[key_t]
470
+ # try string key
471
+ key_s = f"{sample}__{ref}"
472
+ if key_s in store:
473
+ return store[key_s]
474
+ # try stringified tuple keys (some callers store differently)
475
+ for k in store.keys():
476
+ try:
477
+ if isinstance(k, tuple) and len(k) == 2 and str(k[0]) == str(sample) and str(k[1]) == str(ref):
478
+ return store[k]
479
+ if isinstance(k, str) and key_s == k:
480
+ return store[k]
481
+ except Exception:
482
+ continue
483
+ return None
484
+
485
+ for method, cmap in zip(methods, cmap_cycle):
486
+ m = method.lower()
487
+ method_store = adata.uns.get(output_key, {}).get(m, {})
488
+ if not method_store:
489
+ warnings.warn(f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.", stacklevel=2)
490
+ saved_files_by_method[method] = []
491
+ continue
492
+
493
+ # gather numeric values to pick sensible vmin/vmax when not provided
494
+ vals = []
495
+ for s in samples:
496
+ for r in references:
497
+ df = _get_df_from_store(method_store, s, r)
498
+ if isinstance(df, pd.DataFrame) and df.size > 0:
499
+ a = df.values
500
+ a = a[np.isfinite(a)]
501
+ if a.size:
502
+ vals.append(a)
503
+ if vals:
504
+ allvals = np.concatenate(vals)
505
+ else:
506
+ allvals = np.array([])
507
+
508
+ # decide per-method defaults
509
+ if m == "pearson":
510
+ vmn = -1.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
511
+ vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
512
+ vmn = -1.0 if vmn is None else vmn
513
+ vmx = 1.0 if vmx is None else vmx
514
+ elif m == "binary_covariance":
515
+ vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
516
+ vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
517
+ vmn = 0.0 if vmn is None else vmn
518
+ vmx = 1.0 if vmx is None else vmx
519
+ else:
520
+ vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
521
+ if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
522
+ vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
523
+ else:
524
+ vmx = (vmax.get(m) if isinstance(vmax, dict) else vmax)
525
+ vmn = 0.0 if vmn is None else vmn
526
+ if vmx is None:
527
+ vmx = 1.0
528
+
529
+ # prepare pagination over sample rows
530
+ saved_files = []
531
+ n_pages = max(1, int(np.ceil(len(samples) / float(max(1, rows_per_page)))))
532
+ for page_idx in range(n_pages):
533
+ start = page_idx * rows_per_page
534
+ chunk = samples[start : start + rows_per_page]
535
+ nrows = len(chunk)
536
+ ncols = max(1, len(references))
537
+ fig_w = ncols * figsize_per_cell[0]
538
+ fig_h = nrows * figsize_per_cell[1]
539
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
540
+
541
+ # leave margin for rotated sample labels
542
+ plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
543
+
544
+ any_plotted = False
545
+ im = None
546
+ for r_idx, sample in enumerate(chunk):
547
+ for c_idx, ref in enumerate(references):
548
+ ax = axes[r_idx][c_idx]
549
+ df = _get_df_from_store(method_store, sample, ref)
550
+ if not isinstance(df, pd.DataFrame) or df.size == 0:
551
+ ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
552
+ ax.set_xticks([])
553
+ ax.set_yticks([])
554
+ else:
555
+ mat = df.values.astype(float)
556
+ origin = "upper" if flip_display_axes else "lower"
557
+ im = ax.imshow(mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap)
558
+ any_plotted = True
559
+ ax.set_xticks([])
560
+ ax.set_yticks([])
561
+
562
+ # top title is reference (only for top-row)
563
+ if r_idx == 0:
564
+ ax.set_title(str(ref), fontsize=9)
565
+
566
+ # draw rotated sample label into left margin centered on the row
567
+ # compute vertical center of this row's axis in figure coords
568
+ ax0 = axes[r_idx][0]
569
+ ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
570
+ y_center = 0.5 * (ax_y0 + ax_y1)
571
+ # place text at x=0.01 (just inside left margin); rotation controls orientation
572
+ fig.text(0.01, y_center, str(chunk[r_idx]), va="center", ha="left", rotation=sample_label_rotation, fontsize=9)
573
+
574
+ fig.suptitle(f"{method} — per-sample x per-reference matrices (page {page_idx+1}/{n_pages})", fontsize=12, y=0.99)
575
+ fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
576
+
577
+ # colorbar (shared)
578
+ if any_plotted and show_colorbar and (im is not None):
579
+ try:
580
+ cbar_ax = fig.add_axes([0.9, 0.15, 0.02, 0.7])
581
+ fig.colorbar(im, cax=cbar_ax, shrink=cbar_shrink)
582
+ except Exception:
583
+ try:
584
+ fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
585
+ except Exception:
586
+ pass
587
+
588
+ # save or show
589
+ if output_dir:
590
+ fname = f"positionwise_{method}_page{page_idx+1}.png"
591
+ outpath = os.path.join(output_dir, fname)
592
+ plt.savefig(outpath, bbox_inches="tight")
593
+ saved_files.append(outpath)
594
+ plt.close(fig)
595
+ else:
596
+ plt.show()
597
+ saved_files.append("") # placeholder to indicate a figure was shown
598
+
599
+ saved_files_by_method[method] = saved_files
600
+
601
+ return saved_files_by_method