smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,53 +1,95 @@
1
1
  ## calculate_position_Youden
2
-
3
2
  ## Calculating and applying position level thresholds for methylation calls to binarize the SMF data
4
- def calculate_position_Youden(adata, positive_control_sample='positive', negative_control_sample='negative', J_threshold=0.5, obs_column='Reference', infer_on_percentile=False, inference_variable='', save=False, output_directory=''):
5
- """
6
- Adds new variable metadata to each position indicating whether the position provides reliable SMF methylation calls. Also outputs plots of the positional ROC curves.
3
+ from __future__ import annotations
7
4
 
8
- Parameters:
9
- adata (AnnData): An AnnData object.
10
- positive_control_sample (str): string representing the sample name corresponding to the Plus MTase control sample.
11
- negative_control_sample (str): string representing the sample name corresponding to the Minus MTase control sample.
12
- J_threshold (float): A float indicating the J-statistic used to indicate whether a position passes QC for methylation calls.
13
- obs_column (str): The category to iterate over.
14
- infer_on_perdentile (bool | int): If False, use defined postive and negative control samples. If an int (0 < int < 100) is passed, this uses the top and bottom int percentile of methylated reads based on metric in inference_variable column.
15
- inference_variable (str): If infer_on_percentile has an integer value passed, use the AnnData observation column name passed by this string as the metric.
16
- save (bool): Whether to save the ROC plots.
17
- output_directory (str): String representing the path to the output directory to output the ROC curves.
18
-
19
- Returns:
20
- None
21
- """
22
- import numpy as np
23
- import pandas as pd
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ from smftools.logging_utils import get_logger
9
+
10
+ if TYPE_CHECKING:
24
11
  import anndata as ad
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def calculate_position_Youden(
17
+ adata: "ad.AnnData",
18
+ positive_control_sample: str | None = None,
19
+ negative_control_sample: str | None = None,
20
+ J_threshold: float = 0.5,
21
+ ref_column: str = "Reference_strand",
22
+ sample_column: str = "Sample_names",
23
+ infer_on_percentile: bool | int = True,
24
+ inference_variable: str = "Raw_modification_signal",
25
+ save: bool = False,
26
+ output_directory: str | Path = "",
27
+ ) -> None:
28
+ """Add position-level Youden thresholds and optional ROC plots.
29
+
30
+ Args:
31
+ adata: AnnData object.
32
+ positive_control_sample: Sample name for the plus MTase control.
33
+ negative_control_sample: Sample name for the minus MTase control.
34
+ J_threshold: J-statistic threshold for QC.
35
+ ref_column: Obs column for reference/strand categories.
36
+ sample_column: Obs column for sample identifiers.
37
+ infer_on_percentile: If ``False``, use provided controls. If an int in ``(0, 100)``,
38
+ use percentile-based inference from ``inference_variable``.
39
+ inference_variable: Obs column used for percentile inference.
40
+ save: Whether to save ROC plots to disk.
41
+ output_directory: Output directory for ROC plots.
42
+ """
25
43
  import matplotlib.pyplot as plt
26
- from sklearn.metrics import roc_curve, roc_auc_score
44
+ import numpy as np
45
+ from sklearn.metrics import roc_curve
27
46
 
28
47
  control_samples = [positive_control_sample, negative_control_sample]
29
- categories = adata.obs[obs_column].cat.categories
48
+ references = adata.obs[ref_column].cat.categories
30
49
  # Iterate over each category in the specified obs_column
31
- for cat in categories:
32
- print(f"Calculating position Youden statistics for {cat}")
50
+ for ref in references:
51
+ logger.info("Calculating position Youden statistics for %s", ref)
33
52
  # Subset to keep only reads associated with the category
34
- cat_subset = adata[adata.obs[obs_column] == cat]
53
+ ref_subset = adata[adata.obs[ref_column] == ref]
35
54
  # Iterate over positive and negative control samples
36
- for control in control_samples:
37
- # Initialize a dictionary for the given control sample. This will be keyed by dataset and position to point to a tuple of coordinate position and an array of methylation probabilities
38
- adata.uns[f'{cat}_position_methylation_dict_{control}'] = {}
39
- if infer_on_percentile:
40
- sorted_column = cat_subset.obs[inference_variable].sort_values(ascending=False)
41
- if control == "positive":
55
+ for i, control in enumerate(control_samples):
56
+ # If controls are not passed and infer on percentile is True, infer thresholds based on top and bottom percentile windows for a given obs column metric.
57
+ if infer_on_percentile and not control:
58
+ logger.info(
59
+ "Inferring methylation control reads for %s based on %s percentiles of %s",
60
+ ref,
61
+ infer_on_percentile,
62
+ inference_variable,
63
+ )
64
+ sorted_column = ref_subset.obs[inference_variable].sort_values(ascending=False)
65
+ if i == 0:
66
+ logger.info("Using top %s percentile for positive control", infer_on_percentile)
67
+ control = "positive"
68
+ positive_control_sample = control
42
69
  threshold = np.percentile(sorted_column, 100 - infer_on_percentile)
43
- control_subset = cat_subset[cat_subset.obs[inference_variable] >= threshold, :]
70
+ control_subset = ref_subset[ref_subset.obs[inference_variable] >= threshold, :]
44
71
  else:
72
+ logger.info(
73
+ "Using bottom %s percentile for negative control", infer_on_percentile
74
+ )
75
+ control = "negative"
76
+ negative_control_sample = control
45
77
  threshold = np.percentile(sorted_column, infer_on_percentile)
46
- control_subset = cat_subset[cat_subset.obs[inference_variable] <= threshold, :]
78
+ control_subset = ref_subset[ref_subset.obs[inference_variable] <= threshold, :]
79
+ elif not infer_on_percentile and not control:
80
+ logger.error(
81
+ "Can not threshold Anndata on Youden threshold. Need to either provide control samples or set infer_on_percentile to True"
82
+ )
83
+ return
47
84
  else:
85
+ logger.info("Using provided control sample: %s", control)
48
86
  # get the current control subset on the given category
49
- filtered_obs = cat_subset.obs[cat_subset.obs['Sample_names'].str.contains(control, na=False, regex=True)]
50
- control_subset = cat_subset[filtered_obs.index]
87
+ filtered_obs = ref_subset.obs[ref_subset.obs[sample_column] == control]
88
+ control_subset = ref_subset[filtered_obs.index]
89
+
90
+ # Initialize a dictionary for the given control sample. This will be keyed by dataset and position to point to a tuple of coordinate position and an array of methylation probabilities
91
+ adata.uns[f"{ref}_position_methylation_dict_{control}"] = {}
92
+
51
93
  # Iterate through every position in the control subset
52
94
  for position in range(control_subset.shape[1]):
53
95
  # Get the coordinate name associated with that position
@@ -63,29 +105,45 @@ def calculate_position_Youden(adata, positive_control_sample='positive', negativ
63
105
  # Get fraction coverage
64
106
  fraction_coverage = position_coverage / control_subset.shape[0]
65
107
  # Save the position and the position methylation data for the control subset
66
- adata.uns[f'{cat}_position_methylation_dict_{control}'][f'{position}'] = (position, position_data, fraction_coverage)
108
+ adata.uns[f"{ref}_position_methylation_dict_{control}"][f"{position}"] = (
109
+ position,
110
+ position_data,
111
+ fraction_coverage,
112
+ )
67
113
 
68
- for cat in categories:
114
+ for ref in references:
69
115
  fig, ax = plt.subplots(figsize=(6, 4))
70
- plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
71
- plt.xlabel('False Positive Rate')
72
- plt.ylabel('True Positive Rate')
73
- ax.spines['right'].set_visible(False)
74
- ax.spines['top'].set_visible(False)
116
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
117
+ plt.xlabel("False Positive Rate")
118
+ plt.ylabel("True Positive Rate")
119
+ ax.spines["right"].set_visible(False)
120
+ ax.spines["top"].set_visible(False)
75
121
  n_passed_positions = 0
76
122
  n_total_positions = 0
77
123
  # Initialize a list that will hold the positional thresholds for the category
78
124
  probability_thresholding_list = [(np.nan, np.nan)] * adata.shape[1]
79
- for i, key in enumerate(adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'].keys()):
80
- position = int(adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'][key][0])
81
- positive_position_array = adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'][key][1]
82
- fraction_coverage = adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'][key][2]
125
+ for i, key in enumerate(
126
+ adata.uns[f"{ref}_position_methylation_dict_{positive_control_sample}"].keys()
127
+ ):
128
+ position = int(
129
+ adata.uns[f"{ref}_position_methylation_dict_{positive_control_sample}"][key][0]
130
+ )
131
+ positive_position_array = adata.uns[
132
+ f"{ref}_position_methylation_dict_{positive_control_sample}"
133
+ ][key][1]
134
+ fraction_coverage = adata.uns[
135
+ f"{ref}_position_methylation_dict_{positive_control_sample}"
136
+ ][key][2]
83
137
  if fraction_coverage > 0.2:
84
138
  try:
85
- negative_position_array = adata.uns[f'{cat}_position_methylation_dict_{negative_control_sample}'][key][1]
139
+ negative_position_array = adata.uns[
140
+ f"{ref}_position_methylation_dict_{negative_control_sample}"
141
+ ][key][1]
86
142
  # Combine the negative and positive control data
87
143
  data = np.concatenate([negative_position_array, positive_position_array])
88
- labels = np.array([0] * len(negative_position_array) + [1] * len(positive_position_array))
144
+ labels = np.array(
145
+ [0] * len(negative_position_array) + [1] * len(positive_position_array)
146
+ )
89
147
  # Calculate the ROC curve
90
148
  fpr, tpr, thresholds = roc_curve(labels, data)
91
149
  # Calculate Youden's J statistic
@@ -98,18 +156,24 @@ def calculate_position_Youden(adata, positive_control_sample='positive', negativ
98
156
  n_total_positions += 1
99
157
  if max_J > J_threshold:
100
158
  n_passed_positions += 1
101
- plt.plot(fpr, tpr, label='ROC curve')
102
- except:
159
+ plt.plot(fpr, tpr, label="ROC curve")
160
+ except Exception:
103
161
  probability_thresholding_list[position] = (0.8, np.nan)
104
- title = f'ROC Curve for {n_passed_positions} positions with J-stat greater than {J_threshold}\n out of {n_total_positions} total positions on {cat}'
162
+ title = f"ROC Curve for {n_passed_positions} positions with J-stat greater than {J_threshold}\n out of {n_total_positions} total positions on {ref}"
105
163
  plt.title(title)
106
164
  save_name = output_directory / f"{title}.png"
107
165
  if save:
108
166
  plt.savefig(save_name)
109
167
  plt.close()
110
168
  else:
111
- plt.show()
169
+ plt.show()
112
170
 
113
- adata.var[f'{cat}_position_methylation_thresholding_Youden_stats'] = probability_thresholding_list
171
+ adata.var[f"{ref}_position_methylation_thresholding_Youden_stats"] = (
172
+ probability_thresholding_list
173
+ )
114
174
  J_max_list = [probability_thresholding_list[i][1] for i in range(adata.shape[1])]
115
- adata.var[f'{cat}_position_passed_QC'] = [True if i > J_threshold else False for i in J_max_list]
175
+ adata.var[f"{ref}_position_passed_Youden_thresholding_QC"] = [
176
+ True if i > J_threshold else False for i in J_max_list
177
+ ]
178
+
179
+ logger.info("Finished calculating position Youden statistics")
@@ -1,45 +1,74 @@
1
1
  ## calculate_read_length_stats
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from smftools.logging_utils import get_logger
8
+
9
+ if TYPE_CHECKING:
10
+ import anndata as ad
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
3
15
  # Read length QC
4
- def calculate_read_length_stats(adata, reference_column='', sample_names_col=''):
5
- """
6
- Append first valid position in a read and last valid position in the read. From this determine and append the read length.
16
+ def calculate_read_length_stats(
17
+ adata: "ad.AnnData",
18
+ reference_column: str = "",
19
+ sample_names_col: str = "",
20
+ ) -> tuple[int, int]:
21
+ """Calculate per-read length statistics and store them in ``adata.obs``.
22
+
23
+ Args:
24
+ adata: AnnData object.
25
+ reference_column: Obs column containing reference identifiers.
26
+ sample_names_col: Obs column containing sample identifiers.
7
27
 
8
- Parameters:
9
- adata (AnnData): An adata object
10
- reference_column (str): String representing the name of the Reference column to use
11
- sample_names_col (str): String representing the name of the sample name column to use
12
-
13
28
  Returns:
14
- upper_bound (int): last valid position in the dataset
15
- lower_bound (int): first valid position in the dataset
29
+ tuple[int, int]: ``(upper_bound, lower_bound)`` for valid positions in the dataset.
16
30
  """
17
31
  import numpy as np
18
- import anndata as ad
19
32
  import pandas as pd
20
33
 
21
- print('Calculating read length statistics')
34
+ logger.info("Calculating read length statistics")
22
35
 
23
36
  references = set(adata.obs[reference_column])
24
37
  sample_names = set(adata.obs[sample_names_col])
25
38
 
26
39
  ## Add basic observation-level (read-level) metadata to the object: first valid position in a read and last valid position in the read. From this determine the read length. Save two new variable which hold the first and last valid positions in the entire dataset
27
- print('calculating read length stats')
40
+ logger.info("Calculating read length stats")
28
41
  # Add some basic observation-level (read-level) metadata to the anndata object
29
- read_first_valid_position = np.array([int(adata.var_names[i]) for i in np.argmax(~np.isnan(adata.X), axis=1)])
30
- read_last_valid_position = np.array([int(adata.var_names[i]) for i in (adata.X.shape[1] - 1 - np.argmax(~np.isnan(adata.X[:, ::-1]), axis=1))])
31
- read_length = read_last_valid_position - read_first_valid_position + np.ones(len(read_first_valid_position))
42
+ read_first_valid_position = np.array(
43
+ [int(adata.var_names[i]) for i in np.argmax(~np.isnan(adata.X), axis=1)]
44
+ )
45
+ read_last_valid_position = np.array(
46
+ [
47
+ int(adata.var_names[i])
48
+ for i in (adata.X.shape[1] - 1 - np.argmax(~np.isnan(adata.X[:, ::-1]), axis=1))
49
+ ]
50
+ )
51
+ read_length = (
52
+ read_last_valid_position
53
+ - read_first_valid_position
54
+ + np.ones(len(read_first_valid_position))
55
+ )
32
56
 
33
- adata.obs['first_valid_position'] = pd.Series(read_first_valid_position, index=adata.obs.index, dtype=int)
34
- adata.obs['last_valid_position'] = pd.Series(read_last_valid_position, index=adata.obs.index, dtype=int)
35
- adata.obs['read_length'] = pd.Series(read_length, index=adata.obs.index, dtype=int)
57
+ adata.obs["first_valid_position"] = pd.Series(
58
+ read_first_valid_position, index=adata.obs.index, dtype=int
59
+ )
60
+ adata.obs["last_valid_position"] = pd.Series(
61
+ read_last_valid_position, index=adata.obs.index, dtype=int
62
+ )
63
+ adata.obs["read_length"] = pd.Series(read_length, index=adata.obs.index, dtype=int)
36
64
 
37
65
  # Define variables to hold the first and last valid position in the dataset
38
- upper_bound = int(np.nanmax(adata.obs['last_valid_position']))
39
- lower_bound = int(np.nanmin(adata.obs['first_valid_position']))
66
+ upper_bound = int(np.nanmax(adata.obs["last_valid_position"]))
67
+ lower_bound = int(np.nanmin(adata.obs["first_valid_position"]))
40
68
 
41
69
  return upper_bound, lower_bound
42
70
 
71
+
43
72
  # # Add an unstructured element to the anndata object which points to a dictionary of read lengths keyed by reference and sample name. Points to a tuple containing (mean, median, stdev) of the read lengths of the sample for the given reference strand
44
73
  # ## Plot histogram of read length data and save the median and stdev of the read lengths for each sample.
45
74
  # adata.uns['read_length_dict'] = {}
@@ -70,10 +99,10 @@ def calculate_read_length_stats(adata, reference_column='', sample_names_col='')
70
99
  # # Add a vertical line at the median
71
100
  # plt.axvline(median, color='red', linestyle='dashed', linewidth=1)
72
101
  # # Annotate the median
73
- # plt.xlim(lower_bound - 100, upper_bound + 100)
102
+ # plt.xlim(lower_bound - 100, upper_bound + 100)
74
103
  # if save_read_length_histogram:
75
104
  # save_name = output_directory + f'/{readwrite.date_string()} {title}'
76
105
  # plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
77
106
  # plt.close()
78
107
  # else:
79
- # plt.show()
108
+ # plt.show()
@@ -1,62 +1,92 @@
1
- def calculate_read_modification_stats(adata,
2
- reference_column,
3
- sample_names_col,
4
- mod_target_bases,
5
- uns_flag="read_modification_stats_calculated",
6
- bypass=False,
7
- force_redo=False
8
- ):
9
- """
10
- Adds methylation/deamination statistics for each read.
11
- Indicates the read GpC and CpG methylation ratio to other_C methylation (background false positive metric for Cytosine MTase SMF).
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
12
9
 
13
- Parameters:
14
- adata (AnnData): An adata object
15
- reference_column (str): String representing the name of the Reference column to use
16
- sample_names_col (str): String representing the name of the sample name column to use
17
- mod_target_bases:
10
+ logger = get_logger(__name__)
18
11
 
19
- Returns:
20
- None
12
+
13
+ def calculate_read_modification_stats(
14
+ adata: "ad.AnnData",
15
+ reference_column: str,
16
+ sample_names_col: str,
17
+ mod_target_bases: list[str],
18
+ uns_flag: str = "calculate_read_modification_stats_performed",
19
+ bypass: bool = False,
20
+ force_redo: bool = False,
21
+ valid_sites_only: bool = False,
22
+ valid_site_suffix: str = "_valid_coverage",
23
+ ) -> None:
24
+ """Add methylation/deamination statistics for each read.
25
+
26
+ Args:
27
+ adata: AnnData object.
28
+ reference_column: Obs column containing reference identifiers.
29
+ sample_names_col: Obs column containing sample identifiers.
30
+ mod_target_bases: List of target base contexts (e.g., ``["GpC", "CpG"]``).
31
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
32
+ bypass: Whether to skip processing.
33
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
34
+ valid_sites_only: Whether to restrict to valid coverage sites.
35
+ valid_site_suffix: Suffix used for valid-site matrices.
21
36
  """
22
37
  import numpy as np
23
- import anndata as ad
24
38
  import pandas as pd
25
39
 
40
+ if valid_sites_only:
41
+ if adata.uns.get("calculate_coverage_performed", False):
42
+ pass
43
+ else:
44
+ valid_sites_only = False
45
+
46
+ if not valid_sites_only:
47
+ valid_site_suffix = ""
48
+
26
49
  # Only run if not already performed
27
50
  already = bool(adata.uns.get(uns_flag, False))
28
51
  if (already and not force_redo) or bypass:
29
52
  # QC already performed; nothing to do
30
53
  return
31
54
 
32
- print('Calculating read level Modification statistics')
55
+ logger.info("Calculating read level Modification statistics")
33
56
 
34
57
  references = set(adata.obs[reference_column])
35
58
  sample_names = set(adata.obs[sample_names_col])
36
59
  site_types = []
37
60
 
38
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
39
- site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'C_site']
40
-
41
- if 'A' in mod_target_bases:
42
- site_types += ['A_site']
61
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
62
+ site_types += ["GpC_site", "CpG_site", "ambiguous_GpC_CpG_site", "other_C_site", "C_site"]
43
63
 
44
- for site_type in site_types:
45
- adata.obs[f'Modified_{site_type}_count'] = pd.Series(0, index=adata.obs_names, dtype=int)
46
- adata.obs[f'Total_{site_type}_in_read'] = pd.Series(0, index=adata.obs_names, dtype=int)
47
- adata.obs[f'Fraction_{site_type}_modified'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
48
- adata.obs[f'Total_{site_type}_in_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=int)
49
- adata.obs[f'Valid_{site_type}_in_read_vs_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
64
+ if "A" in mod_target_bases:
65
+ site_types += ["A_site"]
50
66
 
67
+ for site_type in site_types:
68
+ adata.obs[f"Modified_{site_type}_count"] = pd.Series(0, index=adata.obs_names, dtype=int)
69
+ adata.obs[f"Total_{site_type}_in_read"] = pd.Series(0, index=adata.obs_names, dtype=int)
70
+ adata.obs[f"Fraction_{site_type}_modified"] = pd.Series(
71
+ np.nan, index=adata.obs_names, dtype=float
72
+ )
73
+ adata.obs[f"Total_{site_type}_in_reference"] = pd.Series(
74
+ np.nan, index=adata.obs_names, dtype=int
75
+ )
76
+ adata.obs[f"Valid_{site_type}_in_read_vs_reference"] = pd.Series(
77
+ np.nan, index=adata.obs_names, dtype=float
78
+ )
51
79
 
52
80
  for ref in references:
53
81
  ref_subset = adata[adata.obs[reference_column] == ref]
54
82
  for site_type in site_types:
55
- print(f'Iterating over {ref}_{site_type}')
56
- observation_matrix = ref_subset.obsm[f'{ref}_{site_type}']
83
+ logger.info("Iterating over %s_%s", ref, site_type)
84
+ observation_matrix = ref_subset.obsm[f"{ref}_{site_type}{valid_site_suffix}"]
57
85
  total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
58
86
  total_positions_in_reference = observation_matrix.shape[1]
59
- fraction_valid_positions_in_read_vs_ref = total_positions_in_read / total_positions_in_reference
87
+ fraction_valid_positions_in_read_vs_ref = (
88
+ total_positions_in_read / total_positions_in_reference
89
+ )
60
90
  number_mods_in_read = np.nansum(observation_matrix, axis=1)
61
91
  fraction_modified = number_mods_in_read / total_positions_in_read
62
92
 
@@ -64,38 +94,42 @@ def calculate_read_modification_stats(adata,
64
94
  number_mods_in_read,
65
95
  total_positions_in_read,
66
96
  out=np.full_like(number_mods_in_read, np.nan, dtype=float),
67
- where=total_positions_in_read != 0
97
+ where=total_positions_in_read != 0,
98
+ )
99
+
100
+ temp_obs_data = pd.DataFrame(
101
+ {
102
+ f"Total_{site_type}_in_read": total_positions_in_read,
103
+ f"Modified_{site_type}_count": number_mods_in_read,
104
+ f"Fraction_{site_type}_modified": fraction_modified,
105
+ f"Total_{site_type}_in_reference": total_positions_in_reference,
106
+ f"Valid_{site_type}_in_read_vs_reference": fraction_valid_positions_in_read_vs_ref,
107
+ },
108
+ index=ref_subset.obs.index,
68
109
  )
69
110
 
70
- temp_obs_data = pd.DataFrame({f'Total_{site_type}_in_read': total_positions_in_read,
71
- f'Modified_{site_type}_count': number_mods_in_read,
72
- f'Fraction_{site_type}_modified': fraction_modified,
73
- f'Total_{site_type}_in_reference': total_positions_in_reference,
74
- f'Valid_{site_type}_in_read_vs_reference': fraction_valid_positions_in_read_vs_ref},
75
- index=ref_subset.obs.index)
76
-
77
111
  adata.obs.update(temp_obs_data)
78
112
 
79
- if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
80
- with np.errstate(divide='ignore', invalid='ignore'):
113
+ if any(base in mod_target_bases for base in ["GpC", "CpG", "C"]):
114
+ with np.errstate(divide="ignore", invalid="ignore"):
81
115
  gpc_to_c_ratio = np.divide(
82
- adata.obs[f'Fraction_GpC_site_modified'],
83
- adata.obs[f'Fraction_other_C_site_modified'],
84
- out=np.full_like(adata.obs[f'Fraction_GpC_site_modified'], np.nan, dtype=float),
85
- where=adata.obs[f'Fraction_other_C_site_modified'] != 0
116
+ adata.obs["Fraction_GpC_site_modified"],
117
+ adata.obs["Fraction_other_C_site_modified"],
118
+ out=np.full_like(adata.obs["Fraction_GpC_site_modified"], np.nan, dtype=float),
119
+ where=adata.obs["Fraction_other_C_site_modified"] != 0,
86
120
  )
87
121
 
88
122
  cpg_to_c_ratio = np.divide(
89
- adata.obs[f'Fraction_CpG_site_modified'],
90
- adata.obs[f'Fraction_other_C_site_modified'],
91
- out=np.full_like(adata.obs[f'Fraction_CpG_site_modified'], np.nan, dtype=float),
92
- where=adata.obs[f'Fraction_other_C_site_modified'] != 0
93
- )
94
-
95
- adata.obs['GpC_to_other_C_mod_ratio'] = gpc_to_c_ratio
96
- adata.obs['CpG_to_other_C_mod_ratio'] = cpg_to_c_ratio
123
+ adata.obs["Fraction_CpG_site_modified"],
124
+ adata.obs["Fraction_other_C_site_modified"],
125
+ out=np.full_like(adata.obs["Fraction_CpG_site_modified"], np.nan, dtype=float),
126
+ where=adata.obs["Fraction_other_C_site_modified"] != 0,
127
+ )
128
+
129
+ adata.obs["GpC_to_other_C_mod_ratio"] = gpc_to_c_ratio
130
+ adata.obs["CpG_to_other_C_mod_ratio"] = cpg_to_c_ratio
97
131
 
98
132
  # mark as done
99
133
  adata.uns[uns_flag] = True
100
134
 
101
- return
135
+ return
@@ -1,23 +1,33 @@
1
- def clean_NaN(adata,
2
- layer=None,
3
- uns_flag='clean_NaN_performed',
4
- bypass=False,
5
- force_redo=True
6
- ):
7
- """
8
- Append layers to adata that contain NaN cleaning strategies.
1
+ from __future__ import annotations
9
2
 
10
- Parameters:
11
- adata (AnnData): an anndata object
12
- layer (str, optional): Name of the layer to fill NaN values in. If None, uses adata.X.
3
+ from typing import TYPE_CHECKING
13
4
 
14
- Modifies:
15
- - Adds new layers to `adata.layers` with different NaN-filling strategies.
16
- """
17
- import numpy as np
18
- import pandas as pd
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
19
8
  import anndata as ad
20
- from ..readwrite import adata_to_df
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def clean_NaN(
14
+ adata: "ad.AnnData",
15
+ layer: str | None = None,
16
+ uns_flag: str = "clean_NaN_performed",
17
+ bypass: bool = False,
18
+ force_redo: bool = True,
19
+ ) -> None:
20
+ """Append layers to ``adata`` that contain NaN-cleaning strategies.
21
+
22
+ Args:
23
+ adata: AnnData object.
24
+ layer: Layer to fill NaN values in. If ``None``, uses ``adata.X``.
25
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
26
+ bypass: Whether to skip processing.
27
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
28
+ """
29
+
30
+ from ..readwrite import adata_to_df
21
31
 
22
32
  # Only run if not already performed
23
33
  already = bool(adata.uns.get(uns_flag, False))
@@ -33,30 +43,30 @@ def clean_NaN(adata,
33
43
  df = adata_to_df(adata, layer=layer)
34
44
 
35
45
  # Fill NaN with closest SMF value (forward then backward fill)
36
- print('Making layer: fill_nans_closest')
37
- adata.layers['fill_nans_closest'] = df.ffill(axis=1).bfill(axis=1).values
46
+ logger.info("Making layer: fill_nans_closest")
47
+ adata.layers["fill_nans_closest"] = df.ffill(axis=1).bfill(axis=1).values
38
48
 
39
49
  # Replace NaN with 0, and 0 with -1
40
- print('Making layer: nan0_0minus1')
50
+ logger.info("Making layer: nan0_0minus1")
41
51
  df_nan0_0minus1 = df.replace(0, -1).fillna(0)
42
- adata.layers['nan0_0minus1'] = df_nan0_0minus1.values
52
+ adata.layers["nan0_0minus1"] = df_nan0_0minus1.values
43
53
 
44
54
  # Replace NaN with 1, and 1 with 2
45
- print('Making layer: nan1_12')
55
+ logger.info("Making layer: nan1_12")
46
56
  df_nan1_12 = df.replace(1, 2).fillna(1)
47
- adata.layers['nan1_12'] = df_nan1_12.values
57
+ adata.layers["nan1_12"] = df_nan1_12.values
48
58
 
49
59
  # Replace NaN with -1
50
- print('Making layer: nan_minus_1')
60
+ logger.info("Making layer: nan_minus_1")
51
61
  df_nan_minus_1 = df.fillna(-1)
52
- adata.layers['nan_minus_1'] = df_nan_minus_1.values
62
+ adata.layers["nan_minus_1"] = df_nan_minus_1.values
53
63
 
54
64
  # Replace NaN with -1
55
- print('Making layer: nan_half')
65
+ logger.info("Making layer: nan_half")
56
66
  df_nan_half = df.fillna(0.5)
57
- adata.layers['nan_half'] = df_nan_half.values
67
+ adata.layers["nan_half"] = df_nan_half.values
58
68
 
59
69
  # mark as done
60
70
  adata.uns[uns_flag] = True
61
71
 
62
- return None
72
+ return None