smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,19 @@
1
- import os
2
1
  import subprocess
3
- import glob
4
- import zipfile
5
- from pathlib import Path
6
2
 
7
- def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassified=True, modkit_summary=False, threads=None):
3
+ from smftools.logging_utils import get_logger
4
+
5
+ logger = get_logger(__name__)
6
+
7
+
8
+ def extract_mods(
9
+ thresholds,
10
+ mod_tsv_dir,
11
+ split_dir,
12
+ bam_suffix,
13
+ skip_unclassified=True,
14
+ modkit_summary=False,
15
+ threads=None,
16
+ ):
8
17
  """
9
18
  Takes all of the aligned, sorted, split modified BAM files and runs Nanopore Modkit Extract to load the modification data into zipped TSV files
10
19
 
@@ -23,10 +32,12 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
23
32
 
24
33
  """
25
34
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
26
- bam_files = sorted(p for p in split_dir.iterdir() if bam_suffix in p.name and '.bai' not in p.name)
35
+ bam_files = sorted(
36
+ p for p in split_dir.iterdir() if bam_suffix in p.name and ".bai" not in p.name
37
+ )
27
38
  if skip_unclassified:
28
39
  bam_files = [p for p in bam_files if "unclassified" not in p.name]
29
- print(f"Running modkit extract for the following bam files: {bam_files}")
40
+ logger.info(f"Running modkit extract for the following bam files: {bam_files}")
30
41
 
31
42
  if threads:
32
43
  threads = str(threads)
@@ -34,14 +45,14 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
34
45
  pass
35
46
 
36
47
  for input_file in bam_files:
37
- print(input_file)
48
+ logger.debug(input_file)
38
49
  # Construct the output TSV file path
39
50
  output_tsv = mod_tsv_dir / (input_file.stem + "_extract.tsv")
40
- output_tsv_gz = output_tsv.parent / (output_tsv.name + '.gz')
51
+ output_tsv_gz = output_tsv.parent / (output_tsv.name + ".gz")
41
52
  if output_tsv_gz.exists():
42
- print(f"{output_tsv_gz} already exists, skipping modkit extract")
53
+ logger.debug(f"{output_tsv_gz} already exists, skipping modkit extract")
43
54
  else:
44
- print(f"Extracting modification data from {input_file}")
55
+ logger.info(f"Extracting modification data from {input_file}")
45
56
  if modkit_summary:
46
57
  # Run modkit summary
47
58
  subprocess.run(["modkit", "summary", str(input_file)])
@@ -50,28 +61,43 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
50
61
  # Run modkit extract
51
62
  if threads:
52
63
  extract_command = [
53
- "modkit", "extract",
54
- "calls", "--mapped-only",
55
- "--filter-threshold", f'{filter_threshold}',
56
- "--mod-thresholds", f"m:{m5C_threshold}",
57
- "--mod-thresholds", f"a:{m6A_threshold}",
58
- "--mod-thresholds", f"h:{hm5C_threshold}",
59
- "-t", threads,
60
- str(input_file), str(output_tsv)
61
- ]
64
+ "modkit",
65
+ "extract",
66
+ "calls",
67
+ "--mapped-only",
68
+ "--filter-threshold",
69
+ f"{filter_threshold}",
70
+ "--mod-thresholds",
71
+ f"m:{m5C_threshold}",
72
+ "--mod-thresholds",
73
+ f"a:{m6A_threshold}",
74
+ "--mod-thresholds",
75
+ f"h:{hm5C_threshold}",
76
+ "-t",
77
+ threads,
78
+ str(input_file),
79
+ str(output_tsv),
80
+ ]
62
81
  else:
63
82
  extract_command = [
64
- "modkit", "extract",
65
- "calls", "--mapped-only",
66
- "--filter-threshold", f'{filter_threshold}',
67
- "--mod-thresholds", f"m:{m5C_threshold}",
68
- "--mod-thresholds", f"a:{m6A_threshold}",
69
- "--mod-thresholds", f"h:{hm5C_threshold}",
70
- str(input_file), str(output_tsv)
71
- ]
83
+ "modkit",
84
+ "extract",
85
+ "calls",
86
+ "--mapped-only",
87
+ "--filter-threshold",
88
+ f"{filter_threshold}",
89
+ "--mod-thresholds",
90
+ f"m:{m5C_threshold}",
91
+ "--mod-thresholds",
92
+ f"a:{m6A_threshold}",
93
+ "--mod-thresholds",
94
+ f"h:{hm5C_threshold}",
95
+ str(input_file),
96
+ str(output_tsv),
97
+ ]
72
98
  subprocess.run(extract_command)
73
99
  # Zip the output TSV
74
- print(f'zipping {output_tsv}')
100
+ logger.info(f"zipping {output_tsv}")
75
101
  if threads:
76
102
  zip_command = ["pigz", "-f", "-p", threads, str(output_tsv)]
77
103
  else:
@@ -79,30 +105,39 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
79
105
  subprocess.run(zip_command, check=True)
80
106
  return
81
107
 
108
+
82
109
  def make_modbed(aligned_sorted_output, thresholds, mod_bed_dir):
83
110
  """
84
111
  Generating position methylation summaries for each barcoded sample starting from the overall BAM file that was direct output of dorado aligner.
85
112
  Parameters:
86
113
  aligned_sorted_output (str): A string representing the file path to the aligned_sorted non-split BAM file.
87
-
114
+
88
115
  Returns:
89
116
  None
90
117
  """
91
- import os
92
118
  import subprocess
93
-
119
+
94
120
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
95
121
  command = [
96
- "modkit", "pileup", str(aligned_sorted_output), str(mod_bed_dir),
97
- "--partition-tag", "BC",
122
+ "modkit",
123
+ "pileup",
124
+ str(aligned_sorted_output),
125
+ str(mod_bed_dir),
126
+ "--partition-tag",
127
+ "BC",
98
128
  "--only-tabs",
99
- "--filter-threshold", f'{filter_threshold}',
100
- "--mod-thresholds", f"m:{m5C_threshold}",
101
- "--mod-thresholds", f"a:{m6A_threshold}",
102
- "--mod-thresholds", f"h:{hm5C_threshold}"
129
+ "--filter-threshold",
130
+ f"{filter_threshold}",
131
+ "--mod-thresholds",
132
+ f"m:{m5C_threshold}",
133
+ "--mod-thresholds",
134
+ f"a:{m6A_threshold}",
135
+ "--mod-thresholds",
136
+ f"h:{hm5C_threshold}",
103
137
  ]
104
138
  subprocess.run(command)
105
139
 
140
+
106
141
  def modQC(aligned_sorted_output, thresholds):
107
142
  """
108
143
  Output the percentile of bases falling at a call threshold (threshold is a probability between 0-1) for the overall BAM file.
@@ -120,10 +155,16 @@ def modQC(aligned_sorted_output, thresholds):
120
155
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
121
156
  subprocess.run(["modkit", "sample-probs", str(aligned_sorted_output)])
122
157
  command = [
123
- "modkit", "summary", str(aligned_sorted_output),
124
- "--filter-threshold", f"{filter_threshold}",
125
- "--mod-thresholds", f"m:{m5C_threshold}",
126
- "--mod-thresholds", f"a:{m6A_threshold}",
127
- "--mod-thresholds", f"h:{hm5C_threshold}"
158
+ "modkit",
159
+ "summary",
160
+ str(aligned_sorted_output),
161
+ "--filter-threshold",
162
+ f"{filter_threshold}",
163
+ "--mod-thresholds",
164
+ f"m:{m5C_threshold}",
165
+ "--mod-thresholds",
166
+ f"a:{m6A_threshold}",
167
+ "--mod-thresholds",
168
+ f"h:{hm5C_threshold}",
128
169
  ]
129
- subprocess.run(command)
170
+ subprocess.run(command)
@@ -1,10 +1,15 @@
1
- import numpy as np
1
+ import concurrent.futures
2
+ import os
3
+
2
4
  import anndata as ad
5
+ import numpy as np
3
6
 
4
- import os
5
- import concurrent.futures
7
+ from smftools.logging_utils import get_logger
8
+
9
+ logger = get_logger(__name__)
6
10
 
7
- def one_hot_encode(sequence, device='auto'):
11
+
12
+ def one_hot_encode(sequence, device="auto"):
8
13
  """
9
14
  One-hot encodes a DNA sequence.
10
15
 
@@ -14,7 +19,7 @@ def one_hot_encode(sequence, device='auto'):
14
19
  Returns:
15
20
  ndarray: Flattened one-hot encoded representation of the input sequence.
16
21
  """
17
- mapping = np.array(['A', 'C', 'G', 'T', 'N'])
22
+ mapping = np.array(["A", "C", "G", "T", "N"])
18
23
 
19
24
  # Ensure input is a list of characters
20
25
  if not isinstance(sequence, list):
@@ -22,14 +27,14 @@ def one_hot_encode(sequence, device='auto'):
22
27
 
23
28
  # Handle empty sequences
24
29
  if len(sequence) == 0:
25
- print("Warning: Empty sequence encountered in one_hot_encode()")
30
+ logger.warning("Empty sequence encountered in one_hot_encode()")
26
31
  return np.zeros(len(mapping)) # Return empty encoding instead of failing
27
32
 
28
33
  # Convert sequence to NumPy array
29
- seq_array = np.array(sequence, dtype='<U1')
34
+ seq_array = np.array(sequence, dtype="<U1")
30
35
 
31
36
  # Replace invalid bases with 'N'
32
- seq_array = np.where(np.isin(seq_array, mapping), seq_array, 'N')
37
+ seq_array = np.where(np.isin(seq_array, mapping), seq_array, "N")
33
38
 
34
39
  # Create one-hot encoding matrix
35
40
  one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
@@ -37,6 +42,7 @@ def one_hot_encode(sequence, device='auto'):
37
42
  # Flatten and return
38
43
  return one_hot_matrix.flatten()
39
44
 
45
+
40
46
  def one_hot_decode(ohe_array):
41
47
  """
42
48
  Takes a flattened one hot encoded array and returns the sequence string from that array.
@@ -47,20 +53,21 @@ def one_hot_decode(ohe_array):
47
53
  sequence (str): Sequence string of the one hot encoded array
48
54
  """
49
55
  # Define the mapping of one-hot encoded indices to DNA bases
50
- mapping = ['A', 'C', 'G', 'T', 'N']
51
-
56
+ mapping = ["A", "C", "G", "T", "N"]
57
+
52
58
  # Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
53
59
  one_hot_matrix = ohe_array.reshape(-1, 5)
54
-
60
+
55
61
  # Get the index of the maximum value (which will be 1) in each row
56
62
  decoded_indices = np.argmax(one_hot_matrix, axis=1)
57
-
63
+
58
64
  # Map the indices back to the corresponding bases
59
65
  sequence_list = [mapping[i] for i in decoded_indices]
60
- sequence = ''.join(sequence_list)
61
-
66
+ sequence = "".join(sequence_list)
67
+
62
68
  return sequence
63
69
 
70
+
64
71
  def ohe_layers_decode(adata, obs_names):
65
72
  """
66
73
  Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
@@ -72,7 +79,7 @@ def ohe_layers_decode(adata, obs_names):
72
79
  sequences (list of str): List of strings of the one hot encoded array
73
80
  """
74
81
  # Define the mapping of one-hot encoded indices to DNA bases
75
- mapping = ['A', 'C', 'G', 'T', 'N']
82
+ mapping = ["A", "C", "G", "T", "N"]
76
83
 
77
84
  ohe_layers = [f"{base}_binary_encoding" for base in mapping]
78
85
  sequences = []
@@ -85,9 +92,10 @@ def ohe_layers_decode(adata, obs_names):
85
92
  ohe_array = np.array(ohe_list)
86
93
  sequence = one_hot_decode(ohe_array)
87
94
  sequences.append(sequence)
88
-
95
+
89
96
  return sequences
90
97
 
98
+
91
99
  def _encode_sequence(args):
92
100
  """Parallel helper function for one-hot encoding."""
93
101
  read_name, seq, device = args
@@ -97,18 +105,29 @@ def _encode_sequence(args):
97
105
  except Exception:
98
106
  return None # Skip invalid sequences
99
107
 
108
+
100
109
  def _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
101
110
  """Encodes a batch and writes to disk immediately."""
102
111
  batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
103
112
 
104
113
  if batch:
105
- save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad')
114
+ save_name = os.path.join(tmp_dir, f"tmp_{prefix}_{record}_{batch_number}.h5ad")
106
115
  tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
107
116
  tmp_ad.write_h5ad(save_name)
108
117
  return save_name
109
118
  return None
110
119
 
111
- def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000, progress_bar=None, device='auto', threads=None):
120
+
121
+ def ohe_batching(
122
+ base_identities,
123
+ tmp_dir,
124
+ record,
125
+ prefix="",
126
+ batch_size=100000,
127
+ progress_bar=None,
128
+ device="auto",
129
+ threads=None,
130
+ ):
112
131
  """
113
132
  Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
114
133
 
@@ -131,7 +150,9 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
131
150
  file_names = []
132
151
 
133
152
  # Step 1: Prepare Data for Parallel Encoding
134
- encoding_args = [(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None]
153
+ encoding_args = [
154
+ (read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None
155
+ ]
135
156
 
136
157
  # Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
137
158
  with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
@@ -141,7 +162,9 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
141
162
 
142
163
  if len(batch_data) >= batch_size:
143
164
  # Step 3: Process and Write Batch Immediately
144
- file_name = _encode_and_save_batch(batch_data.copy(), tmp_dir, prefix, record, batch_number)
165
+ file_name = _encode_and_save_batch(
166
+ batch_data.copy(), tmp_dir, prefix, record, batch_number
167
+ )
145
168
  if file_name:
146
169
  file_names.append(file_name)
147
170
 
@@ -157,4 +180,4 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
157
180
  if file_name:
158
181
  file_names.append(file_name)
159
182
 
160
- return file_names
183
+ return file_names
@@ -1,26 +1,29 @@
1
- from ..config import LoadExperimentConfig
2
- from ..readwrite import make_dirs
1
+ from __future__ import annotations
3
2
 
4
3
  import os
5
4
  import subprocess
6
5
  from pathlib import Path
6
+ from typing import Iterable
7
7
 
8
8
  import pod5 as p5
9
9
 
10
- from typing import Union, List
10
+ from smftools.logging_utils import get_logger
11
11
 
12
- def basecall_pod5s(config_path):
13
- """
14
- Basecall from pod5s given a config file.
12
+ from ..config import LoadExperimentConfig
13
+ from ..informatics.basecalling import canoncall, modcall
14
+ from ..readwrite import make_dirs
15
15
 
16
- Parameters:
17
- config_path (str): File path to the basecall configuration file
16
+ logger = get_logger(__name__)
18
17
 
19
- Returns:
20
- None
18
+
19
+ def basecall_pod5s(config_path: str | Path) -> None:
20
+ """Basecall POD5 inputs using a configuration file.
21
+
22
+ Args:
23
+ config_path: Path to the basecall configuration file.
21
24
  """
22
25
  # Default params
23
- bam_suffix = '.bam' # If different, change from here.
26
+ bam_suffix = ".bam" # If different, change from here.
24
27
 
25
28
  # Load experiment config parameters into global variables
26
29
  experiment_config = LoadExperimentConfig(config_path)
@@ -30,66 +33,89 @@ def basecall_pod5s(config_path):
30
33
  default_value = None
31
34
 
32
35
  # General config variable init
33
- input_data_path = Path(var_dict.get('input_data_path', default_value)) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
34
- output_directory = Path(var_dict.get('output_directory', default_value)) # Path to the output directory to make for the analysis. Necessary.
35
- model = var_dict.get('model', default_value) # needed for dorado basecaller
36
- model_dir = Path(var_dict.get('model_dir', default_value)) # model directory
37
- barcode_kit = var_dict.get('barcode_kit', default_value) # needed for dorado basecaller
38
- barcode_both_ends = var_dict.get('barcode_both_ends', default_value) # dorado demultiplexing
39
- trim = var_dict.get('trim', default_value) # dorado adapter and barcode removal
40
- device = var_dict.get('device', 'auto')
36
+ input_data_path = Path(
37
+ var_dict.get("input_data_path", default_value)
38
+ ) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
39
+ output_directory = Path(
40
+ var_dict.get("output_directory", default_value)
41
+ ) # Path to the output directory to make for the analysis. Necessary.
42
+ model = var_dict.get("model", default_value) # needed for dorado basecaller
43
+ model_dir = Path(var_dict.get("model_dir", default_value)) # model directory
44
+ barcode_kit = var_dict.get("barcode_kit", default_value) # needed for dorado basecaller
45
+ barcode_both_ends = var_dict.get("barcode_both_ends", default_value) # dorado demultiplexing
46
+ trim = var_dict.get("trim", default_value) # dorado adapter and barcode removal
47
+ device = var_dict.get("device", "auto")
41
48
 
42
49
  # Modified basecalling specific variable init
43
- filter_threshold = var_dict.get('filter_threshold', default_value)
44
- m6A_threshold = var_dict.get('m6A_threshold', default_value)
45
- m5C_threshold = var_dict.get('m5C_threshold', default_value)
46
- hm5C_threshold = var_dict.get('hm5C_threshold', default_value)
50
+ filter_threshold = var_dict.get("filter_threshold", default_value)
51
+ m6A_threshold = var_dict.get("m6A_threshold", default_value)
52
+ m5C_threshold = var_dict.get("m5C_threshold", default_value)
53
+ hm5C_threshold = var_dict.get("hm5C_threshold", default_value)
47
54
  thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
48
- mod_list = var_dict.get('mod_list', default_value)
49
-
55
+ mod_list = var_dict.get("mod_list", default_value)
56
+
50
57
  # Make initial output directory
51
58
  make_dirs([output_directory])
52
59
 
53
60
  # Get the input filetype
54
61
  if input_data_path.is_file():
55
62
  input_data_filetype = input_data_path.suffixes[0]
56
- input_is_pod5 = input_data_filetype in ['.pod5','.p5']
57
- input_is_fast5 = input_data_filetype in ['.fast5','.f5']
63
+ input_is_pod5 = input_data_filetype in [".pod5", ".p5"]
64
+ input_is_fast5 = input_data_filetype in [".fast5", ".f5"]
58
65
 
59
66
  elif input_data_path.is_dir():
60
67
  # Get the file names in the input data dir
61
68
  input_files = input_data_path.iterdir()
62
- input_is_pod5 = sum([True for file in input_files if '.pod5' in file or '.p5' in file])
63
- input_is_fast5 = sum([True for file in input_files if '.fast5' in file or '.f5' in file])
69
+ input_is_pod5 = sum([True for file in input_files if ".pod5" in file or ".p5" in file])
70
+ input_is_fast5 = sum([True for file in input_files if ".fast5" in file or ".f5" in file])
64
71
 
65
72
  # If the input files are not pod5 files, and they are fast5 files, convert the files to a pod5 file before proceeding.
66
73
  if input_is_fast5 and not input_is_pod5:
67
74
  # take the input directory of fast5 files and write out a single pod5 file into the output directory.
68
- output_pod5 = output_directory / 'FAST5s_to_POD5.pod5'
69
- print(f'Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}')
75
+ output_pod5 = output_directory / "FAST5s_to_POD5.pod5"
76
+ logger.info(
77
+ f"Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}"
78
+ )
70
79
  fast5_to_pod5(input_data_path, output_pod5)
71
80
  # Reassign the pod5_dir variable to point to the new pod5 file.
72
81
  input_data_path = output_pod5
73
82
 
74
83
  model_basename = model.name
75
- model_basename = model_basename.replace('.', '_')
84
+ model_basename = model_basename.replace(".", "_")
76
85
 
77
86
  if mod_list:
78
87
  mod_string = "_".join(mod_list)
79
88
  bam = output_directory / f"{model_basename}_{mod_string}_calls"
80
- modcall(model, input_data_path, barcode_kit, mod_list, bam, bam_suffix, barcode_both_ends, trim, device)
89
+ modcall(
90
+ model,
91
+ input_data_path,
92
+ barcode_kit,
93
+ mod_list,
94
+ bam,
95
+ bam_suffix,
96
+ barcode_both_ends,
97
+ trim,
98
+ device,
99
+ )
81
100
  else:
82
101
  bam = output_directory / f"{model_basename}_canonical_basecalls"
83
- canoncall(model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device)
102
+ canoncall(
103
+ model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device
104
+ )
84
105
 
85
106
 
86
107
  def fast5_to_pod5(
87
- fast5_dir: Union[str, Path, List[Union[str, Path]]],
88
- output_pod5: Union[str, Path] = "FAST5s_to_POD5.pod5"
108
+ fast5_dir: str | Path | Iterable[str | Path],
109
+ output_pod5: str | Path = "FAST5s_to_POD5.pod5",
89
110
  ) -> None:
90
- """
91
- Convert Nanopore FAST5 files (single file, list of files, or directory)
92
- into a single .pod5 output using the 'pod5 convert fast5' CLI tool.
111
+ """Convert FAST5 inputs into a single POD5 file.
112
+
113
+ Args:
114
+ fast5_dir: FAST5 file path, directory, or iterable of file paths to convert.
115
+ output_pod5: Output POD5 file path.
116
+
117
+ Raises:
118
+ FileNotFoundError: If no FAST5 files are found or the input path is invalid.
93
119
  """
94
120
 
95
121
  output_pod5 = str(output_pod5) # ensure string
@@ -122,45 +148,51 @@ def fast5_to_pod5(
122
148
 
123
149
  raise FileNotFoundError(f"Input path invalid: {fast5_dir}")
124
150
 
125
- def subsample_pod5(pod5_path, read_name_path, output_directory):
126
- """
127
- Takes a POD5 file and a text file containing read names of interest and writes out a subsampled POD5 for just those reads.
128
- This is a useful function when you have a list of read names that mapped to a region of interest that you want to reanalyze from the pod5 level.
129
151
 
130
- Parameters:
131
- pod5_path (str): File path to the POD5 file (or directory of multiple pod5 files) to subsample.
132
- read_name_path (str | int): File path to a text file of read names. One read name per line. If an int value is passed, a random subset of that many reads will occur
133
- output_directory (str): A file path to the directory to output the file.
152
+ def subsample_pod5(
153
+ pod5_path: str | Path,
154
+ read_name_path: str | int,
155
+ output_directory: str | Path,
156
+ ) -> None:
157
+ """Write a subsampled POD5 containing selected reads.
134
158
 
135
- Returns:
136
- None
159
+ Args:
160
+ pod5_path: POD5 file path or directory of POD5 files to subsample.
161
+ read_name_path: Path to a text file of read names (one per line) or an integer
162
+ specifying a random subset size.
163
+ output_directory: Directory to write the subsampled POD5 file.
137
164
  """
138
165
 
139
166
  if os.path.isdir(pod5_path):
140
167
  pod5_path_is_dir = True
141
- input_pod5_base = 'input_pod5s.pod5'
168
+ input_pod5_base = "input_pod5s.pod5"
142
169
  files = os.listdir(pod5_path)
143
- pod5_files = [os.path.join(pod5_path, file) for file in files if '.pod5' in file]
170
+ pod5_files = [os.path.join(pod5_path, file) for file in files if ".pod5" in file]
144
171
  pod5_files.sort()
145
- print(f'Found input pod5s: {pod5_files}')
146
-
172
+ logger.info(f"Found input pod5s: {pod5_files}")
173
+
147
174
  elif os.path.exists(pod5_path):
148
175
  pod5_path_is_dir = False
149
176
  input_pod5_base = os.path.basename(pod5_path)
150
177
 
151
178
  else:
152
- print('Error: pod5_path passed does not exist')
179
+ logger.error("pod5_path passed does not exist")
153
180
  return None
154
181
 
155
- if type(read_name_path) == str:
182
+ if type(read_name_path) is str:
156
183
  input_read_name_base = os.path.basename(read_name_path)
157
- output_base = input_pod5_base.split('.pod5')[0] + '_' + input_read_name_base.split('.txt')[0] + '_subsampled.pod5'
184
+ output_base = (
185
+ input_pod5_base.split(".pod5")[0]
186
+ + "_"
187
+ + input_read_name_base.split(".txt")[0]
188
+ + "_subsampled.pod5"
189
+ )
158
190
 
159
191
  # extract read names into a list of strings
160
- with open(read_name_path, 'r') as file:
192
+ with open(read_name_path, "r") as file:
161
193
  read_names = [line.strip() for line in file]
162
194
 
163
- print(f'Looking for read_ids: {read_names}')
195
+ logger.info(f"Looking for read_ids: {read_names}")
164
196
  read_records = []
165
197
 
166
198
  if pod5_path_is_dir:
@@ -168,22 +200,25 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
168
200
  with p5.Reader(input_pod5) as reader:
169
201
  try:
170
202
  for read_record in reader.reads(selection=read_names, missing_ok=True):
171
- read_records.append(read_record.to_read())
172
- print(f'Found read in {input_pod5}: {read_record.read_id}')
173
- except:
174
- print('Skipping pod5, could not find reads')
175
- else:
203
+ read_records.append(read_record.to_read())
204
+ logger.info(f"Found read in {input_pod5}: {read_record.read_id}")
205
+ except Exception:
206
+ logger.warning("Skipping pod5, could not find reads")
207
+ else:
176
208
  with p5.Reader(pod5_path) as reader:
177
209
  try:
178
210
  for read_record in reader.reads(selection=read_names):
179
211
  read_records.append(read_record.to_read())
180
- print(f'Found read in {input_pod5}: {read_record}')
181
- except:
182
- print('Could not find reads')
212
+ logger.info(f"Found read in {input_pod5}: {read_record}")
213
+ except Exception:
214
+ logger.warning("Could not find reads")
183
215
 
184
- elif type(read_name_path) == int:
216
+ elif type(read_name_path) is int:
185
217
  import random
186
- output_base = input_pod5_base.split('.pod5')[0] + f'_{read_name_path}_randomly_subsampled.pod5'
218
+
219
+ output_base = (
220
+ input_pod5_base.split(".pod5")[0] + f"_{read_name_path}_randomly_subsampled.pod5"
221
+ )
187
222
  all_read_records = []
188
223
 
189
224
  if pod5_path_is_dir:
@@ -191,7 +226,7 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
191
226
  random.shuffle(pod5_files)
192
227
  for input_pod5 in pod5_files:
193
228
  # iterate over the input pod5s
194
- print(f'Opening pod5 file {input_pod5}')
229
+ logger.info(f"Opening pod5 file {input_pod5}")
195
230
  with p5.Reader(pod5_path) as reader:
196
231
  for read_record in reader.reads():
197
232
  all_read_records.append(read_record.to_read())
@@ -202,9 +237,11 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
202
237
  if read_name_path <= len(all_read_records):
203
238
  read_records = random.sample(all_read_records, read_name_path)
204
239
  else:
205
- print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
240
+ logger.info(
241
+ "Trying to sample more reads than are contained in the input pod5s, taking all reads"
242
+ )
206
243
  read_records = all_read_records
207
-
244
+
208
245
  else:
209
246
  with p5.Reader(pod5_path) as reader:
210
247
  for read_record in reader.reads():
@@ -214,11 +251,13 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
214
251
  # if the subsampling amount is less than the record amount in the file, randomly subsample the reads
215
252
  read_records = random.sample(all_read_records, read_name_path)
216
253
  else:
217
- print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
254
+ logger.info(
255
+ "Trying to sample more reads than are contained in the input pod5s, taking all reads"
256
+ )
218
257
  read_records = all_read_records
219
258
 
220
259
  output_pod5 = os.path.join(output_directory, output_base)
221
260
 
222
261
  # Write the subsampled POD5
223
262
  with p5.Writer(output_pod5) as writer:
224
- writer.add_reads(read_records)
263
+ writer.add_reads(read_records)