smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,21 @@
1
- import os
1
+ from __future__ import annotations
2
+
2
3
  import subprocess
3
- import glob
4
- import zipfile
5
- from pathlib import Path
6
4
 
7
- def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassified=True, modkit_summary=False, threads=None):
5
+ from smftools.logging_utils import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ def extract_mods(
11
+ thresholds,
12
+ mod_tsv_dir,
13
+ split_dir,
14
+ bam_suffix,
15
+ skip_unclassified=True,
16
+ modkit_summary=False,
17
+ threads=None,
18
+ ):
8
19
  """
9
20
  Takes all of the aligned, sorted, split modified BAM files and runs Nanopore Modkit Extract to load the modification data into zipped TSV files
10
21
 
@@ -23,10 +34,12 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
23
34
 
24
35
  """
25
36
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
26
- bam_files = sorted(p for p in split_dir.iterdir() if bam_suffix in p.name and '.bai' not in p.name)
37
+ bam_files = sorted(
38
+ p for p in split_dir.iterdir() if bam_suffix in p.name and ".bai" not in p.name
39
+ )
27
40
  if skip_unclassified:
28
41
  bam_files = [p for p in bam_files if "unclassified" not in p.name]
29
- print(f"Running modkit extract for the following bam files: {bam_files}")
42
+ logger.info(f"Running modkit extract for the following bam files: {bam_files}")
30
43
 
31
44
  if threads:
32
45
  threads = str(threads)
@@ -34,14 +47,14 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
34
47
  pass
35
48
 
36
49
  for input_file in bam_files:
37
- print(input_file)
50
+ logger.debug(input_file)
38
51
  # Construct the output TSV file path
39
52
  output_tsv = mod_tsv_dir / (input_file.stem + "_extract.tsv")
40
- output_tsv_gz = output_tsv.parent / (output_tsv.name + '.gz')
53
+ output_tsv_gz = output_tsv.parent / (output_tsv.name + ".gz")
41
54
  if output_tsv_gz.exists():
42
- print(f"{output_tsv_gz} already exists, skipping modkit extract")
55
+ logger.debug(f"{output_tsv_gz} already exists, skipping modkit extract")
43
56
  else:
44
- print(f"Extracting modification data from {input_file}")
57
+ logger.info(f"Extracting modification data from {input_file}")
45
58
  if modkit_summary:
46
59
  # Run modkit summary
47
60
  subprocess.run(["modkit", "summary", str(input_file)])
@@ -50,28 +63,43 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
50
63
  # Run modkit extract
51
64
  if threads:
52
65
  extract_command = [
53
- "modkit", "extract",
54
- "calls", "--mapped-only",
55
- "--filter-threshold", f'{filter_threshold}',
56
- "--mod-thresholds", f"m:{m5C_threshold}",
57
- "--mod-thresholds", f"a:{m6A_threshold}",
58
- "--mod-thresholds", f"h:{hm5C_threshold}",
59
- "-t", threads,
60
- str(input_file), str(output_tsv)
61
- ]
66
+ "modkit",
67
+ "extract",
68
+ "calls",
69
+ "--mapped-only",
70
+ "--filter-threshold",
71
+ f"{filter_threshold}",
72
+ "--mod-thresholds",
73
+ f"m:{m5C_threshold}",
74
+ "--mod-thresholds",
75
+ f"a:{m6A_threshold}",
76
+ "--mod-thresholds",
77
+ f"h:{hm5C_threshold}",
78
+ "-t",
79
+ threads,
80
+ str(input_file),
81
+ str(output_tsv),
82
+ ]
62
83
  else:
63
84
  extract_command = [
64
- "modkit", "extract",
65
- "calls", "--mapped-only",
66
- "--filter-threshold", f'{filter_threshold}',
67
- "--mod-thresholds", f"m:{m5C_threshold}",
68
- "--mod-thresholds", f"a:{m6A_threshold}",
69
- "--mod-thresholds", f"h:{hm5C_threshold}",
70
- str(input_file), str(output_tsv)
71
- ]
85
+ "modkit",
86
+ "extract",
87
+ "calls",
88
+ "--mapped-only",
89
+ "--filter-threshold",
90
+ f"{filter_threshold}",
91
+ "--mod-thresholds",
92
+ f"m:{m5C_threshold}",
93
+ "--mod-thresholds",
94
+ f"a:{m6A_threshold}",
95
+ "--mod-thresholds",
96
+ f"h:{hm5C_threshold}",
97
+ str(input_file),
98
+ str(output_tsv),
99
+ ]
72
100
  subprocess.run(extract_command)
73
101
  # Zip the output TSV
74
- print(f'zipping {output_tsv}')
102
+ logger.info(f"zipping {output_tsv}")
75
103
  if threads:
76
104
  zip_command = ["pigz", "-f", "-p", threads, str(output_tsv)]
77
105
  else:
@@ -79,30 +107,39 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
79
107
  subprocess.run(zip_command, check=True)
80
108
  return
81
109
 
110
+
82
111
  def make_modbed(aligned_sorted_output, thresholds, mod_bed_dir):
83
112
  """
84
113
  Generating position methylation summaries for each barcoded sample starting from the overall BAM file that was direct output of dorado aligner.
85
114
  Parameters:
86
115
  aligned_sorted_output (str): A string representing the file path to the aligned_sorted non-split BAM file.
87
-
116
+
88
117
  Returns:
89
118
  None
90
119
  """
91
- import os
92
120
  import subprocess
93
-
121
+
94
122
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
95
123
  command = [
96
- "modkit", "pileup", str(aligned_sorted_output), str(mod_bed_dir),
97
- "--partition-tag", "BC",
124
+ "modkit",
125
+ "pileup",
126
+ str(aligned_sorted_output),
127
+ str(mod_bed_dir),
128
+ "--partition-tag",
129
+ "BC",
98
130
  "--only-tabs",
99
- "--filter-threshold", f'{filter_threshold}',
100
- "--mod-thresholds", f"m:{m5C_threshold}",
101
- "--mod-thresholds", f"a:{m6A_threshold}",
102
- "--mod-thresholds", f"h:{hm5C_threshold}"
131
+ "--filter-threshold",
132
+ f"{filter_threshold}",
133
+ "--mod-thresholds",
134
+ f"m:{m5C_threshold}",
135
+ "--mod-thresholds",
136
+ f"a:{m6A_threshold}",
137
+ "--mod-thresholds",
138
+ f"h:{hm5C_threshold}",
103
139
  ]
104
140
  subprocess.run(command)
105
141
 
142
+
106
143
  def modQC(aligned_sorted_output, thresholds):
107
144
  """
108
145
  Output the percentile of bases falling at a call threshold (threshold is a probability between 0-1) for the overall BAM file.
@@ -120,10 +157,16 @@ def modQC(aligned_sorted_output, thresholds):
120
157
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
121
158
  subprocess.run(["modkit", "sample-probs", str(aligned_sorted_output)])
122
159
  command = [
123
- "modkit", "summary", str(aligned_sorted_output),
124
- "--filter-threshold", f"{filter_threshold}",
125
- "--mod-thresholds", f"m:{m5C_threshold}",
126
- "--mod-thresholds", f"a:{m6A_threshold}",
127
- "--mod-thresholds", f"h:{hm5C_threshold}"
160
+ "modkit",
161
+ "summary",
162
+ str(aligned_sorted_output),
163
+ "--filter-threshold",
164
+ f"{filter_threshold}",
165
+ "--mod-thresholds",
166
+ f"m:{m5C_threshold}",
167
+ "--mod-thresholds",
168
+ f"a:{m6A_threshold}",
169
+ "--mod-thresholds",
170
+ f"h:{hm5C_threshold}",
128
171
  ]
129
- subprocess.run(command)
172
+ subprocess.run(command)
@@ -1,10 +1,17 @@
1
- import numpy as np
2
- import anndata as ad
1
+ from __future__ import annotations
3
2
 
4
- import os
5
3
  import concurrent.futures
4
+ import os
5
+
6
+ import anndata as ad
7
+ import numpy as np
8
+
9
+ from smftools.logging_utils import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
6
13
 
7
- def one_hot_encode(sequence, device='auto'):
14
+ def one_hot_encode(sequence, device="auto"):
8
15
  """
9
16
  One-hot encodes a DNA sequence.
10
17
 
@@ -14,7 +21,7 @@ def one_hot_encode(sequence, device='auto'):
14
21
  Returns:
15
22
  ndarray: Flattened one-hot encoded representation of the input sequence.
16
23
  """
17
- mapping = np.array(['A', 'C', 'G', 'T', 'N'])
24
+ mapping = np.array(["A", "C", "G", "T", "N"])
18
25
 
19
26
  # Ensure input is a list of characters
20
27
  if not isinstance(sequence, list):
@@ -22,14 +29,14 @@ def one_hot_encode(sequence, device='auto'):
22
29
 
23
30
  # Handle empty sequences
24
31
  if len(sequence) == 0:
25
- print("Warning: Empty sequence encountered in one_hot_encode()")
32
+ logger.warning("Empty sequence encountered in one_hot_encode()")
26
33
  return np.zeros(len(mapping)) # Return empty encoding instead of failing
27
34
 
28
35
  # Convert sequence to NumPy array
29
- seq_array = np.array(sequence, dtype='<U1')
36
+ seq_array = np.array(sequence, dtype="<U1")
30
37
 
31
38
  # Replace invalid bases with 'N'
32
- seq_array = np.where(np.isin(seq_array, mapping), seq_array, 'N')
39
+ seq_array = np.where(np.isin(seq_array, mapping), seq_array, "N")
33
40
 
34
41
  # Create one-hot encoding matrix
35
42
  one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
@@ -37,6 +44,7 @@ def one_hot_encode(sequence, device='auto'):
37
44
  # Flatten and return
38
45
  return one_hot_matrix.flatten()
39
46
 
47
+
40
48
  def one_hot_decode(ohe_array):
41
49
  """
42
50
  Takes a flattened one hot encoded array and returns the sequence string from that array.
@@ -47,20 +55,21 @@ def one_hot_decode(ohe_array):
47
55
  sequence (str): Sequence string of the one hot encoded array
48
56
  """
49
57
  # Define the mapping of one-hot encoded indices to DNA bases
50
- mapping = ['A', 'C', 'G', 'T', 'N']
51
-
58
+ mapping = ["A", "C", "G", "T", "N"]
59
+
52
60
  # Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
53
61
  one_hot_matrix = ohe_array.reshape(-1, 5)
54
-
62
+
55
63
  # Get the index of the maximum value (which will be 1) in each row
56
64
  decoded_indices = np.argmax(one_hot_matrix, axis=1)
57
-
65
+
58
66
  # Map the indices back to the corresponding bases
59
67
  sequence_list = [mapping[i] for i in decoded_indices]
60
- sequence = ''.join(sequence_list)
61
-
68
+ sequence = "".join(sequence_list)
69
+
62
70
  return sequence
63
71
 
72
+
64
73
  def ohe_layers_decode(adata, obs_names):
65
74
  """
66
75
  Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
@@ -72,7 +81,7 @@ def ohe_layers_decode(adata, obs_names):
72
81
  sequences (list of str): List of strings of the one hot encoded array
73
82
  """
74
83
  # Define the mapping of one-hot encoded indices to DNA bases
75
- mapping = ['A', 'C', 'G', 'T', 'N']
84
+ mapping = ["A", "C", "G", "T", "N"]
76
85
 
77
86
  ohe_layers = [f"{base}_binary_encoding" for base in mapping]
78
87
  sequences = []
@@ -85,9 +94,10 @@ def ohe_layers_decode(adata, obs_names):
85
94
  ohe_array = np.array(ohe_list)
86
95
  sequence = one_hot_decode(ohe_array)
87
96
  sequences.append(sequence)
88
-
97
+
89
98
  return sequences
90
99
 
100
+
91
101
  def _encode_sequence(args):
92
102
  """Parallel helper function for one-hot encoding."""
93
103
  read_name, seq, device = args
@@ -97,18 +107,29 @@ def _encode_sequence(args):
97
107
  except Exception:
98
108
  return None # Skip invalid sequences
99
109
 
110
+
100
111
  def _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
101
112
  """Encodes a batch and writes to disk immediately."""
102
113
  batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
103
114
 
104
115
  if batch:
105
- save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad')
116
+ save_name = os.path.join(tmp_dir, f"tmp_{prefix}_{record}_{batch_number}.h5ad")
106
117
  tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
107
118
  tmp_ad.write_h5ad(save_name)
108
119
  return save_name
109
120
  return None
110
121
 
111
- def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000, progress_bar=None, device='auto', threads=None):
122
+
123
+ def ohe_batching(
124
+ base_identities,
125
+ tmp_dir,
126
+ record,
127
+ prefix="",
128
+ batch_size=100000,
129
+ progress_bar=None,
130
+ device="auto",
131
+ threads=None,
132
+ ):
112
133
  """
113
134
  Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
114
135
 
@@ -131,7 +152,9 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
131
152
  file_names = []
132
153
 
133
154
  # Step 1: Prepare Data for Parallel Encoding
134
- encoding_args = [(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None]
155
+ encoding_args = [
156
+ (read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None
157
+ ]
135
158
 
136
159
  # Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
137
160
  with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
@@ -141,7 +164,9 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
141
164
 
142
165
  if len(batch_data) >= batch_size:
143
166
  # Step 3: Process and Write Batch Immediately
144
- file_name = _encode_and_save_batch(batch_data.copy(), tmp_dir, prefix, record, batch_number)
167
+ file_name = _encode_and_save_batch(
168
+ batch_data.copy(), tmp_dir, prefix, record, batch_number
169
+ )
145
170
  if file_name:
146
171
  file_names.append(file_name)
147
172
 
@@ -157,4 +182,4 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
157
182
  if file_name:
158
183
  file_names.append(file_name)
159
184
 
160
- return file_names
185
+ return file_names
@@ -1,26 +1,30 @@
1
- from ..config import LoadExperimentConfig
2
- from ..readwrite import make_dirs
1
+ from __future__ import annotations
3
2
 
4
3
  import os
5
4
  import subprocess
6
5
  from pathlib import Path
6
+ from typing import Iterable
7
7
 
8
- import pod5 as p5
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
9
10
 
10
- from typing import Union, List
11
+ from ..config import LoadExperimentConfig
12
+ from ..informatics.basecalling import canoncall, modcall
13
+ from ..readwrite import make_dirs
11
14
 
12
- def basecall_pod5s(config_path):
13
- """
14
- Basecall from pod5s given a config file.
15
+ logger = get_logger(__name__)
15
16
 
16
- Parameters:
17
- config_path (str): File path to the basecall configuration file
17
+ p5 = require("pod5", extra="ont", purpose="POD5 IO")
18
18
 
19
- Returns:
20
- None
19
+
20
+ def basecall_pod5s(config_path: str | Path) -> None:
21
+ """Basecall POD5 inputs using a configuration file.
22
+
23
+ Args:
24
+ config_path: Path to the basecall configuration file.
21
25
  """
22
26
  # Default params
23
- bam_suffix = '.bam' # If different, change from here.
27
+ bam_suffix = ".bam" # If different, change from here.
24
28
 
25
29
  # Load experiment config parameters into global variables
26
30
  experiment_config = LoadExperimentConfig(config_path)
@@ -30,66 +34,89 @@ def basecall_pod5s(config_path):
30
34
  default_value = None
31
35
 
32
36
  # General config variable init
33
- input_data_path = Path(var_dict.get('input_data_path', default_value)) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
34
- output_directory = Path(var_dict.get('output_directory', default_value)) # Path to the output directory to make for the analysis. Necessary.
35
- model = var_dict.get('model', default_value) # needed for dorado basecaller
36
- model_dir = Path(var_dict.get('model_dir', default_value)) # model directory
37
- barcode_kit = var_dict.get('barcode_kit', default_value) # needed for dorado basecaller
38
- barcode_both_ends = var_dict.get('barcode_both_ends', default_value) # dorado demultiplexing
39
- trim = var_dict.get('trim', default_value) # dorado adapter and barcode removal
40
- device = var_dict.get('device', 'auto')
37
+ input_data_path = Path(
38
+ var_dict.get("input_data_path", default_value)
39
+ ) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
40
+ output_directory = Path(
41
+ var_dict.get("output_directory", default_value)
42
+ ) # Path to the output directory to make for the analysis. Necessary.
43
+ model = var_dict.get("model", default_value) # needed for dorado basecaller
44
+ model_dir = Path(var_dict.get("model_dir", default_value)) # model directory
45
+ barcode_kit = var_dict.get("barcode_kit", default_value) # needed for dorado basecaller
46
+ barcode_both_ends = var_dict.get("barcode_both_ends", default_value) # dorado demultiplexing
47
+ trim = var_dict.get("trim", default_value) # dorado adapter and barcode removal
48
+ device = var_dict.get("device", "auto")
41
49
 
42
50
  # Modified basecalling specific variable init
43
- filter_threshold = var_dict.get('filter_threshold', default_value)
44
- m6A_threshold = var_dict.get('m6A_threshold', default_value)
45
- m5C_threshold = var_dict.get('m5C_threshold', default_value)
46
- hm5C_threshold = var_dict.get('hm5C_threshold', default_value)
51
+ filter_threshold = var_dict.get("filter_threshold", default_value)
52
+ m6A_threshold = var_dict.get("m6A_threshold", default_value)
53
+ m5C_threshold = var_dict.get("m5C_threshold", default_value)
54
+ hm5C_threshold = var_dict.get("hm5C_threshold", default_value)
47
55
  thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
48
- mod_list = var_dict.get('mod_list', default_value)
49
-
56
+ mod_list = var_dict.get("mod_list", default_value)
57
+
50
58
  # Make initial output directory
51
59
  make_dirs([output_directory])
52
60
 
53
61
  # Get the input filetype
54
62
  if input_data_path.is_file():
55
63
  input_data_filetype = input_data_path.suffixes[0]
56
- input_is_pod5 = input_data_filetype in ['.pod5','.p5']
57
- input_is_fast5 = input_data_filetype in ['.fast5','.f5']
64
+ input_is_pod5 = input_data_filetype in [".pod5", ".p5"]
65
+ input_is_fast5 = input_data_filetype in [".fast5", ".f5"]
58
66
 
59
67
  elif input_data_path.is_dir():
60
68
  # Get the file names in the input data dir
61
69
  input_files = input_data_path.iterdir()
62
- input_is_pod5 = sum([True for file in input_files if '.pod5' in file or '.p5' in file])
63
- input_is_fast5 = sum([True for file in input_files if '.fast5' in file or '.f5' in file])
70
+ input_is_pod5 = sum([True for file in input_files if ".pod5" in file or ".p5" in file])
71
+ input_is_fast5 = sum([True for file in input_files if ".fast5" in file or ".f5" in file])
64
72
 
65
73
  # If the input files are not pod5 files, and they are fast5 files, convert the files to a pod5 file before proceeding.
66
74
  if input_is_fast5 and not input_is_pod5:
67
75
  # take the input directory of fast5 files and write out a single pod5 file into the output directory.
68
- output_pod5 = output_directory / 'FAST5s_to_POD5.pod5'
69
- print(f'Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}')
76
+ output_pod5 = output_directory / "FAST5s_to_POD5.pod5"
77
+ logger.info(
78
+ f"Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}"
79
+ )
70
80
  fast5_to_pod5(input_data_path, output_pod5)
71
81
  # Reassign the pod5_dir variable to point to the new pod5 file.
72
82
  input_data_path = output_pod5
73
83
 
74
84
  model_basename = model.name
75
- model_basename = model_basename.replace('.', '_')
85
+ model_basename = model_basename.replace(".", "_")
76
86
 
77
87
  if mod_list:
78
88
  mod_string = "_".join(mod_list)
79
89
  bam = output_directory / f"{model_basename}_{mod_string}_calls"
80
- modcall(model, input_data_path, barcode_kit, mod_list, bam, bam_suffix, barcode_both_ends, trim, device)
90
+ modcall(
91
+ model,
92
+ input_data_path,
93
+ barcode_kit,
94
+ mod_list,
95
+ bam,
96
+ bam_suffix,
97
+ barcode_both_ends,
98
+ trim,
99
+ device,
100
+ )
81
101
  else:
82
102
  bam = output_directory / f"{model_basename}_canonical_basecalls"
83
- canoncall(model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device)
103
+ canoncall(
104
+ model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device
105
+ )
84
106
 
85
107
 
86
108
  def fast5_to_pod5(
87
- fast5_dir: Union[str, Path, List[Union[str, Path]]],
88
- output_pod5: Union[str, Path] = "FAST5s_to_POD5.pod5"
109
+ fast5_dir: str | Path | Iterable[str | Path],
110
+ output_pod5: str | Path = "FAST5s_to_POD5.pod5",
89
111
  ) -> None:
90
- """
91
- Convert Nanopore FAST5 files (single file, list of files, or directory)
92
- into a single .pod5 output using the 'pod5 convert fast5' CLI tool.
112
+ """Convert FAST5 inputs into a single POD5 file.
113
+
114
+ Args:
115
+ fast5_dir: FAST5 file path, directory, or iterable of file paths to convert.
116
+ output_pod5: Output POD5 file path.
117
+
118
+ Raises:
119
+ FileNotFoundError: If no FAST5 files are found or the input path is invalid.
93
120
  """
94
121
 
95
122
  output_pod5 = str(output_pod5) # ensure string
@@ -122,45 +149,51 @@ def fast5_to_pod5(
122
149
 
123
150
  raise FileNotFoundError(f"Input path invalid: {fast5_dir}")
124
151
 
125
- def subsample_pod5(pod5_path, read_name_path, output_directory):
126
- """
127
- Takes a POD5 file and a text file containing read names of interest and writes out a subsampled POD5 for just those reads.
128
- This is a useful function when you have a list of read names that mapped to a region of interest that you want to reanalyze from the pod5 level.
129
152
 
130
- Parameters:
131
- pod5_path (str): File path to the POD5 file (or directory of multiple pod5 files) to subsample.
132
- read_name_path (str | int): File path to a text file of read names. One read name per line. If an int value is passed, a random subset of that many reads will occur
133
- output_directory (str): A file path to the directory to output the file.
153
+ def subsample_pod5(
154
+ pod5_path: str | Path,
155
+ read_name_path: str | int,
156
+ output_directory: str | Path,
157
+ ) -> None:
158
+ """Write a subsampled POD5 containing selected reads.
134
159
 
135
- Returns:
136
- None
160
+ Args:
161
+ pod5_path: POD5 file path or directory of POD5 files to subsample.
162
+ read_name_path: Path to a text file of read names (one per line) or an integer
163
+ specifying a random subset size.
164
+ output_directory: Directory to write the subsampled POD5 file.
137
165
  """
138
166
 
139
167
  if os.path.isdir(pod5_path):
140
168
  pod5_path_is_dir = True
141
- input_pod5_base = 'input_pod5s.pod5'
169
+ input_pod5_base = "input_pod5s.pod5"
142
170
  files = os.listdir(pod5_path)
143
- pod5_files = [os.path.join(pod5_path, file) for file in files if '.pod5' in file]
171
+ pod5_files = [os.path.join(pod5_path, file) for file in files if ".pod5" in file]
144
172
  pod5_files.sort()
145
- print(f'Found input pod5s: {pod5_files}')
146
-
173
+ logger.info(f"Found input pod5s: {pod5_files}")
174
+
147
175
  elif os.path.exists(pod5_path):
148
176
  pod5_path_is_dir = False
149
177
  input_pod5_base = os.path.basename(pod5_path)
150
178
 
151
179
  else:
152
- print('Error: pod5_path passed does not exist')
180
+ logger.error("pod5_path passed does not exist")
153
181
  return None
154
182
 
155
- if type(read_name_path) == str:
183
+ if type(read_name_path) is str:
156
184
  input_read_name_base = os.path.basename(read_name_path)
157
- output_base = input_pod5_base.split('.pod5')[0] + '_' + input_read_name_base.split('.txt')[0] + '_subsampled.pod5'
185
+ output_base = (
186
+ input_pod5_base.split(".pod5")[0]
187
+ + "_"
188
+ + input_read_name_base.split(".txt")[0]
189
+ + "_subsampled.pod5"
190
+ )
158
191
 
159
192
  # extract read names into a list of strings
160
- with open(read_name_path, 'r') as file:
193
+ with open(read_name_path, "r") as file:
161
194
  read_names = [line.strip() for line in file]
162
195
 
163
- print(f'Looking for read_ids: {read_names}')
196
+ logger.info(f"Looking for read_ids: {read_names}")
164
197
  read_records = []
165
198
 
166
199
  if pod5_path_is_dir:
@@ -168,22 +201,25 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
168
201
  with p5.Reader(input_pod5) as reader:
169
202
  try:
170
203
  for read_record in reader.reads(selection=read_names, missing_ok=True):
171
- read_records.append(read_record.to_read())
172
- print(f'Found read in {input_pod5}: {read_record.read_id}')
173
- except:
174
- print('Skipping pod5, could not find reads')
175
- else:
204
+ read_records.append(read_record.to_read())
205
+ logger.info(f"Found read in {input_pod5}: {read_record.read_id}")
206
+ except Exception:
207
+ logger.warning("Skipping pod5, could not find reads")
208
+ else:
176
209
  with p5.Reader(pod5_path) as reader:
177
210
  try:
178
211
  for read_record in reader.reads(selection=read_names):
179
212
  read_records.append(read_record.to_read())
180
- print(f'Found read in {input_pod5}: {read_record}')
181
- except:
182
- print('Could not find reads')
213
+ logger.info(f"Found read in {input_pod5}: {read_record}")
214
+ except Exception:
215
+ logger.warning("Could not find reads")
183
216
 
184
- elif type(read_name_path) == int:
217
+ elif type(read_name_path) is int:
185
218
  import random
186
- output_base = input_pod5_base.split('.pod5')[0] + f'_{read_name_path}_randomly_subsampled.pod5'
219
+
220
+ output_base = (
221
+ input_pod5_base.split(".pod5")[0] + f"_{read_name_path}_randomly_subsampled.pod5"
222
+ )
187
223
  all_read_records = []
188
224
 
189
225
  if pod5_path_is_dir:
@@ -191,7 +227,7 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
191
227
  random.shuffle(pod5_files)
192
228
  for input_pod5 in pod5_files:
193
229
  # iterate over the input pod5s
194
- print(f'Opening pod5 file {input_pod5}')
230
+ logger.info(f"Opening pod5 file {input_pod5}")
195
231
  with p5.Reader(pod5_path) as reader:
196
232
  for read_record in reader.reads():
197
233
  all_read_records.append(read_record.to_read())
@@ -202,9 +238,11 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
202
238
  if read_name_path <= len(all_read_records):
203
239
  read_records = random.sample(all_read_records, read_name_path)
204
240
  else:
205
- print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
241
+ logger.info(
242
+ "Trying to sample more reads than are contained in the input pod5s, taking all reads"
243
+ )
206
244
  read_records = all_read_records
207
-
245
+
208
246
  else:
209
247
  with p5.Reader(pod5_path) as reader:
210
248
  for read_record in reader.reads():
@@ -214,11 +252,13 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
214
252
  # if the subsampling amount is less than the record amount in the file, randomly subsample the reads
215
253
  read_records = random.sample(all_read_records, read_name_path)
216
254
  else:
217
- print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
255
+ logger.info(
256
+ "Trying to sample more reads than are contained in the input pod5s, taking all reads"
257
+ )
218
258
  read_records = all_read_records
219
259
 
220
260
  output_pod5 = os.path.join(output_directory, output_base)
221
261
 
222
262
  # Write the subsampled POD5
223
263
  with p5.Writer(output_pod5) as writer:
224
- writer.add_reads(read_records)
264
+ writer.add_reads(read_records)