smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -117,123 +117,485 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
117
117
 
118
118
  return results_dict
119
119
 
120
- def compute_positionwise_statistic(
120
+ import copy
121
+ import warnings
122
+ from typing import Dict, Any, List, Optional, Tuple, Union
123
+
124
+ import numpy as np
125
+ import pandas as pd
126
+ import matplotlib.pyplot as plt
127
+
128
+ # optional imports
129
+ try:
130
+ from joblib import Parallel, delayed
131
+ JOBLIB_AVAILABLE = True
132
+ except Exception:
133
+ JOBLIB_AVAILABLE = False
134
+
135
+ try:
136
+ from scipy.stats import chi2_contingency
137
+ SCIPY_STATS_AVAILABLE = True
138
+ except Exception:
139
+ SCIPY_STATS_AVAILABLE = False
140
+
141
+ # -----------------------------
142
+ # Compute positionwise statistic (multi-method + simple site_types)
143
+ # -----------------------------
144
+ import numpy as np
145
+ import pandas as pd
146
+ from typing import List, Optional, Sequence, Dict, Any, Tuple
147
+ from contextlib import contextmanager
148
+ from joblib import Parallel, delayed, cpu_count
149
+ import joblib
150
+ from tqdm import tqdm
151
+ from scipy.stats import chi2_contingency
152
+ import warnings
153
+ import matplotlib.pyplot as plt
154
+ from itertools import cycle
155
+ import os
156
+ import warnings
157
+
158
+
159
+ # ---------------------------
160
+ # joblib <-> tqdm integration
161
+ # ---------------------------
162
+ @contextmanager
163
+ def tqdm_joblib(tqdm_object: tqdm):
164
+ """Context manager to patch joblib to update a tqdm progress bar."""
165
+ old = joblib.parallel.BatchCompletionCallBack
166
+
167
+ class TqdmBatchCompletionCallback(old): # type: ignore
168
+ def __call__(self, *args, **kwargs):
169
+ try:
170
+ tqdm_object.update(n=self.batch_size)
171
+ except Exception:
172
+ tqdm_object.update(1)
173
+ return super().__call__(*args, **kwargs)
174
+
175
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
176
+ try:
177
+ yield tqdm_object
178
+ finally:
179
+ joblib.parallel.BatchCompletionCallBack = old
180
+
181
+
182
+ # ---------------------------
183
+ # row workers (upper-triangle only)
184
+ # ---------------------------
185
+ def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
186
+ n_pos = X_bin.shape[1]
187
+ row = np.full((n_pos,), np.nan, dtype=float)
188
+ xi = X_bin[:, i]
189
+ for j in range(i, n_pos):
190
+ xj = X_bin[:, j]
191
+ mask = (~np.isnan(xi)) & (~np.isnan(xj))
192
+ if int(mask.sum()) < int(min_count_for_pairwise):
193
+ continue
194
+ try:
195
+ table = pd.crosstab(xi[mask], xj[mask])
196
+ if table.shape != (2, 2):
197
+ continue
198
+ chi2, _, _, _ = chi2_contingency(table, correction=False)
199
+ row[j] = float(chi2)
200
+ except Exception:
201
+ row[j] = np.nan
202
+ return (i, row)
203
+
204
+
205
+ def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
206
+ n_pos = X_bin.shape[1]
207
+ row = np.full((n_pos,), np.nan, dtype=float)
208
+ xi = X_bin[:, i]
209
+ for j in range(i, n_pos):
210
+ xj = X_bin[:, j]
211
+ mask = (~np.isnan(xi)) & (~np.isnan(xj))
212
+ if int(mask.sum()) < int(min_count_for_pairwise):
213
+ continue
214
+ a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
215
+ b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
216
+ c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
217
+ d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
218
+ try:
219
+ if (a + b) > 0 and (c + d) > 0 and (c > 0):
220
+ p1 = a / float(a + b)
221
+ p2 = c / float(c + d)
222
+ row[j] = float(p1 / p2) if p2 > 0 else np.nan
223
+ else:
224
+ row[j] = np.nan
225
+ except Exception:
226
+ row[j] = np.nan
227
+ return (i, row)
228
+
229
+ def compute_positionwise_statistics(
121
230
  adata,
122
- layer,
123
- method="pearson",
124
- groupby=["Reference_strand"],
125
- output_key="positionwise_result",
126
- site_config=None,
127
- encoding="signed",
128
- max_threads=None
231
+ layer: str,
232
+ methods: Sequence[str] = ("pearson",),
233
+ sample_col: str = "Barcode",
234
+ ref_col: str = "Reference_strand",
235
+ site_types: Optional[Sequence[str]] = None,
236
+ encoding: str = "signed",
237
+ output_key: str = "positionwise_result",
238
+ min_count_for_pairwise: int = 10,
239
+ max_threads: Optional[int] = None,
240
+ reverse_indices_on_store: bool = False,
129
241
  ):
130
242
  """
131
- Computes a position-by-position matrix (correlation, RR, or Chi-squared) from an adata layer.
243
+ Compute per-(sample,ref) positionwise matrices for methods in `methods`.
132
244
 
133
- Parameters:
134
- adata (AnnData): Annotated data matrix.
135
- layer (str): Name of the adata layer to use.
136
- method (str): 'pearson', 'binary_covariance', 'relative_risk', or 'chi_squared'.
137
- groupby (str or list): Column(s) in adata.obs to group by.
138
- output_key (str): Key in adata.uns to store results.
139
- site_config (dict): Optional {ref: [site_types]} to restrict sites per reference.
140
- encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
141
- max_threads (int): Number of parallel threads to use (joblib).
245
+ Results stored at:
246
+ adata.uns[output_key][method][ (sample, ref) ] = DataFrame
247
+ adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
142
248
  """
143
- import numpy as np
144
- import pandas as pd
145
- from scipy.stats import chi2_contingency
146
- from joblib import Parallel, delayed
147
- from tqdm import tqdm
249
+ if isinstance(methods, str):
250
+ methods = [methods]
251
+ methods = [m.lower() for m in methods]
252
+
253
+ # prepare containers
254
+ adata.uns[output_key] = {m: {} for m in methods}
255
+ adata.uns[output_key + "_n"] = {m: {} for m in methods}
256
+
257
+ # workers
258
+ if max_threads is None or max_threads <= 0:
259
+ n_jobs = max(1, cpu_count() or 1)
260
+ else:
261
+ n_jobs = max(1, int(max_threads))
262
+
263
+ # samples / refs
264
+ sseries = adata.obs[sample_col]
265
+ if not pd.api.types.is_categorical_dtype(sseries):
266
+ sseries = sseries.astype("category")
267
+ samples = list(sseries.cat.categories)
268
+
269
+ rseries = adata.obs[ref_col]
270
+ if not pd.api.types.is_categorical_dtype(rseries):
271
+ rseries = rseries.astype("category")
272
+ references = list(rseries.cat.categories)
273
+
274
+ total_tasks = len(samples) * len(references)
275
+ pbar_outer = tqdm(total=total_tasks, desc="positionwise (sample x ref)", unit="cell")
276
+
277
+ for sample in samples:
278
+ for ref in references:
279
+ label = (sample, ref)
280
+ try:
281
+ mask = (adata.obs[sample_col] == sample) & (adata.obs[ref_col] == ref)
282
+ subset = adata[mask]
283
+ n_reads = subset.shape[0]
284
+
285
+ # nothing to do -> store empty placeholders
286
+ if n_reads == 0:
287
+ for m in methods:
288
+ adata.uns[output_key][m][label] = pd.DataFrame()
289
+ adata.uns[output_key + "_n"][m][label] = 0
290
+ pbar_outer.update(1)
291
+ continue
148
292
 
149
- if isinstance(groupby, str):
150
- groupby = [groupby]
293
+ # select var columns based on site_types and reference
294
+ if site_types:
295
+ col_mask = np.zeros(subset.shape[1], dtype=bool)
296
+ for st in site_types:
297
+ colname = f"{ref}_{st}"
298
+ if colname in subset.var.columns:
299
+ col_mask |= np.asarray(subset.var[colname].values, dtype=bool)
300
+ else:
301
+ # if mask not present, warn once (but keep searching)
302
+ # user may pass generic site types
303
+ pass
304
+ if not col_mask.any():
305
+ selected_var_idx = np.arange(subset.shape[1])
306
+ else:
307
+ selected_var_idx = np.nonzero(col_mask)[0]
308
+ else:
309
+ selected_var_idx = np.arange(subset.shape[1])
310
+
311
+ if selected_var_idx.size == 0:
312
+ for m in methods:
313
+ adata.uns[output_key][m][label] = pd.DataFrame()
314
+ adata.uns[output_key + "_n"][m][label] = int(n_reads)
315
+ pbar_outer.update(1)
316
+ continue
151
317
 
152
- adata.uns[output_key] = {}
153
- adata.uns[output_key + "_n"] = {}
318
+ # extract matrix
319
+ if (layer in subset.layers) and (subset.layers[layer] is not None):
320
+ X = subset.layers[layer]
321
+ else:
322
+ X = subset.X
323
+ X = np.asarray(X, dtype=float)
324
+ X = X[:, selected_var_idx] # (n_reads, n_pos)
325
+
326
+ # binary encoding
327
+ if encoding == "signed":
328
+ X_bin = np.where(X == 1, 1.0, np.where(X == -1, 0.0, np.nan))
329
+ else:
330
+ X_bin = np.where(X == 1, 1.0, np.where(X == 0, 0.0, np.nan))
331
+
332
+ n_pos = X_bin.shape[1]
333
+ if n_pos == 0:
334
+ for m in methods:
335
+ adata.uns[output_key][m][label] = pd.DataFrame()
336
+ adata.uns[output_key + "_n"][m][label] = int(n_reads)
337
+ pbar_outer.update(1)
338
+ continue
154
339
 
155
- label_col = "__".join(groupby)
156
- adata.obs[label_col] = adata.obs[groupby].astype(str).agg("_".join, axis=1)
340
+ var_names = list(subset.var_names[selected_var_idx])
341
+
342
+ # compute per-method
343
+ for method in methods:
344
+ m = method.lower()
345
+ if m == "pearson":
346
+ # pairwise Pearson with column demean (nan-aware approximation)
347
+ with np.errstate(invalid="ignore"):
348
+ col_mean = np.nanmean(X_bin, axis=0)
349
+ Xc = X_bin - col_mean # nan preserved
350
+ Xc0 = np.nan_to_num(Xc, nan=0.0)
351
+ cov = Xc0.T @ Xc0
352
+ denom = (np.sqrt((Xc0**2).sum(axis=0))[:, None] * np.sqrt((Xc0**2).sum(axis=0))[None, :])
353
+ with np.errstate(divide="ignore", invalid="ignore"):
354
+ mat = np.where(denom != 0.0, cov / denom, np.nan)
355
+ elif m == "binary_covariance":
356
+ binary = (X_bin == 1).astype(float)
357
+ valid = (~np.isnan(X_bin)).astype(float)
358
+ with np.errstate(divide="ignore", invalid="ignore"):
359
+ numerator = binary.T @ binary
360
+ denominator = valid.T @ valid
361
+ mat = np.true_divide(numerator, denominator)
362
+ mat[~np.isfinite(mat)] = 0.0
363
+ elif m in ("chi_squared", "relative_risk"):
364
+ if m == "chi_squared":
365
+ worker = _chi2_row_job
366
+ else:
367
+ worker = _relative_risk_row_job
368
+ out = np.full((n_pos, n_pos), np.nan, dtype=float)
369
+ tasks = (delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos))
370
+ pbar_rows = tqdm(total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False)
371
+ with tqdm_joblib(pbar_rows):
372
+ results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
373
+ pbar_rows.close()
374
+ for i, row in results:
375
+ out[int(i), :] = row
376
+ iu = np.triu_indices(n_pos, k=1)
377
+ out[iu[1], iu[0]] = out[iu]
378
+ mat = out
379
+ else:
380
+ raise ValueError(f"Unsupported method: {method}")
381
+
382
+ # optionally reverse order at store-time
383
+ if reverse_indices_on_store:
384
+ mat_store = np.flip(np.flip(mat, axis=0), axis=1)
385
+ idx_names = var_names[::-1]
386
+ else:
387
+ mat_store = mat
388
+ idx_names = var_names
389
+
390
+ # make dataframe with labels
391
+ df = pd.DataFrame(mat_store, index=idx_names, columns=idx_names)
392
+
393
+ adata.uns[output_key][m][label] = df
394
+ adata.uns[output_key + "_n"][m][label] = int(n_reads)
395
+
396
+ except Exception as exc:
397
+ warnings.warn(f"Failed computing positionwise for {sample}__{ref}: {exc}")
398
+ finally:
399
+ pbar_outer.update(1)
400
+
401
+ pbar_outer.close()
402
+ return None
403
+
404
+
405
+ # ---------------------------
406
+ # Plotting function
407
+ # ---------------------------
408
+
409
+ def plot_positionwise_matrices(
410
+ adata,
411
+ methods: List[str],
412
+ cmaps: Optional[List[str]] = None,
413
+ sample_col: str = "Barcode",
414
+ ref_col: str = "Reference_strand",
415
+ output_dir: Optional[str] = None,
416
+ vmin: Optional[Dict[str, float]] = None,
417
+ vmax: Optional[Dict[str, float]] = None,
418
+ figsize_per_cell: Tuple[float, float] = (3.5, 3.5),
419
+ dpi: int = 160,
420
+ cbar_shrink: float = 0.9,
421
+ output_key: str = "positionwise_result",
422
+ show_colorbar: bool = True,
423
+ flip_display_axes: bool = False,
424
+ rows_per_page: int = 6,
425
+ sample_label_rotation: float = 90.0,
426
+ ):
427
+ """
428
+ Plot grids of matrices for each method with pagination and rotated sample-row labels.
157
429
 
158
- for group in adata.obs[label_col].unique():
159
- subset = adata[adata.obs[label_col] == group].copy()
160
- if subset.shape[0] == 0:
430
+ New args:
431
+ - rows_per_page: how many sample rows per page/figure (pagination)
432
+ - sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
433
+ Returns:
434
+ dict mapping method -> list of saved filenames (empty list if figures were shown).
435
+ """
436
+ if isinstance(methods, str):
437
+ methods = [methods]
438
+ if cmaps is None:
439
+ cmaps = ["viridis"] * len(methods)
440
+ cmap_cycle = cycle(cmaps)
441
+
442
+ # canonicalize sample/ref order
443
+ sseries = adata.obs[sample_col]
444
+ if not pd.api.types.is_categorical_dtype(sseries):
445
+ sseries = sseries.astype("category")
446
+ samples = list(sseries.cat.categories)
447
+
448
+ rseries = adata.obs[ref_col]
449
+ if not pd.api.types.is_categorical_dtype(rseries):
450
+ rseries = rseries.astype("category")
451
+ references = list(rseries.cat.categories)
452
+
453
+ # ensure directories
454
+ if output_dir:
455
+ os.makedirs(output_dir, exist_ok=True)
456
+
457
+ saved_files_by_method = {}
458
+
459
+ def _get_df_from_store(store, sample, ref):
460
+ """
461
+ try multiple key formats: (sample, ref) tuple, 'sample__ref' string,
462
+ or str(sample)+'__'+str(ref). Return None if not found.
463
+ """
464
+ if store is None:
465
+ return None
466
+ # try tuple key
467
+ key_t = (sample, ref)
468
+ if key_t in store:
469
+ return store[key_t]
470
+ # try string key
471
+ key_s = f"{sample}__{ref}"
472
+ if key_s in store:
473
+ return store[key_s]
474
+ # try stringified tuple keys (some callers store differently)
475
+ for k in store.keys():
476
+ try:
477
+ if isinstance(k, tuple) and len(k) == 2 and str(k[0]) == str(sample) and str(k[1]) == str(ref):
478
+ return store[k]
479
+ if isinstance(k, str) and key_s == k:
480
+ return store[k]
481
+ except Exception:
482
+ continue
483
+ return None
484
+
485
+ for method, cmap in zip(methods, cmap_cycle):
486
+ m = method.lower()
487
+ method_store = adata.uns.get(output_key, {}).get(m, {})
488
+ if not method_store:
489
+ warnings.warn(f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.", stacklevel=2)
490
+ saved_files_by_method[method] = []
161
491
  continue
162
492
 
163
- ref = subset.obs["Reference_strand"].unique()[0] if "Reference_strand" in groupby else None
164
-
165
- if site_config and ref in site_config:
166
- site_mask = np.zeros(subset.shape[1], dtype=bool)
167
- for site in site_config[ref]:
168
- site_mask |= subset.var[f"{ref}_{site}"]
169
- subset = subset[:, site_mask].copy()
170
-
171
- X = subset.layers[layer].copy()
172
-
173
- if encoding == "signed":
174
- X_bin = np.where(X == 1, 1, np.where(X == -1, 0, np.nan))
493
+ # gather numeric values to pick sensible vmin/vmax when not provided
494
+ vals = []
495
+ for s in samples:
496
+ for r in references:
497
+ df = _get_df_from_store(method_store, s, r)
498
+ if isinstance(df, pd.DataFrame) and df.size > 0:
499
+ a = df.values
500
+ a = a[np.isfinite(a)]
501
+ if a.size:
502
+ vals.append(a)
503
+ if vals:
504
+ allvals = np.concatenate(vals)
175
505
  else:
176
- X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
177
-
178
- n_pos = subset.shape[1]
179
- mat = np.zeros((n_pos, n_pos))
180
-
181
- if method == "pearson":
182
- with np.errstate(invalid='ignore'):
183
- mat = np.corrcoef(np.nan_to_num(X_bin).T)
184
-
185
- elif method == "binary_covariance":
186
- binary = (X_bin == 1).astype(float)
187
- valid = (X_bin == 1) | (X_bin == 0) # Only consider true binary (ignore NaN)
188
- valid = valid.astype(float)
189
-
190
- numerator = np.dot(binary.T, binary)
191
- denominator = np.dot(valid.T, valid)
192
-
193
- with np.errstate(divide='ignore', invalid='ignore'):
194
- mat = np.true_divide(numerator, denominator)
195
- mat[~np.isfinite(mat)] = 0
196
-
197
- elif method in ["relative_risk", "chi_squared"]:
198
- def compute_row(i):
199
- row = np.zeros(n_pos)
200
- xi = X_bin[:, i]
201
- for j in range(n_pos):
202
- xj = X_bin[:, j]
203
- mask = ~np.isnan(xi) & ~np.isnan(xj)
204
- if np.sum(mask) < 10:
205
- row[j] = np.nan
206
- continue
207
- if method == "relative_risk":
208
- a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
209
- b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
210
- c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
211
- d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
212
- if (a + b) > 0 and (c + d) > 0 and c > 0:
213
- p1 = a / (a + b)
214
- p2 = c / (c + d)
215
- row[j] = p1 / p2 if p2 > 0 else np.nan
216
- else:
217
- row[j] = np.nan
218
- elif method == "chi_squared":
219
- table = pd.crosstab(xi[mask], xj[mask])
220
- if table.shape != (2, 2):
221
- row[j] = np.nan
222
- else:
223
- chi2, _, _, _ = chi2_contingency(table, correction=False)
224
- row[j] = chi2
225
- return row
226
-
227
- mat = np.array(
228
- Parallel(n_jobs=max_threads)(
229
- delayed(compute_row)(i) for i in tqdm(range(n_pos), desc=f"{method}: {group}")
230
- )
231
- )
232
-
506
+ allvals = np.array([])
507
+
508
+ # decide per-method defaults
509
+ if m == "pearson":
510
+ vmn = -1.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
511
+ vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
512
+ vmn = -1.0 if vmn is None else vmn
513
+ vmx = 1.0 if vmx is None else vmx
514
+ elif m == "binary_covariance":
515
+ vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
516
+ vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
517
+ vmn = 0.0 if vmn is None else vmn
518
+ vmx = 1.0 if vmx is None else vmx
233
519
  else:
234
- raise ValueError(f"Unsupported method: {method}")
520
+ vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
521
+ if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
522
+ vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
523
+ else:
524
+ vmx = (vmax.get(m) if isinstance(vmax, dict) else vmax)
525
+ vmn = 0.0 if vmn is None else vmn
526
+ if vmx is None:
527
+ vmx = 1.0
528
+
529
+ # prepare pagination over sample rows
530
+ saved_files = []
531
+ n_pages = max(1, int(np.ceil(len(samples) / float(max(1, rows_per_page)))))
532
+ for page_idx in range(n_pages):
533
+ start = page_idx * rows_per_page
534
+ chunk = samples[start : start + rows_per_page]
535
+ nrows = len(chunk)
536
+ ncols = max(1, len(references))
537
+ fig_w = ncols * figsize_per_cell[0]
538
+ fig_h = nrows * figsize_per_cell[1]
539
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
540
+
541
+ # leave margin for rotated sample labels
542
+ plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
543
+
544
+ any_plotted = False
545
+ im = None
546
+ for r_idx, sample in enumerate(chunk):
547
+ for c_idx, ref in enumerate(references):
548
+ ax = axes[r_idx][c_idx]
549
+ df = _get_df_from_store(method_store, sample, ref)
550
+ if not isinstance(df, pd.DataFrame) or df.size == 0:
551
+ ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
552
+ ax.set_xticks([])
553
+ ax.set_yticks([])
554
+ else:
555
+ mat = df.values.astype(float)
556
+ origin = "upper" if flip_display_axes else "lower"
557
+ im = ax.imshow(mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap)
558
+ any_plotted = True
559
+ ax.set_xticks([])
560
+ ax.set_yticks([])
561
+
562
+ # top title is reference (only for top-row)
563
+ if r_idx == 0:
564
+ ax.set_title(str(ref), fontsize=9)
565
+
566
+ # draw rotated sample label into left margin centered on the row
567
+ # compute vertical center of this row's axis in figure coords
568
+ ax0 = axes[r_idx][0]
569
+ ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
570
+ y_center = 0.5 * (ax_y0 + ax_y1)
571
+ # place text at x=0.01 (just inside left margin); rotation controls orientation
572
+ fig.text(0.01, y_center, str(chunk[r_idx]), va="center", ha="left", rotation=sample_label_rotation, fontsize=9)
573
+
574
+ fig.suptitle(f"{method} — per-sample x per-reference matrices (page {page_idx+1}/{n_pages})", fontsize=12, y=0.99)
575
+ fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
576
+
577
+ # colorbar (shared)
578
+ if any_plotted and show_colorbar and (im is not None):
579
+ try:
580
+ cbar_ax = fig.add_axes([0.9, 0.15, 0.02, 0.7])
581
+ fig.colorbar(im, cax=cbar_ax, shrink=cbar_shrink)
582
+ except Exception:
583
+ try:
584
+ fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
585
+ except Exception:
586
+ pass
587
+
588
+ # save or show
589
+ if output_dir:
590
+ fname = f"positionwise_{method}_page{page_idx+1}.png"
591
+ outpath = os.path.join(output_dir, fname)
592
+ plt.savefig(outpath, bbox_inches="tight")
593
+ saved_files.append(outpath)
594
+ plt.close(fig)
595
+ else:
596
+ plt.show()
597
+ saved_files.append("") # placeholder to indicate a figure was shown
598
+
599
+ saved_files_by_method[method] = saved_files
235
600
 
236
- var_names = subset.var_names.astype(int)
237
- mat_df = pd.DataFrame(mat, index=var_names, columns=var_names)
238
- adata.uns[output_key][group] = mat_df
239
- adata.uns[output_key + "_n"][group] = subset.shape[0]
601
+ return saved_files_by_method