smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,54 +1,76 @@
1
- def calculate_coverage(adata,
2
- ref_column='Reference_strand',
3
- position_nan_threshold=0.01,
4
- uns_flag='calculate_coverage_performed'):
5
- """
6
- Append position-level metadata regarding whether the position is informative within the given observation category.
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
7
9
 
8
- Parameters:
9
- adata (AnnData): An AnnData object
10
- obs_column (str): Observation column value to subset on prior to calculating position statistics for that category.
11
- position_nan_threshold (float): A minimal fractional threshold of coverage within the obs_column category to call the position as valid.
10
+ logger = get_logger(__name__)
12
11
 
13
- Modifies:
14
- - Adds new columns to `adata.var` containing coverage statistics.
12
+
13
+ def calculate_coverage(
14
+ adata: "ad.AnnData",
15
+ ref_column: str = "Reference_strand",
16
+ position_nan_threshold: float = 0.01,
17
+ smf_modality: str = "deaminase",
18
+ target_layer: str = "binarized_methylation",
19
+ uns_flag: str = "calculate_coverage_performed",
20
+ force_redo: bool = False,
21
+ ) -> None:
22
+ """Append position-level coverage metadata per reference category.
23
+
24
+ Args:
25
+ adata: AnnData object.
26
+ ref_column: Obs column used to define reference/strand categories.
27
+ position_nan_threshold: Minimum fraction of coverage to mark a position as valid.
28
+ smf_modality: SMF modality. Use ``adata.X`` for conversion/deaminase or ``target_layer`` for direct.
29
+ target_layer: Layer used for direct SMF coverage calculations.
30
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
31
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
15
32
  """
16
33
  import numpy as np
17
34
  import pandas as pd
18
- import anndata as ad
19
35
 
20
36
  # Only run if not already performed
21
37
  already = bool(adata.uns.get(uns_flag, False))
22
- if already:
38
+ if already and not force_redo:
23
39
  # QC already performed; nothing to do
24
40
  return
25
-
41
+
26
42
  references = adata.obs[ref_column].cat.categories
27
43
  n_categories_with_position = np.zeros(adata.shape[1])
28
44
 
29
45
  # Loop over references
30
46
  for ref in references:
31
- print(f'Assessing positional coverage across samples for {ref} reference')
47
+ logger.info("Assessing positional coverage across samples for %s reference", ref)
32
48
 
33
49
  # Subset to current category
34
50
  ref_mask = adata.obs[ref_column] == ref
35
51
  temp_ref_adata = adata[ref_mask]
36
52
 
53
+ if smf_modality == "direct":
54
+ matrix = temp_ref_adata.layers[target_layer]
55
+ else:
56
+ matrix = temp_ref_adata.X
57
+
37
58
  # Compute fraction of valid coverage
38
- ref_valid_coverage = np.sum(~np.isnan(temp_ref_adata.X), axis=0)
59
+ ref_valid_coverage = np.sum(~np.isnan(matrix), axis=0)
39
60
  ref_valid_fraction = ref_valid_coverage / temp_ref_adata.shape[0] # Avoid extra computation
40
61
 
41
62
  # Store coverage stats
42
- adata.var[f'{ref}_valid_fraction'] = pd.Series(ref_valid_fraction, index=adata.var.index)
63
+ adata.var[f"{ref}_valid_count"] = pd.Series(ref_valid_coverage, index=adata.var.index)
64
+ adata.var[f"{ref}_valid_fraction"] = pd.Series(ref_valid_fraction, index=adata.var.index)
43
65
 
44
66
  # Assign whether the position is covered based on threshold
45
- adata.var[f'position_in_{ref}'] = ref_valid_fraction >= position_nan_threshold
67
+ adata.var[f"position_in_{ref}"] = ref_valid_fraction >= position_nan_threshold
46
68
 
47
69
  # Sum the number of categories covering each position
48
- n_categories_with_position += adata.var[f'position_in_{ref}'].values
70
+ n_categories_with_position += adata.var[f"position_in_{ref}"].values
49
71
 
50
72
  # Store final category count
51
- adata.var[f'N_{ref_column}_with_position'] = n_categories_with_position.astype(int)
73
+ adata.var[f"N_{ref_column}_with_position"] = n_categories_with_position.astype(int)
52
74
 
53
75
  # mark as done
54
- adata.uns[uns_flag] = True
76
+ adata.uns[uns_flag] = True
@@ -1,5 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  # calculate_pairwise_differences
2
4
 
5
+
3
6
  def calculate_pairwise_differences(arrays):
4
7
  """
5
8
  Calculate the pairwise differences for a list of h-stacked ndarrays. Ignore N-positions
@@ -41,7 +44,7 @@ def calculate_pairwise_differences(arrays):
41
44
  # Calculate the hamming distance directly with boolean operations
42
45
  differences = (array_i != array_j) & ~combined_mask
43
46
  distance = np.sum(differences) / np.sum(~combined_mask)
44
-
47
+
45
48
  # Store the symmetric distances
46
49
  distance_matrix[i, j] = distance
47
50
  distance_matrix[j, i] = distance
@@ -1,6 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  ## calculate_pairwise_hamming_distances
2
4
 
3
- ## Conversion SMF Specific
5
+
6
+ ## Conversion SMF Specific
4
7
  def calculate_pairwise_hamming_distances(arrays):
5
8
  """
6
9
  Calculate the pairwise Hamming distances for a list of h-stacked ndarrays.
@@ -13,8 +16,9 @@ def calculate_pairwise_hamming_distances(arrays):
13
16
 
14
17
  """
15
18
  import numpy as np
16
- from tqdm import tqdm
17
19
  from scipy.spatial.distance import hamming
20
+ from tqdm import tqdm
21
+
18
22
  num_arrays = len(arrays)
19
23
  # Initialize an empty distance matrix
20
24
  distance_matrix = np.zeros((num_arrays, num_arrays))
@@ -24,4 +28,4 @@ def calculate_pairwise_hamming_distances(arrays):
24
28
  distance = hamming(arrays[i], arrays[j])
25
29
  distance_matrix[i, j] = distance
26
30
  distance_matrix[j, i] = distance
27
- return distance_matrix
31
+ return distance_matrix
@@ -1,69 +1,102 @@
1
1
  ## calculate_position_Youden
2
2
  ## Calculating and applying position level thresholds for methylation calls to binarize the SMF data
3
- def calculate_position_Youden(adata,
4
- positive_control_sample=None,
5
- negative_control_sample=None,
6
- J_threshold=0.5,
7
- ref_column='Reference_strand',
8
- sample_column='Sample_names',
9
- infer_on_percentile=True,
10
- inference_variable='Raw_modification_signal',
11
- save=False,
12
- output_directory=''):
13
- """
14
- Adds new variable metadata to each position indicating whether the position provides reliable SMF methylation calls. Also outputs plots of the positional ROC curves.
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def calculate_position_Youden(
18
+ adata: "ad.AnnData",
19
+ positive_control_sample: str | None = None,
20
+ negative_control_sample: str | None = None,
21
+ J_threshold: float = 0.5,
22
+ ref_column: str = "Reference_strand",
23
+ sample_column: str = "Sample_names",
24
+ infer_on_percentile: bool | int = True,
25
+ inference_variable: str = "Raw_modification_signal",
26
+ save: bool = False,
27
+ output_directory: str | Path = "",
28
+ ) -> None:
29
+ """Add position-level Youden thresholds and optional ROC plots.
15
30
 
16
- Parameters:
17
- adata (AnnData): An AnnData object.
18
- positive_control_sample (str): string representing the sample name corresponding to the Plus MTase control sample.
19
- negative_control_sample (str): string representing the sample name corresponding to the Minus MTase control sample.
20
- J_threshold (float): A float indicating the J-statistic used to indicate whether a position passes QC for methylation calls.
21
- obs_column (str): The category to iterate over.
22
- infer_on_perdentile (bool | int): If False, use defined postive and negative control samples. If an int (0 < int < 100) is passed, this uses the top and bottom int percentile of methylated reads based on metric in inference_variable column.
23
- inference_variable (str): If infer_on_percentile has an integer value passed, use the AnnData observation column name passed by this string as the metric.
24
- save (bool): Whether to save the ROC plots.
25
- output_directory (str): String representing the path to the output directory to output the ROC curves.
26
-
27
- Returns:
28
- None
31
+ Args:
32
+ adata: AnnData object.
33
+ positive_control_sample: Sample name for the plus MTase control.
34
+ negative_control_sample: Sample name for the minus MTase control.
35
+ J_threshold: J-statistic threshold for QC.
36
+ ref_column: Obs column for reference/strand categories.
37
+ sample_column: Obs column for sample identifiers.
38
+ infer_on_percentile: If ``False``, use provided controls. If an int in ``(0, 100)``,
39
+ use percentile-based inference from ``inference_variable``.
40
+ inference_variable: Obs column used for percentile inference.
41
+ save: Whether to save ROC plots to disk.
42
+ output_directory: Output directory for ROC plots.
29
43
  """
30
44
  import numpy as np
31
- import pandas as pd
32
- import anndata as ad
33
- import matplotlib.pyplot as plt
34
- from sklearn.metrics import roc_curve, roc_auc_score
45
+
46
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="Youden ROC plots")
47
+ sklearn_metrics = require(
48
+ "sklearn.metrics",
49
+ extra="ml-base",
50
+ purpose="Youden ROC curve calculation",
51
+ )
52
+ roc_curve = sklearn_metrics.roc_curve
35
53
 
36
54
  control_samples = [positive_control_sample, negative_control_sample]
37
- references = adata.obs[ref_column].cat.categories
55
+ references = adata.obs[ref_column].cat.categories
38
56
  # Iterate over each category in the specified obs_column
39
57
  for ref in references:
40
- print(f"Calculating position Youden statistics for {ref}")
58
+ logger.info("Calculating position Youden statistics for %s", ref)
41
59
  # Subset to keep only reads associated with the category
42
60
  ref_subset = adata[adata.obs[ref_column] == ref]
43
61
  # Iterate over positive and negative control samples
44
62
  for i, control in enumerate(control_samples):
45
- # Initialize a dictionary for the given control sample. This will be keyed by dataset and position to point to a tuple of coordinate position and an array of methylation probabilities
46
- adata.uns[f'{ref}_position_methylation_dict_{control}'] = {}
47
63
  # If controls are not passed and infer on percentile is True, infer thresholds based on top and bottom percentile windows for a given obs column metric.
48
64
  if infer_on_percentile and not control:
65
+ logger.info(
66
+ "Inferring methylation control reads for %s based on %s percentiles of %s",
67
+ ref,
68
+ infer_on_percentile,
69
+ inference_variable,
70
+ )
49
71
  sorted_column = ref_subset.obs[inference_variable].sort_values(ascending=False)
50
72
  if i == 0:
51
- control == 'positive'
73
+ logger.info("Using top %s percentile for positive control", infer_on_percentile)
74
+ control = "positive"
52
75
  positive_control_sample = control
53
76
  threshold = np.percentile(sorted_column, 100 - infer_on_percentile)
54
77
  control_subset = ref_subset[ref_subset.obs[inference_variable] >= threshold, :]
55
78
  else:
56
- control == 'negative'
79
+ logger.info(
80
+ "Using bottom %s percentile for negative control", infer_on_percentile
81
+ )
82
+ control = "negative"
57
83
  negative_control_sample = control
58
84
  threshold = np.percentile(sorted_column, infer_on_percentile)
59
- control_subset = ref_subset[ref_subset.obs[inference_variable] <= threshold, :]
85
+ control_subset = ref_subset[ref_subset.obs[inference_variable] <= threshold, :]
60
86
  elif not infer_on_percentile and not control:
61
- print("Can not threshold Anndata on Youden threshold. Need to either provide control samples or set infer_on_percentile to True")
87
+ logger.error(
88
+ "Can not threshold Anndata on Youden threshold. Need to either provide control samples or set infer_on_percentile to True"
89
+ )
62
90
  return
63
91
  else:
92
+ logger.info("Using provided control sample: %s", control)
64
93
  # get the current control subset on the given category
65
94
  filtered_obs = ref_subset.obs[ref_subset.obs[sample_column] == control]
66
95
  control_subset = ref_subset[filtered_obs.index]
96
+
97
+ # Initialize a dictionary for the given control sample. This will be keyed by dataset and position to point to a tuple of coordinate position and an array of methylation probabilities
98
+ adata.uns[f"{ref}_position_methylation_dict_{control}"] = {}
99
+
67
100
  # Iterate through every position in the control subset
68
101
  for position in range(control_subset.shape[1]):
69
102
  # Get the coordinate name associated with that position
@@ -79,29 +112,45 @@ def calculate_position_Youden(adata,
79
112
  # Get fraction coverage
80
113
  fraction_coverage = position_coverage / control_subset.shape[0]
81
114
  # Save the position and the position methylation data for the control subset
82
- adata.uns[f'{ref}_position_methylation_dict_{control}'][f'{position}'] = (position, position_data, fraction_coverage)
115
+ adata.uns[f"{ref}_position_methylation_dict_{control}"][f"{position}"] = (
116
+ position,
117
+ position_data,
118
+ fraction_coverage,
119
+ )
83
120
 
84
121
  for ref in references:
85
122
  fig, ax = plt.subplots(figsize=(6, 4))
86
- plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
87
- plt.xlabel('False Positive Rate')
88
- plt.ylabel('True Positive Rate')
89
- ax.spines['right'].set_visible(False)
90
- ax.spines['top'].set_visible(False)
123
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
124
+ plt.xlabel("False Positive Rate")
125
+ plt.ylabel("True Positive Rate")
126
+ ax.spines["right"].set_visible(False)
127
+ ax.spines["top"].set_visible(False)
91
128
  n_passed_positions = 0
92
129
  n_total_positions = 0
93
130
  # Initialize a list that will hold the positional thresholds for the category
94
131
  probability_thresholding_list = [(np.nan, np.nan)] * adata.shape[1]
95
- for i, key in enumerate(adata.uns[f'{ref}_position_methylation_dict_{positive_control_sample}'].keys()):
96
- position = int(adata.uns[f'{ref}_position_methylation_dict_{positive_control_sample}'][key][0])
97
- positive_position_array = adata.uns[f'{ref}_position_methylation_dict_{positive_control_sample}'][key][1]
98
- fraction_coverage = adata.uns[f'{ref}_position_methylation_dict_{positive_control_sample}'][key][2]
132
+ for i, key in enumerate(
133
+ adata.uns[f"{ref}_position_methylation_dict_{positive_control_sample}"].keys()
134
+ ):
135
+ position = int(
136
+ adata.uns[f"{ref}_position_methylation_dict_{positive_control_sample}"][key][0]
137
+ )
138
+ positive_position_array = adata.uns[
139
+ f"{ref}_position_methylation_dict_{positive_control_sample}"
140
+ ][key][1]
141
+ fraction_coverage = adata.uns[
142
+ f"{ref}_position_methylation_dict_{positive_control_sample}"
143
+ ][key][2]
99
144
  if fraction_coverage > 0.2:
100
145
  try:
101
- negative_position_array = adata.uns[f'{ref}_position_methylation_dict_{negative_control_sample}'][key][1]
146
+ negative_position_array = adata.uns[
147
+ f"{ref}_position_methylation_dict_{negative_control_sample}"
148
+ ][key][1]
102
149
  # Combine the negative and positive control data
103
150
  data = np.concatenate([negative_position_array, positive_position_array])
104
- labels = np.array([0] * len(negative_position_array) + [1] * len(positive_position_array))
151
+ labels = np.array(
152
+ [0] * len(negative_position_array) + [1] * len(positive_position_array)
153
+ )
105
154
  # Calculate the ROC curve
106
155
  fpr, tpr, thresholds = roc_curve(labels, data)
107
156
  # Calculate Youden's J statistic
@@ -114,18 +163,24 @@ def calculate_position_Youden(adata,
114
163
  n_total_positions += 1
115
164
  if max_J > J_threshold:
116
165
  n_passed_positions += 1
117
- plt.plot(fpr, tpr, label='ROC curve')
118
- except:
166
+ plt.plot(fpr, tpr, label="ROC curve")
167
+ except Exception:
119
168
  probability_thresholding_list[position] = (0.8, np.nan)
120
- title = f'ROC Curve for {n_passed_positions} positions with J-stat greater than {J_threshold}\n out of {n_total_positions} total positions on {ref}'
169
+ title = f"ROC Curve for {n_passed_positions} positions with J-stat greater than {J_threshold}\n out of {n_total_positions} total positions on {ref}"
121
170
  plt.title(title)
122
171
  save_name = output_directory / f"{title}.png"
123
172
  if save:
124
173
  plt.savefig(save_name)
125
174
  plt.close()
126
175
  else:
127
- plt.show()
176
+ plt.show()
128
177
 
129
- adata.var[f'{ref}_position_methylation_thresholding_Youden_stats'] = probability_thresholding_list
178
+ adata.var[f"{ref}_position_methylation_thresholding_Youden_stats"] = (
179
+ probability_thresholding_list
180
+ )
130
181
  J_max_list = [probability_thresholding_list[i][1] for i in range(adata.shape[1])]
131
- adata.var[f'{ref}_position_passed_QC'] = [True if i > J_threshold else False for i in J_max_list]
182
+ adata.var[f"{ref}_position_passed_Youden_thresholding_QC"] = [
183
+ True if i > J_threshold else False for i in J_max_list
184
+ ]
185
+
186
+ logger.info("Finished calculating position Youden statistics")
@@ -1,45 +1,74 @@
1
1
  ## calculate_read_length_stats
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from smftools.logging_utils import get_logger
8
+
9
+ if TYPE_CHECKING:
10
+ import anndata as ad
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
3
15
  # Read length QC
4
- def calculate_read_length_stats(adata, reference_column='', sample_names_col=''):
5
- """
6
- Append first valid position in a read and last valid position in the read. From this determine and append the read length.
16
+ def calculate_read_length_stats(
17
+ adata: "ad.AnnData",
18
+ reference_column: str = "",
19
+ sample_names_col: str = "",
20
+ ) -> tuple[int, int]:
21
+ """Calculate per-read length statistics and store them in ``adata.obs``.
22
+
23
+ Args:
24
+ adata: AnnData object.
25
+ reference_column: Obs column containing reference identifiers.
26
+ sample_names_col: Obs column containing sample identifiers.
7
27
 
8
- Parameters:
9
- adata (AnnData): An adata object
10
- reference_column (str): String representing the name of the Reference column to use
11
- sample_names_col (str): String representing the name of the sample name column to use
12
-
13
28
  Returns:
14
- upper_bound (int): last valid position in the dataset
15
- lower_bound (int): first valid position in the dataset
29
+ tuple[int, int]: ``(upper_bound, lower_bound)`` for valid positions in the dataset.
16
30
  """
17
31
  import numpy as np
18
- import anndata as ad
19
32
  import pandas as pd
20
33
 
21
- print('Calculating read length statistics')
34
+ logger.info("Calculating read length statistics")
22
35
 
23
36
  references = set(adata.obs[reference_column])
24
37
  sample_names = set(adata.obs[sample_names_col])
25
38
 
26
39
  ## Add basic observation-level (read-level) metadata to the object: first valid position in a read and last valid position in the read. From this determine the read length. Save two new variable which hold the first and last valid positions in the entire dataset
27
- print('calculating read length stats')
40
+ logger.info("Calculating read length stats")
28
41
  # Add some basic observation-level (read-level) metadata to the anndata object
29
- read_first_valid_position = np.array([int(adata.var_names[i]) for i in np.argmax(~np.isnan(adata.X), axis=1)])
30
- read_last_valid_position = np.array([int(adata.var_names[i]) for i in (adata.X.shape[1] - 1 - np.argmax(~np.isnan(adata.X[:, ::-1]), axis=1))])
31
- read_length = read_last_valid_position - read_first_valid_position + np.ones(len(read_first_valid_position))
42
+ read_first_valid_position = np.array(
43
+ [int(adata.var_names[i]) for i in np.argmax(~np.isnan(adata.X), axis=1)]
44
+ )
45
+ read_last_valid_position = np.array(
46
+ [
47
+ int(adata.var_names[i])
48
+ for i in (adata.X.shape[1] - 1 - np.argmax(~np.isnan(adata.X[:, ::-1]), axis=1))
49
+ ]
50
+ )
51
+ read_length = (
52
+ read_last_valid_position
53
+ - read_first_valid_position
54
+ + np.ones(len(read_first_valid_position))
55
+ )
32
56
 
33
- adata.obs['first_valid_position'] = pd.Series(read_first_valid_position, index=adata.obs.index, dtype=int)
34
- adata.obs['last_valid_position'] = pd.Series(read_last_valid_position, index=adata.obs.index, dtype=int)
35
- adata.obs['read_length'] = pd.Series(read_length, index=adata.obs.index, dtype=int)
57
+ adata.obs["first_valid_position"] = pd.Series(
58
+ read_first_valid_position, index=adata.obs.index, dtype=int
59
+ )
60
+ adata.obs["last_valid_position"] = pd.Series(
61
+ read_last_valid_position, index=adata.obs.index, dtype=int
62
+ )
63
+ adata.obs["read_length"] = pd.Series(read_length, index=adata.obs.index, dtype=int)
36
64
 
37
65
  # Define variables to hold the first and last valid position in the dataset
38
- upper_bound = int(np.nanmax(adata.obs['last_valid_position']))
39
- lower_bound = int(np.nanmin(adata.obs['first_valid_position']))
66
+ upper_bound = int(np.nanmax(adata.obs["last_valid_position"]))
67
+ lower_bound = int(np.nanmin(adata.obs["first_valid_position"]))
40
68
 
41
69
  return upper_bound, lower_bound
42
70
 
71
+
43
72
  # # Add an unstructured element to the anndata object which points to a dictionary of read lengths keyed by reference and sample name. Points to a tuple containing (mean, median, stdev) of the read lengths of the sample for the given reference strand
44
73
  # ## Plot histogram of read length data and save the median and stdev of the read lengths for each sample.
45
74
  # adata.uns['read_length_dict'] = {}
@@ -70,10 +99,10 @@ def calculate_read_length_stats(adata, reference_column='', sample_names_col='')
70
99
  # # Add a vertical line at the median
71
100
  # plt.axvline(median, color='red', linestyle='dashed', linewidth=1)
72
101
  # # Annotate the median
73
- # plt.xlim(lower_bound - 100, upper_bound + 100)
102
+ # plt.xlim(lower_bound - 100, upper_bound + 100)
74
103
  # if save_read_length_histogram:
75
104
  # save_name = output_directory + f'/{readwrite.date_string()} {title}'
76
105
  # plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
77
106
  # plt.close()
78
107
  # else:
79
- # plt.show()
108
+ # plt.show()