smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+ from contextlib import contextmanager
6
+ from itertools import cycle
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy.stats import chi2_contingency
12
+ from tqdm import tqdm
13
+
14
+ from smftools.optional_imports import require
15
+
16
+ if TYPE_CHECKING:
17
+ import anndata as ad
18
+
19
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
20
+
21
+ # -----------------------------
22
+ # Compute positionwise statistic (multi-method + simple site_types)
23
+ # -----------------------------
24
+
25
+
1
26
  # ------------------------- Utilities -------------------------
2
- def random_fill_nans(X):
27
+ def random_fill_nans(X: np.ndarray) -> np.ndarray:
28
+ """Fill NaNs with random values in-place.
29
+
30
+ Args:
31
+ X: Input array with NaNs.
32
+
33
+ Returns:
34
+ numpy.ndarray: Array with NaNs replaced by random values.
35
+ """
3
36
  import numpy as np
37
+
4
38
  nan_mask = np.isnan(X)
5
39
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
40
  return X
7
41
 
8
- def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
9
- """
10
- Perform Bayesian-style methylation vs activity analysis independently within each group.
11
42
 
12
- Parameters:
13
- adata (AnnData): Annotated data matrix.
14
- sites (list of str): List of site keys (e.g., ['GpC_site', 'CpG_site']).
15
- alpha (float): FDR threshold for significance.
16
- groupby (str or list of str): Column(s) in adata.obs to group by.
43
+ def calculate_relative_risk_on_activity(
44
+ adata: "ad.AnnData",
45
+ sites: Sequence[str],
46
+ alpha: float = 0.05,
47
+ groupby: str | Sequence[str] | None = None,
48
+ ) -> dict:
49
+ """Perform methylation vs. activity analysis within each group.
50
+
51
+ Args:
52
+ adata: Annotated data matrix.
53
+ sites: Site keys (e.g., ``["GpC_site", "CpG_site"]``).
54
+ alpha: FDR threshold for significance.
55
+ groupby: Obs column(s) to group by.
17
56
 
18
57
  Returns:
19
- results_dict (dict): Dictionary with structure:
20
- results_dict[ref][group_label] = (results_df, sig_df)
58
+ dict: Mapping of reference -> group label -> ``(results_df, sig_df)``.
21
59
  """
22
60
  import numpy as np
23
61
  import pandas as pd
@@ -25,30 +63,44 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
25
63
  from statsmodels.stats.multitest import multipletests
26
64
 
27
65
  def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
28
- p_adj = multipletests(p_values, method='fdr_bh')[1] if p_values else []
66
+ """Build result and significant-data DataFrames for a reference.
67
+
68
+ Args:
69
+ ref: Reference name.
70
+ site_subset: AnnData subset restricted to sites.
71
+ positions_list: Positions tested.
72
+ relative_risks: Relative risk values.
73
+ p_values: Raw p-values.
74
+
75
+ Returns:
76
+ Tuple of (results_df, sig_df).
77
+ """
78
+ p_adj = multipletests(p_values, method="fdr_bh")[1] if p_values else []
29
79
 
30
80
  genomic_positions = np.array(site_subset.var_names)[positions_list]
31
81
  is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
32
82
  is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
33
83
 
34
- results_df = pd.DataFrame({
35
- 'Feature_Index': positions_list,
36
- 'Genomic_Position': genomic_positions.astype(int),
37
- 'Relative_Risk': relative_risks,
38
- 'Adjusted_P_Value': p_adj,
39
- 'GpC_Site': is_gpc_site,
40
- 'CpG_Site': is_cpg_site
41
- })
42
-
43
- results_df['log2_Relative_Risk'] = np.log2(results_df['Relative_Risk'].replace(0, 1e-300))
44
- results_df['-log10_Adj_P'] = -np.log10(results_df['Adjusted_P_Value'].replace(0, 1e-300))
45
- sig_df = results_df[results_df['Adjusted_P_Value'] < alpha]
84
+ results_df = pd.DataFrame(
85
+ {
86
+ "Feature_Index": positions_list,
87
+ "Genomic_Position": genomic_positions.astype(int),
88
+ "Relative_Risk": relative_risks,
89
+ "Adjusted_P_Value": p_adj,
90
+ "GpC_Site": is_gpc_site,
91
+ "CpG_Site": is_cpg_site,
92
+ }
93
+ )
94
+
95
+ results_df["log2_Relative_Risk"] = np.log2(results_df["Relative_Risk"].replace(0, 1e-300))
96
+ results_df["-log10_Adj_P"] = -np.log10(results_df["Adjusted_P_Value"].replace(0, 1e-300))
97
+ sig_df = results_df[results_df["Adjusted_P_Value"] < alpha]
46
98
  return results_df, sig_df
47
99
 
48
100
  results_dict = {}
49
101
 
50
- for ref in adata.obs['Reference_strand'].unique():
51
- ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
102
+ for ref in adata.obs["Reference_strand"].unique():
103
+ ref_subset = adata[adata.obs["Reference_strand"] == ref].copy()
52
104
  if ref_subset.shape[0] == 0:
53
105
  continue
54
106
 
@@ -56,20 +108,22 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
56
108
  if groupby is not None:
57
109
  if isinstance(groupby, str):
58
110
  groupby = [groupby]
111
+
59
112
  def format_group_label(row):
113
+ """Format a group label string from obs row values."""
60
114
  return ",".join([f"{col}={row[col]}" for col in groupby])
61
115
 
62
- combined_label = '__'.join(groupby)
116
+ combined_label = "__".join(groupby)
63
117
  ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
64
118
  groups = ref_subset.obs[combined_label].unique()
65
119
  else:
66
120
  combined_label = None
67
- groups = ['all']
121
+ groups = ["all"]
68
122
 
69
123
  results_dict[ref] = {}
70
124
 
71
125
  for group in groups:
72
- if group == 'all':
126
+ if group == "all":
73
127
  group_subset = ref_subset
74
128
  else:
75
129
  group_subset = ref_subset[ref_subset.obs[combined_label] == group]
@@ -85,7 +139,7 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
85
139
 
86
140
  # Matrix and labels
87
141
  X = random_fill_nans(site_subset.X.copy())
88
- y = site_subset.obs['activity_status'].map({'Active': 1, 'Silent': 0}).values
142
+ y = site_subset.obs["activity_status"].map({"Active": 1, "Silent": 0}).values
89
143
  P_active = np.mean(y)
90
144
 
91
145
  # Analysis
@@ -104,7 +158,9 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
104
158
  continue
105
159
 
106
160
  P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
107
- P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (1 - P_methylated)
161
+ P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (
162
+ 1 - P_methylated
163
+ )
108
164
  RR = P_active_given_methylated / P_active_given_unmethylated
109
165
 
110
166
  _, p_value = fisher_exact(table)
@@ -112,49 +168,13 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
112
168
  relative_risks.append(RR)
113
169
  p_values.append(p_value)
114
170
 
115
- results_df, sig_df = compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values)
171
+ results_df, sig_df = compute_risk_df(
172
+ ref, site_subset, positions_list, relative_risks, p_values
173
+ )
116
174
  results_dict[ref][group] = (results_df, sig_df)
117
175
 
118
176
  return results_dict
119
177
 
120
- import copy
121
- import warnings
122
- from typing import Dict, Any, List, Optional, Tuple, Union
123
-
124
- import numpy as np
125
- import pandas as pd
126
- import matplotlib.pyplot as plt
127
-
128
- # optional imports
129
- try:
130
- from joblib import Parallel, delayed
131
- JOBLIB_AVAILABLE = True
132
- except Exception:
133
- JOBLIB_AVAILABLE = False
134
-
135
- try:
136
- from scipy.stats import chi2_contingency
137
- SCIPY_STATS_AVAILABLE = True
138
- except Exception:
139
- SCIPY_STATS_AVAILABLE = False
140
-
141
- # -----------------------------
142
- # Compute positionwise statistic (multi-method + simple site_types)
143
- # -----------------------------
144
- import numpy as np
145
- import pandas as pd
146
- from typing import List, Optional, Sequence, Dict, Any, Tuple
147
- from contextlib import contextmanager
148
- from joblib import Parallel, delayed, cpu_count
149
- import joblib
150
- from tqdm import tqdm
151
- from scipy.stats import chi2_contingency
152
- import warnings
153
- import matplotlib.pyplot as plt
154
- from itertools import cycle
155
- import os
156
- import warnings
157
-
158
178
 
159
179
  # ---------------------------
160
180
  # joblib <-> tqdm integration
@@ -162,10 +182,15 @@ import warnings
162
182
  @contextmanager
163
183
  def tqdm_joblib(tqdm_object: tqdm):
164
184
  """Context manager to patch joblib to update a tqdm progress bar."""
185
+ joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
186
+
165
187
  old = joblib.parallel.BatchCompletionCallBack
166
188
 
167
189
  class TqdmBatchCompletionCallback(old): # type: ignore
190
+ """Joblib callback that updates a tqdm progress bar."""
191
+
168
192
  def __call__(self, *args, **kwargs):
193
+ """Update the progress bar when a batch completes."""
169
194
  try:
170
195
  tqdm_object.update(n=self.batch_size)
171
196
  except Exception:
@@ -183,6 +208,16 @@ def tqdm_joblib(tqdm_object: tqdm):
183
208
  # row workers (upper-triangle only)
184
209
  # ---------------------------
185
210
  def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
211
+ """Compute chi-squared statistics for one row of a pairwise matrix.
212
+
213
+ Args:
214
+ i: Row index.
215
+ X_bin: Binary matrix.
216
+ min_count_for_pairwise: Minimum count for valid comparison.
217
+
218
+ Returns:
219
+ Tuple of (row_index, row_values).
220
+ """
186
221
  n_pos = X_bin.shape[1]
187
222
  row = np.full((n_pos,), np.nan, dtype=float)
188
223
  xi = X_bin[:, i]
@@ -202,7 +237,19 @@ def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tup
202
237
  return (i, row)
203
238
 
204
239
 
205
- def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
240
+ def _relative_risk_row_job(
241
+ i: int, X_bin: np.ndarray, min_count_for_pairwise: int
242
+ ) -> Tuple[int, np.ndarray]:
243
+ """Compute relative-risk values for one row of a pairwise matrix.
244
+
245
+ Args:
246
+ i: Row index.
247
+ X_bin: Binary matrix.
248
+ min_count_for_pairwise: Minimum count for valid comparison.
249
+
250
+ Returns:
251
+ Tuple of (row_index, row_values).
252
+ """
206
253
  n_pos = X_bin.shape[1]
207
254
  row = np.full((n_pos,), np.nan, dtype=float)
208
255
  xi = X_bin[:, i]
@@ -226,8 +273,9 @@ def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: in
226
273
  row[j] = np.nan
227
274
  return (i, row)
228
275
 
276
+
229
277
  def compute_positionwise_statistics(
230
- adata,
278
+ adata: "ad.AnnData",
231
279
  layer: str,
232
280
  methods: Sequence[str] = ("pearson",),
233
281
  sample_col: str = "Barcode",
@@ -238,14 +286,24 @@ def compute_positionwise_statistics(
238
286
  min_count_for_pairwise: int = 10,
239
287
  max_threads: Optional[int] = None,
240
288
  reverse_indices_on_store: bool = False,
241
- ):
289
+ ) -> None:
290
+ """Compute per-(sample, ref) positionwise matrices for selected methods.
291
+
292
+ Args:
293
+ adata: AnnData object to analyze.
294
+ layer: Layer name to use for statistics.
295
+ methods: Methods to compute (e.g., ``"pearson"``).
296
+ sample_col: Obs column containing sample identifiers.
297
+ ref_col: Obs column containing reference identifiers.
298
+ site_types: Optional site types to subset positions.
299
+ encoding: ``"signed"`` or ``"binary"`` encoding.
300
+ output_key: Key prefix for results stored in ``adata.uns``.
301
+ min_count_for_pairwise: Minimum counts for pairwise comparisons.
302
+ max_threads: Maximum number of threads.
303
+ reverse_indices_on_store: Whether to reverse indices on output storage.
242
304
  """
243
- Compute per-(sample,ref) positionwise matrices for methods in `methods`.
305
+ joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
244
306
 
245
- Results stored at:
246
- adata.uns[output_key][method][ (sample, ref) ] = DataFrame
247
- adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
248
- """
249
307
  if isinstance(methods, str):
250
308
  methods = [methods]
251
309
  methods = [m.lower() for m in methods]
@@ -256,7 +314,7 @@ def compute_positionwise_statistics(
256
314
 
257
315
  # workers
258
316
  if max_threads is None or max_threads <= 0:
259
- n_jobs = max(1, cpu_count() or 1)
317
+ n_jobs = max(1, joblib.cpu_count() or 1)
260
318
  else:
261
319
  n_jobs = max(1, int(max_threads))
262
320
 
@@ -349,7 +407,10 @@ def compute_positionwise_statistics(
349
407
  Xc = X_bin - col_mean # nan preserved
350
408
  Xc0 = np.nan_to_num(Xc, nan=0.0)
351
409
  cov = Xc0.T @ Xc0
352
- denom = (np.sqrt((Xc0**2).sum(axis=0))[:, None] * np.sqrt((Xc0**2).sum(axis=0))[None, :])
410
+ denom = (
411
+ np.sqrt((Xc0**2).sum(axis=0))[:, None]
412
+ * np.sqrt((Xc0**2).sum(axis=0))[None, :]
413
+ )
353
414
  with np.errstate(divide="ignore", invalid="ignore"):
354
415
  mat = np.where(denom != 0.0, cov / denom, np.nan)
355
416
  elif m == "binary_covariance":
@@ -366,10 +427,15 @@ def compute_positionwise_statistics(
366
427
  else:
367
428
  worker = _relative_risk_row_job
368
429
  out = np.full((n_pos, n_pos), np.nan, dtype=float)
369
- tasks = (delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos))
370
- pbar_rows = tqdm(total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False)
430
+ tasks = (
431
+ joblib.delayed(worker)(i, X_bin, min_count_for_pairwise)
432
+ for i in range(n_pos)
433
+ )
434
+ pbar_rows = tqdm(
435
+ total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False
436
+ )
371
437
  with tqdm_joblib(pbar_rows):
372
- results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
438
+ results = joblib.Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
373
439
  pbar_rows.close()
374
440
  for i, row in results:
375
441
  out[int(i), :] = row
@@ -406,6 +472,7 @@ def compute_positionwise_statistics(
406
472
  # Plotting function
407
473
  # ---------------------------
408
474
 
475
+
409
476
  def plot_positionwise_matrices(
410
477
  adata,
411
478
  methods: List[str],
@@ -427,9 +494,10 @@ def plot_positionwise_matrices(
427
494
  """
428
495
  Plot grids of matrices for each method with pagination and rotated sample-row labels.
429
496
 
430
- New args:
497
+ Args:
431
498
  - rows_per_page: how many sample rows per page/figure (pagination)
432
499
  - sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
500
+
433
501
  Returns:
434
502
  dict mapping method -> list of saved filenames (empty list if figures were shown).
435
503
  """
@@ -474,7 +542,12 @@ def plot_positionwise_matrices(
474
542
  # try stringified tuple keys (some callers store differently)
475
543
  for k in store.keys():
476
544
  try:
477
- if isinstance(k, tuple) and len(k) == 2 and str(k[0]) == str(sample) and str(k[1]) == str(ref):
545
+ if (
546
+ isinstance(k, tuple)
547
+ and len(k) == 2
548
+ and str(k[0]) == str(sample)
549
+ and str(k[1]) == str(ref)
550
+ ):
478
551
  return store[k]
479
552
  if isinstance(k, str) and key_s == k:
480
553
  return store[k]
@@ -486,7 +559,10 @@ def plot_positionwise_matrices(
486
559
  m = method.lower()
487
560
  method_store = adata.uns.get(output_key, {}).get(m, {})
488
561
  if not method_store:
489
- warnings.warn(f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.", stacklevel=2)
562
+ warnings.warn(
563
+ f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.",
564
+ stacklevel=2,
565
+ )
490
566
  saved_files_by_method[method] = []
491
567
  continue
492
568
 
@@ -507,21 +583,41 @@ def plot_positionwise_matrices(
507
583
 
508
584
  # decide per-method defaults
509
585
  if m == "pearson":
510
- vmn = -1.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
511
- vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
586
+ vmn = (
587
+ -1.0
588
+ if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
589
+ else (vmin.get(m) if isinstance(vmin, dict) else vmin)
590
+ )
591
+ vmx = (
592
+ 1.0
593
+ if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
594
+ else (vmax.get(m) if isinstance(vmax, dict) else vmax)
595
+ )
512
596
  vmn = -1.0 if vmn is None else vmn
513
597
  vmx = 1.0 if vmx is None else vmx
514
598
  elif m == "binary_covariance":
515
- vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
516
- vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
599
+ vmn = (
600
+ 0.0
601
+ if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
602
+ else (vmin.get(m) if isinstance(vmin, dict) else vmin)
603
+ )
604
+ vmx = (
605
+ 1.0
606
+ if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
607
+ else (vmax.get(m) if isinstance(vmax, dict) else vmax)
608
+ )
517
609
  vmn = 0.0 if vmn is None else vmn
518
610
  vmx = 1.0 if vmx is None else vmx
519
611
  else:
520
- vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
612
+ vmn = (
613
+ 0.0
614
+ if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
615
+ else (vmin.get(m) if isinstance(vmin, dict) else vmin)
616
+ )
521
617
  if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
522
618
  vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
523
619
  else:
524
- vmx = (vmax.get(m) if isinstance(vmax, dict) else vmax)
620
+ vmx = vmax.get(m) if isinstance(vmax, dict) else vmax
525
621
  vmn = 0.0 if vmn is None else vmn
526
622
  if vmx is None:
527
623
  vmx = 1.0
@@ -536,7 +632,9 @@ def plot_positionwise_matrices(
536
632
  ncols = max(1, len(references))
537
633
  fig_w = ncols * figsize_per_cell[0]
538
634
  fig_h = nrows * figsize_per_cell[1]
539
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
635
+ fig, axes = plt.subplots(
636
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
637
+ )
540
638
 
541
639
  # leave margin for rotated sample labels
542
640
  plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
@@ -548,13 +646,24 @@ def plot_positionwise_matrices(
548
646
  ax = axes[r_idx][c_idx]
549
647
  df = _get_df_from_store(method_store, sample, ref)
550
648
  if not isinstance(df, pd.DataFrame) or df.size == 0:
551
- ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
649
+ ax.text(
650
+ 0.5,
651
+ 0.5,
652
+ "No data",
653
+ ha="center",
654
+ va="center",
655
+ transform=ax.transAxes,
656
+ fontsize=10,
657
+ color="gray",
658
+ )
552
659
  ax.set_xticks([])
553
660
  ax.set_yticks([])
554
661
  else:
555
662
  mat = df.values.astype(float)
556
663
  origin = "upper" if flip_display_axes else "lower"
557
- im = ax.imshow(mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap)
664
+ im = ax.imshow(
665
+ mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap
666
+ )
558
667
  any_plotted = True
559
668
  ax.set_xticks([])
560
669
  ax.set_yticks([])
@@ -569,9 +678,21 @@ def plot_positionwise_matrices(
569
678
  ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
570
679
  y_center = 0.5 * (ax_y0 + ax_y1)
571
680
  # place text at x=0.01 (just inside left margin); rotation controls orientation
572
- fig.text(0.01, y_center, str(chunk[r_idx]), va="center", ha="left", rotation=sample_label_rotation, fontsize=9)
573
-
574
- fig.suptitle(f"{method} — per-sample x per-reference matrices (page {page_idx+1}/{n_pages})", fontsize=12, y=0.99)
681
+ fig.text(
682
+ 0.01,
683
+ y_center,
684
+ str(chunk[r_idx]),
685
+ va="center",
686
+ ha="left",
687
+ rotation=sample_label_rotation,
688
+ fontsize=9,
689
+ )
690
+
691
+ fig.suptitle(
692
+ f"{method} — per-sample x per-reference matrices (page {page_idx + 1}/{n_pages})",
693
+ fontsize=12,
694
+ y=0.99,
695
+ )
575
696
  fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
576
697
 
577
698
  # colorbar (shared)
@@ -587,7 +708,7 @@ def plot_positionwise_matrices(
587
708
 
588
709
  # save or show
589
710
  if output_dir:
590
- fname = f"positionwise_{method}_page{page_idx+1}.png"
711
+ fname = f"positionwise_{method}_page{page_idx + 1}.png"
591
712
  outpath = os.path.join(output_dir, fname)
592
713
  plt.savefig(outpath, bbox_inches="tight")
593
714
  saved_files.append(outpath)
@@ -1,36 +1,53 @@
1
1
  # ------------------------- Utilities -------------------------
2
- def random_fill_nans(X):
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING, Sequence
5
+
6
+ if TYPE_CHECKING:
7
+ import anndata as ad
8
+ import numpy as np
9
+
10
+
11
+ def random_fill_nans(X: "np.ndarray") -> "np.ndarray":
12
+ """Fill NaNs with random values in-place.
13
+
14
+ Args:
15
+ X: Input array with NaNs.
16
+
17
+ Returns:
18
+ numpy.ndarray: Array with NaNs replaced by random values.
19
+ """
3
20
  import numpy as np
21
+
4
22
  nan_mask = np.isnan(X)
5
23
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
24
  return X
7
25
 
26
+
8
27
  def calculate_row_entropy(
9
- adata,
10
- layer,
11
- output_key="entropy",
12
- site_config=None,
13
- ref_col="Reference_strand",
14
- encoding="signed",
15
- max_threads=None):
16
- """
17
- Adds an obs column to the adata that calculates entropy within each read from a given layer
18
- when looking at each site type passed in the site_config list.
19
-
20
- Parameters:
21
- adata (AnnData): The annotated data matrix.
22
- layer (str): Name of the layer to use for entropy calculation.
23
- method (str): Unused currently. Placeholder for potential future methods.
24
- output_key (str): Base name for the entropy column in adata.obs.
25
- site_config (dict): {ref: [site_types]} for masking relevant sites.
26
- ref_col (str): Column in adata.obs denoting reference strands.
27
- encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
28
- max_threads (int): Number of threads for parallel processing.
28
+ adata: "ad.AnnData",
29
+ layer: str,
30
+ output_key: str = "entropy",
31
+ site_config: dict[str, Sequence[str]] | None = None,
32
+ ref_col: str = "Reference_strand",
33
+ encoding: str = "signed",
34
+ max_threads: int | None = None,
35
+ ) -> None:
36
+ """Add per-read entropy values to ``adata.obs``.
37
+
38
+ Args:
39
+ adata: Annotated data matrix.
40
+ layer: Layer name to use for entropy calculation.
41
+ output_key: Base name for the entropy column in ``adata.obs``.
42
+ site_config: Mapping of reference to site types for masking.
43
+ ref_col: Obs column containing reference strands.
44
+ encoding: ``"signed"`` (1/-1/0) or ``"binary"`` (1/0/NaN).
45
+ max_threads: Number of threads for parallel processing.
29
46
  """
30
47
  import numpy as np
31
48
  import pandas as pd
32
- from scipy.stats import entropy
33
49
  from joblib import Parallel, delayed
50
+ from scipy.stats import entropy
34
51
  from tqdm import tqdm
35
52
 
36
53
  entropy_values = []
@@ -55,12 +72,14 @@ def calculate_row_entropy(
55
72
  X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
56
73
 
57
74
  def compute_entropy(row):
75
+ """Compute Shannon entropy for a row with NaNs ignored."""
58
76
  counts = pd.Series(row).value_counts(dropna=True).sort_index()
59
77
  probs = counts / counts.sum()
60
78
  return entropy(probs, base=2)
61
79
 
62
80
  entropies = Parallel(n_jobs=max_threads)(
63
- delayed(compute_entropy)(X_bin[i, :]) for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
81
+ delayed(compute_entropy)(X_bin[i, :])
82
+ for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
64
83
  )
65
84
 
66
85
  entropy_values.extend(entropies)
@@ -69,6 +88,7 @@ def calculate_row_entropy(
69
88
  entropy_key = f"{output_key}_entropy"
70
89
  adata.obs.loc[row_indices, entropy_key] = entropy_values
71
90
 
91
+
72
92
  def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sorted=True):
73
93
  """
74
94
  Fast autocorrelation over real genomic spacing.
@@ -125,13 +145,13 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
125
145
  j += 1
126
146
  # consider pairs (i, i+1...j-1)
127
147
  if j - i > 1:
128
- diffs = pos[i+1:j] - pos[i] # 1..max_lag
129
- contrib = xc[i] * xc[i+1:j] # contributions for each pair
148
+ diffs = pos[i + 1 : j] - pos[i] # 1..max_lag
149
+ contrib = xc[i] * xc[i + 1 : j] # contributions for each pair
130
150
  # accumulate weighted sums and counts per lag
131
- lag_sums[:max_lag+1] += np.bincount(diffs, weights=contrib,
132
- minlength=max_lag+1)[:max_lag+1]
133
- lag_counts[:max_lag+1] += np.bincount(diffs,
134
- minlength=max_lag+1)[:max_lag+1]
151
+ lag_sums[: max_lag + 1] += np.bincount(diffs, weights=contrib, minlength=max_lag + 1)[
152
+ : max_lag + 1
153
+ ]
154
+ lag_counts[: max_lag + 1] += np.bincount(diffs, minlength=max_lag + 1)[: max_lag + 1]
135
155
 
136
156
  autocorr = np.full(max_lag + 1, np.nan, dtype=np.float64)
137
157
  nz = lag_counts > 0
@@ -140,6 +160,7 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
140
160
 
141
161
  return autocorr.astype(np.float32, copy=False)
142
162
 
163
+
143
164
  # def binary_autocorrelation_with_spacing(row, positions, max_lag=1000):
144
165
  # """
145
166
  # Compute autocorrelation within a read using real genomic spacing from `positions`.