smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. smftools/__init__.py +9 -4
  2. smftools/_version.py +1 -1
  3. smftools/cli.py +184 -0
  4. smftools/config/__init__.py +1 -0
  5. smftools/config/conversion.yaml +33 -0
  6. smftools/config/deaminase.yaml +56 -0
  7. smftools/config/default.yaml +253 -0
  8. smftools/config/direct.yaml +17 -0
  9. smftools/config/experiment_config.py +1191 -0
  10. smftools/hmm/HMM.py +1576 -0
  11. smftools/hmm/__init__.py +20 -0
  12. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  13. smftools/hmm/call_hmm_peaks.py +106 -0
  14. smftools/{tools → hmm}/display_hmm.py +3 -3
  15. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  16. smftools/{tools → hmm}/train_hmm.py +1 -1
  17. smftools/informatics/__init__.py +0 -2
  18. smftools/informatics/archived/deaminase_smf.py +132 -0
  19. smftools/informatics/fast5_to_pod5.py +4 -1
  20. smftools/informatics/helpers/__init__.py +3 -4
  21. smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
  22. smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
  23. smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
  24. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
  25. smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
  26. smftools/informatics/helpers/discover_input_files.py +100 -0
  27. smftools/informatics/helpers/extract_base_identities.py +29 -3
  28. smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
  29. smftools/informatics/helpers/find_conversion_sites.py +5 -4
  30. smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
  31. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  32. smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
  33. smftools/informatics/helpers/split_and_index_BAM.py +1 -5
  34. smftools/load_adata.py +1346 -0
  35. smftools/machine_learning/__init__.py +12 -0
  36. smftools/machine_learning/data/__init__.py +2 -0
  37. smftools/machine_learning/data/anndata_data_module.py +234 -0
  38. smftools/machine_learning/evaluation/__init__.py +2 -0
  39. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  40. smftools/machine_learning/evaluation/evaluators.py +223 -0
  41. smftools/machine_learning/inference/__init__.py +3 -0
  42. smftools/machine_learning/inference/inference_utils.py +27 -0
  43. smftools/machine_learning/inference/lightning_inference.py +68 -0
  44. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  45. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  46. smftools/machine_learning/models/base.py +295 -0
  47. smftools/machine_learning/models/cnn.py +138 -0
  48. smftools/machine_learning/models/lightning_base.py +345 -0
  49. smftools/machine_learning/models/mlp.py +26 -0
  50. smftools/{tools → machine_learning}/models/positional.py +3 -2
  51. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  52. smftools/machine_learning/models/sklearn_models.py +273 -0
  53. smftools/machine_learning/models/transformer.py +303 -0
  54. smftools/machine_learning/training/__init__.py +2 -0
  55. smftools/machine_learning/training/train_lightning_model.py +135 -0
  56. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  57. smftools/plotting/__init__.py +4 -1
  58. smftools/plotting/autocorrelation_plotting.py +611 -0
  59. smftools/plotting/general_plotting.py +566 -89
  60. smftools/plotting/hmm_plotting.py +260 -0
  61. smftools/plotting/qc_plotting.py +270 -0
  62. smftools/preprocessing/__init__.py +13 -8
  63. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  64. smftools/preprocessing/append_base_context.py +122 -0
  65. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  66. smftools/preprocessing/calculate_complexity_II.py +248 -0
  67. smftools/preprocessing/calculate_coverage.py +10 -1
  68. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  69. smftools/preprocessing/clean_NaN.py +17 -1
  70. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  71. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  72. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  73. smftools/preprocessing/invert_adata.py +12 -5
  74. smftools/preprocessing/load_sample_sheet.py +19 -4
  75. smftools/readwrite.py +849 -43
  76. smftools/tools/__init__.py +3 -32
  77. smftools/tools/calculate_umap.py +5 -5
  78. smftools/tools/general_tools.py +3 -3
  79. smftools/tools/position_stats.py +468 -106
  80. smftools/tools/read_stats.py +115 -1
  81. smftools/tools/spatial_autocorrelation.py +562 -0
  82. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
  83. smftools-0.2.1.dist-info/RECORD +161 -0
  84. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  85. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  86. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  87. smftools/informatics/load_adata.py +0 -182
  88. smftools/preprocessing/append_C_context.py +0 -82
  89. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  90. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  91. smftools/preprocessing/filter_reads_on_length.py +0 -51
  92. smftools/tools/call_hmm_peaks.py +0 -105
  93. smftools/tools/data/__init__.py +0 -2
  94. smftools/tools/data/anndata_data_module.py +0 -90
  95. smftools/tools/evaluation/__init__.py +0 -0
  96. smftools/tools/inference/__init__.py +0 -1
  97. smftools/tools/inference/lightning_inference.py +0 -41
  98. smftools/tools/models/base.py +0 -14
  99. smftools/tools/models/cnn.py +0 -34
  100. smftools/tools/models/lightning_base.py +0 -41
  101. smftools/tools/models/mlp.py +0 -17
  102. smftools/tools/models/sklearn_models.py +0 -40
  103. smftools/tools/models/transformer.py +0 -133
  104. smftools/tools/training/__init__.py +0 -1
  105. smftools/tools/training/train_lightning_model.py +0 -47
  106. smftools-0.1.7.dist-info/RECORD +0 -136
  107. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  108. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  109. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  110. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  111. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  112. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  113. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  114. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  115. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  116. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  117. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  118. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  119. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  120. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,101 @@
1
+ def calculate_read_modification_stats(adata,
2
+ reference_column,
3
+ sample_names_col,
4
+ mod_target_bases,
5
+ uns_flag="read_modification_stats_calculated",
6
+ bypass=False,
7
+ force_redo=False
8
+ ):
9
+ """
10
+ Adds methylation/deamination statistics for each read.
11
+ Indicates the read GpC and CpG methylation ratio to other_C methylation (background false positive metric for Cytosine MTase SMF).
12
+
13
+ Parameters:
14
+ adata (AnnData): An adata object
15
+ reference_column (str): String representing the name of the Reference column to use
16
+ sample_names_col (str): String representing the name of the sample name column to use
17
+ mod_target_bases:
18
+
19
+ Returns:
20
+ None
21
+ """
22
+ import numpy as np
23
+ import anndata as ad
24
+ import pandas as pd
25
+
26
+ # Only run if not already performed
27
+ already = bool(adata.uns.get(uns_flag, False))
28
+ if (already and not force_redo) or bypass:
29
+ # QC already performed; nothing to do
30
+ return
31
+
32
+ print('Calculating read level Modification statistics')
33
+
34
+ references = set(adata.obs[reference_column])
35
+ sample_names = set(adata.obs[sample_names_col])
36
+ site_types = []
37
+
38
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
39
+ site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'any_C_site']
40
+
41
+ if 'A' in mod_target_bases:
42
+ site_types += ['A_site']
43
+
44
+ for site_type in site_types:
45
+ adata.obs[f'Modified_{site_type}_count'] = pd.Series(0, index=adata.obs_names, dtype=int)
46
+ adata.obs[f'Total_{site_type}_in_read'] = pd.Series(0, index=adata.obs_names, dtype=int)
47
+ adata.obs[f'Fraction_{site_type}_modified'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
48
+ adata.obs[f'Total_{site_type}_in_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=int)
49
+ adata.obs[f'Valid_{site_type}_in_read_vs_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
50
+
51
+
52
+ for ref in references:
53
+ ref_subset = adata[adata.obs[reference_column] == ref]
54
+ for site_type in site_types:
55
+ print(f'Iterating over {ref}_{site_type}')
56
+ observation_matrix = ref_subset.obsm[f'{ref}_{site_type}']
57
+ total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
58
+ total_positions_in_reference = observation_matrix.shape[1]
59
+ fraction_valid_positions_in_read_vs_ref = total_positions_in_read / total_positions_in_reference
60
+ number_mods_in_read = np.nansum(observation_matrix, axis=1)
61
+ fraction_modified = number_mods_in_read / total_positions_in_read
62
+
63
+ fraction_modified = np.divide(
64
+ number_mods_in_read,
65
+ total_positions_in_read,
66
+ out=np.full_like(number_mods_in_read, np.nan, dtype=float),
67
+ where=total_positions_in_read != 0
68
+ )
69
+
70
+ temp_obs_data = pd.DataFrame({f'Total_{site_type}_in_read': total_positions_in_read,
71
+ f'Modified_{site_type}_count': number_mods_in_read,
72
+ f'Fraction_{site_type}_modified': fraction_modified,
73
+ f'Total_{site_type}_in_reference': total_positions_in_reference,
74
+ f'Valid_{site_type}_in_read_vs_reference': fraction_valid_positions_in_read_vs_ref},
75
+ index=ref_subset.obs.index)
76
+
77
+ adata.obs.update(temp_obs_data)
78
+
79
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
80
+ with np.errstate(divide='ignore', invalid='ignore'):
81
+ gpc_to_c_ratio = np.divide(
82
+ adata.obs[f'Fraction_GpC_site_modified'],
83
+ adata.obs[f'Fraction_other_C_site_modified'],
84
+ out=np.full_like(adata.obs[f'Fraction_GpC_site_modified'], np.nan, dtype=float),
85
+ where=adata.obs[f'Fraction_other_C_site_modified'] != 0
86
+ )
87
+
88
+ cpg_to_c_ratio = np.divide(
89
+ adata.obs[f'Fraction_CpG_site_modified'],
90
+ adata.obs[f'Fraction_other_C_site_modified'],
91
+ out=np.full_like(adata.obs[f'Fraction_CpG_site_modified'], np.nan, dtype=float),
92
+ where=adata.obs[f'Fraction_other_C_site_modified'] != 0
93
+ )
94
+
95
+ adata.obs['GpC_to_other_C_mod_ratio'] = gpc_to_c_ratio
96
+ adata.obs['CpG_to_other_C_mod_ratio'] = cpg_to_c_ratio
97
+
98
+ # mark as done
99
+ adata.uns[uns_flag] = True
100
+
101
+ return
@@ -1,4 +1,9 @@
1
- def clean_NaN(adata, layer=None):
1
+ def clean_NaN(adata,
2
+ layer=None,
3
+ uns_flag='clean_NaN_performed',
4
+ bypass=False,
5
+ force_redo=True
6
+ ):
2
7
  """
3
8
  Append layers to adata that contain NaN cleaning strategies.
4
9
 
@@ -14,6 +19,12 @@ def clean_NaN(adata, layer=None):
14
19
  import anndata as ad
15
20
  from ..readwrite import adata_to_df
16
21
 
22
+ # Only run if not already performed
23
+ already = bool(adata.uns.get(uns_flag, False))
24
+ if (already and not force_redo) or bypass:
25
+ # QC already performed; nothing to do
26
+ return
27
+
17
28
  # Ensure the specified layer exists
18
29
  if layer and layer not in adata.layers:
19
30
  raise ValueError(f"Layer '{layer}' not found in adata.layers.")
@@ -44,3 +55,8 @@ def clean_NaN(adata, layer=None):
44
55
  print('Making layer: nan_half')
45
56
  df_nan_half = df.fillna(0.5)
46
57
  adata.layers['nan_half'] = df_nan_half.values
58
+
59
+ # mark as done
60
+ adata.uns[uns_flag] = True
61
+
62
+ return None
@@ -0,0 +1,158 @@
1
+ from typing import Optional, Union, Sequence
2
+ import numpy as np
3
+ import pandas as pd
4
+ import anndata as ad
5
+
6
+ def filter_reads_on_length_quality_mapping(
7
+ adata: ad.AnnData,
8
+ filter_on_coordinates: Union[bool, Sequence] = False,
9
+ # New single-range params (preferred):
10
+ read_length: Optional[Sequence[float]] = None, # e.g. [min, max]
11
+ length_ratio: Optional[Sequence[float]] = None, # e.g. [min, max]
12
+ read_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
13
+ mapping_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
14
+ uns_flag: str = "reads_removed_failing_length_quality_mapping_qc",
15
+ bypass: bool = False,
16
+ force_redo: bool = True
17
+ ) -> ad.AnnData:
18
+ """
19
+ Filter AnnData by coordinate window, read length, length ratios, read quality and mapping quality.
20
+
21
+ New: you may pass `read_length=[min, max]` (or tuple) to set both min/max in one argument.
22
+ If `read_length` is given it overrides scalar min/max variants (which are not present in this signature).
23
+ Same behavior supported for `length_ratio`, `read_quality`, `mapping_quality`.
24
+
25
+ Returns a filtered copy of the input AnnData and marks adata.uns[uns_flag] = True.
26
+ """
27
+ # early exit
28
+ already = bool(adata.uns.get(uns_flag, False))
29
+ if bypass or (already and not force_redo):
30
+ return adata
31
+
32
+ adata_work = adata
33
+ start_n = adata_work.n_obs
34
+
35
+ # --- coordinate filtering (unchanged) ---
36
+ if filter_on_coordinates:
37
+ try:
38
+ low, high = tuple(filter_on_coordinates)
39
+ except Exception:
40
+ raise ValueError("filter_on_coordinates must be False or an iterable of two numbers (low, high).")
41
+ try:
42
+ var_coords = np.array([float(v) for v in adata_work.var_names])
43
+ if low > high:
44
+ low, high = high, low
45
+ col_mask_bool = (var_coords >= float(low)) & (var_coords <= float(high))
46
+ if not col_mask_bool.any():
47
+ start_idx = int(np.argmin(np.abs(var_coords - float(low))))
48
+ end_idx = int(np.argmin(np.abs(var_coords - float(high))))
49
+ lo_idx, hi_idx = min(start_idx, end_idx), max(start_idx, end_idx)
50
+ selected_cols = list(adata_work.var_names[lo_idx : hi_idx + 1])
51
+ else:
52
+ selected_cols = list(adata_work.var_names[col_mask_bool])
53
+ print(f"Subsetting adata to coordinates between {low} and {high}: keeping {len(selected_cols)} variables.")
54
+ adata_work = adata_work[:, selected_cols].copy()
55
+ except Exception:
56
+ print("Warning: could not interpret adata.var_names as numeric coordinates — skipping coordinate filtering.")
57
+
58
+ # --- helper to coerce range inputs ---
59
+ def _coerce_range(range_arg):
60
+ """
61
+ Given range_arg which may be None or a 2-seq [min,max], return (min_or_None, max_or_None).
62
+ If both present and min>max they are swapped.
63
+ """
64
+ if range_arg is None:
65
+ return None, None
66
+ if not isinstance(range_arg, (list, tuple, np.ndarray)) or len(range_arg) != 2:
67
+ # not a 2-element range -> treat as no restriction (or you could raise)
68
+ return None, None
69
+ lo_raw, hi_raw = range_arg[0], range_arg[1]
70
+ lo = None if lo_raw is None else float(lo_raw)
71
+ hi = None if hi_raw is None else float(hi_raw)
72
+ if (lo is not None) and (hi is not None) and lo > hi:
73
+ lo, hi = hi, lo
74
+ return lo, hi
75
+
76
+ # Resolve ranges using only the provided range arguments
77
+ rl_min, rl_max = _coerce_range(read_length)
78
+ lr_min, lr_max = _coerce_range(length_ratio)
79
+ rq_min, rq_max = _coerce_range(read_quality)
80
+ mq_min, mq_max = _coerce_range(mapping_quality)
81
+
82
+ # --- build combined mask ---
83
+ combined_mask = pd.Series(True, index=adata_work.obs.index)
84
+
85
+ # read length filter
86
+ if (rl_min is not None) or (rl_max is not None):
87
+ if "mapped_length" not in adata_work.obs.columns:
88
+ print("Warning: 'mapped_length' not found in adata.obs — skipping read_length filter.")
89
+ else:
90
+ vals = pd.to_numeric(adata_work.obs["mapped_length"], errors="coerce")
91
+ mask = pd.Series(True, index=adata_work.obs.index)
92
+ if rl_min is not None:
93
+ mask &= (vals >= rl_min)
94
+ if rl_max is not None:
95
+ mask &= (vals <= rl_max)
96
+ mask &= vals.notna()
97
+ combined_mask &= mask
98
+ print(f"Planned read_length filter: min={rl_min}, max={rl_max}")
99
+
100
+ # length ratio filter
101
+ if (lr_min is not None) or (lr_max is not None):
102
+ if "mapped_length_to_reference_length_ratio" not in adata_work.obs.columns:
103
+ print("Warning: 'mapped_length_to_reference_length_ratio' not found in adata.obs — skipping length_ratio filter.")
104
+ else:
105
+ vals = pd.to_numeric(adata_work.obs["mapped_length_to_reference_length_ratio"], errors="coerce")
106
+ mask = pd.Series(True, index=adata_work.obs.index)
107
+ if lr_min is not None:
108
+ mask &= (vals >= lr_min)
109
+ if lr_max is not None:
110
+ mask &= (vals <= lr_max)
111
+ mask &= vals.notna()
112
+ combined_mask &= mask
113
+ print(f"Planned length_ratio filter: min={lr_min}, max={lr_max}")
114
+
115
+ # read quality filter (supporting optional range but typically min only)
116
+ if (rq_min is not None) or (rq_max is not None):
117
+ if "read_quality" not in adata_work.obs.columns:
118
+ print("Warning: 'read_quality' not found in adata.obs — skipping read_quality filter.")
119
+ else:
120
+ vals = pd.to_numeric(adata_work.obs["read_quality"], errors="coerce")
121
+ mask = pd.Series(True, index=adata_work.obs.index)
122
+ if rq_min is not None:
123
+ mask &= (vals >= rq_min)
124
+ if rq_max is not None:
125
+ mask &= (vals <= rq_max)
126
+ mask &= vals.notna()
127
+ combined_mask &= mask
128
+ print(f"Planned read_quality filter: min={rq_min}, max={rq_max}")
129
+
130
+ # mapping quality filter (supporting optional range but typically min only)
131
+ if (mq_min is not None) or (mq_max is not None):
132
+ if "mapping_quality" not in adata_work.obs.columns:
133
+ print("Warning: 'mapping_quality' not found in adata.obs — skipping mapping_quality filter.")
134
+ else:
135
+ vals = pd.to_numeric(adata_work.obs["mapping_quality"], errors="coerce")
136
+ mask = pd.Series(True, index=adata_work.obs.index)
137
+ if mq_min is not None:
138
+ mask &= (vals >= mq_min)
139
+ if mq_max is not None:
140
+ mask &= (vals <= mq_max)
141
+ mask &= vals.notna()
142
+ combined_mask &= mask
143
+ print(f"Planned mapping_quality filter: min={mq_min}, max={mq_max}")
144
+
145
+ # Apply combined mask and report
146
+ s0 = adata_work.n_obs
147
+ combined_mask_bool = combined_mask.astype(bool).values
148
+ adata_work = adata_work[combined_mask_bool].copy()
149
+ s1 = adata_work.n_obs
150
+ print(f"Combined filters applied: kept {s1} / {s0} reads (removed {s0 - s1})")
151
+
152
+ final_n = adata_work.n_obs
153
+ print(f"Filtering complete: start={start_n}, final={final_n}, removed={start_n - final_n}")
154
+
155
+ # mark as done
156
+ adata_work.uns[uns_flag] = True
157
+
158
+ return adata_work
@@ -0,0 +1,352 @@
1
+ import math
2
+ import gc
3
+ import numpy as np
4
+ import pandas as pd
5
+ import anndata as ad
6
+ from typing import Optional, Sequence, List
7
+
8
+ def filter_reads_on_modification_thresholds(
9
+ adata: ad.AnnData,
10
+ smf_modality: str,
11
+ mod_target_bases: List[str] = [],
12
+ gpc_thresholds: Optional[Sequence[float]] = None,
13
+ cpg_thresholds: Optional[Sequence[float]] = None,
14
+ any_c_thresholds: Optional[Sequence[float]] = None,
15
+ a_thresholds: Optional[Sequence[float]] = None,
16
+ use_other_c_as_background: bool = False,
17
+ min_valid_fraction_positions_in_read_vs_ref: Optional[float] = None,
18
+ uns_flag: str = 'reads_filtered_on_modification_thresholds',
19
+ bypass: bool = False,
20
+ force_redo: bool = False,
21
+ reference_column: str = 'Reference_strand',
22
+ # memory-control options:
23
+ batch_size: int = 200,
24
+ compute_obs_if_missing: bool = True,
25
+ treat_zero_as_invalid: bool = False
26
+ ) -> ad.AnnData:
27
+ """
28
+ Memory-efficient filtering by per-read modification thresholds.
29
+
30
+ - If required obs columns exist, uses them directly (fast).
31
+ - Otherwise, computes the relevant per-read metrics per-reference in batches
32
+ and writes them into adata.obs before filtering.
33
+
34
+ Parameters of interest (same semantics as your original function):
35
+ - gpc_thresholds, cpg_thresholds, any_c_thresholds, a_thresholds:
36
+ each should be [min, max] (floats 0..1) or None.
37
+ - use_other_c_as_background: require GpC/CpG > other_C background (if present).
38
+ - min_valid_fraction_positions_in_read_vs_ref: minimum fraction of valid sites
39
+ in the read vs reference (0..1). If None, this check is skipped.
40
+ - compute_obs_if_missing: if True, compute Fraction_* and Valid_* obs columns
41
+ if they are not already present, using a low-memory per-ref strategy.
42
+ - treat_zero_as_invalid: if True, a zero in X counts as invalid (non-site).
43
+ If False, zeros are considered valid positions (adjust to your data semantics).
44
+ """
45
+
46
+ # quick exit flags:
47
+ already = bool(adata.uns.get(uns_flag, False))
48
+ if (already and not force_redo) or bypass:
49
+ return adata
50
+
51
+ # helper: check whether obs columns exist for a particular mod type
52
+ def obs_has_columns_for(mod_type):
53
+ col_pref = {
54
+ "GpC": ("Fraction_GpC_site_modified", f"Valid_GpC_site_in_read_vs_reference"),
55
+ "CpG": ("Fraction_CpG_site_modified", f"Valid_CpG_site_in_read_vs_reference"),
56
+ "C": ("Fraction_any_C_site_modified", f"Valid_any_C_site_in_read_vs_reference"),
57
+ "A": ("Fraction_A_site_modified", f"Valid_A_site_in_read_vs_reference"),
58
+ }.get(mod_type, (None, None))
59
+ return (col_pref[0] in adata.obs.columns) and (col_pref[1] in adata.obs.columns)
60
+
61
+ # if all required obs columns are present, use them directly (fast path)
62
+ required_present = True
63
+ for mt, thr in (("GpC", gpc_thresholds), ("CpG", cpg_thresholds), ("C", any_c_thresholds), ("A", a_thresholds)):
64
+ if thr is not None and mt in mod_target_bases:
65
+ if not obs_has_columns_for(mt):
66
+ required_present = False
67
+ break
68
+
69
+ # If required obs columns are not present and compute_obs_if_missing is False => error
70
+ if not required_present and not compute_obs_if_missing:
71
+ raise RuntimeError(
72
+ "Required per-read summary columns not found in adata.obs and compute_obs_if_missing is False."
73
+ )
74
+
75
+ # Build mapping from reference -> var column names (expected pattern)
76
+ # e.g. var column names: "{ref}_GpC_site", "{ref}_CpG_site", "{ref}_any_C_site", "{ref}_other_C_site", "{ref}_A_site"
77
+ # If your var column naming differs, adjust these suffixes.
78
+ refs = list(adata.obs[reference_column].astype('category').cat.categories)
79
+
80
+ def _find_var_col_for(ref, suffix):
81
+ name = f"{ref}_{suffix}"
82
+ if name in adata.var.columns:
83
+ return name
84
+ return None
85
+
86
+ # If we need to compute obs summaries: do so per-reference in batches
87
+ if not required_present and compute_obs_if_missing:
88
+ n_obs = adata.n_obs
89
+ # prepare empty columns in obs if they don't exist; fill later
90
+ # We'll create only columns that are relevant to mod_target_bases
91
+ create_cols = {}
92
+ if "GpC" in mod_target_bases:
93
+ create_cols["Fraction_GpC_site_modified"] = np.full((n_obs,), np.nan)
94
+ create_cols["Valid_GpC_site_in_read_vs_reference"] = np.full((n_obs,), np.nan)
95
+ # optional background ratio if other_C exists
96
+ create_cols["GpC_to_other_C_mod_ratio"] = np.full((n_obs,), np.nan)
97
+ if "CpG" in mod_target_bases:
98
+ create_cols["Fraction_CpG_site_modified"] = np.full((n_obs,), np.nan)
99
+ create_cols["Valid_CpG_site_in_read_vs_reference"] = np.full((n_obs,), np.nan)
100
+ create_cols["CpG_to_other_C_mod_ratio"] = np.full((n_obs,), np.nan)
101
+ if "C" in mod_target_bases:
102
+ create_cols["Fraction_any_C_site_modified"] = np.full((n_obs,), np.nan)
103
+ create_cols["Valid_any_C_site_in_read_vs_reference"] = np.full((n_obs,), np.nan)
104
+ if "A" in mod_target_bases:
105
+ create_cols["Fraction_A_site_modified"] = np.full((n_obs,), np.nan)
106
+ create_cols["Valid_A_site_in_read_vs_reference"] = np.full((n_obs,), np.nan)
107
+
108
+ # helper to compute for one reference and one suffix
109
+ def _compute_for_ref_and_suffix(ref, suffix, out_frac_arr, out_valid_arr):
110
+ """
111
+ Compute fraction modified and valid fraction for reads mapping to 'ref'
112
+ using var column named f"{ref}_{suffix}" to select var columns.
113
+ """
114
+ var_colname = _find_var_col_for(ref, suffix)
115
+ if var_colname is None:
116
+ # nothing to compute
117
+ return
118
+
119
+ # var boolean mask (which var columns belong to this suffix for the ref)
120
+ try:
121
+ var_mask_bool = np.asarray(adata.var[var_colname].values).astype(bool)
122
+ except Exception:
123
+ # if var has values not boolean, attempt coercion
124
+ var_mask_bool = np.asarray(pd.to_numeric(adata.var[var_colname], errors='coerce').fillna(0).astype(bool))
125
+
126
+ if not var_mask_bool.any():
127
+ return
128
+ col_indices = np.where(var_mask_bool)[0]
129
+ n_cols_for_ref = len(col_indices)
130
+ if n_cols_for_ref == 0:
131
+ return
132
+
133
+ # rows that belong to this reference
134
+ row_indices_all = np.where(adata.obs[reference_column].values == ref)[0]
135
+ if len(row_indices_all) == 0:
136
+ return
137
+
138
+ # process rows for this reference in batches to avoid allocating huge slices
139
+ for start in range(0, len(row_indices_all), batch_size):
140
+ block_rows_idx = row_indices_all[start : start + batch_size]
141
+ # slice rows x selected columns
142
+ X_block = adata.X[block_rows_idx, :][:, col_indices]
143
+
144
+ # If sparse, sum(axis=1) returns a (nrows,1) sparse/dense object -> coerce to 1d array
145
+ # If dense, this will be a dense array but limited to batch_size * n_cols_for_ref
146
+ # Count modified (assume numeric values where >0 indicate modification)
147
+ try:
148
+ # use vectorized sums; works for sparse/dense
149
+ # "modified_count" - count of entries > 0 (or > 0.5 if binary probabilities)
150
+ if hasattr(X_block, "toarray") and not isinstance(X_block, np.ndarray):
151
+ # sparse or matrix-like: convert sums carefully
152
+ # We compute:
153
+ # modified_count = (X_block > 0).sum(axis=1)
154
+ # valid_count = (non-nan if float data else non-zero) per row
155
+ # For sparse, .data are only stored nonzeros, so (X_block > 0).sum is fine
156
+ modified_count = np.asarray((X_block > 0).sum(axis=1)).ravel()
157
+ if np.isnan(X_block.data).any() if hasattr(X_block, 'data') else False:
158
+ # if sparse with stored NaNs (!) handle differently - unlikely
159
+ valid_count = np.asarray(~np.isnan(X_block.toarray()).sum(axis=1)).ravel()
160
+ else:
161
+ if treat_zero_as_invalid:
162
+ # valid = number of non-zero entries
163
+ valid_count = np.asarray((X_block != 0).sum(axis=1)).ravel()
164
+ else:
165
+ # treat all positions as valid positions (they exist in reference) -> denominator = n_cols_for_ref
166
+ valid_count = np.full_like(modified_count, n_cols_for_ref, dtype=float)
167
+ else:
168
+ # dense numpy
169
+ Xb = np.asarray(X_block)
170
+ if np.isnan(Xb).any():
171
+ valid_count = np.sum(~np.isnan(Xb), axis=1).astype(float)
172
+ else:
173
+ if treat_zero_as_invalid:
174
+ valid_count = np.sum(Xb != 0, axis=1).astype(float)
175
+ else:
176
+ valid_count = np.full((Xb.shape[0],), float(n_cols_for_ref))
177
+ modified_count = np.sum(Xb > 0, axis=1).astype(float)
178
+ except Exception:
179
+ # fallback to safe dense conversion per-row (shouldn't be needed usually)
180
+ Xb = np.asarray(X_block.toarray() if hasattr(X_block, "toarray") else X_block)
181
+ if Xb.size == 0:
182
+ modified_count = np.zeros(len(block_rows_idx), dtype=float)
183
+ valid_count = np.zeros(len(block_rows_idx), dtype=float)
184
+ else:
185
+ if np.isnan(Xb).any():
186
+ valid_count = np.sum(~np.isnan(Xb), axis=1).astype(float)
187
+ else:
188
+ if treat_zero_as_invalid:
189
+ valid_count = np.sum(Xb != 0, axis=1).astype(float)
190
+ else:
191
+ valid_count = np.full((Xb.shape[0],), float(n_cols_for_ref))
192
+ modified_count = np.sum(Xb > 0, axis=1).astype(float)
193
+
194
+ # fraction modified = modified_count / valid_count (guard divide-by-zero)
195
+ frac = np.zeros_like(modified_count, dtype=float)
196
+ mask_valid_nonzero = (valid_count > 0)
197
+ frac[mask_valid_nonzero] = modified_count[mask_valid_nonzero] / valid_count[mask_valid_nonzero]
198
+
199
+ # write to out arrays
200
+ out_frac_arr[block_rows_idx] = frac
201
+ # valid fraction relative to reference = valid_count / n_cols_for_ref
202
+ out_valid_arr[block_rows_idx] = np.zeros_like(valid_count, dtype=float)
203
+ out_valid_arr[block_rows_idx][mask_valid_nonzero] = (valid_count[mask_valid_nonzero] / float(n_cols_for_ref))
204
+
205
+ # free block memory ASAP
206
+ del X_block, modified_count, valid_count, frac
207
+ gc.collect()
208
+
209
+ # compute for each reference and required suffixes
210
+ # GpC
211
+ if "GpC" in mod_target_bases:
212
+ for ref in refs:
213
+ _compute_for_ref_and_suffix(ref, "GpC_site", create_cols["Fraction_GpC_site_modified"], create_cols["Valid_GpC_site_in_read_vs_reference"])
214
+ # other_C (for background)
215
+ # We'll also compute 'other_C' per reference if it exists
216
+ other_c_per_ref = {}
217
+ for ref in refs:
218
+ other_col = _find_var_col_for(ref, "other_C_site")
219
+ if other_col:
220
+ other_c_per_ref[ref] = np.where(np.asarray(adata.var[other_col].values).astype(bool))[0]
221
+
222
+ # CpG
223
+ if "CpG" in mod_target_bases:
224
+ for ref in refs:
225
+ _compute_for_ref_and_suffix(ref, "CpG_site", create_cols["Fraction_CpG_site_modified"], create_cols["Valid_CpG_site_in_read_vs_reference"])
226
+
227
+ # any C
228
+ if "C" in mod_target_bases:
229
+ for ref in refs:
230
+ _compute_for_ref_and_suffix(ref, "any_C_site", create_cols["Fraction_any_C_site_modified"], create_cols["Valid_any_C_site_in_read_vs_reference"])
231
+
232
+ # A
233
+ if "A" in mod_target_bases:
234
+ for ref in refs:
235
+ _compute_for_ref_and_suffix(ref, "A_site", create_cols["Fraction_A_site_modified"], create_cols["Valid_A_site_in_read_vs_reference"])
236
+
237
+ # write created arrays into adata.obs
238
+ for cname, arr in create_cols.items():
239
+ adata.obs[cname] = arr
240
+
241
+ # optionally compute GpC_to_other_C_mod_ratio and CpG_to_other_C_mod_ratio (if other_C masks exist)
242
+ if "GpC" in mod_target_bases and use_other_c_as_background:
243
+ # compute per-ref background ratio if both exist
244
+ # Simplest approach: if 'Fraction_GpC_site_modified' and 'Fraction_other_C_site_modified' exist, compute ratio
245
+ if "Fraction_other_C_site_modified" in adata.obs.columns:
246
+ with np.errstate(divide='ignore', invalid='ignore'):
247
+ ratio = adata.obs["Fraction_GpC_site_modified"].astype(float) / adata.obs["Fraction_other_C_site_modified"].astype(float)
248
+ adata.obs["GpC_to_other_C_mod_ratio"] = ratio.fillna(0.0)
249
+ else:
250
+ adata.obs["GpC_to_other_C_mod_ratio"] = np.nan
251
+
252
+ if "CpG" in mod_target_bases and use_other_c_as_background:
253
+ if "Fraction_other_C_site_modified" in adata.obs.columns:
254
+ with np.errstate(divide='ignore', invalid='ignore'):
255
+ ratio = adata.obs["Fraction_CpG_site_modified"].astype(float) / adata.obs["Fraction_other_C_site_modified"].astype(float)
256
+ adata.obs["CpG_to_other_C_mod_ratio"] = ratio.fillna(0.0)
257
+ else:
258
+ adata.obs["CpG_to_other_C_mod_ratio"] = np.nan
259
+
260
+ # free memory
261
+ del create_cols
262
+ gc.collect()
263
+
264
+ # --- Now apply the filters using adata.obs columns (this part is identical to your previous code but memory-friendly) ---
265
+ filtered = adata # we'll chain subset operations
266
+
267
+ # helper to get min/max from param like [min, max] or tuple(None,..)
268
+ def _unpack_minmax(thr):
269
+ if thr is None:
270
+ return None, None
271
+ try:
272
+ lo, hi = float(thr[0]) if thr[0] is not None else None, float(thr[1]) if thr[1] is not None else None
273
+ if lo is not None and hi is not None and lo > hi:
274
+ lo, hi = hi, lo
275
+ return lo, hi
276
+ except Exception:
277
+ return None, None
278
+
279
+ # GpC thresholds
280
+ if gpc_thresholds and 'GpC' in mod_target_bases:
281
+ lo, hi = _unpack_minmax(gpc_thresholds)
282
+ if use_other_c_as_background and smf_modality != 'deaminase' and "GpC_to_other_C_mod_ratio" in filtered.obs.columns:
283
+ filtered = filtered[filtered.obs["GpC_to_other_C_mod_ratio"].astype(float) > 1]
284
+ if lo is not None:
285
+ s0 = filtered.n_obs
286
+ filtered = filtered[filtered.obs["Fraction_GpC_site_modified"].astype(float) > lo]
287
+ print(f"Removed {s0 - filtered.n_obs} reads below min GpC fraction {lo}")
288
+ if hi is not None:
289
+ s0 = filtered.n_obs
290
+ filtered = filtered[filtered.obs["Fraction_GpC_site_modified"].astype(float) < hi]
291
+ print(f"Removed {s0 - filtered.n_obs} reads above max GpC fraction {hi}")
292
+ if (min_valid_fraction_positions_in_read_vs_ref is not None) and ("Valid_GpC_site_in_read_vs_reference" in filtered.obs.columns):
293
+ s0 = filtered.n_obs
294
+ filtered = filtered[filtered.obs["Valid_GpC_site_in_read_vs_reference"].astype(float) > float(min_valid_fraction_positions_in_read_vs_ref)]
295
+ print(f"Removed {s0 - filtered.n_obs} reads with insufficient valid GpC site fraction vs ref")
296
+
297
+ # CpG thresholds
298
+ if cpg_thresholds and 'CpG' in mod_target_bases:
299
+ lo, hi = _unpack_minmax(cpg_thresholds)
300
+ if use_other_c_as_background and smf_modality != 'deaminase' and "CpG_to_other_C_mod_ratio" in filtered.obs.columns:
301
+ filtered = filtered[filtered.obs["CpG_to_other_C_mod_ratio"].astype(float) > 1]
302
+ if lo is not None:
303
+ s0 = filtered.n_obs
304
+ filtered = filtered[filtered.obs["Fraction_CpG_site_modified"].astype(float) > lo]
305
+ print(f"Removed {s0 - filtered.n_obs} reads below min CpG fraction {lo}")
306
+ if hi is not None:
307
+ s0 = filtered.n_obs
308
+ filtered = filtered[filtered.obs["Fraction_CpG_site_modified"].astype(float) < hi]
309
+ print(f"Removed {s0 - filtered.n_obs} reads above max CpG fraction {hi}")
310
+ if (min_valid_fraction_positions_in_read_vs_ref is not None) and ("Valid_CpG_site_in_read_vs_reference" in filtered.obs.columns):
311
+ s0 = filtered.n_obs
312
+ filtered = filtered[filtered.obs["Valid_CpG_site_in_read_vs_reference"].astype(float) > float(min_valid_fraction_positions_in_read_vs_ref)]
313
+ print(f"Removed {s0 - filtered.n_obs} reads with insufficient valid CpG site fraction vs ref")
314
+
315
+ # any C thresholds
316
+ if any_c_thresholds and 'C' in mod_target_bases:
317
+ lo, hi = _unpack_minmax(any_c_thresholds)
318
+ if lo is not None:
319
+ s0 = filtered.n_obs
320
+ filtered = filtered[filtered.obs["Fraction_any_C_site_modified"].astype(float) > lo]
321
+ print(f"Removed {s0 - filtered.n_obs} reads below min any-C fraction {lo}")
322
+ if hi is not None:
323
+ s0 = filtered.n_obs
324
+ filtered = filtered[filtered.obs["Fraction_any_C_site_modified"].astype(float) < hi]
325
+ print(f"Removed {s0 - filtered.n_obs} reads above max any-C fraction {hi}")
326
+ if (min_valid_fraction_positions_in_read_vs_ref is not None) and ("Valid_any_C_site_in_read_vs_reference" in filtered.obs.columns):
327
+ s0 = filtered.n_obs
328
+ filtered = filtered[filtered.obs["Valid_any_C_site_in_read_vs_reference"].astype(float) > float(min_valid_fraction_positions_in_read_vs_ref)]
329
+ print(f"Removed {s0 - filtered.n_obs} reads with insufficient valid any-C site fraction vs ref")
330
+
331
+ # A thresholds
332
+ if a_thresholds and 'A' in mod_target_bases:
333
+ lo, hi = _unpack_minmax(a_thresholds)
334
+ if lo is not None:
335
+ s0 = filtered.n_obs
336
+ filtered = filtered[filtered.obs["Fraction_A_site_modified"].astype(float) > lo]
337
+ print(f"Removed {s0 - filtered.n_obs} reads below min A fraction {lo}")
338
+ if hi is not None:
339
+ s0 = filtered.n_obs
340
+ filtered = filtered[filtered.obs["Fraction_A_site_modified"].astype(float) < hi]
341
+ print(f"Removed {s0 - filtered.n_obs} reads above max A fraction {hi}")
342
+ if (min_valid_fraction_positions_in_read_vs_ref is not None) and ("Valid_A_site_in_read_vs_reference" in filtered.obs.columns):
343
+ s0 = filtered.n_obs
344
+ filtered = filtered[filtered.obs["Valid_A_site_in_read_vs_reference"].astype(float) > float(min_valid_fraction_positions_in_read_vs_ref)]
345
+ print(f"Removed {s0 - filtered.n_obs} reads with insufficient valid A site fraction vs ref")
346
+
347
+ filtered = filtered.copy()
348
+
349
+ # mark as done
350
+ filtered.uns[uns_flag] = True
351
+
352
+ return filtered