smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,115 @@
1
+ ## calculate_position_Youden
2
+
3
+ ## Calculating and applying position level thresholds for methylation calls to binarize the SMF data
4
+ def calculate_position_Youden(adata, positive_control_sample='positive', negative_control_sample='negative', J_threshold=0.5, obs_column='Reference', infer_on_percentile=False, inference_variable='', save=False, output_directory=''):
5
+ """
6
+ Adds new variable metadata to each position indicating whether the position provides reliable SMF methylation calls. Also outputs plots of the positional ROC curves.
7
+
8
+ Parameters:
9
+ adata (AnnData): An AnnData object.
10
+ positive_control_sample (str): string representing the sample name corresponding to the Plus MTase control sample.
11
+ negative_control_sample (str): string representing the sample name corresponding to the Minus MTase control sample.
12
+ J_threshold (float): A float indicating the J-statistic used to indicate whether a position passes QC for methylation calls.
13
+ obs_column (str): The category to iterate over.
14
+ infer_on_perdentile (bool | int): If False, use defined postive and negative control samples. If an int (0 < int < 100) is passed, this uses the top and bottom int percentile of methylated reads based on metric in inference_variable column.
15
+ inference_variable (str): If infer_on_percentile has an integer value passed, use the AnnData observation column name passed by this string as the metric.
16
+ save (bool): Whether to save the ROC plots.
17
+ output_directory (str): String representing the path to the output directory to output the ROC curves.
18
+
19
+ Returns:
20
+ None
21
+ """
22
+ import numpy as np
23
+ import pandas as pd
24
+ import anndata as ad
25
+ import matplotlib.pyplot as plt
26
+ from sklearn.metrics import roc_curve, roc_auc_score
27
+
28
+ control_samples = [positive_control_sample, negative_control_sample]
29
+ categories = adata.obs[obs_column].cat.categories
30
+ # Iterate over each category in the specified obs_column
31
+ for cat in categories:
32
+ print(f"Calculating position Youden statistics for {cat}")
33
+ # Subset to keep only reads associated with the category
34
+ cat_subset = adata[adata.obs[obs_column] == cat]
35
+ # Iterate over positive and negative control samples
36
+ for control in control_samples:
37
+ # Initialize a dictionary for the given control sample. This will be keyed by dataset and position to point to a tuple of coordinate position and an array of methylation probabilities
38
+ adata.uns[f'{cat}_position_methylation_dict_{control}'] = {}
39
+ if infer_on_percentile:
40
+ sorted_column = cat_subset.obs[inference_variable].sort_values(ascending=False)
41
+ if control == "positive":
42
+ threshold = np.percentile(sorted_column, 100 - infer_on_percentile)
43
+ control_subset = cat_subset[cat_subset.obs[inference_variable] >= threshold, :]
44
+ else:
45
+ threshold = np.percentile(sorted_column, infer_on_percentile)
46
+ control_subset = cat_subset[cat_subset.obs[inference_variable] <= threshold, :]
47
+ else:
48
+ # get the current control subset on the given category
49
+ filtered_obs = cat_subset.obs[cat_subset.obs['Sample_names'].str.contains(control, na=False, regex=True)]
50
+ control_subset = cat_subset[filtered_obs.index]
51
+ # Iterate through every position in the control subset
52
+ for position in range(control_subset.shape[1]):
53
+ # Get the coordinate name associated with that position
54
+ coordinate = control_subset.var_names[position]
55
+ # Get the array of methlyation probabilities for each read in the subset at that position
56
+ position_data = control_subset.X[:, position]
57
+ # Get the indexes of everywhere that is not a nan value
58
+ nan_mask = ~np.isnan(position_data)
59
+ # Keep only the methlyation data that has real values
60
+ position_data = position_data[nan_mask]
61
+ # Get the position data coverage
62
+ position_coverage = len(position_data)
63
+ # Get fraction coverage
64
+ fraction_coverage = position_coverage / control_subset.shape[0]
65
+ # Save the position and the position methylation data for the control subset
66
+ adata.uns[f'{cat}_position_methylation_dict_{control}'][f'{position}'] = (position, position_data, fraction_coverage)
67
+
68
+ for cat in categories:
69
+ fig, ax = plt.subplots(figsize=(6, 4))
70
+ plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
71
+ plt.xlabel('False Positive Rate')
72
+ plt.ylabel('True Positive Rate')
73
+ ax.spines['right'].set_visible(False)
74
+ ax.spines['top'].set_visible(False)
75
+ n_passed_positions = 0
76
+ n_total_positions = 0
77
+ # Initialize a list that will hold the positional thresholds for the category
78
+ probability_thresholding_list = [(np.nan, np.nan)] * adata.shape[1]
79
+ for i, key in enumerate(adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'].keys()):
80
+ position = int(adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'][key][0])
81
+ positive_position_array = adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'][key][1]
82
+ fraction_coverage = adata.uns[f'{cat}_position_methylation_dict_{positive_control_sample}'][key][2]
83
+ if fraction_coverage > 0.2:
84
+ try:
85
+ negative_position_array = adata.uns[f'{cat}_position_methylation_dict_{negative_control_sample}'][key][1]
86
+ # Combine the negative and positive control data
87
+ data = np.concatenate([negative_position_array, positive_position_array])
88
+ labels = np.array([0] * len(negative_position_array) + [1] * len(positive_position_array))
89
+ # Calculate the ROC curve
90
+ fpr, tpr, thresholds = roc_curve(labels, data)
91
+ # Calculate Youden's J statistic
92
+ J = tpr - fpr
93
+ optimal_idx = np.argmax(J)
94
+ optimal_threshold = thresholds[optimal_idx]
95
+ max_J = np.max(J)
96
+ data_tuple = (optimal_threshold, max_J)
97
+ probability_thresholding_list[position] = data_tuple
98
+ n_total_positions += 1
99
+ if max_J > J_threshold:
100
+ n_passed_positions += 1
101
+ plt.plot(fpr, tpr, label='ROC curve')
102
+ except:
103
+ probability_thresholding_list[position] = (0.8, np.nan)
104
+ title = f'ROC Curve for {n_passed_positions} positions with J-stat greater than {J_threshold}\n out of {n_total_positions} total positions on {cat}'
105
+ plt.title(title)
106
+ save_name = output_directory + f'/{title}'
107
+ if save:
108
+ plt.savefig(save_name)
109
+ plt.close()
110
+ else:
111
+ plt.show()
112
+
113
+ adata.var[f'{cat}_position_methylation_thresholding_Youden_stats'] = probability_thresholding_list
114
+ J_max_list = [probability_thresholding_list[i][1] for i in range(adata.shape[1])]
115
+ adata.var[f'{cat}_position_passed_QC'] = [True if i > J_threshold else False for i in J_max_list]
@@ -0,0 +1,79 @@
1
+ ## calculate_read_length_stats
2
+
3
+ # Read length QC
4
+ def calculate_read_length_stats(adata, reference_column='', sample_names_col=''):
5
+ """
6
+ Append first valid position in a read and last valid position in the read. From this determine and append the read length.
7
+
8
+ Parameters:
9
+ adata (AnnData): An adata object
10
+ reference_column (str): String representing the name of the Reference column to use
11
+ sample_names_col (str): String representing the name of the sample name column to use
12
+
13
+ Returns:
14
+ upper_bound (int): last valid position in the dataset
15
+ lower_bound (int): first valid position in the dataset
16
+ """
17
+ import numpy as np
18
+ import anndata as ad
19
+ import pandas as pd
20
+
21
+ print('Calculating read length statistics')
22
+
23
+ references = set(adata.obs[reference_column])
24
+ sample_names = set(adata.obs[sample_names_col])
25
+
26
+ ## Add basic observation-level (read-level) metadata to the object: first valid position in a read and last valid position in the read. From this determine the read length. Save two new variable which hold the first and last valid positions in the entire dataset
27
+ print('calculating read length stats')
28
+ # Add some basic observation-level (read-level) metadata to the anndata object
29
+ read_first_valid_position = np.array([int(adata.var_names[i]) for i in np.argmax(~np.isnan(adata.X), axis=1)])
30
+ read_last_valid_position = np.array([int(adata.var_names[i]) for i in (adata.X.shape[1] - 1 - np.argmax(~np.isnan(adata.X[:, ::-1]), axis=1))])
31
+ read_length = read_last_valid_position - read_first_valid_position + np.ones(len(read_first_valid_position))
32
+
33
+ adata.obs['first_valid_position'] = pd.Series(read_first_valid_position, index=adata.obs.index, dtype=int)
34
+ adata.obs['last_valid_position'] = pd.Series(read_last_valid_position, index=adata.obs.index, dtype=int)
35
+ adata.obs['read_length'] = pd.Series(read_length, index=adata.obs.index, dtype=int)
36
+
37
+ # Define variables to hold the first and last valid position in the dataset
38
+ upper_bound = int(np.nanmax(adata.obs['last_valid_position']))
39
+ lower_bound = int(np.nanmin(adata.obs['first_valid_position']))
40
+
41
+ return upper_bound, lower_bound
42
+
43
+ # # Add an unstructured element to the anndata object which points to a dictionary of read lengths keyed by reference and sample name. Points to a tuple containing (mean, median, stdev) of the read lengths of the sample for the given reference strand
44
+ # ## Plot histogram of read length data and save the median and stdev of the read lengths for each sample.
45
+ # adata.uns['read_length_dict'] = {}
46
+
47
+ # for reference in references:
48
+ # temp_reference_adata = adata[adata.obs[reference_column] == reference].copy()
49
+ # split_reference = reference.split('_')[0][1:]
50
+ # for sample in sample_names:
51
+ # temp_sample_adata = temp_reference_adata[temp_reference_adata.obs[sample_names_col] == sample].copy()
52
+ # temp_data = temp_sample_adata.obs['read_length']
53
+ # max_length = np.max(temp_data)
54
+ # mean = np.mean(temp_data)
55
+ # median = np.median(temp_data)
56
+ # stdev = np.std(temp_data)
57
+ # adata.uns['read_length_dict'][f'{reference}_{sample}'] = [mean, median, stdev]
58
+ # if not np.isnan(max_length):
59
+ # n_bins = int(max_length // 100)
60
+ # else:
61
+ # n_bins = 1
62
+ # if show_read_length_histogram or save_read_length_histogram:
63
+ # plt.figure(figsize=(10, 6))
64
+ # plt.text(median + 0.5, max(plt.hist(temp_data, bins=n_bins)[0]) / 2, f'Median: {median:.2f}', color='red')
65
+ # plt.hist(temp_data, bins=n_bins, alpha=0.7, color='blue', edgecolor='black')
66
+ # plt.xlabel('Read Length')
67
+ # plt.ylabel('Count')
68
+ # title = f'Read length distribution of {temp_sample_adata.shape[0]} total reads from {sample} sample on {split_reference} allele'
69
+ # plt.title(title)
70
+ # # Add a vertical line at the median
71
+ # plt.axvline(median, color='red', linestyle='dashed', linewidth=1)
72
+ # # Annotate the median
73
+ # plt.xlim(lower_bound - 100, upper_bound + 100)
74
+ # if save_read_length_histogram:
75
+ # save_name = output_directory + f'/{readwrite.date_string()} {title}'
76
+ # plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
77
+ # plt.close()
78
+ # else:
79
+ # plt.show()
@@ -0,0 +1,101 @@
1
+ def calculate_read_modification_stats(adata,
2
+ reference_column,
3
+ sample_names_col,
4
+ mod_target_bases,
5
+ uns_flag="read_modification_stats_calculated",
6
+ bypass=False,
7
+ force_redo=False
8
+ ):
9
+ """
10
+ Adds methylation/deamination statistics for each read.
11
+ Indicates the read GpC and CpG methylation ratio to other_C methylation (background false positive metric for Cytosine MTase SMF).
12
+
13
+ Parameters:
14
+ adata (AnnData): An adata object
15
+ reference_column (str): String representing the name of the Reference column to use
16
+ sample_names_col (str): String representing the name of the sample name column to use
17
+ mod_target_bases:
18
+
19
+ Returns:
20
+ None
21
+ """
22
+ import numpy as np
23
+ import anndata as ad
24
+ import pandas as pd
25
+
26
+ # Only run if not already performed
27
+ already = bool(adata.uns.get(uns_flag, False))
28
+ if (already and not force_redo) or bypass:
29
+ # QC already performed; nothing to do
30
+ return
31
+
32
+ print('Calculating read level Modification statistics')
33
+
34
+ references = set(adata.obs[reference_column])
35
+ sample_names = set(adata.obs[sample_names_col])
36
+ site_types = []
37
+
38
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
39
+ site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'any_C_site']
40
+
41
+ if 'A' in mod_target_bases:
42
+ site_types += ['A_site']
43
+
44
+ for site_type in site_types:
45
+ adata.obs[f'Modified_{site_type}_count'] = pd.Series(0, index=adata.obs_names, dtype=int)
46
+ adata.obs[f'Total_{site_type}_in_read'] = pd.Series(0, index=adata.obs_names, dtype=int)
47
+ adata.obs[f'Fraction_{site_type}_modified'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
48
+ adata.obs[f'Total_{site_type}_in_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=int)
49
+ adata.obs[f'Valid_{site_type}_in_read_vs_reference'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
50
+
51
+
52
+ for ref in references:
53
+ ref_subset = adata[adata.obs[reference_column] == ref]
54
+ for site_type in site_types:
55
+ print(f'Iterating over {ref}_{site_type}')
56
+ observation_matrix = ref_subset.obsm[f'{ref}_{site_type}']
57
+ total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
58
+ total_positions_in_reference = observation_matrix.shape[1]
59
+ fraction_valid_positions_in_read_vs_ref = total_positions_in_read / total_positions_in_reference
60
+ number_mods_in_read = np.nansum(observation_matrix, axis=1)
61
+ fraction_modified = number_mods_in_read / total_positions_in_read
62
+
63
+ fraction_modified = np.divide(
64
+ number_mods_in_read,
65
+ total_positions_in_read,
66
+ out=np.full_like(number_mods_in_read, np.nan, dtype=float),
67
+ where=total_positions_in_read != 0
68
+ )
69
+
70
+ temp_obs_data = pd.DataFrame({f'Total_{site_type}_in_read': total_positions_in_read,
71
+ f'Modified_{site_type}_count': number_mods_in_read,
72
+ f'Fraction_{site_type}_modified': fraction_modified,
73
+ f'Total_{site_type}_in_reference': total_positions_in_reference,
74
+ f'Valid_{site_type}_in_read_vs_reference': fraction_valid_positions_in_read_vs_ref},
75
+ index=ref_subset.obs.index)
76
+
77
+ adata.obs.update(temp_obs_data)
78
+
79
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
80
+ with np.errstate(divide='ignore', invalid='ignore'):
81
+ gpc_to_c_ratio = np.divide(
82
+ adata.obs[f'Fraction_GpC_site_modified'],
83
+ adata.obs[f'Fraction_other_C_site_modified'],
84
+ out=np.full_like(adata.obs[f'Fraction_GpC_site_modified'], np.nan, dtype=float),
85
+ where=adata.obs[f'Fraction_other_C_site_modified'] != 0
86
+ )
87
+
88
+ cpg_to_c_ratio = np.divide(
89
+ adata.obs[f'Fraction_CpG_site_modified'],
90
+ adata.obs[f'Fraction_other_C_site_modified'],
91
+ out=np.full_like(adata.obs[f'Fraction_CpG_site_modified'], np.nan, dtype=float),
92
+ where=adata.obs[f'Fraction_other_C_site_modified'] != 0
93
+ )
94
+
95
+ adata.obs['GpC_to_other_C_mod_ratio'] = gpc_to_c_ratio
96
+ adata.obs['CpG_to_other_C_mod_ratio'] = cpg_to_c_ratio
97
+
98
+ # mark as done
99
+ adata.uns[uns_flag] = True
100
+
101
+ return
@@ -0,0 +1,62 @@
1
+ def clean_NaN(adata,
2
+ layer=None,
3
+ uns_flag='clean_NaN_performed',
4
+ bypass=False,
5
+ force_redo=True
6
+ ):
7
+ """
8
+ Append layers to adata that contain NaN cleaning strategies.
9
+
10
+ Parameters:
11
+ adata (AnnData): an anndata object
12
+ layer (str, optional): Name of the layer to fill NaN values in. If None, uses adata.X.
13
+
14
+ Modifies:
15
+ - Adds new layers to `adata.layers` with different NaN-filling strategies.
16
+ """
17
+ import numpy as np
18
+ import pandas as pd
19
+ import anndata as ad
20
+ from ..readwrite import adata_to_df
21
+
22
+ # Only run if not already performed
23
+ already = bool(adata.uns.get(uns_flag, False))
24
+ if (already and not force_redo) or bypass:
25
+ # QC already performed; nothing to do
26
+ return
27
+
28
+ # Ensure the specified layer exists
29
+ if layer and layer not in adata.layers:
30
+ raise ValueError(f"Layer '{layer}' not found in adata.layers.")
31
+
32
+ # Convert to DataFrame
33
+ df = adata_to_df(adata, layer=layer)
34
+
35
+ # Fill NaN with closest SMF value (forward then backward fill)
36
+ print('Making layer: fill_nans_closest')
37
+ adata.layers['fill_nans_closest'] = df.ffill(axis=1).bfill(axis=1).values
38
+
39
+ # Replace NaN with 0, and 0 with -1
40
+ print('Making layer: nan0_0minus1')
41
+ df_nan0_0minus1 = df.replace(0, -1).fillna(0)
42
+ adata.layers['nan0_0minus1'] = df_nan0_0minus1.values
43
+
44
+ # Replace NaN with 1, and 1 with 2
45
+ print('Making layer: nan1_12')
46
+ df_nan1_12 = df.replace(1, 2).fillna(1)
47
+ adata.layers['nan1_12'] = df_nan1_12.values
48
+
49
+ # Replace NaN with -1
50
+ print('Making layer: nan_minus_1')
51
+ df_nan_minus_1 = df.fillna(-1)
52
+ adata.layers['nan_minus_1'] = df_nan_minus_1.values
53
+
54
+ # Replace NaN with -1
55
+ print('Making layer: nan_half')
56
+ df_nan_half = df.fillna(0.5)
57
+ adata.layers['nan_half'] = df_nan_half.values
58
+
59
+ # mark as done
60
+ adata.uns[uns_flag] = True
61
+
62
+ return None
@@ -0,0 +1,31 @@
1
+ ## filter_adata_by_nan_proportion
2
+
3
+ def filter_adata_by_nan_proportion(adata, threshold, axis='obs'):
4
+ """
5
+ Filters an anndata object on a nan proportion threshold in a given matrix axis.
6
+
7
+ Parameters:
8
+ adata (AnnData):
9
+ threshold (float): The max np.nan content to allow in the given axis.
10
+ axis (str): Whether to filter the adata based on obs or var np.nan content
11
+ Returns:
12
+ filtered_adata
13
+ """
14
+ import numpy as np
15
+ import anndata as ad
16
+
17
+ if axis == 'obs':
18
+ # Calculate the proportion of NaN values in each read
19
+ nan_proportion = np.isnan(adata.X).mean(axis=1)
20
+ # Filter reads to keep reads with less than a certain NaN proportion
21
+ filtered_indices = np.where(nan_proportion <= threshold)[0]
22
+ filtered_adata = adata[filtered_indices, :].copy()
23
+ elif axis == 'var':
24
+ # Calculate the proportion of NaN values at a given position
25
+ nan_proportion = np.isnan(adata.X).mean(axis=0)
26
+ # Filter positions to keep positions with less than a certain NaN proportion
27
+ filtered_indices = np.where(nan_proportion <= threshold)[0]
28
+ filtered_adata = adata[:, filtered_indices].copy()
29
+ else:
30
+ raise ValueError("Axis must be either 'obs' or 'var'")
31
+ return filtered_adata
@@ -0,0 +1,158 @@
1
+ from typing import Optional, Union, Sequence
2
+ import numpy as np
3
+ import pandas as pd
4
+ import anndata as ad
5
+
6
+ def filter_reads_on_length_quality_mapping(
7
+ adata: ad.AnnData,
8
+ filter_on_coordinates: Union[bool, Sequence] = False,
9
+ # New single-range params (preferred):
10
+ read_length: Optional[Sequence[float]] = None, # e.g. [min, max]
11
+ length_ratio: Optional[Sequence[float]] = None, # e.g. [min, max]
12
+ read_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
13
+ mapping_quality: Optional[Sequence[float]] = None, # e.g. [min, max] (commonly min only)
14
+ uns_flag: str = "reads_removed_failing_length_quality_mapping_qc",
15
+ bypass: bool = False,
16
+ force_redo: bool = True
17
+ ) -> ad.AnnData:
18
+ """
19
+ Filter AnnData by coordinate window, read length, length ratios, read quality and mapping quality.
20
+
21
+ New: you may pass `read_length=[min, max]` (or tuple) to set both min/max in one argument.
22
+ If `read_length` is given it overrides scalar min/max variants (which are not present in this signature).
23
+ Same behavior supported for `length_ratio`, `read_quality`, `mapping_quality`.
24
+
25
+ Returns a filtered copy of the input AnnData and marks adata.uns[uns_flag] = True.
26
+ """
27
+ # early exit
28
+ already = bool(adata.uns.get(uns_flag, False))
29
+ if bypass or (already and not force_redo):
30
+ return adata
31
+
32
+ adata_work = adata
33
+ start_n = adata_work.n_obs
34
+
35
+ # --- coordinate filtering (unchanged) ---
36
+ if filter_on_coordinates:
37
+ try:
38
+ low, high = tuple(filter_on_coordinates)
39
+ except Exception:
40
+ raise ValueError("filter_on_coordinates must be False or an iterable of two numbers (low, high).")
41
+ try:
42
+ var_coords = np.array([float(v) for v in adata_work.var_names])
43
+ if low > high:
44
+ low, high = high, low
45
+ col_mask_bool = (var_coords >= float(low)) & (var_coords <= float(high))
46
+ if not col_mask_bool.any():
47
+ start_idx = int(np.argmin(np.abs(var_coords - float(low))))
48
+ end_idx = int(np.argmin(np.abs(var_coords - float(high))))
49
+ lo_idx, hi_idx = min(start_idx, end_idx), max(start_idx, end_idx)
50
+ selected_cols = list(adata_work.var_names[lo_idx : hi_idx + 1])
51
+ else:
52
+ selected_cols = list(adata_work.var_names[col_mask_bool])
53
+ print(f"Subsetting adata to coordinates between {low} and {high}: keeping {len(selected_cols)} variables.")
54
+ adata_work = adata_work[:, selected_cols].copy()
55
+ except Exception:
56
+ print("Warning: could not interpret adata.var_names as numeric coordinates — skipping coordinate filtering.")
57
+
58
+ # --- helper to coerce range inputs ---
59
+ def _coerce_range(range_arg):
60
+ """
61
+ Given range_arg which may be None or a 2-seq [min,max], return (min_or_None, max_or_None).
62
+ If both present and min>max they are swapped.
63
+ """
64
+ if range_arg is None:
65
+ return None, None
66
+ if not isinstance(range_arg, (list, tuple, np.ndarray)) or len(range_arg) != 2:
67
+ # not a 2-element range -> treat as no restriction (or you could raise)
68
+ return None, None
69
+ lo_raw, hi_raw = range_arg[0], range_arg[1]
70
+ lo = None if lo_raw is None else float(lo_raw)
71
+ hi = None if hi_raw is None else float(hi_raw)
72
+ if (lo is not None) and (hi is not None) and lo > hi:
73
+ lo, hi = hi, lo
74
+ return lo, hi
75
+
76
+ # Resolve ranges using only the provided range arguments
77
+ rl_min, rl_max = _coerce_range(read_length)
78
+ lr_min, lr_max = _coerce_range(length_ratio)
79
+ rq_min, rq_max = _coerce_range(read_quality)
80
+ mq_min, mq_max = _coerce_range(mapping_quality)
81
+
82
+ # --- build combined mask ---
83
+ combined_mask = pd.Series(True, index=adata_work.obs.index)
84
+
85
+ # read length filter
86
+ if (rl_min is not None) or (rl_max is not None):
87
+ if "mapped_length" not in adata_work.obs.columns:
88
+ print("Warning: 'mapped_length' not found in adata.obs — skipping read_length filter.")
89
+ else:
90
+ vals = pd.to_numeric(adata_work.obs["mapped_length"], errors="coerce")
91
+ mask = pd.Series(True, index=adata_work.obs.index)
92
+ if rl_min is not None:
93
+ mask &= (vals >= rl_min)
94
+ if rl_max is not None:
95
+ mask &= (vals <= rl_max)
96
+ mask &= vals.notna()
97
+ combined_mask &= mask
98
+ print(f"Planned read_length filter: min={rl_min}, max={rl_max}")
99
+
100
+ # length ratio filter
101
+ if (lr_min is not None) or (lr_max is not None):
102
+ if "mapped_length_to_reference_length_ratio" not in adata_work.obs.columns:
103
+ print("Warning: 'mapped_length_to_reference_length_ratio' not found in adata.obs — skipping length_ratio filter.")
104
+ else:
105
+ vals = pd.to_numeric(adata_work.obs["mapped_length_to_reference_length_ratio"], errors="coerce")
106
+ mask = pd.Series(True, index=adata_work.obs.index)
107
+ if lr_min is not None:
108
+ mask &= (vals >= lr_min)
109
+ if lr_max is not None:
110
+ mask &= (vals <= lr_max)
111
+ mask &= vals.notna()
112
+ combined_mask &= mask
113
+ print(f"Planned length_ratio filter: min={lr_min}, max={lr_max}")
114
+
115
+ # read quality filter (supporting optional range but typically min only)
116
+ if (rq_min is not None) or (rq_max is not None):
117
+ if "read_quality" not in adata_work.obs.columns:
118
+ print("Warning: 'read_quality' not found in adata.obs — skipping read_quality filter.")
119
+ else:
120
+ vals = pd.to_numeric(adata_work.obs["read_quality"], errors="coerce")
121
+ mask = pd.Series(True, index=adata_work.obs.index)
122
+ if rq_min is not None:
123
+ mask &= (vals >= rq_min)
124
+ if rq_max is not None:
125
+ mask &= (vals <= rq_max)
126
+ mask &= vals.notna()
127
+ combined_mask &= mask
128
+ print(f"Planned read_quality filter: min={rq_min}, max={rq_max}")
129
+
130
+ # mapping quality filter (supporting optional range but typically min only)
131
+ if (mq_min is not None) or (mq_max is not None):
132
+ if "mapping_quality" not in adata_work.obs.columns:
133
+ print("Warning: 'mapping_quality' not found in adata.obs — skipping mapping_quality filter.")
134
+ else:
135
+ vals = pd.to_numeric(adata_work.obs["mapping_quality"], errors="coerce")
136
+ mask = pd.Series(True, index=adata_work.obs.index)
137
+ if mq_min is not None:
138
+ mask &= (vals >= mq_min)
139
+ if mq_max is not None:
140
+ mask &= (vals <= mq_max)
141
+ mask &= vals.notna()
142
+ combined_mask &= mask
143
+ print(f"Planned mapping_quality filter: min={mq_min}, max={mq_max}")
144
+
145
+ # Apply combined mask and report
146
+ s0 = adata_work.n_obs
147
+ combined_mask_bool = combined_mask.astype(bool).values
148
+ adata_work = adata_work[combined_mask_bool].copy()
149
+ s1 = adata_work.n_obs
150
+ print(f"Combined filters applied: kept {s1} / {s0} reads (removed {s0 - s1})")
151
+
152
+ final_n = adata_work.n_obs
153
+ print(f"Filtering complete: start={start_n}, final={final_n}, removed={start_n - final_n}")
154
+
155
+ # mark as done
156
+ adata_work.uns[uns_flag] = True
157
+
158
+ return adata_work