smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
5
+
6
+ if TYPE_CHECKING:
7
+ import anndata as ad
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ # optional imports
14
+ try:
15
+ from joblib import Parallel, delayed
16
+
17
+ JOBLIB_AVAILABLE = True
18
+ except Exception:
19
+ JOBLIB_AVAILABLE = False
20
+
21
+ try:
22
+ from scipy.stats import chi2_contingency
23
+
24
+ SCIPY_STATS_AVAILABLE = True
25
+ except Exception:
26
+ SCIPY_STATS_AVAILABLE = False
27
+
28
+ # -----------------------------
29
+ # Compute positionwise statistic (multi-method + simple site_types)
30
+ # -----------------------------
31
+ import os
32
+ from contextlib import contextmanager
33
+ from itertools import cycle
34
+
35
+ import joblib
36
+ from joblib import Parallel, cpu_count, delayed
37
+ from scipy.stats import chi2_contingency
38
+ from tqdm import tqdm
39
+
40
+
1
41
  # ------------------------- Utilities -------------------------
2
- def random_fill_nans(X):
42
+ def random_fill_nans(X: np.ndarray) -> np.ndarray:
43
+ """Fill NaNs with random values in-place.
44
+
45
+ Args:
46
+ X: Input array with NaNs.
47
+
48
+ Returns:
49
+ numpy.ndarray: Array with NaNs replaced by random values.
50
+ """
3
51
  import numpy as np
52
+
4
53
  nan_mask = np.isnan(X)
5
54
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
55
  return X
7
56
 
8
- def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
9
- """
10
- Perform Bayesian-style methylation vs activity analysis independently within each group.
11
57
 
12
- Parameters:
13
- adata (AnnData): Annotated data matrix.
14
- sites (list of str): List of site keys (e.g., ['GpC_site', 'CpG_site']).
15
- alpha (float): FDR threshold for significance.
16
- groupby (str or list of str): Column(s) in adata.obs to group by.
58
+ def calculate_relative_risk_on_activity(
59
+ adata: "ad.AnnData",
60
+ sites: Sequence[str],
61
+ alpha: float = 0.05,
62
+ groupby: str | Sequence[str] | None = None,
63
+ ) -> dict:
64
+ """Perform methylation vs. activity analysis within each group.
65
+
66
+ Args:
67
+ adata: Annotated data matrix.
68
+ sites: Site keys (e.g., ``["GpC_site", "CpG_site"]``).
69
+ alpha: FDR threshold for significance.
70
+ groupby: Obs column(s) to group by.
17
71
 
18
72
  Returns:
19
- results_dict (dict): Dictionary with structure:
20
- results_dict[ref][group_label] = (results_df, sig_df)
73
+ dict: Mapping of reference -> group label -> ``(results_df, sig_df)``.
21
74
  """
22
75
  import numpy as np
23
76
  import pandas as pd
@@ -25,30 +78,44 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
25
78
  from statsmodels.stats.multitest import multipletests
26
79
 
27
80
  def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
28
- p_adj = multipletests(p_values, method='fdr_bh')[1] if p_values else []
81
+ """Build result and significant-data DataFrames for a reference.
82
+
83
+ Args:
84
+ ref: Reference name.
85
+ site_subset: AnnData subset restricted to sites.
86
+ positions_list: Positions tested.
87
+ relative_risks: Relative risk values.
88
+ p_values: Raw p-values.
89
+
90
+ Returns:
91
+ Tuple of (results_df, sig_df).
92
+ """
93
+ p_adj = multipletests(p_values, method="fdr_bh")[1] if p_values else []
29
94
 
30
95
  genomic_positions = np.array(site_subset.var_names)[positions_list]
31
96
  is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
32
97
  is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
33
98
 
34
- results_df = pd.DataFrame({
35
- 'Feature_Index': positions_list,
36
- 'Genomic_Position': genomic_positions.astype(int),
37
- 'Relative_Risk': relative_risks,
38
- 'Adjusted_P_Value': p_adj,
39
- 'GpC_Site': is_gpc_site,
40
- 'CpG_Site': is_cpg_site
41
- })
42
-
43
- results_df['log2_Relative_Risk'] = np.log2(results_df['Relative_Risk'].replace(0, 1e-300))
44
- results_df['-log10_Adj_P'] = -np.log10(results_df['Adjusted_P_Value'].replace(0, 1e-300))
45
- sig_df = results_df[results_df['Adjusted_P_Value'] < alpha]
99
+ results_df = pd.DataFrame(
100
+ {
101
+ "Feature_Index": positions_list,
102
+ "Genomic_Position": genomic_positions.astype(int),
103
+ "Relative_Risk": relative_risks,
104
+ "Adjusted_P_Value": p_adj,
105
+ "GpC_Site": is_gpc_site,
106
+ "CpG_Site": is_cpg_site,
107
+ }
108
+ )
109
+
110
+ results_df["log2_Relative_Risk"] = np.log2(results_df["Relative_Risk"].replace(0, 1e-300))
111
+ results_df["-log10_Adj_P"] = -np.log10(results_df["Adjusted_P_Value"].replace(0, 1e-300))
112
+ sig_df = results_df[results_df["Adjusted_P_Value"] < alpha]
46
113
  return results_df, sig_df
47
114
 
48
115
  results_dict = {}
49
116
 
50
- for ref in adata.obs['Reference_strand'].unique():
51
- ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
117
+ for ref in adata.obs["Reference_strand"].unique():
118
+ ref_subset = adata[adata.obs["Reference_strand"] == ref].copy()
52
119
  if ref_subset.shape[0] == 0:
53
120
  continue
54
121
 
@@ -56,20 +123,22 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
56
123
  if groupby is not None:
57
124
  if isinstance(groupby, str):
58
125
  groupby = [groupby]
126
+
59
127
  def format_group_label(row):
128
+ """Format a group label string from obs row values."""
60
129
  return ",".join([f"{col}={row[col]}" for col in groupby])
61
130
 
62
- combined_label = '__'.join(groupby)
131
+ combined_label = "__".join(groupby)
63
132
  ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
64
133
  groups = ref_subset.obs[combined_label].unique()
65
134
  else:
66
135
  combined_label = None
67
- groups = ['all']
136
+ groups = ["all"]
68
137
 
69
138
  results_dict[ref] = {}
70
139
 
71
140
  for group in groups:
72
- if group == 'all':
141
+ if group == "all":
73
142
  group_subset = ref_subset
74
143
  else:
75
144
  group_subset = ref_subset[ref_subset.obs[combined_label] == group]
@@ -85,7 +154,7 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
85
154
 
86
155
  # Matrix and labels
87
156
  X = random_fill_nans(site_subset.X.copy())
88
- y = site_subset.obs['activity_status'].map({'Active': 1, 'Silent': 0}).values
157
+ y = site_subset.obs["activity_status"].map({"Active": 1, "Silent": 0}).values
89
158
  P_active = np.mean(y)
90
159
 
91
160
  # Analysis
@@ -104,7 +173,9 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
104
173
  continue
105
174
 
106
175
  P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
107
- P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (1 - P_methylated)
176
+ P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (
177
+ 1 - P_methylated
178
+ )
108
179
  RR = P_active_given_methylated / P_active_given_unmethylated
109
180
 
110
181
  _, p_value = fisher_exact(table)
@@ -112,49 +183,13 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
112
183
  relative_risks.append(RR)
113
184
  p_values.append(p_value)
114
185
 
115
- results_df, sig_df = compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values)
186
+ results_df, sig_df = compute_risk_df(
187
+ ref, site_subset, positions_list, relative_risks, p_values
188
+ )
116
189
  results_dict[ref][group] = (results_df, sig_df)
117
190
 
118
191
  return results_dict
119
192
 
120
- import copy
121
- import warnings
122
- from typing import Dict, Any, List, Optional, Tuple, Union
123
-
124
- import numpy as np
125
- import pandas as pd
126
- import matplotlib.pyplot as plt
127
-
128
- # optional imports
129
- try:
130
- from joblib import Parallel, delayed
131
- JOBLIB_AVAILABLE = True
132
- except Exception:
133
- JOBLIB_AVAILABLE = False
134
-
135
- try:
136
- from scipy.stats import chi2_contingency
137
- SCIPY_STATS_AVAILABLE = True
138
- except Exception:
139
- SCIPY_STATS_AVAILABLE = False
140
-
141
- # -----------------------------
142
- # Compute positionwise statistic (multi-method + simple site_types)
143
- # -----------------------------
144
- import numpy as np
145
- import pandas as pd
146
- from typing import List, Optional, Sequence, Dict, Any, Tuple
147
- from contextlib import contextmanager
148
- from joblib import Parallel, delayed, cpu_count
149
- import joblib
150
- from tqdm import tqdm
151
- from scipy.stats import chi2_contingency
152
- import warnings
153
- import matplotlib.pyplot as plt
154
- from itertools import cycle
155
- import os
156
- import warnings
157
-
158
193
 
159
194
  # ---------------------------
160
195
  # joblib <-> tqdm integration
@@ -165,7 +200,10 @@ def tqdm_joblib(tqdm_object: tqdm):
165
200
  old = joblib.parallel.BatchCompletionCallBack
166
201
 
167
202
  class TqdmBatchCompletionCallback(old): # type: ignore
203
+ """Joblib callback that updates a tqdm progress bar."""
204
+
168
205
  def __call__(self, *args, **kwargs):
206
+ """Update the progress bar when a batch completes."""
169
207
  try:
170
208
  tqdm_object.update(n=self.batch_size)
171
209
  except Exception:
@@ -183,6 +221,16 @@ def tqdm_joblib(tqdm_object: tqdm):
183
221
  # row workers (upper-triangle only)
184
222
  # ---------------------------
185
223
  def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
224
+ """Compute chi-squared statistics for one row of a pairwise matrix.
225
+
226
+ Args:
227
+ i: Row index.
228
+ X_bin: Binary matrix.
229
+ min_count_for_pairwise: Minimum count for valid comparison.
230
+
231
+ Returns:
232
+ Tuple of (row_index, row_values).
233
+ """
186
234
  n_pos = X_bin.shape[1]
187
235
  row = np.full((n_pos,), np.nan, dtype=float)
188
236
  xi = X_bin[:, i]
@@ -202,7 +250,19 @@ def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tup
202
250
  return (i, row)
203
251
 
204
252
 
205
- def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
253
+ def _relative_risk_row_job(
254
+ i: int, X_bin: np.ndarray, min_count_for_pairwise: int
255
+ ) -> Tuple[int, np.ndarray]:
256
+ """Compute relative-risk values for one row of a pairwise matrix.
257
+
258
+ Args:
259
+ i: Row index.
260
+ X_bin: Binary matrix.
261
+ min_count_for_pairwise: Minimum count for valid comparison.
262
+
263
+ Returns:
264
+ Tuple of (row_index, row_values).
265
+ """
206
266
  n_pos = X_bin.shape[1]
207
267
  row = np.full((n_pos,), np.nan, dtype=float)
208
268
  xi = X_bin[:, i]
@@ -226,8 +286,9 @@ def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: in
226
286
  row[j] = np.nan
227
287
  return (i, row)
228
288
 
289
+
229
290
  def compute_positionwise_statistics(
230
- adata,
291
+ adata: "ad.AnnData",
231
292
  layer: str,
232
293
  methods: Sequence[str] = ("pearson",),
233
294
  sample_col: str = "Barcode",
@@ -238,13 +299,21 @@ def compute_positionwise_statistics(
238
299
  min_count_for_pairwise: int = 10,
239
300
  max_threads: Optional[int] = None,
240
301
  reverse_indices_on_store: bool = False,
241
- ):
242
- """
243
- Compute per-(sample,ref) positionwise matrices for methods in `methods`.
244
-
245
- Results stored at:
246
- adata.uns[output_key][method][ (sample, ref) ] = DataFrame
247
- adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
302
+ ) -> None:
303
+ """Compute per-(sample, ref) positionwise matrices for selected methods.
304
+
305
+ Args:
306
+ adata: AnnData object to analyze.
307
+ layer: Layer name to use for statistics.
308
+ methods: Methods to compute (e.g., ``"pearson"``).
309
+ sample_col: Obs column containing sample identifiers.
310
+ ref_col: Obs column containing reference identifiers.
311
+ site_types: Optional site types to subset positions.
312
+ encoding: ``"signed"`` or ``"binary"`` encoding.
313
+ output_key: Key prefix for results stored in ``adata.uns``.
314
+ min_count_for_pairwise: Minimum counts for pairwise comparisons.
315
+ max_threads: Maximum number of threads.
316
+ reverse_indices_on_store: Whether to reverse indices on output storage.
248
317
  """
249
318
  if isinstance(methods, str):
250
319
  methods = [methods]
@@ -349,7 +418,10 @@ def compute_positionwise_statistics(
349
418
  Xc = X_bin - col_mean # nan preserved
350
419
  Xc0 = np.nan_to_num(Xc, nan=0.0)
351
420
  cov = Xc0.T @ Xc0
352
- denom = (np.sqrt((Xc0**2).sum(axis=0))[:, None] * np.sqrt((Xc0**2).sum(axis=0))[None, :])
421
+ denom = (
422
+ np.sqrt((Xc0**2).sum(axis=0))[:, None]
423
+ * np.sqrt((Xc0**2).sum(axis=0))[None, :]
424
+ )
353
425
  with np.errstate(divide="ignore", invalid="ignore"):
354
426
  mat = np.where(denom != 0.0, cov / denom, np.nan)
355
427
  elif m == "binary_covariance":
@@ -366,8 +438,12 @@ def compute_positionwise_statistics(
366
438
  else:
367
439
  worker = _relative_risk_row_job
368
440
  out = np.full((n_pos, n_pos), np.nan, dtype=float)
369
- tasks = (delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos))
370
- pbar_rows = tqdm(total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False)
441
+ tasks = (
442
+ delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos)
443
+ )
444
+ pbar_rows = tqdm(
445
+ total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False
446
+ )
371
447
  with tqdm_joblib(pbar_rows):
372
448
  results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
373
449
  pbar_rows.close()
@@ -406,6 +482,7 @@ def compute_positionwise_statistics(
406
482
  # Plotting function
407
483
  # ---------------------------
408
484
 
485
+
409
486
  def plot_positionwise_matrices(
410
487
  adata,
411
488
  methods: List[str],
@@ -427,9 +504,10 @@ def plot_positionwise_matrices(
427
504
  """
428
505
  Plot grids of matrices for each method with pagination and rotated sample-row labels.
429
506
 
430
- New args:
507
+ Args:
431
508
  - rows_per_page: how many sample rows per page/figure (pagination)
432
509
  - sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
510
+
433
511
  Returns:
434
512
  dict mapping method -> list of saved filenames (empty list if figures were shown).
435
513
  """
@@ -474,7 +552,12 @@ def plot_positionwise_matrices(
474
552
  # try stringified tuple keys (some callers store differently)
475
553
  for k in store.keys():
476
554
  try:
477
- if isinstance(k, tuple) and len(k) == 2 and str(k[0]) == str(sample) and str(k[1]) == str(ref):
555
+ if (
556
+ isinstance(k, tuple)
557
+ and len(k) == 2
558
+ and str(k[0]) == str(sample)
559
+ and str(k[1]) == str(ref)
560
+ ):
478
561
  return store[k]
479
562
  if isinstance(k, str) and key_s == k:
480
563
  return store[k]
@@ -486,7 +569,10 @@ def plot_positionwise_matrices(
486
569
  m = method.lower()
487
570
  method_store = adata.uns.get(output_key, {}).get(m, {})
488
571
  if not method_store:
489
- warnings.warn(f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.", stacklevel=2)
572
+ warnings.warn(
573
+ f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.",
574
+ stacklevel=2,
575
+ )
490
576
  saved_files_by_method[method] = []
491
577
  continue
492
578
 
@@ -507,21 +593,41 @@ def plot_positionwise_matrices(
507
593
 
508
594
  # decide per-method defaults
509
595
  if m == "pearson":
510
- vmn = -1.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
511
- vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
596
+ vmn = (
597
+ -1.0
598
+ if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
599
+ else (vmin.get(m) if isinstance(vmin, dict) else vmin)
600
+ )
601
+ vmx = (
602
+ 1.0
603
+ if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
604
+ else (vmax.get(m) if isinstance(vmax, dict) else vmax)
605
+ )
512
606
  vmn = -1.0 if vmn is None else vmn
513
607
  vmx = 1.0 if vmx is None else vmx
514
608
  elif m == "binary_covariance":
515
- vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
516
- vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
609
+ vmn = (
610
+ 0.0
611
+ if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
612
+ else (vmin.get(m) if isinstance(vmin, dict) else vmin)
613
+ )
614
+ vmx = (
615
+ 1.0
616
+ if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
617
+ else (vmax.get(m) if isinstance(vmax, dict) else vmax)
618
+ )
517
619
  vmn = 0.0 if vmn is None else vmn
518
620
  vmx = 1.0 if vmx is None else vmx
519
621
  else:
520
- vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
622
+ vmn = (
623
+ 0.0
624
+ if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
625
+ else (vmin.get(m) if isinstance(vmin, dict) else vmin)
626
+ )
521
627
  if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
522
628
  vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
523
629
  else:
524
- vmx = (vmax.get(m) if isinstance(vmax, dict) else vmax)
630
+ vmx = vmax.get(m) if isinstance(vmax, dict) else vmax
525
631
  vmn = 0.0 if vmn is None else vmn
526
632
  if vmx is None:
527
633
  vmx = 1.0
@@ -536,7 +642,9 @@ def plot_positionwise_matrices(
536
642
  ncols = max(1, len(references))
537
643
  fig_w = ncols * figsize_per_cell[0]
538
644
  fig_h = nrows * figsize_per_cell[1]
539
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
645
+ fig, axes = plt.subplots(
646
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
647
+ )
540
648
 
541
649
  # leave margin for rotated sample labels
542
650
  plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
@@ -548,13 +656,24 @@ def plot_positionwise_matrices(
548
656
  ax = axes[r_idx][c_idx]
549
657
  df = _get_df_from_store(method_store, sample, ref)
550
658
  if not isinstance(df, pd.DataFrame) or df.size == 0:
551
- ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
659
+ ax.text(
660
+ 0.5,
661
+ 0.5,
662
+ "No data",
663
+ ha="center",
664
+ va="center",
665
+ transform=ax.transAxes,
666
+ fontsize=10,
667
+ color="gray",
668
+ )
552
669
  ax.set_xticks([])
553
670
  ax.set_yticks([])
554
671
  else:
555
672
  mat = df.values.astype(float)
556
673
  origin = "upper" if flip_display_axes else "lower"
557
- im = ax.imshow(mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap)
674
+ im = ax.imshow(
675
+ mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap
676
+ )
558
677
  any_plotted = True
559
678
  ax.set_xticks([])
560
679
  ax.set_yticks([])
@@ -569,9 +688,21 @@ def plot_positionwise_matrices(
569
688
  ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
570
689
  y_center = 0.5 * (ax_y0 + ax_y1)
571
690
  # place text at x=0.01 (just inside left margin); rotation controls orientation
572
- fig.text(0.01, y_center, str(chunk[r_idx]), va="center", ha="left", rotation=sample_label_rotation, fontsize=9)
573
-
574
- fig.suptitle(f"{method} — per-sample x per-reference matrices (page {page_idx+1}/{n_pages})", fontsize=12, y=0.99)
691
+ fig.text(
692
+ 0.01,
693
+ y_center,
694
+ str(chunk[r_idx]),
695
+ va="center",
696
+ ha="left",
697
+ rotation=sample_label_rotation,
698
+ fontsize=9,
699
+ )
700
+
701
+ fig.suptitle(
702
+ f"{method} — per-sample x per-reference matrices (page {page_idx + 1}/{n_pages})",
703
+ fontsize=12,
704
+ y=0.99,
705
+ )
575
706
  fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
576
707
 
577
708
  # colorbar (shared)
@@ -587,7 +718,7 @@ def plot_positionwise_matrices(
587
718
 
588
719
  # save or show
589
720
  if output_dir:
590
- fname = f"positionwise_{method}_page{page_idx+1}.png"
721
+ fname = f"positionwise_{method}_page{page_idx + 1}.png"
591
722
  outpath = os.path.join(output_dir, fname)
592
723
  plt.savefig(outpath, bbox_inches="tight")
593
724
  saved_files.append(outpath)
@@ -1,36 +1,53 @@
1
1
  # ------------------------- Utilities -------------------------
2
- def random_fill_nans(X):
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING, Sequence
5
+
6
+ if TYPE_CHECKING:
7
+ import anndata as ad
8
+ import numpy as np
9
+
10
+
11
+ def random_fill_nans(X: "np.ndarray") -> "np.ndarray":
12
+ """Fill NaNs with random values in-place.
13
+
14
+ Args:
15
+ X: Input array with NaNs.
16
+
17
+ Returns:
18
+ numpy.ndarray: Array with NaNs replaced by random values.
19
+ """
3
20
  import numpy as np
21
+
4
22
  nan_mask = np.isnan(X)
5
23
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
24
  return X
7
25
 
26
+
8
27
  def calculate_row_entropy(
9
- adata,
10
- layer,
11
- output_key="entropy",
12
- site_config=None,
13
- ref_col="Reference_strand",
14
- encoding="signed",
15
- max_threads=None):
16
- """
17
- Adds an obs column to the adata that calculates entropy within each read from a given layer
18
- when looking at each site type passed in the site_config list.
19
-
20
- Parameters:
21
- adata (AnnData): The annotated data matrix.
22
- layer (str): Name of the layer to use for entropy calculation.
23
- method (str): Unused currently. Placeholder for potential future methods.
24
- output_key (str): Base name for the entropy column in adata.obs.
25
- site_config (dict): {ref: [site_types]} for masking relevant sites.
26
- ref_col (str): Column in adata.obs denoting reference strands.
27
- encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
28
- max_threads (int): Number of threads for parallel processing.
28
+ adata: "ad.AnnData",
29
+ layer: str,
30
+ output_key: str = "entropy",
31
+ site_config: dict[str, Sequence[str]] | None = None,
32
+ ref_col: str = "Reference_strand",
33
+ encoding: str = "signed",
34
+ max_threads: int | None = None,
35
+ ) -> None:
36
+ """Add per-read entropy values to ``adata.obs``.
37
+
38
+ Args:
39
+ adata: Annotated data matrix.
40
+ layer: Layer name to use for entropy calculation.
41
+ output_key: Base name for the entropy column in ``adata.obs``.
42
+ site_config: Mapping of reference to site types for masking.
43
+ ref_col: Obs column containing reference strands.
44
+ encoding: ``"signed"`` (1/-1/0) or ``"binary"`` (1/0/NaN).
45
+ max_threads: Number of threads for parallel processing.
29
46
  """
30
47
  import numpy as np
31
48
  import pandas as pd
32
- from scipy.stats import entropy
33
49
  from joblib import Parallel, delayed
50
+ from scipy.stats import entropy
34
51
  from tqdm import tqdm
35
52
 
36
53
  entropy_values = []
@@ -55,12 +72,14 @@ def calculate_row_entropy(
55
72
  X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
56
73
 
57
74
  def compute_entropy(row):
75
+ """Compute Shannon entropy for a row with NaNs ignored."""
58
76
  counts = pd.Series(row).value_counts(dropna=True).sort_index()
59
77
  probs = counts / counts.sum()
60
78
  return entropy(probs, base=2)
61
79
 
62
80
  entropies = Parallel(n_jobs=max_threads)(
63
- delayed(compute_entropy)(X_bin[i, :]) for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
81
+ delayed(compute_entropy)(X_bin[i, :])
82
+ for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
64
83
  )
65
84
 
66
85
  entropy_values.extend(entropies)
@@ -69,6 +88,7 @@ def calculate_row_entropy(
69
88
  entropy_key = f"{output_key}_entropy"
70
89
  adata.obs.loc[row_indices, entropy_key] = entropy_values
71
90
 
91
+
72
92
  def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sorted=True):
73
93
  """
74
94
  Fast autocorrelation over real genomic spacing.
@@ -125,13 +145,13 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
125
145
  j += 1
126
146
  # consider pairs (i, i+1...j-1)
127
147
  if j - i > 1:
128
- diffs = pos[i+1:j] - pos[i] # 1..max_lag
129
- contrib = xc[i] * xc[i+1:j] # contributions for each pair
148
+ diffs = pos[i + 1 : j] - pos[i] # 1..max_lag
149
+ contrib = xc[i] * xc[i + 1 : j] # contributions for each pair
130
150
  # accumulate weighted sums and counts per lag
131
- lag_sums[:max_lag+1] += np.bincount(diffs, weights=contrib,
132
- minlength=max_lag+1)[:max_lag+1]
133
- lag_counts[:max_lag+1] += np.bincount(diffs,
134
- minlength=max_lag+1)[:max_lag+1]
151
+ lag_sums[: max_lag + 1] += np.bincount(diffs, weights=contrib, minlength=max_lag + 1)[
152
+ : max_lag + 1
153
+ ]
154
+ lag_counts[: max_lag + 1] += np.bincount(diffs, minlength=max_lag + 1)[: max_lag + 1]
135
155
 
136
156
  autocorr = np.full(max_lag + 1, np.nan, dtype=np.float64)
137
157
  nz = lag_counts > 0
@@ -140,6 +160,7 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
140
160
 
141
161
  return autocorr.astype(np.float32, copy=False)
142
162
 
163
+
143
164
  # def binary_autocorrelation_with_spacing(row, positions, max_lag=1000):
144
165
  # """
145
166
  # Compute autocorrelation within a read using real genomic spacing from `positions`.