smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,82 +0,0 @@
1
- ## append_C_context
2
-
3
- ## Conversion SMF Specific
4
- # Read methylation QC
5
- def append_C_context(adata, obs_column='Reference', use_consensus=False, native=False):
6
- """
7
- Adds Cytosine context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
8
-
9
- Parameters:
10
- adata (AnnData): The input adata object.
11
- obs_column (str): The observation column in which to stratify on. Default is 'Reference', which should not be changed for most purposes.
12
- use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
13
- native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
14
-
15
- Returns:
16
- None
17
- """
18
- import numpy as np
19
- import anndata as ad
20
-
21
- print('Adding Cytosine context based on reference FASTA sequence for sample')
22
-
23
- site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C', 'any_C_site']
24
- categories = adata.obs[obs_column].cat.categories
25
- for cat in categories:
26
- # Assess if the strand is the top or bottom strand converted
27
- if 'top' in cat:
28
- strand = 'top'
29
- elif 'bottom' in cat:
30
- strand = 'bottom'
31
-
32
- if native:
33
- basename = cat.split(f"_{strand}")[0]
34
- if use_consensus:
35
- sequence = adata.uns[f'{basename}_consensus_sequence']
36
- else:
37
- # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
38
- sequence = adata.uns[f'{basename}_FASTA_sequence']
39
- else:
40
- basename = cat.split(f"_{strand}")[0]
41
- if use_consensus:
42
- sequence = adata.uns[f'{basename}_consensus_sequence']
43
- else:
44
- # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
45
- sequence = adata.uns[f'{basename}_FASTA_sequence']
46
- # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
47
- boolean_dict = {}
48
- for site_type in site_types:
49
- boolean_dict[f'{cat}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
50
-
51
- if strand == 'top':
52
- # Iterate through the sequence and apply the criteria
53
- for i in range(1, len(sequence) - 1):
54
- if sequence[i] == 'C':
55
- boolean_dict[f'{cat}_any_C_site'][i] = True
56
- if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
57
- boolean_dict[f'{cat}_GpC_site'][i] = True
58
- elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
59
- boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
60
- elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
61
- boolean_dict[f'{cat}_CpG_site'][i] = True
62
- elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
63
- boolean_dict[f'{cat}_other_C'][i] = True
64
- elif strand == 'bottom':
65
- # Iterate through the sequence and apply the criteria
66
- for i in range(1, len(sequence) - 1):
67
- if sequence[i] == 'G':
68
- boolean_dict[f'{cat}_any_C_site'][i] = True
69
- if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
70
- boolean_dict[f'{cat}_GpC_site'][i] = True
71
- elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
72
- boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
73
- elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
74
- boolean_dict[f'{cat}_CpG_site'][i] = True
75
- elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
76
- boolean_dict[f'{cat}_other_C'][i] = True
77
- else:
78
- print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
79
-
80
- for site_type in site_types:
81
- adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
82
- adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
@@ -1,94 +0,0 @@
1
- ## calculate_converted_read_methylation_stats
2
-
3
- ## Conversion SMF Specific
4
- # Read methylation QC
5
-
6
- def calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col):
7
- """
8
- Adds methylation statistics for each read. Indicates whether the read GpC methylation exceeded other_C methylation (background false positives).
9
-
10
- Parameters:
11
- adata (AnnData): An adata object
12
- reference_column (str): String representing the name of the Reference column to use
13
- sample_names_col (str): String representing the name of the sample name column to use
14
-
15
- Returns:
16
- None
17
- """
18
- import numpy as np
19
- import anndata as ad
20
- import pandas as pd
21
-
22
- print('Calculating read level methylation statistics')
23
-
24
- references = set(adata.obs[reference_column])
25
- sample_names = set(adata.obs[sample_names_col])
26
-
27
- site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C']
28
-
29
- for site_type in site_types:
30
- adata.obs[f'{site_type}_row_methylation_sums'] = pd.Series(0, index=adata.obs_names, dtype=int)
31
- adata.obs[f'{site_type}_row_methylation_means'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
32
- adata.obs[f'number_valid_{site_type}_in_read'] = pd.Series(0, index=adata.obs_names, dtype=int)
33
- adata.obs[f'fraction_valid_{site_type}_in_range'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
34
- for cat in references:
35
- cat_subset = adata[adata.obs[reference_column] == cat].copy()
36
- for site_type in site_types:
37
- print(f'Iterating over {cat}_{site_type}')
38
- observation_matrix = cat_subset.obsm[f'{cat}_{site_type}']
39
- number_valid_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
40
- row_methylation_sums = np.nansum(observation_matrix, axis=1)
41
- number_valid_positions_in_read[number_valid_positions_in_read == 0] = 1
42
- fraction_valid_positions_in_range = number_valid_positions_in_read / np.max(number_valid_positions_in_read)
43
- row_methylation_means = np.divide(row_methylation_sums, number_valid_positions_in_read)
44
- temp_obs_data = pd.DataFrame({f'number_valid_{site_type}_in_read': number_valid_positions_in_read,
45
- f'fraction_valid_{site_type}_in_range': fraction_valid_positions_in_range,
46
- f'{site_type}_row_methylation_sums': row_methylation_sums,
47
- f'{site_type}_row_methylation_means': row_methylation_means}, index=cat_subset.obs.index)
48
- adata.obs.update(temp_obs_data)
49
- # Indicate whether the read-level GpC methylation rate exceeds the false methylation rate of the read
50
- pass_array = np.array(adata.obs[f'GpC_site_row_methylation_means'] > adata.obs[f'other_C_row_methylation_means'])
51
- adata.obs['GpC_above_other_C'] = pd.Series(pass_array, index=adata.obs.index, dtype=bool)
52
-
53
- # Below should be a plotting function
54
- # adata.uns['methylation_dict'] = {}
55
- # n_bins = 50
56
- # site_types_to_analyze = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C']
57
-
58
- # for reference in references:
59
- # reference_adata = adata[adata.obs[reference_column] == reference].copy()
60
- # split_reference = reference.split('_')[0][1:]
61
- # for sample in sample_names:
62
- # sample_adata = reference_adata[reference_adata.obs[sample_names_col] == sample].copy()
63
- # for site_type in site_types_to_analyze:
64
- # methylation_data = sample_adata.obs[f'{site_type}_row_methylation_means']
65
- # max_meth = np.max(sample_adata.obs[f'{site_type}_row_methylation_sums'])
66
- # if not np.isnan(max_meth):
67
- # n_bins = int(max_meth // 2)
68
- # else:
69
- # n_bins = 1
70
- # mean = np.mean(methylation_data)
71
- # median = np.median(methylation_data)
72
- # stdev = np.std(methylation_data)
73
- # adata.uns['methylation_dict'][f'{reference}_{sample}_{site_type}'] = [mean, median, stdev]
74
- # if show_methylation_histogram or save_methylation_histogram:
75
- # fig, ax = plt.subplots(figsize=(6, 4))
76
- # count, bins, patches = plt.hist(methylation_data, bins=n_bins, weights=np.ones(len(methylation_data)) / len(methylation_data), alpha=0.7, color='blue', edgecolor='black')
77
- # plt.axvline(median, color='red', linestyle='dashed', linewidth=1)
78
- # plt.text(median + stdev, max(count)*0.8, f'Median: {median:.2f}', color='red')
79
- # plt.axvline(median - stdev, color='green', linestyle='dashed', linewidth=1, label=f'Stdev: {stdev:.2f}')
80
- # plt.axvline(median + stdev, color='green', linestyle='dashed', linewidth=1)
81
- # plt.text(median + stdev + 0.05, max(count) / 3, f'+1 Stdev: {stdev:.2f}', color='green')
82
- # plt.xlabel('Fraction methylated')
83
- # plt.ylabel('Proportion')
84
- # title = f'Distribution of {methylation_data.shape[0]} read {site_type} methylation means \nfor {sample} sample on {split_reference} after filtering'
85
- # plt.title(title, pad=20)
86
- # plt.xlim(-0.05, 1.05) # Set x-axis range from 0 to 1
87
- # ax.spines['right'].set_visible(False)
88
- # ax.spines['top'].set_visible(False)
89
- # save_name = output_directory + f'/{readwrite.date_string()} {title}'
90
- # if save_methylation_histogram:
91
- # plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
92
- # plt.close()
93
- # else:
94
- # plt.show()
@@ -1,44 +0,0 @@
1
- ## filter_converted_reads_on_methylation
2
-
3
- ## Conversion SMF Specific
4
- def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, min_SMF_threshold=0.025, max_SMF_threshold=0.975):
5
- """
6
- Filter adata object using minimum thresholds for valid SMF site fraction in read, as well as minimum methylation content in read.
7
-
8
- Parameters:
9
- adata (AnnData): An adata object.
10
- valid_SMF_site_threshold (float): A minimum proportion of valid SMF sites that must be present in the read. Default is 0.8
11
- min_SMF_threshold (float): A minimum read methylation level. Default is 0.025
12
- Returns:
13
- Anndata
14
- """
15
- import numpy as np
16
- import anndata as ad
17
- import pandas as pd
18
-
19
- if valid_SMF_site_threshold:
20
- # Keep reads that have over a given valid GpC site content
21
- adata = adata[adata.obs['fraction_valid_GpC_site_in_range'] > valid_SMF_site_threshold].copy()
22
-
23
- if min_SMF_threshold:
24
- # Keep reads with SMF methylation over background methylation.
25
- below_background = (~adata.obs['GpC_above_other_C']).sum()
26
- print(f'Removing {below_background} reads that have GpC conversion below background conversion rate')
27
- adata = adata[adata.obs['GpC_above_other_C'] == True].copy()
28
- # Keep reads over a defined methylation threshold
29
- s0 = adata.shape[0]
30
- adata = adata[adata.obs['GpC_site_row_methylation_means'] > min_SMF_threshold].copy()
31
- s1 = adata.shape[0]
32
- below_threshold = s0 - s1
33
- print(f'Removing {below_threshold} reads that have GpC conversion below a minimum threshold conversion rate')
34
-
35
- if max_SMF_threshold:
36
- # Keep reads below a defined methylation threshold
37
- s0 = adata.shape[0]
38
- adata = adata[adata.obs['GpC_site_row_methylation_means'] < max_SMF_threshold].copy()
39
- s1 = adata.shape[0]
40
- above_threshold = s0 - s1
41
- print(f'Removing {above_threshold} reads that have GpC conversion above a maximum threshold conversion rate')
42
-
43
- return adata
44
-
@@ -1,51 +0,0 @@
1
- ## filter_reads_on_length
2
-
3
- def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2700, max_read_length=3200):
4
- """
5
- Filters the adata object to keep a defined coordinate window, as well as reads that are over a minimum threshold in length.
6
-
7
- Parameters:
8
- adata (AnnData): An adata object.
9
- filter_on_coordinates (bool | list): If False, skips filtering. Otherwise, provide a list containing integers representing the lower and upper bound coordinates to filter on. Default is False.
10
- min_read_length (int): The minimum read length to keep in the filtered dataset. Default is 2700.
11
- max_read_length (int): The maximum query read length to keep in the filtered dataset. Default is 3200.
12
-
13
- Returns:
14
- adata
15
- """
16
- import numpy as np
17
- import anndata as ad
18
- import pandas as pd
19
-
20
- if filter_on_coordinates:
21
- lower_bound, upper_bound = filter_on_coordinates
22
- # Extract the position information from the adata object as an np array
23
- var_names_arr = adata.var_names.astype(int).to_numpy()
24
- # Find the upper bound coordinate that is closest to the specified value
25
- closest_end_index = np.argmin(np.abs(var_names_arr - upper_bound))
26
- upper_bound = int(adata.var_names[closest_end_index])
27
- # Find the lower bound coordinate that is closest to the specified value
28
- closest_start_index = np.argmin(np.abs(var_names_arr - lower_bound))
29
- lower_bound = int(adata.var_names[closest_start_index])
30
- # Get a list of positional indexes that encompass the lower and upper bounds of the dataset
31
- position_list = list(range(lower_bound, upper_bound + 1))
32
- position_list = [str(pos) for pos in position_list]
33
- position_set = set(position_list)
34
- print(f'Subsetting adata to keep data between coordinates {lower_bound} and {upper_bound}')
35
- adata = adata[:, adata.var_names.isin(position_set)].copy()
36
-
37
- if min_read_length:
38
- print(f'Subsetting adata to keep reads longer than {min_read_length}')
39
- s0 = adata.shape[0]
40
- adata = adata[adata.obs['read_length'] > min_read_length].copy()
41
- s1 = adata.shape[0]
42
- print(f'Removed {s0-s1} reads less than {min_read_length} basepairs in length')
43
-
44
- if max_read_length:
45
- print(f'Subsetting adata to keep reads shorter than {max_read_length}')
46
- s0 = adata.shape[0]
47
- adata = adata[adata.obs['read_length'] < max_read_length].copy()
48
- s1 = adata.shape[0]
49
- print(f'Removed {s0-s1} reads greater than {max_read_length} basepairs in length')
50
-
51
- return adata
@@ -1,105 +0,0 @@
1
- def call_hmm_peaks(adata, feature_configs, obs_column='Reference_strand', site_types=['GpC_site', 'CpG_site'], save_plot=False, output_dir=None, date_tag=None):
2
- """
3
- Calls peaks from HMM feature layers and annotates them into the AnnData object.
4
-
5
- Parameters:
6
- adata : AnnData object with HMM layers (from apply_hmm)
7
- feature_configs : dict
8
- min_distance : minimum distance between peaks
9
- peak_width : window size around peak centers
10
- peak_prominence : required peak prominence
11
- peak_threshold : threshold for labeling a read as "present" at a peak
12
- site_types : list of var site types to aggregate
13
- save_plot : whether to save the plot
14
- output_dir : path to save the figure if save_plot=True
15
- date_tag : optional tag for filename
16
- """
17
- import matplotlib.pyplot as plt
18
- from scipy.signal import find_peaks
19
- import os
20
- import numpy as np
21
-
22
- peak_columns = []
23
-
24
- for feature_layer, config in feature_configs.items():
25
- min_distance = config.get('min_distance', 200)
26
- peak_width = config.get('peak_width', 200)
27
- peak_prominence = config.get('peak_prominence', 0.2)
28
- peak_threshold = config.get('peak_threshold', 0.8)
29
-
30
- # 1️⃣ Calculate mean intensity profile
31
- matrix = adata.layers[feature_layer]
32
- means = np.mean(matrix, axis=0)
33
- feature_peak_columns = []
34
-
35
- # 2️⃣ Peak calling
36
- peak_centers, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
- adata.uns[f'{feature_layer} peak_centers'] = peak_centers
38
-
39
- # 3️⃣ Plot
40
- plt.figure(figsize=(6, 3))
41
- plt.plot(range(len(means)), means)
42
- plt.title(f"{feature_layer} density with peak calls")
43
- plt.xlabel("Genomic position")
44
- plt.ylabel("Mean feature density")
45
- y = max(means) / 2
46
- for i, center in enumerate(peak_centers):
47
- plus_minus_width = peak_width // 2
48
- start = center - plus_minus_width
49
- end = center + plus_minus_width
50
- plt.axvspan(start, end, color='purple', alpha=0.2)
51
- plt.axvline(center, color='red', linestyle='--')
52
- if i%2:
53
- aligned = [end, 'left']
54
- else:
55
- aligned = [start, 'right']
56
- plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
57
-
58
- if save_plot and output_dir:
59
- filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
60
- plt.savefig(filename, bbox_inches='tight')
61
- print(f"Saved plot to {filename}")
62
- else:
63
- plt.show()
64
-
65
- # 4️⃣ Annotate peaks back into adata.obs
66
- for center in peak_centers:
67
- half_width = peak_width // 2
68
- start, end = center - half_width, center + half_width
69
- colname = f'{feature_layer}_peak_{center}'
70
- peak_columns.append(colname)
71
- feature_peak_columns.append(colname)
72
-
73
- adata.var[colname] = (
74
- (adata.var_names.astype(int) >= start) &
75
- (adata.var_names.astype(int) <= end)
76
- )
77
-
78
- # Feature layer intensity around peak
79
- mean_values = np.mean(matrix[:, start:end+1], axis=1)
80
- sum_values = np.sum(matrix[:, start:end+1], axis=1)
81
- adata.obs[f'mean_{feature_layer}_around_{center}'] = mean_values
82
- adata.obs[f'sum_{feature_layer}_around_{center}'] = sum_values
83
- adata.obs[f'{feature_layer}_present_at_{center}'] = mean_values > peak_threshold
84
-
85
- # Site-type based aggregation
86
- for site_type in site_types:
87
- adata.obs[f'{site_type}_sum_around_{center}'] = 0
88
- adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
89
-
90
- references = adata.obs[obs_column].cat.categories
91
- for ref in adata.obs[obs_column].cat.categories:
92
- subset = adata[adata.obs[obs_column] == ref]
93
- for site_type in site_types:
94
- mask = subset.var.get(f'{ref}_{site_type}', None)
95
- if mask is not None:
96
- region_mask = (subset.var_names[mask].astype(int) >= start) & (subset.var_names[mask].astype(int) <= end)
97
- region = subset[:, mask].X[:, region_mask]
98
- adata.obs.loc[subset.obs.index, f'{site_type}_sum_around_{center}'] = np.nansum(region, axis=1)
99
- adata.obs.loc[subset.obs.index, f'{site_type}_mean_around_{center}'] = np.nanmean(region, axis=1)
100
-
101
- adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
102
- print(f"✅ Peak annotation completed for {feature_layer} with {len(peak_centers)} peaks.")
103
-
104
- # Combine all peaks into a single "is_in_any_peak" column
105
- adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
@@ -1,2 +0,0 @@
1
- from .anndata_data_module import AnnDataModule
2
- from .preprocessing import random_fill_nans
@@ -1,90 +0,0 @@
1
- import torch
2
- from torch.utils.data import DataLoader, TensorDataset, random_split
3
- import pytorch_lightning as pl
4
- import numpy as np
5
- import pandas as pd
6
-
7
- class AnnDataModule(pl.LightningDataModule):
8
- def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
9
- batch_size=64, train_frac=0.7, random_seed=42, split_col='train_val_split', split_save_path=None, load_existing_split=False,
10
- inference_mode=False):
11
- super().__init__()
12
- self.adata = adata # The adata object
13
- self.tensor_source = tensor_source # X, layers, obsm
14
- self.tensor_key = tensor_key # name of the layer or obsm key
15
- self.label_col = label_col # name of the label column in obs
16
- self.batch_size = batch_size
17
- self.train_frac = train_frac
18
- self.random_seed = random_seed
19
- self.split_col = split_col # Name of obs column to store "train"/"val"
20
- self.split_save_path = split_save_path # Where to save the obs_names and train/test split logging
21
- self.load_existing_split = load_existing_split # Whether to load from an existing split
22
- self.inference_mode = inference_mode # Whether to load the AnnDataModule in inference mode.
23
-
24
- def setup(self, stage=None):
25
- # Load feature matrix
26
- if self.tensor_source == "X":
27
- X = self.adata.X
28
- elif self.tensor_source == "layers":
29
- assert self.tensor_key in self.adata.layers, f"Layer '{self.tensor_key}' not found."
30
- X = self.adata.layers[self.tensor_key]
31
- elif self.tensor_source == "obsm":
32
- assert self.tensor_key in self.adata.obsm, f"obsm key '{self.tensor_key}' not found."
33
- X = self.adata.obsm[self.tensor_key]
34
- else:
35
- raise ValueError(f"Invalid tensor_source: {self.tensor_source}")
36
-
37
- # Convert to tensor
38
- X_tensor = torch.tensor(X, dtype=torch.float32)
39
-
40
- if self.inference_mode:
41
- self.infer_dataset = TensorDataset(X_tensor)
42
-
43
- else:
44
- # Load and encode labels
45
- y = self.adata.obs[self.label_col]
46
- if y.dtype.name == 'category':
47
- y = y.cat.codes
48
- y_tensor = torch.tensor(y.values, dtype=torch.long)
49
-
50
- # Use existing split
51
- if self.load_existing_split:
52
- split_df = pd.read_csv(self.split_save_path, index_col=0)
53
- assert self.split_col in split_df.columns, f"'{self.split_col}' column missing in split file."
54
- self.adata.obs[self.split_col] = split_df.loc[self.adata.obs_names][self.split_col].values
55
-
56
- # If no split exists, create one
57
- if self.split_col not in self.adata.obs:
58
- full_dataset = TensorDataset(X_tensor, y_tensor)
59
- n_train = int(self.train_frac * len(full_dataset))
60
- n_val = len(full_dataset) - n_train
61
- self.train_set, self.val_set = random_split(
62
- full_dataset, [n_train, n_val],
63
- generator=torch.Generator().manual_seed(self.random_seed)
64
- )
65
- # Assign split labels
66
- split_array = np.full(len(self.adata), "val", dtype=object)
67
- train_idx = self.train_set.indices if hasattr(self.train_set, "indices") else self.train_set._indices
68
- split_array[train_idx] = "train"
69
- self.adata.obs[self.split_col] = split_array
70
-
71
- # Save to disk
72
- if self.split_save_path:
73
- self.adata.obs[[self.split_col]].to_csv(self.split_save_path)
74
- else:
75
- split_labels = self.adata.obs[self.split_col].values
76
- train_mask = split_labels == "train"
77
- val_mask = split_labels == "val"
78
- self.train_set = TensorDataset(X_tensor[train_mask], y_tensor[train_mask])
79
- self.val_set = TensorDataset(X_tensor[val_mask], y_tensor[val_mask])
80
-
81
- def train_dataloader(self):
82
- return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
83
-
84
- def val_dataloader(self):
85
- return DataLoader(self.val_set, batch_size=self.batch_size)
86
-
87
- def predict_dataloader(self):
88
- if not self.inference_mode:
89
- raise RuntimeError("predict_dataloader only available in inference mode.")
90
- return DataLoader(self.infer_dataset, batch_size=self.batch_size)
@@ -1 +0,0 @@
1
- from .lightning_inference import run_lightning_inference
@@ -1,41 +0,0 @@
1
- import torch
2
- import pandas as pd
3
- import numpy as np
4
- from pytorch_lightning import Trainer
5
-
6
- def run_lightning_inference(
7
- adata,
8
- model,
9
- datamodule,
10
- label_col="labels",
11
- prefix="model"
12
- ):
13
-
14
- # Get class labels
15
- if label_col in adata.obs and pd.api.types.is_categorical_dtype(adata.obs[label_col]):
16
- class_labels = adata.obs[label_col].cat.categories.tolist()
17
- else:
18
- raise ValueError("label_col must be a categorical column in adata.obs")
19
-
20
- # Run predictions
21
- trainer = Trainer(accelerator="auto", devices=1, logger=False, enable_checkpointing=False)
22
- preds = trainer.predict(model, datamodule=datamodule)
23
- probs = torch.cat(preds, dim=0).cpu().numpy() # (N, C)
24
- pred_class_idx = probs.argmax(axis=1)
25
- pred_class_labels = [class_labels[i] for i in pred_class_idx]
26
- pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
27
-
28
- # Construct full prefix with label_col
29
- full_prefix = f"{prefix}_{label_col}"
30
-
31
- # Store predictions in obs
32
- adata.obs[f"{full_prefix}_pred"] = pred_class_idx
33
- adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
34
- adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
35
-
36
- # Per-class probabilities
37
- for i, class_name in enumerate(class_labels):
38
- adata.obs[f"{full_prefix}_prob_{class_name}"] = probs[:, i]
39
-
40
- # Full probability matrix in obsm
41
- adata.obsm[f"{full_prefix}_pred_prob_all"] = probs
@@ -1,14 +0,0 @@
1
- import torch.nn as nn
2
- from ..utils.device import detect_device
3
-
4
- class BaseTorchModel(nn.Module):
5
- """
6
- Minimal base class for torch models that:
7
- - Stores device
8
- - Moves model to detected device on init
9
- """
10
- def __init__(self, dropout_rate=0.2):
11
- super().__init__()
12
- self.device = detect_device() # detects available devices
13
- self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
14
- self.to(self.device) # move model to device
@@ -1,34 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .base import BaseTorchModel
4
-
5
- class CNNClassifier(BaseTorchModel):
6
- def __init__(self, input_size, num_classes, **kwargs):
7
- super().__init__(**kwargs)
8
- # Define convolutional layers
9
- self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
10
- self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
11
- # Define activation function
12
- self.relu = nn.ReLU()
13
-
14
- # Determine the flattened size dynamically
15
- dummy_input = torch.zeros(1, 1, input_size).to(self.device)
16
- with torch.no_grad():
17
- dummy_output = self._forward_conv(dummy_input)
18
- flattened_size = dummy_output.view(1, -1).shape[1]
19
-
20
- # Define fully connected layers
21
- self.fc1 = nn.Linear(flattened_size, 64)
22
- self.fc2 = nn.Linear(64, num_classes)
23
-
24
- def _forward_conv(self, x):
25
- x = self.relu(self.conv1(x))
26
- x = self.relu(self.conv2(x))
27
- return x
28
-
29
- def forward(self, x):
30
- x = x.unsqueeze(1) # [B, 1, L]
31
- x = self._forward_conv(x)
32
- x = x.view(x.size(0), -1) # flatten
33
- x = self.relu(self.fc1(x))
34
- return self.fc2(x)
@@ -1,41 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
-
4
- class TorchClassifierWrapper(pl.LightningModule):
5
- def __init__(
6
- self,
7
- model: torch.nn.Module,
8
- optimizer_cls=torch.optim.AdamW,
9
- optimizer_kwargs=None,
10
- criterion_cls=torch.nn.CrossEntropyLoss,
11
- criterion_kwargs=None,
12
- lr: float = 1e-3,
13
- ):
14
- super().__init__()
15
- self.model = model
16
- self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
17
- self.optimizer_cls = optimizer_cls
18
- self.optimizer_kwargs = optimizer_kwargs or {}
19
- self.criterion = criterion_cls(**(criterion_kwargs or {}))
20
- self.lr = lr
21
-
22
- def forward(self, x):
23
- return self.model(x)
24
-
25
- def training_step(self, batch, batch_idx):
26
- x, y = batch
27
- logits = self(x)
28
- loss = self.criterion(logits, y)
29
- self.log("train_loss", loss, prog_bar=True)
30
- return loss
31
-
32
- def validation_step(self, batch, batch_idx):
33
- x, y = batch
34
- logits = self(x)
35
- loss = self.criterion(logits, y)
36
- acc = (logits.argmax(dim=1) == y).float().mean()
37
- self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=True)
38
- return loss
39
-
40
- def configure_optimizers(self):
41
- return self.optimizer_cls(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
@@ -1,17 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .base import BaseTorchModel
4
-
5
- class MLPClassifier(BaseTorchModel):
6
- def __init__(self, input_dim, num_classes, hidden_sizes=(128, 64), **kwargs):
7
- super().__init__(**kwargs)
8
- layers = []
9
- prev = input_dim
10
- for h in hidden_sizes:
11
- layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(self.dropout_rate)])
12
- prev = h
13
- layers.append(nn.Linear(prev, num_classes))
14
- self.model = nn.Sequential(*layers)
15
-
16
- def forward(self, x):
17
- return self.model(x)