smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,62 +1,92 @@
1
- def calculate_read_modification_stats(adata,
2
- reference_column,
3
- sample_names_col,
4
- mod_target_bases,
5
- uns_flag="calculate_read_modification_stats_performed",
6
- bypass=False,
7
- force_redo=False
8
- ):
9
- """
10
- Adds methylation/deamination statistics for each read.
11
- Indicates the read GpC and CpG methylation ratio to other_C methylation (background false positive metric for Cytosine MTase SMF).
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
12
9
 
13
- Parameters:
14
- adata (AnnData): An adata object
15
- reference_column (str): String representing the name of the Reference column to use
16
- sample_names_col (str): String representing the name of the sample name column to use
17
- mod_target_bases:
10
+ logger = get_logger(__name__)
18
11
 
19
- Returns:
20
- None
12
+
13
+ def calculate_read_modification_stats(
14
+ adata: "ad.AnnData",
15
+ reference_column: str,
16
+ sample_names_col: str,
17
+ mod_target_bases: list[str],
18
+ uns_flag: str = "calculate_read_modification_stats_performed",
19
+ bypass: bool = False,
20
+ force_redo: bool = False,
21
+ valid_sites_only: bool = False,
22
+ valid_site_suffix: str = "_valid_coverage",
23
+ ) -> None:
24
+ """Add methylation/deamination statistics for each read.
25
+
26
+ Args:
27
+ adata: AnnData object.
28
+ reference_column: Obs column containing reference identifiers.
29
+ sample_names_col: Obs column containing sample identifiers.
30
+ mod_target_bases: List of target base contexts (e.g., ``["GpC", "CpG"]``).
31
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
32
+ bypass: Whether to skip processing.
33
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
34
+ valid_sites_only: Whether to restrict to valid coverage sites.
35
+ valid_site_suffix: Suffix used for valid-site matrices.
21
36
  """
22
37
  import numpy as np
23
- import anndata as ad
24
38
  import pandas as pd
25
39
 
40
+ if valid_sites_only:
41
+ if adata.uns.get("calculate_coverage_performed", False):
42
+ pass
43
+ else:
44
+ valid_sites_only = False
45
+
46
+ if not valid_sites_only:
47
+ valid_site_suffix = ""
48
+
26
49
  # Only run if not already performed
27
50
  already = bool(adata.uns.get(uns_flag, False))
28
51
  if (already and not force_redo) or bypass:
29
52
  # QC already performed; nothing to do
30
53
  return
31
54
 
32
- print('Calculating read level Modification statistics')
55
+ logger.info("Calculating read level Modification statistics")
33
56
 
34
57
  references = set(adata.obs[reference_column])
35
58
  sample_names = set(adata.obs[sample_names_col])
36
59
  site_types = []
37
60
 
38
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
39
- site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'C_site']
40
-
41
- if 'A' in mod_target_bases:
42
- site_types += ['A_site']
61
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
62
+ site_types += ["GpC_site", "CpG_site", "ambiguous_GpC_CpG_site", "other_C_site", "C_site"]
43
63
 
44
- for site_type in site_types:
45
- adata.obs[f'Modified_{site_type}_count'] = pd.Series(0, index=adata.obs_names, dtype=int)
46
- adata.obs[f'Total_{site_type}_in_read'] = pd.Series(0, index=adata.obs_names, dtype=int)
47
- adata.obs[f'Fraction_{site_type}_modified'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
48
- adata.obs[f'Total_{site_type}_in_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=int)
49
- adata.obs[f'Valid_{site_type}_in_read_vs_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
64
+ if "A" in mod_target_bases:
65
+ site_types += ["A_site"]
50
66
 
67
+ for site_type in site_types:
68
+ adata.obs[f"Modified_{site_type}_count"] = pd.Series(0, index=adata.obs_names, dtype=int)
69
+ adata.obs[f"Total_{site_type}_in_read"] = pd.Series(0, index=adata.obs_names, dtype=int)
70
+ adata.obs[f"Fraction_{site_type}_modified"] = pd.Series(
71
+ np.nan, index=adata.obs_names, dtype=float
72
+ )
73
+ adata.obs[f"Total_{site_type}_in_reference"] = pd.Series(
74
+ np.nan, index=adata.obs_names, dtype=int
75
+ )
76
+ adata.obs[f"Valid_{site_type}_in_read_vs_reference"] = pd.Series(
77
+ np.nan, index=adata.obs_names, dtype=float
78
+ )
51
79
 
52
80
  for ref in references:
53
81
  ref_subset = adata[adata.obs[reference_column] == ref]
54
82
  for site_type in site_types:
55
- print(f'Iterating over {ref}_{site_type}')
56
- observation_matrix = ref_subset.obsm[f'{ref}_{site_type}']
83
+ logger.info("Iterating over %s_%s", ref, site_type)
84
+ observation_matrix = ref_subset.obsm[f"{ref}_{site_type}{valid_site_suffix}"]
57
85
  total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
58
86
  total_positions_in_reference = observation_matrix.shape[1]
59
- fraction_valid_positions_in_read_vs_ref = total_positions_in_read / total_positions_in_reference
87
+ fraction_valid_positions_in_read_vs_ref = (
88
+ total_positions_in_read / total_positions_in_reference
89
+ )
60
90
  number_mods_in_read = np.nansum(observation_matrix, axis=1)
61
91
  fraction_modified = number_mods_in_read / total_positions_in_read
62
92
 
@@ -64,38 +94,42 @@ def calculate_read_modification_stats(adata,
64
94
  number_mods_in_read,
65
95
  total_positions_in_read,
66
96
  out=np.full_like(number_mods_in_read, np.nan, dtype=float),
67
- where=total_positions_in_read != 0
97
+ where=total_positions_in_read != 0,
98
+ )
99
+
100
+ temp_obs_data = pd.DataFrame(
101
+ {
102
+ f"Total_{site_type}_in_read": total_positions_in_read,
103
+ f"Modified_{site_type}_count": number_mods_in_read,
104
+ f"Fraction_{site_type}_modified": fraction_modified,
105
+ f"Total_{site_type}_in_reference": total_positions_in_reference,
106
+ f"Valid_{site_type}_in_read_vs_reference": fraction_valid_positions_in_read_vs_ref,
107
+ },
108
+ index=ref_subset.obs.index,
68
109
  )
69
110
 
70
- temp_obs_data = pd.DataFrame({f'Total_{site_type}_in_read': total_positions_in_read,
71
- f'Modified_{site_type}_count': number_mods_in_read,
72
- f'Fraction_{site_type}_modified': fraction_modified,
73
- f'Total_{site_type}_in_reference': total_positions_in_reference,
74
- f'Valid_{site_type}_in_read_vs_reference': fraction_valid_positions_in_read_vs_ref},
75
- index=ref_subset.obs.index)
76
-
77
111
  adata.obs.update(temp_obs_data)
78
112
 
79
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
80
- with np.errstate(divide='ignore', invalid='ignore'):
113
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
114
+ with np.errstate(divide="ignore", invalid="ignore"):
81
115
  gpc_to_c_ratio = np.divide(
82
- adata.obs[f'Fraction_GpC_site_modified'],
83
- adata.obs[f'Fraction_other_C_site_modified'],
84
- out=np.full_like(adata.obs[f'Fraction_GpC_site_modified'], np.nan, dtype=float),
85
- where=adata.obs[f'Fraction_other_C_site_modified'] != 0
116
+ adata.obs["Fraction_GpC_site_modified"],
117
+ adata.obs["Fraction_other_C_site_modified"],
118
+ out=np.full_like(adata.obs["Fraction_GpC_site_modified"], np.nan, dtype=float),
119
+ where=adata.obs["Fraction_other_C_site_modified"] != 0,
86
120
  )
87
121
 
88
122
  cpg_to_c_ratio = np.divide(
89
- adata.obs[f'Fraction_CpG_site_modified'],
90
- adata.obs[f'Fraction_other_C_site_modified'],
91
- out=np.full_like(adata.obs[f'Fraction_CpG_site_modified'], np.nan, dtype=float),
92
- where=adata.obs[f'Fraction_other_C_site_modified'] != 0
93
- )
94
-
95
- adata.obs['GpC_to_other_C_mod_ratio'] = gpc_to_c_ratio
96
- adata.obs['CpG_to_other_C_mod_ratio'] = cpg_to_c_ratio
123
+ adata.obs["Fraction_CpG_site_modified"],
124
+ adata.obs["Fraction_other_C_site_modified"],
125
+ out=np.full_like(adata.obs["Fraction_CpG_site_modified"], np.nan, dtype=float),
126
+ where=adata.obs["Fraction_other_C_site_modified"] != 0,
127
+ )
128
+
129
+ adata.obs["GpC_to_other_C_mod_ratio"] = gpc_to_c_ratio
130
+ adata.obs["CpG_to_other_C_mod_ratio"] = cpg_to_c_ratio
97
131
 
98
132
  # mark as done
99
133
  adata.uns[uns_flag] = True
100
134
 
101
- return
135
+ return
@@ -1,23 +1,33 @@
1
- def clean_NaN(adata,
2
- layer=None,
3
- uns_flag='clean_NaN_performed',
4
- bypass=False,
5
- force_redo=True
6
- ):
7
- """
8
- Append layers to adata that contain NaN cleaning strategies.
1
+ from __future__ import annotations
9
2
 
10
- Parameters:
11
- adata (AnnData): an anndata object
12
- layer (str, optional): Name of the layer to fill NaN values in. If None, uses adata.X.
3
+ from typing import TYPE_CHECKING
13
4
 
14
- Modifies:
15
- - Adds new layers to `adata.layers` with different NaN-filling strategies.
16
- """
17
- import numpy as np
18
- import pandas as pd
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
19
8
  import anndata as ad
20
- from ..readwrite import adata_to_df
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def clean_NaN(
14
+ adata: "ad.AnnData",
15
+ layer: str | None = None,
16
+ uns_flag: str = "clean_NaN_performed",
17
+ bypass: bool = False,
18
+ force_redo: bool = True,
19
+ ) -> None:
20
+ """Append layers to ``adata`` that contain NaN-cleaning strategies.
21
+
22
+ Args:
23
+ adata: AnnData object.
24
+ layer: Layer to fill NaN values in. If ``None``, uses ``adata.X``.
25
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
26
+ bypass: Whether to skip processing.
27
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
28
+ """
29
+
30
+ from ..readwrite import adata_to_df
21
31
 
22
32
  # Only run if not already performed
23
33
  already = bool(adata.uns.get(uns_flag, False))
@@ -33,30 +43,30 @@ def clean_NaN(adata,
33
43
  df = adata_to_df(adata, layer=layer)
34
44
 
35
45
  # Fill NaN with closest SMF value (forward then backward fill)
36
- print('Making layer: fill_nans_closest')
37
- adata.layers['fill_nans_closest'] = df.ffill(axis=1).bfill(axis=1).values
46
+ logger.info("Making layer: fill_nans_closest")
47
+ adata.layers["fill_nans_closest"] = df.ffill(axis=1).bfill(axis=1).values
38
48
 
39
49
  # Replace NaN with 0, and 0 with -1
40
- print('Making layer: nan0_0minus1')
50
+ logger.info("Making layer: nan0_0minus1")
41
51
  df_nan0_0minus1 = df.replace(0, -1).fillna(0)
42
- adata.layers['nan0_0minus1'] = df_nan0_0minus1.values
52
+ adata.layers["nan0_0minus1"] = df_nan0_0minus1.values
43
53
 
44
54
  # Replace NaN with 1, and 1 with 2
45
- print('Making layer: nan1_12')
55
+ logger.info("Making layer: nan1_12")
46
56
  df_nan1_12 = df.replace(1, 2).fillna(1)
47
- adata.layers['nan1_12'] = df_nan1_12.values
57
+ adata.layers["nan1_12"] = df_nan1_12.values
48
58
 
49
59
  # Replace NaN with -1
50
- print('Making layer: nan_minus_1')
60
+ logger.info("Making layer: nan_minus_1")
51
61
  df_nan_minus_1 = df.fillna(-1)
52
- adata.layers['nan_minus_1'] = df_nan_minus_1.values
62
+ adata.layers["nan_minus_1"] = df_nan_minus_1.values
53
63
 
54
64
  # Replace NaN with -1
55
- print('Making layer: nan_half')
65
+ logger.info("Making layer: nan_half")
56
66
  df_nan_half = df.fillna(0.5)
57
- adata.layers['nan_half'] = df_nan_half.values
67
+ adata.layers["nan_half"] = df_nan_half.values
58
68
 
59
69
  # mark as done
60
70
  adata.uns[uns_flag] = True
61
71
 
62
- return None
72
+ return None
@@ -1,26 +1,38 @@
1
1
  ## filter_adata_by_nan_proportion
2
2
 
3
- def filter_adata_by_nan_proportion(adata, threshold, axis='obs'):
4
- """
5
- Filters an anndata object on a nan proportion threshold in a given matrix axis.
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+
11
+ def filter_adata_by_nan_proportion(
12
+ adata: "ad.AnnData", threshold: float, axis: str = "obs"
13
+ ) -> "ad.AnnData":
14
+ """Filter an AnnData object on NaN proportion in a matrix axis.
15
+
16
+ Args:
17
+ adata: AnnData object to filter.
18
+ threshold: Maximum allowed NaN proportion.
19
+ axis: Whether to filter based on ``"obs"`` or ``"var"`` NaN content.
6
20
 
7
- Parameters:
8
- adata (AnnData):
9
- threshold (float): The max np.nan content to allow in the given axis.
10
- axis (str): Whether to filter the adata based on obs or var np.nan content
11
21
  Returns:
12
- filtered_adata
22
+ anndata.AnnData: Filtered AnnData object.
23
+
24
+ Raises:
25
+ ValueError: If ``axis`` is not ``"obs"`` or ``"var"``.
13
26
  """
14
27
  import numpy as np
15
- import anndata as ad
16
28
 
17
- if axis == 'obs':
29
+ if axis == "obs":
18
30
  # Calculate the proportion of NaN values in each read
19
31
  nan_proportion = np.isnan(adata.X).mean(axis=1)
20
32
  # Filter reads to keep reads with less than a certain NaN proportion
21
33
  filtered_indices = np.where(nan_proportion <= threshold)[0]
22
34
  filtered_adata = adata[filtered_indices, :].copy()
23
- elif axis == 'var':
35
+ elif axis == "var":
24
36
  # Calculate the proportion of NaN values at a given position
25
37
  nan_proportion = np.isnan(adata.X).mean(axis=0)
26
38
  # Filter positions to keep positions with less than a certain NaN proportion
@@ -28,4 +40,4 @@ def filter_adata_by_nan_proportion(adata, threshold, axis='obs'):
28
40
  filtered_adata = adata[:, filtered_indices].copy()
29
41
  else:
30
42
  raise ValueError("Axis must be either 'obs' or 'var'")
31
- return filtered_adata
43
+ return filtered_adata
@@ -1,28 +1,43 @@
1
- from typing import Optional, Union, Sequence
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Sequence, Union
4
+
5
+ import anndata as ad
2
6
  import numpy as np
3
7
  import pandas as pd
4
- import anndata as ad
8
+
9
+ from smftools.logging_utils import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
5
13
 
6
14
  def filter_reads_on_length_quality_mapping(
7
15
  adata: ad.AnnData,
8
16
  filter_on_coordinates: Union[bool, Sequence] = False,
9
17
  # New single-range params (preferred):
10
- read_length: Optional[Sequence[float]] = None, # e.g. [min, max]
11
- length_ratio: Optional[Sequence[float]] = None, # e.g. [min, max]
12
- read_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
13
- mapping_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
18
+ read_length: Optional[Sequence[float]] = None, # e.g. [min, max]
19
+ length_ratio: Optional[Sequence[float]] = None, # e.g. [min, max]
20
+ read_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
21
+ mapping_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
14
22
  uns_flag: str = "filter_reads_on_length_quality_mapping_performed",
15
23
  bypass: bool = False,
16
- force_redo: bool = True
24
+ force_redo: bool = True,
17
25
  ) -> ad.AnnData:
18
- """
19
- Filter AnnData by coordinate window, read length, length ratios, read quality and mapping quality.
20
-
21
- New: you may pass `read_length=[min, max]` (or tuple) to set both min/max in one argument.
22
- If `read_length` is given it overrides scalar min/max variants (which are not present in this signature).
23
- Same behavior supported for `length_ratio`, `read_quality`, `mapping_quality`.
24
-
25
- Returns a filtered copy of the input AnnData and marks adata.uns[uns_flag] = True.
26
+ """Filter AnnData by coordinates, read length, quality, and mapping metrics.
27
+
28
+ Args:
29
+ adata: AnnData object to filter.
30
+ filter_on_coordinates: Optional coordinate window as a two-value sequence.
31
+ read_length: Read length range as ``[min, max]``.
32
+ length_ratio: Length ratio range as ``[min, max]``.
33
+ read_quality: Read quality range as ``[min, max]``.
34
+ mapping_quality: Mapping quality range as ``[min, max]``.
35
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
36
+ bypass: Whether to skip processing.
37
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
38
+
39
+ Returns:
40
+ anndata.AnnData: Filtered copy of the input AnnData.
26
41
  """
27
42
  # early exit
28
43
  already = bool(adata.uns.get(uns_flag, False))
@@ -37,7 +52,9 @@ def filter_reads_on_length_quality_mapping(
37
52
  try:
38
53
  low, high = tuple(filter_on_coordinates)
39
54
  except Exception:
40
- raise ValueError("filter_on_coordinates must be False or an iterable of two numbers (low, high).")
55
+ raise ValueError(
56
+ "filter_on_coordinates must be False or an iterable of two numbers (low, high)."
57
+ )
41
58
  try:
42
59
  var_coords = np.array([float(v) for v in adata_work.var_names])
43
60
  if low > high:
@@ -50,10 +67,17 @@ def filter_reads_on_length_quality_mapping(
50
67
  selected_cols = list(adata_work.var_names[lo_idx : hi_idx + 1])
51
68
  else:
52
69
  selected_cols = list(adata_work.var_names[col_mask_bool])
53
- print(f"Subsetting adata to coordinates between {low} and {high}: keeping {len(selected_cols)} variables.")
70
+ logger.info(
71
+ "Subsetting adata to coordinates between %s and %s: keeping %s variables.",
72
+ low,
73
+ high,
74
+ len(selected_cols),
75
+ )
54
76
  adata_work = adata_work[:, selected_cols].copy()
55
77
  except Exception:
56
- print("Warning: could not interpret adata.var_names as numeric coordinates — skipping coordinate filtering.")
78
+ logger.warning(
79
+ "Could not interpret adata.var_names as numeric coordinates — skipping coordinate filtering."
80
+ )
57
81
 
58
82
  # --- helper to coerce range inputs ---
59
83
  def _coerce_range(range_arg):
@@ -85,72 +109,83 @@ def filter_reads_on_length_quality_mapping(
85
109
  # read length filter
86
110
  if (rl_min is not None) or (rl_max is not None):
87
111
  if "mapped_length" not in adata_work.obs.columns:
88
- print("Warning: 'mapped_length' not found in adata.obs — skipping read_length filter.")
112
+ logger.warning("'mapped_length' not found in adata.obs — skipping read_length filter.")
89
113
  else:
90
114
  vals = pd.to_numeric(adata_work.obs["mapped_length"], errors="coerce")
91
115
  mask = pd.Series(True, index=adata_work.obs.index)
92
116
  if rl_min is not None:
93
- mask &= (vals >= rl_min)
117
+ mask &= vals >= rl_min
94
118
  if rl_max is not None:
95
- mask &= (vals <= rl_max)
119
+ mask &= vals <= rl_max
96
120
  mask &= vals.notna()
97
121
  combined_mask &= mask
98
- print(f"Planned read_length filter: min={rl_min}, max={rl_max}")
122
+ logger.info("Planned read_length filter: min=%s, max=%s", rl_min, rl_max)
99
123
 
100
124
  # length ratio filter
101
125
  if (lr_min is not None) or (lr_max is not None):
102
126
  if "mapped_length_to_reference_length_ratio" not in adata_work.obs.columns:
103
- print("Warning: 'mapped_length_to_reference_length_ratio' not found in adata.obs — skipping length_ratio filter.")
127
+ logger.warning(
128
+ "'mapped_length_to_reference_length_ratio' not found in adata.obs — skipping length_ratio filter."
129
+ )
104
130
  else:
105
- vals = pd.to_numeric(adata_work.obs["mapped_length_to_reference_length_ratio"], errors="coerce")
131
+ vals = pd.to_numeric(
132
+ adata_work.obs["mapped_length_to_reference_length_ratio"], errors="coerce"
133
+ )
106
134
  mask = pd.Series(True, index=adata_work.obs.index)
107
135
  if lr_min is not None:
108
- mask &= (vals >= lr_min)
136
+ mask &= vals >= lr_min
109
137
  if lr_max is not None:
110
- mask &= (vals <= lr_max)
138
+ mask &= vals <= lr_max
111
139
  mask &= vals.notna()
112
140
  combined_mask &= mask
113
- print(f"Planned length_ratio filter: min={lr_min}, max={lr_max}")
141
+ logger.info("Planned length_ratio filter: min=%s, max=%s", lr_min, lr_max)
114
142
 
115
143
  # read quality filter (supporting optional range but typically min only)
116
144
  if (rq_min is not None) or (rq_max is not None):
117
145
  if "read_quality" not in adata_work.obs.columns:
118
- print("Warning: 'read_quality' not found in adata.obs — skipping read_quality filter.")
146
+ logger.warning("'read_quality' not found in adata.obs — skipping read_quality filter.")
119
147
  else:
120
148
  vals = pd.to_numeric(adata_work.obs["read_quality"], errors="coerce")
121
149
  mask = pd.Series(True, index=adata_work.obs.index)
122
150
  if rq_min is not None:
123
- mask &= (vals >= rq_min)
151
+ mask &= vals >= rq_min
124
152
  if rq_max is not None:
125
- mask &= (vals <= rq_max)
153
+ mask &= vals <= rq_max
126
154
  mask &= vals.notna()
127
155
  combined_mask &= mask
128
- print(f"Planned read_quality filter: min={rq_min}, max={rq_max}")
156
+ logger.info("Planned read_quality filter: min=%s, max=%s", rq_min, rq_max)
129
157
 
130
158
  # mapping quality filter (supporting optional range but typically min only)
131
159
  if (mq_min is not None) or (mq_max is not None):
132
160
  if "mapping_quality" not in adata_work.obs.columns:
133
- print("Warning: 'mapping_quality' not found in adata.obs — skipping mapping_quality filter.")
161
+ logger.warning(
162
+ "'mapping_quality' not found in adata.obs — skipping mapping_quality filter."
163
+ )
134
164
  else:
135
165
  vals = pd.to_numeric(adata_work.obs["mapping_quality"], errors="coerce")
136
166
  mask = pd.Series(True, index=adata_work.obs.index)
137
167
  if mq_min is not None:
138
- mask &= (vals >= mq_min)
168
+ mask &= vals >= mq_min
139
169
  if mq_max is not None:
140
- mask &= (vals <= mq_max)
170
+ mask &= vals <= mq_max
141
171
  mask &= vals.notna()
142
172
  combined_mask &= mask
143
- print(f"Planned mapping_quality filter: min={mq_min}, max={mq_max}")
173
+ logger.info("Planned mapping_quality filter: min=%s, max=%s", mq_min, mq_max)
144
174
 
145
175
  # Apply combined mask and report
146
176
  s0 = adata_work.n_obs
147
177
  combined_mask_bool = combined_mask.astype(bool).values
148
178
  adata_work = adata_work[combined_mask_bool].copy()
149
179
  s1 = adata_work.n_obs
150
- print(f"Combined filters applied: kept {s1} / {s0} reads (removed {s0 - s1})")
180
+ logger.info("Combined filters applied: kept %s / %s reads (removed %s)", s1, s0, s0 - s1)
151
181
 
152
182
  final_n = adata_work.n_obs
153
- print(f"Filtering complete: start={start_n}, final={final_n}, removed={start_n - final_n}")
183
+ logger.info(
184
+ "Filtering complete: start=%s, final=%s, removed=%s",
185
+ start_n,
186
+ final_n,
187
+ start_n - final_n,
188
+ )
154
189
 
155
190
  # mark as done
156
191
  adata_work.uns[uns_flag] = True