smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,160 @@
1
+ import numpy as np
2
+ import anndata as ad
3
+
4
+ import os
5
+ import concurrent.futures
6
+
7
+ def one_hot_encode(sequence, device='auto'):
8
+ """
9
+ One-hot encodes a DNA sequence.
10
+
11
+ Parameters:
12
+ sequence (str or list): DNA sequence (e.g., "ACGTN" or ['A', 'C', 'G', 'T', 'N']).
13
+
14
+ Returns:
15
+ ndarray: Flattened one-hot encoded representation of the input sequence.
16
+ """
17
+ mapping = np.array(['A', 'C', 'G', 'T', 'N'])
18
+
19
+ # Ensure input is a list of characters
20
+ if not isinstance(sequence, list):
21
+ sequence = list(sequence) # Convert string to list of characters
22
+
23
+ # Handle empty sequences
24
+ if len(sequence) == 0:
25
+ print("Warning: Empty sequence encountered in one_hot_encode()")
26
+ return np.zeros(len(mapping)) # Return empty encoding instead of failing
27
+
28
+ # Convert sequence to NumPy array
29
+ seq_array = np.array(sequence, dtype='<U1')
30
+
31
+ # Replace invalid bases with 'N'
32
+ seq_array = np.where(np.isin(seq_array, mapping), seq_array, 'N')
33
+
34
+ # Create one-hot encoding matrix
35
+ one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
36
+
37
+ # Flatten and return
38
+ return one_hot_matrix.flatten()
39
+
40
+ def one_hot_decode(ohe_array):
41
+ """
42
+ Takes a flattened one hot encoded array and returns the sequence string from that array.
43
+ Parameters:
44
+ ohe_array (np.array): A one hot encoded array
45
+
46
+ Returns:
47
+ sequence (str): Sequence string of the one hot encoded array
48
+ """
49
+ # Define the mapping of one-hot encoded indices to DNA bases
50
+ mapping = ['A', 'C', 'G', 'T', 'N']
51
+
52
+ # Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
53
+ one_hot_matrix = ohe_array.reshape(-1, 5)
54
+
55
+ # Get the index of the maximum value (which will be 1) in each row
56
+ decoded_indices = np.argmax(one_hot_matrix, axis=1)
57
+
58
+ # Map the indices back to the corresponding bases
59
+ sequence_list = [mapping[i] for i in decoded_indices]
60
+ sequence = ''.join(sequence_list)
61
+
62
+ return sequence
63
+
64
+ def ohe_layers_decode(adata, obs_names):
65
+ """
66
+ Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
67
+ Parameters:
68
+ adata (AnnData): An anndata object.
69
+ obs_names (list): A list of observation name strings to retrieve sequences for.
70
+
71
+ Returns:
72
+ sequences (list of str): List of strings of the one hot encoded array
73
+ """
74
+ # Define the mapping of one-hot encoded indices to DNA bases
75
+ mapping = ['A', 'C', 'G', 'T', 'N']
76
+
77
+ ohe_layers = [f"{base}_binary_encoding" for base in mapping]
78
+ sequences = []
79
+
80
+ for obs_name in obs_names:
81
+ obs_subset = adata[obs_name]
82
+ ohe_list = []
83
+ for layer in ohe_layers:
84
+ ohe_list += list(obs_subset.layers[layer])
85
+ ohe_array = np.array(ohe_list)
86
+ sequence = one_hot_decode(ohe_array)
87
+ sequences.append(sequence)
88
+
89
+ return sequences
90
+
91
+ def _encode_sequence(args):
92
+ """Parallel helper function for one-hot encoding."""
93
+ read_name, seq, device = args
94
+ try:
95
+ one_hot_matrix = one_hot_encode(seq, device)
96
+ return read_name, one_hot_matrix
97
+ except Exception:
98
+ return None # Skip invalid sequences
99
+
100
+ def _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
101
+ """Encodes a batch and writes to disk immediately."""
102
+ batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
103
+
104
+ if batch:
105
+ save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad')
106
+ tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
107
+ tmp_ad.write_h5ad(save_name)
108
+ return save_name
109
+ return None
110
+
111
+ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000, progress_bar=None, device='auto', threads=None):
112
+ """
113
+ Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
114
+
115
+ Parameters:
116
+ base_identities (dict): Dictionary mapping read names to sequences.
117
+ tmp_dir (str): Directory for storing temporary files.
118
+ record (str): Record name.
119
+ prefix (str): Prefix for file naming.
120
+ batch_size (int): Number of reads per batch.
121
+ progress_bar (tqdm instance, optional): Shared progress bar.
122
+ device (str): Device for encoding.
123
+ threads (int, optional): Number of parallel workers.
124
+
125
+ Returns:
126
+ list: List of valid H5AD file paths.
127
+ """
128
+ threads = threads or os.cpu_count() # Default to max available CPU cores
129
+ batch_data = []
130
+ batch_number = 0
131
+ file_names = []
132
+
133
+ # Step 1: Prepare Data for Parallel Encoding
134
+ encoding_args = [(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None]
135
+
136
+ # Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
137
+ with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
138
+ for result in executor.map(_encode_sequence, encoding_args):
139
+ if result:
140
+ batch_data.append(result)
141
+
142
+ if len(batch_data) >= batch_size:
143
+ # Step 3: Process and Write Batch Immediately
144
+ file_name = _encode_and_save_batch(batch_data.copy(), tmp_dir, prefix, record, batch_number)
145
+ if file_name:
146
+ file_names.append(file_name)
147
+
148
+ batch_data.clear()
149
+ batch_number += 1
150
+
151
+ if progress_bar:
152
+ progress_bar.update(1)
153
+
154
+ # Step 4: Process Remaining Batch
155
+ if batch_data:
156
+ file_name = _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number)
157
+ if file_name:
158
+ file_names.append(file_name)
159
+
160
+ return file_names
@@ -0,0 +1,224 @@
1
+ from ..config import LoadExperimentConfig
2
+ from ..readwrite import make_dirs
3
+
4
+ import os
5
+ import subprocess
6
+ from pathlib import Path
7
+
8
+ import pod5 as p5
9
+
10
+ from typing import Union, List
11
+
12
+ def basecall_pod5s(config_path):
13
+ """
14
+ Basecall from pod5s given a config file.
15
+
16
+ Parameters:
17
+ config_path (str): File path to the basecall configuration file
18
+
19
+ Returns:
20
+ None
21
+ """
22
+ # Default params
23
+ bam_suffix = '.bam' # If different, change from here.
24
+
25
+ # Load experiment config parameters into global variables
26
+ experiment_config = LoadExperimentConfig(config_path)
27
+ var_dict = experiment_config.var_dict
28
+
29
+ # These below variables will point to default_value if they are empty in the experiment_config.csv or if the variable is fully omitted from the csv.
30
+ default_value = None
31
+
32
+ # General config variable init
33
+ input_data_path = Path(var_dict.get('input_data_path', default_value)) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
34
+ output_directory = Path(var_dict.get('output_directory', default_value)) # Path to the output directory to make for the analysis. Necessary.
35
+ model = var_dict.get('model', default_value) # needed for dorado basecaller
36
+ model_dir = Path(var_dict.get('model_dir', default_value)) # model directory
37
+ barcode_kit = var_dict.get('barcode_kit', default_value) # needed for dorado basecaller
38
+ barcode_both_ends = var_dict.get('barcode_both_ends', default_value) # dorado demultiplexing
39
+ trim = var_dict.get('trim', default_value) # dorado adapter and barcode removal
40
+ device = var_dict.get('device', 'auto')
41
+
42
+ # Modified basecalling specific variable init
43
+ filter_threshold = var_dict.get('filter_threshold', default_value)
44
+ m6A_threshold = var_dict.get('m6A_threshold', default_value)
45
+ m5C_threshold = var_dict.get('m5C_threshold', default_value)
46
+ hm5C_threshold = var_dict.get('hm5C_threshold', default_value)
47
+ thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
48
+ mod_list = var_dict.get('mod_list', default_value)
49
+
50
+ # Make initial output directory
51
+ make_dirs([output_directory])
52
+
53
+ # Get the input filetype
54
+ if input_data_path.is_file():
55
+ input_data_filetype = input_data_path.suffixes[0]
56
+ input_is_pod5 = input_data_filetype in ['.pod5','.p5']
57
+ input_is_fast5 = input_data_filetype in ['.fast5','.f5']
58
+
59
+ elif input_data_path.is_dir():
60
+ # Get the file names in the input data dir
61
+ input_files = input_data_path.iterdir()
62
+ input_is_pod5 = sum([True for file in input_files if '.pod5' in file or '.p5' in file])
63
+ input_is_fast5 = sum([True for file in input_files if '.fast5' in file or '.f5' in file])
64
+
65
+ # If the input files are not pod5 files, and they are fast5 files, convert the files to a pod5 file before proceeding.
66
+ if input_is_fast5 and not input_is_pod5:
67
+ # take the input directory of fast5 files and write out a single pod5 file into the output directory.
68
+ output_pod5 = output_directory / 'FAST5s_to_POD5.pod5'
69
+ print(f'Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}')
70
+ fast5_to_pod5(input_data_path, output_pod5)
71
+ # Reassign the pod5_dir variable to point to the new pod5 file.
72
+ input_data_path = output_pod5
73
+
74
+ model_basename = model.name
75
+ model_basename = model_basename.replace('.', '_')
76
+
77
+ if mod_list:
78
+ mod_string = "_".join(mod_list)
79
+ bam = output_directory / f"{model_basename}_{mod_string}_calls"
80
+ modcall(model, input_data_path, barcode_kit, mod_list, bam, bam_suffix, barcode_both_ends, trim, device)
81
+ else:
82
+ bam = output_directory / f"{model_basename}_canonical_basecalls"
83
+ canoncall(model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device)
84
+
85
+
86
+ def fast5_to_pod5(
87
+ fast5_dir: Union[str, Path, List[Union[str, Path]]],
88
+ output_pod5: Union[str, Path] = "FAST5s_to_POD5.pod5"
89
+ ) -> None:
90
+ """
91
+ Convert Nanopore FAST5 files (single file, list of files, or directory)
92
+ into a single .pod5 output using the 'pod5 convert fast5' CLI tool.
93
+ """
94
+
95
+ output_pod5 = str(output_pod5) # ensure string
96
+
97
+ # 1) If user gives a list of FAST5 files
98
+ if isinstance(fast5_dir, (list, tuple)):
99
+ fast5_paths = [str(Path(f)) for f in fast5_dir]
100
+ cmd = ["pod5", "convert", "fast5", *fast5_paths, "--output", output_pod5]
101
+ subprocess.run(cmd, check=True)
102
+ return
103
+
104
+ # Ensure Path object
105
+ p = Path(fast5_dir)
106
+
107
+ # 2) If user gives a single file
108
+ if p.is_file():
109
+ cmd = ["pod5", "convert", "fast5", str(p), "--output", output_pod5]
110
+ subprocess.run(cmd, check=True)
111
+ return
112
+
113
+ # 3) If user gives a directory → collect FAST5s
114
+ if p.is_dir():
115
+ fast5_paths = sorted(str(f) for f in p.glob("*.fast5"))
116
+ if not fast5_paths:
117
+ raise FileNotFoundError(f"No FAST5 files found in {p}")
118
+
119
+ cmd = ["pod5", "convert", "fast5", *fast5_paths, "--output", output_pod5]
120
+ subprocess.run(cmd, check=True)
121
+ return
122
+
123
+ raise FileNotFoundError(f"Input path invalid: {fast5_dir}")
124
+
125
+ def subsample_pod5(pod5_path, read_name_path, output_directory):
126
+ """
127
+ Takes a POD5 file and a text file containing read names of interest and writes out a subsampled POD5 for just those reads.
128
+ This is a useful function when you have a list of read names that mapped to a region of interest that you want to reanalyze from the pod5 level.
129
+
130
+ Parameters:
131
+ pod5_path (str): File path to the POD5 file (or directory of multiple pod5 files) to subsample.
132
+ read_name_path (str | int): File path to a text file of read names. One read name per line. If an int value is passed, a random subset of that many reads will occur
133
+ output_directory (str): A file path to the directory to output the file.
134
+
135
+ Returns:
136
+ None
137
+ """
138
+
139
+ if os.path.isdir(pod5_path):
140
+ pod5_path_is_dir = True
141
+ input_pod5_base = 'input_pod5s.pod5'
142
+ files = os.listdir(pod5_path)
143
+ pod5_files = [os.path.join(pod5_path, file) for file in files if '.pod5' in file]
144
+ pod5_files.sort()
145
+ print(f'Found input pod5s: {pod5_files}')
146
+
147
+ elif os.path.exists(pod5_path):
148
+ pod5_path_is_dir = False
149
+ input_pod5_base = os.path.basename(pod5_path)
150
+
151
+ else:
152
+ print('Error: pod5_path passed does not exist')
153
+ return None
154
+
155
+ if type(read_name_path) == str:
156
+ input_read_name_base = os.path.basename(read_name_path)
157
+ output_base = input_pod5_base.split('.pod5')[0] + '_' + input_read_name_base.split('.txt')[0] + '_subsampled.pod5'
158
+
159
+ # extract read names into a list of strings
160
+ with open(read_name_path, 'r') as file:
161
+ read_names = [line.strip() for line in file]
162
+
163
+ print(f'Looking for read_ids: {read_names}')
164
+ read_records = []
165
+
166
+ if pod5_path_is_dir:
167
+ for input_pod5 in pod5_files:
168
+ with p5.Reader(input_pod5) as reader:
169
+ try:
170
+ for read_record in reader.reads(selection=read_names, missing_ok=True):
171
+ read_records.append(read_record.to_read())
172
+ print(f'Found read in {input_pod5}: {read_record.read_id}')
173
+ except:
174
+ print('Skipping pod5, could not find reads')
175
+ else:
176
+ with p5.Reader(pod5_path) as reader:
177
+ try:
178
+ for read_record in reader.reads(selection=read_names):
179
+ read_records.append(read_record.to_read())
180
+ print(f'Found read in {input_pod5}: {read_record}')
181
+ except:
182
+ print('Could not find reads')
183
+
184
+ elif type(read_name_path) == int:
185
+ import random
186
+ output_base = input_pod5_base.split('.pod5')[0] + f'_{read_name_path}_randomly_subsampled.pod5'
187
+ all_read_records = []
188
+
189
+ if pod5_path_is_dir:
190
+ # Shuffle the list of input pod5 paths
191
+ random.shuffle(pod5_files)
192
+ for input_pod5 in pod5_files:
193
+ # iterate over the input pod5s
194
+ print(f'Opening pod5 file {input_pod5}')
195
+ with p5.Reader(pod5_path) as reader:
196
+ for read_record in reader.reads():
197
+ all_read_records.append(read_record.to_read())
198
+ # When enough reads are in all_read_records, stop accumulating reads.
199
+ if len(all_read_records) >= read_name_path:
200
+ break
201
+
202
+ if read_name_path <= len(all_read_records):
203
+ read_records = random.sample(all_read_records, read_name_path)
204
+ else:
205
+ print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
206
+ read_records = all_read_records
207
+
208
+ else:
209
+ with p5.Reader(pod5_path) as reader:
210
+ for read_record in reader.reads():
211
+ # get all read records from the input pod5
212
+ all_read_records.append(read_record.to_read())
213
+ if read_name_path <= len(all_read_records):
214
+ # if the subsampling amount is less than the record amount in the file, randomly subsample the reads
215
+ read_records = random.sample(all_read_records, read_name_path)
216
+ else:
217
+ print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
218
+ read_records = all_read_records
219
+
220
+ output_pod5 = os.path.join(output_directory, output_base)
221
+
222
+ # Write the subsampled POD5
223
+ with p5.Writer(output_pod5) as writer:
224
+ writer.add_reads(read_records)
@@ -9,10 +9,13 @@ def run_multiqc(input_dir, output_dir):
9
9
  Returns:
10
10
  - None: The function executes MultiQC and prints the status.
11
11
  """
12
- import os
12
+ from ..readwrite import make_dirs
13
13
  import subprocess
14
14
  # Ensure the output directory exists
15
- os.makedirs(output_dir, exist_ok=True)
15
+ make_dirs(output_dir)
16
+
17
+ input_dir = str(input_dir)
18
+ output_dir = str(output_dir)
16
19
 
17
20
  # Construct MultiQC command
18
21
  command = ["multiqc", input_dir, "-o", output_dir]
@@ -0,0 +1,12 @@
1
+ from . import models
2
+ from . import data
3
+ from . import utils
4
+ from . import evaluation
5
+ from . import inference
6
+ from . import training
7
+
8
+ __all__ = [
9
+ "calculate_relative_risk_on_activity",
10
+ "evaluate_models_by_subgroup",
11
+ "prepare_melted_model_data",
12
+ ]
@@ -0,0 +1,2 @@
1
+ from .anndata_data_module import AnnDataModule, build_anndata_loader
2
+ from .preprocessing import random_fill_nans
@@ -0,0 +1,234 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
3
+ import pytorch_lightning as pl
4
+ import numpy as np
5
+ import pandas as pd
6
+ from .preprocessing import random_fill_nans
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+
9
+
10
+ class AnnDataDataset(Dataset):
11
+ """
12
+ Generic PyTorch Dataset from AnnData.
13
+ """
14
+ def __init__(self, adata, tensor_source="X", tensor_key=None, label_col=None, window_start=None, window_size=None):
15
+ self.adata = adata
16
+ self.tensor_source = tensor_source
17
+ self.tensor_key = tensor_key
18
+ self.label_col = label_col
19
+ self.window_start = window_start
20
+ self.window_size = window_size
21
+
22
+ if tensor_source == "X":
23
+ X = adata.X
24
+ elif tensor_source == "layers":
25
+ assert tensor_key in adata.layers
26
+ X = adata.layers[tensor_key]
27
+ elif tensor_source == "obsm":
28
+ assert tensor_key in adata.obsm
29
+ X = adata.obsm[tensor_key]
30
+ else:
31
+ raise ValueError(f"Invalid tensor_source: {tensor_source}")
32
+
33
+ if self.window_start is not None and self.window_size is not None:
34
+ X = X[:, self.window_start : self.window_start + self.window_size]
35
+
36
+ X = random_fill_nans(X)
37
+
38
+ self.X_tensor = torch.tensor(X, dtype=torch.float32)
39
+
40
+ if label_col is not None:
41
+ y = adata.obs[label_col]
42
+ if y.dtype.name == 'category':
43
+ y = y.cat.codes
44
+ self.y_tensor = torch.tensor(y.values, dtype=torch.long)
45
+ else:
46
+ self.y_tensor = None
47
+
48
+ def numpy(self, indices):
49
+ return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
50
+
51
+ def __len__(self):
52
+ return len(self.X_tensor)
53
+
54
+ def __getitem__(self, idx):
55
+ x = self.X_tensor[idx]
56
+ if self.y_tensor is not None:
57
+ y = self.y_tensor[idx]
58
+ return x, y
59
+ else:
60
+ return (x,)
61
+
62
+
63
+ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
64
+ random_seed=42, split_col="train_val_test_split",
65
+ load_existing_split=False, split_save_path=None):
66
+ """
67
+ Perform split and record assignment into adata.obs[split_col].
68
+ """
69
+ total_len = len(dataset)
70
+
71
+ if load_existing_split:
72
+ if split_col in adata.obs:
73
+ pass # use existing
74
+ elif split_save_path:
75
+ split_df = pd.read_csv(split_save_path, index_col=0)
76
+ adata.obs[split_col] = split_df.loc[adata.obs_names][split_col].values
77
+ else:
78
+ raise ValueError("No existing split column found and no file provided.")
79
+ else:
80
+ indices = np.arange(total_len)
81
+ np.random.seed(random_seed)
82
+ np.random.shuffle(indices)
83
+
84
+ n_train = int(train_frac * total_len)
85
+ n_val = int(val_frac * total_len)
86
+ n_test = total_len - n_train - n_val
87
+
88
+ split_array = np.full(total_len, "test", dtype=object)
89
+ split_array[indices[:n_train]] = "train"
90
+ split_array[indices[n_train:n_train + n_val]] = "val"
91
+ adata.obs[split_col] = split_array
92
+
93
+ if split_save_path:
94
+ adata.obs[[split_col]].to_csv(split_save_path)
95
+
96
+ split_labels = adata.obs[split_col].values
97
+ train_indices = np.where(split_labels == "train")[0]
98
+ val_indices = np.where(split_labels == "val")[0]
99
+ test_indices = np.where(split_labels == "test")[0]
100
+
101
+ train_set = Subset(dataset, train_indices)
102
+ val_set = Subset(dataset, val_indices)
103
+ test_set = Subset(dataset, test_indices)
104
+
105
+ return train_set, val_set, test_set
106
+
107
+ class AnnDataModule(pl.LightningDataModule):
108
+ """
109
+ Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
110
+ """
111
+ def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
112
+ batch_size=64, train_frac=0.6, val_frac=0.1, test_frac=0.3, random_seed=42,
113
+ inference_mode=False, split_col="train_val_test_split", split_save_path=None,
114
+ load_existing_split=False, window_start=None, window_size=None, num_workers=None, persistent_workers=False):
115
+ super().__init__()
116
+ self.adata = adata
117
+ self.tensor_source = tensor_source
118
+ self.tensor_key = tensor_key
119
+ self.label_col = label_col
120
+ self.batch_size = batch_size
121
+ self.train_frac = train_frac
122
+ self.val_frac = val_frac
123
+ self.test_frac = test_frac
124
+ self.random_seed = random_seed
125
+ self.inference_mode = inference_mode
126
+ self.split_col = split_col
127
+ self.split_save_path = split_save_path
128
+ self.load_existing_split = load_existing_split
129
+ self.var_names = adata.var_names.copy()
130
+ self.window_start = window_start
131
+ self.window_size = window_size
132
+ self.num_workers = num_workers
133
+ self.persistent_workers = persistent_workers
134
+
135
+ def setup(self, stage=None):
136
+ dataset = AnnDataDataset(self.adata, self.tensor_source, self.tensor_key,
137
+ None if self.inference_mode else self.label_col,
138
+ window_start=self.window_start, window_size=self.window_size)
139
+
140
+ if self.inference_mode:
141
+ self.infer_dataset = dataset
142
+ return
143
+
144
+ self.train_set, self.val_set, self.test_set = split_dataset(
145
+ self.adata, dataset, train_frac=self.train_frac, val_frac=self.val_frac,
146
+ test_frac=self.test_frac, random_seed=self.random_seed,
147
+ split_col=self.split_col, split_save_path=self.split_save_path,
148
+ load_existing_split=self.load_existing_split
149
+ )
150
+
151
+ def train_dataloader(self):
152
+ if self.num_workers:
153
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
154
+ else:
155
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
156
+
157
+ def val_dataloader(self):
158
+ if self.num_workers:
159
+ return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
160
+ else:
161
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
162
+
163
+ def test_dataloader(self):
164
+ if self.num_workers:
165
+ return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
166
+ else:
167
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
168
+
169
+ def predict_dataloader(self):
170
+ if not self.inference_mode:
171
+ raise RuntimeError("Only valid in inference mode")
172
+ return DataLoader(self.infer_dataset, batch_size=self.batch_size)
173
+
174
+ def compute_class_weights(self):
175
+ train_indices = self.train_set.indices # get the indices of the training set
176
+ y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
177
+ y_train = y_all[train_indices].cpu().numpy() # get the labels for the training set and move to a numpy array
178
+
179
+ class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
180
+ return torch.tensor(class_weights, dtype=torch.float32)
181
+
182
+ def inference_numpy(self):
183
+ """
184
+ Return inference data as numpy for use in sklearn inference.
185
+ """
186
+ if not self.inference_mode:
187
+ raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
188
+ X_np = self.infer_dataset.X_tensor.numpy()
189
+ return X_np
190
+
191
+ def to_numpy(self):
192
+ """
193
+ Move the AnnDataModule tensors into numpy arrays
194
+ """
195
+ if not self.inference_mode:
196
+ train_X, train_y = self.train_set.dataset.numpy(self.train_set.indices)
197
+ val_X, val_y = self.val_set.dataset.numpy(self.val_set.indices)
198
+ test_X, test_Y = self.test_set.dataset.numpy(self.test_set.indices)
199
+ return train_X, train_y, val_X, val_y, test_X, test_Y
200
+ else:
201
+ return self.inference_numpy()
202
+
203
+
204
+ def build_anndata_loader(
205
+ adata, tensor_source="X", tensor_key=None, label_col=None, train_frac=0.6, val_frac=0.1,
206
+ test_frac=0.3, random_seed=42, batch_size=64, lightning=True, inference_mode=False,
207
+ split_col="train_val_test_split", split_save_path=None, load_existing_split=False
208
+ ):
209
+ """
210
+ Unified pipeline for both Lightning and raw PyTorch.
211
+ The lightning loader works for both Lightning and the Sklearn wrapper.
212
+ Set lightning to False if you want to make data loaders for base PyTorch or base sklearn models
213
+ """
214
+ if lightning:
215
+ return AnnDataModule(
216
+ adata, tensor_source=tensor_source, tensor_key=tensor_key, label_col=label_col,
217
+ batch_size=batch_size, train_frac=train_frac, val_frac=val_frac, test_frac=test_frac,
218
+ random_seed=random_seed, inference_mode=inference_mode,
219
+ split_col=split_col, split_save_path=split_save_path, load_existing_split=load_existing_split
220
+ )
221
+ else:
222
+ var_names = adata.var_names.copy()
223
+ dataset = AnnDataDataset(adata, tensor_source, tensor_key, None if inference_mode else label_col)
224
+ if inference_mode:
225
+ return DataLoader(dataset, batch_size=batch_size)
226
+ else:
227
+ train_set, val_set, test_set = split_dataset(
228
+ adata, dataset, train_frac, val_frac, test_frac, random_seed,
229
+ split_col, split_save_path, load_existing_split
230
+ )
231
+ train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
232
+ val_loader = DataLoader(val_set, batch_size=batch_size)
233
+ test_loader = DataLoader(test_set, batch_size=batch_size)
234
+ return train_loader, val_loader, test_loader
@@ -0,0 +1,2 @@
1
+ from .evaluators import ModelEvaluator, PostInferenceModelEvaluator
2
+ from .eval_utils import flatten_sliding_window_results
@@ -0,0 +1,31 @@
1
+ import pandas as pd
2
+
3
+ def flatten_sliding_window_results(results_dict):
4
+ """
5
+ Flatten nested sliding window results into pandas DataFrame.
6
+
7
+ Expects structure:
8
+ results[model_name][window_size][window_center]['metrics'][metric_name]
9
+ """
10
+ records = []
11
+
12
+ for model_name, model_results in results_dict.items():
13
+ for window_size, window_results in model_results.items():
14
+ for center_var, result in window_results.items():
15
+ metrics = result['metrics']
16
+ record = {
17
+ 'model': model_name,
18
+ 'window_size': window_size,
19
+ 'center_var': center_var
20
+ }
21
+ # Add all metrics
22
+ record.update(metrics)
23
+ records.append(record)
24
+
25
+ df = pd.DataFrame.from_records(records)
26
+
27
+ # Convert center_var to numeric if possible (optional but helpful for plotting)
28
+ df['center_var'] = pd.to_numeric(df['center_var'], errors='coerce')
29
+ df = df.sort_values(['model', 'window_size', 'center_var'])
30
+
31
+ return df