smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,64 +1,88 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent.futures
2
4
  import gc
3
- from .bam_functions import count_aligned_reads
5
+ import re
6
+ import shutil
7
+ from pathlib import Path
8
+ from typing import Iterable, Optional, Union
9
+
10
+ import numpy as np
4
11
  import pandas as pd
5
12
  from tqdm import tqdm
6
- import numpy as np
7
- from pathlib import Path
8
- from typing import Union, Iterable, Optional
9
- import shutil
10
13
 
11
- def filter_bam_records(bam, mapping_threshold):
14
+ from smftools.logging_utils import get_logger
15
+
16
+ from .bam_functions import count_aligned_reads
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
12
22
  """Processes a single BAM file, counts reads, and determines records to analyze."""
13
- aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
14
-
23
+ aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
24
+ bam, samtools_backend
25
+ )
26
+
15
27
  total_reads = aligned_reads_count + unaligned_reads_count
16
28
  percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
17
- print(f'{percent_aligned:.2f}% of reads in {bam} aligned successfully')
29
+ logger.info(f"{percent_aligned:.2f}% of reads in {bam} aligned successfully")
18
30
 
19
31
  records = []
20
32
  for record, (count, percentage) in record_counts_dict.items():
21
- print(f'{count} reads mapped to reference {record}. This is {percentage*100:.2f}% of all mapped reads in {bam}')
33
+ logger.info(
34
+ f"{count} reads mapped to reference {record}. This is {percentage * 100:.2f}% of all mapped reads in {bam}"
35
+ )
22
36
  if percentage >= mapping_threshold:
23
37
  records.append(record)
24
-
38
+
25
39
  return set(records)
26
40
 
27
- def parallel_filter_bams(bam_path_list, mapping_threshold):
41
+
42
+ def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
28
43
  """Parallel processing for multiple BAM files."""
29
44
  records_to_analyze = set()
30
45
 
31
46
  with concurrent.futures.ProcessPoolExecutor() as executor:
32
- results = executor.map(filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list))
47
+ results = executor.map(
48
+ filter_bam_records,
49
+ bam_path_list,
50
+ [mapping_threshold] * len(bam_path_list),
51
+ [samtools_backend] * len(bam_path_list),
52
+ )
33
53
 
34
54
  # Aggregate results
35
55
  for result in results:
36
56
  records_to_analyze.update(result)
37
57
 
38
- print(f'Records to analyze: {records_to_analyze}')
58
+ logger.info(f"Records to analyze: {records_to_analyze}")
39
59
  return records_to_analyze
40
60
 
61
+
41
62
  def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
42
63
  """
43
64
  Loads and filters a single TSV file based on chromosome and position criteria.
44
65
  """
45
- temp_df = pd.read_csv(tsv, sep='\t', header=0)
66
+ temp_df = pd.read_csv(tsv, sep="\t", header=0)
46
67
  filtered_records = {}
47
68
 
48
69
  for record in records_to_analyze:
49
70
  if record not in reference_dict:
50
71
  continue
51
-
72
+
52
73
  ref_length = reference_dict[record][0]
53
- filtered_df = temp_df[(temp_df['chrom'] == record) &
54
- (temp_df['ref_position'] >= 0) &
55
- (temp_df['ref_position'] < ref_length)]
74
+ filtered_df = temp_df[
75
+ (temp_df["chrom"] == record)
76
+ & (temp_df["ref_position"] >= 0)
77
+ & (temp_df["ref_position"] < ref_length)
78
+ ]
56
79
 
57
80
  if not filtered_df.empty:
58
81
  filtered_records[record] = {sample_index: filtered_df}
59
82
 
60
83
  return filtered_records
61
84
 
85
+
62
86
  def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
63
87
  """
64
88
  Loads and filters TSV files in parallel.
@@ -78,52 +102,60 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
78
102
 
79
103
  with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
80
104
  futures = {
81
- executor.submit(process_tsv, tsv, records_to_analyze, reference_dict, sample_index): sample_index
105
+ executor.submit(
106
+ process_tsv, tsv, records_to_analyze, reference_dict, sample_index
107
+ ): sample_index
82
108
  for sample_index, tsv in enumerate(tsv_batch)
83
109
  }
84
110
 
85
- for future in tqdm(concurrent.futures.as_completed(futures), desc=f'Processing batch {batch}', total=batch_size):
111
+ for future in tqdm(
112
+ concurrent.futures.as_completed(futures),
113
+ desc=f"Processing batch {batch}",
114
+ total=batch_size,
115
+ ):
86
116
  result = future.result()
87
117
  for record, sample_data in result.items():
88
118
  dict_total[record].update(sample_data)
89
119
 
90
120
  return dict_total
91
121
 
122
+
92
123
  def update_dict_to_skip(dict_to_skip, detected_modifications):
93
124
  """
94
125
  Updates the dict_to_skip set based on the detected modifications.
95
-
126
+
96
127
  Parameters:
97
128
  dict_to_skip (set): The initial set of dictionary indices to skip.
98
129
  detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
99
-
130
+
100
131
  Returns:
101
132
  set: The updated dict_to_skip set.
102
133
  """
103
134
  # Define which indices correspond to modification-specific or strand-specific dictionaries
104
- A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
105
- C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
106
- combined_dicts = {7, 8} # Combined strand dictionaries
135
+ A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
136
+ C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
137
+ combined_dicts = {7, 8} # Combined strand dictionaries
107
138
 
108
139
  # If '6mA' is present, remove the A_stranded indices from the skip set
109
- if '6mA' in detected_modifications:
140
+ if "6mA" in detected_modifications:
110
141
  dict_to_skip -= A_stranded_dicts
111
142
  # If '5mC' is present, remove the C_stranded indices from the skip set
112
- if '5mC' in detected_modifications:
143
+ if "5mC" in detected_modifications:
113
144
  dict_to_skip -= C_stranded_dicts
114
145
  # If both modifications are present, remove the combined indices from the skip set
115
- if '6mA' in detected_modifications and '5mC' in detected_modifications:
146
+ if "6mA" in detected_modifications and "5mC" in detected_modifications:
116
147
  dict_to_skip -= combined_dicts
117
148
 
118
149
  return dict_to_skip
119
150
 
151
+
120
152
  def process_modifications_for_sample(args):
121
153
  """
122
154
  Processes a single (record, sample) pair to extract modification-specific data.
123
-
155
+
124
156
  Parameters:
125
157
  args: (record, sample_index, sample_df, mods, max_reference_length)
126
-
158
+
127
159
  Returns:
128
160
  (record, sample_index, result) where result is a dict with keys:
129
161
  'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
@@ -131,29 +163,30 @@ def process_modifications_for_sample(args):
131
163
  """
132
164
  record, sample_index, sample_df, mods, max_reference_length = args
133
165
  result = {}
134
- if '6mA' in mods:
135
- m6a_df = sample_df[sample_df['modified_primary_base'] == 'A']
136
- result['m6A'] = m6a_df
137
- result['m6A_minus'] = m6a_df[m6a_df['ref_strand'] == '-']
138
- result['m6A_plus'] = m6a_df[m6a_df['ref_strand'] == '+']
166
+ if "6mA" in mods:
167
+ m6a_df = sample_df[sample_df["modified_primary_base"] == "A"]
168
+ result["m6A"] = m6a_df
169
+ result["m6A_minus"] = m6a_df[m6a_df["ref_strand"] == "-"]
170
+ result["m6A_plus"] = m6a_df[m6a_df["ref_strand"] == "+"]
139
171
  m6a_df = None
140
172
  gc.collect()
141
- if '5mC' in mods:
142
- m5c_df = sample_df[sample_df['modified_primary_base'] == 'C']
143
- result['5mC'] = m5c_df
144
- result['5mC_minus'] = m5c_df[m5c_df['ref_strand'] == '-']
145
- result['5mC_plus'] = m5c_df[m5c_df['ref_strand'] == '+']
173
+ if "5mC" in mods:
174
+ m5c_df = sample_df[sample_df["modified_primary_base"] == "C"]
175
+ result["5mC"] = m5c_df
176
+ result["5mC_minus"] = m5c_df[m5c_df["ref_strand"] == "-"]
177
+ result["5mC_plus"] = m5c_df[m5c_df["ref_strand"] == "+"]
146
178
  m5c_df = None
147
179
  gc.collect()
148
- if '6mA' in mods and '5mC' in mods:
149
- result['combined_minus'] = []
150
- result['combined_plus'] = []
180
+ if "6mA" in mods and "5mC" in mods:
181
+ result["combined_minus"] = []
182
+ result["combined_plus"] = []
151
183
  return record, sample_index, result
152
184
 
185
+
153
186
  def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
154
187
  """
155
188
  Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
156
-
189
+
157
190
  Returns:
158
191
  processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
159
192
  """
@@ -164,18 +197,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
164
197
  processed_results = {}
165
198
  with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
166
199
  for record, sample_index, result in tqdm(
167
- executor.map(process_modifications_for_sample, tasks),
168
- total=len(tasks),
169
- desc="Processing modifications"):
200
+ executor.map(process_modifications_for_sample, tasks),
201
+ total=len(tasks),
202
+ desc="Processing modifications",
203
+ ):
170
204
  if record not in processed_results:
171
205
  processed_results[record] = {}
172
206
  processed_results[record][sample_index] = result
173
207
  return processed_results
174
208
 
209
+
175
210
  def merge_modification_results(processed_results, mods):
176
211
  """
177
212
  Merges individual sample results into global dictionaries.
178
-
213
+
179
214
  Returns:
180
215
  A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
181
216
  """
@@ -189,44 +224,52 @@ def merge_modification_results(processed_results, mods):
189
224
  combined_plus = {}
190
225
  for record, sample_results in processed_results.items():
191
226
  for sample_index, res in sample_results.items():
192
- if '6mA' in mods:
227
+ if "6mA" in mods:
193
228
  if record not in m6A_dict:
194
229
  m6A_dict[record], m6A_minus[record], m6A_plus[record] = {}, {}, {}
195
- m6A_dict[record][sample_index] = res.get('m6A', pd.DataFrame())
196
- m6A_minus[record][sample_index] = res.get('m6A_minus', pd.DataFrame())
197
- m6A_plus[record][sample_index] = res.get('m6A_plus', pd.DataFrame())
198
- if '5mC' in mods:
230
+ m6A_dict[record][sample_index] = res.get("m6A", pd.DataFrame())
231
+ m6A_minus[record][sample_index] = res.get("m6A_minus", pd.DataFrame())
232
+ m6A_plus[record][sample_index] = res.get("m6A_plus", pd.DataFrame())
233
+ if "5mC" in mods:
199
234
  if record not in c5m_dict:
200
235
  c5m_dict[record], c5m_minus[record], c5m_plus[record] = {}, {}, {}
201
- c5m_dict[record][sample_index] = res.get('5mC', pd.DataFrame())
202
- c5m_minus[record][sample_index] = res.get('5mC_minus', pd.DataFrame())
203
- c5m_plus[record][sample_index] = res.get('5mC_plus', pd.DataFrame())
204
- if '6mA' in mods and '5mC' in mods:
236
+ c5m_dict[record][sample_index] = res.get("5mC", pd.DataFrame())
237
+ c5m_minus[record][sample_index] = res.get("5mC_minus", pd.DataFrame())
238
+ c5m_plus[record][sample_index] = res.get("5mC_plus", pd.DataFrame())
239
+ if "6mA" in mods and "5mC" in mods:
205
240
  if record not in combined_minus:
206
241
  combined_minus[record], combined_plus[record] = {}, {}
207
- combined_minus[record][sample_index] = res.get('combined_minus', [])
208
- combined_plus[record][sample_index] = res.get('combined_plus', [])
209
- return (m6A_dict, m6A_minus, m6A_plus,
210
- c5m_dict, c5m_minus, c5m_plus,
211
- combined_minus, combined_plus)
242
+ combined_minus[record][sample_index] = res.get("combined_minus", [])
243
+ combined_plus[record][sample_index] = res.get("combined_plus", [])
244
+ return (
245
+ m6A_dict,
246
+ m6A_minus,
247
+ m6A_plus,
248
+ c5m_dict,
249
+ c5m_minus,
250
+ c5m_plus,
251
+ combined_minus,
252
+ combined_plus,
253
+ )
254
+
212
255
 
213
256
  def process_stranded_methylation(args):
214
257
  """
215
258
  Processes a single (dict_index, record, sample) task.
216
-
259
+
217
260
  For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
218
- For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
261
+ For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
219
262
  NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
220
-
263
+
221
264
  Parameters:
222
265
  args: (dict_index, record, sample, dict_list, max_reference_length)
223
-
266
+
224
267
  Returns:
225
268
  (dict_index, record, sample, processed_data)
226
269
  """
227
270
  dict_index, record, sample, dict_list, max_reference_length = args
228
271
  processed_data = {}
229
-
272
+
230
273
  # For combined bottom strand (index 7)
231
274
  if dict_index == 7:
232
275
  temp_a = dict_list[2][record].get(sample, {}).copy()
@@ -235,18 +278,18 @@ def process_stranded_methylation(args):
235
278
  for read in set(temp_a.keys()) | set(temp_c.keys()):
236
279
  if read in temp_a:
237
280
  # Convert using pd.to_numeric with errors='coerce'
238
- value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
281
+ value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
239
282
  else:
240
283
  value_a = None
241
284
  if read in temp_c:
242
- value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
285
+ value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
243
286
  else:
244
287
  value_c = None
245
288
  if value_a is not None and value_c is not None:
246
289
  processed_data[read] = np.where(
247
290
  np.isnan(value_a) & np.isnan(value_c),
248
291
  np.nan,
249
- np.nan_to_num(value_a) + np.nan_to_num(value_c)
292
+ np.nan_to_num(value_a) + np.nan_to_num(value_c),
250
293
  )
251
294
  elif value_a is not None:
252
295
  processed_data[read] = value_a
@@ -261,18 +304,18 @@ def process_stranded_methylation(args):
261
304
  processed_data = {}
262
305
  for read in set(temp_a.keys()) | set(temp_c.keys()):
263
306
  if read in temp_a:
264
- value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
307
+ value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
265
308
  else:
266
309
  value_a = None
267
310
  if read in temp_c:
268
- value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
311
+ value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
269
312
  else:
270
313
  value_c = None
271
314
  if value_a is not None and value_c is not None:
272
315
  processed_data[read] = np.where(
273
316
  np.isnan(value_a) & np.isnan(value_c),
274
317
  np.nan,
275
- np.nan_to_num(value_a) + np.nan_to_num(value_c)
318
+ np.nan_to_num(value_a) + np.nan_to_num(value_c),
276
319
  )
277
320
  elif value_a is not None:
278
321
  processed_data[read] = value_a
@@ -286,24 +329,28 @@ def process_stranded_methylation(args):
286
329
  temp_df = dict_list[dict_index][record][sample]
287
330
  processed_data = {}
288
331
  # Extract columns and convert probabilities to float (coercing errors)
289
- read_ids = temp_df['read_id'].values
290
- positions = temp_df['ref_position'].values
291
- call_codes = temp_df['call_code'].values
292
- probabilities = pd.to_numeric(temp_df['call_prob'].values, errors='coerce')
293
-
294
- modified_codes = {'a', 'h', 'm'}
295
- canonical_codes = {'-'}
296
-
332
+ read_ids = temp_df["read_id"].values
333
+ positions = temp_df["ref_position"].values
334
+ call_codes = temp_df["call_code"].values
335
+ probabilities = pd.to_numeric(temp_df["call_prob"].values, errors="coerce")
336
+
337
+ modified_codes = {"a", "h", "m"}
338
+ canonical_codes = {"-"}
339
+
297
340
  # Compute methylation probabilities (vectorized)
298
341
  methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
299
- methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
300
- methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
301
-
342
+ methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[
343
+ np.isin(call_codes, list(modified_codes))
344
+ ]
345
+ methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
346
+ 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
347
+ )
348
+
302
349
  # Preallocate storage for each unique read
303
350
  unique_reads = np.unique(read_ids)
304
351
  for read in unique_reads:
305
352
  processed_data[read] = np.full(max_reference_length, np.nan, dtype=float)
306
-
353
+
307
354
  # Assign values efficiently
308
355
  for i in range(len(read_ids)):
309
356
  read = read_ids[i]
@@ -314,10 +361,11 @@ def process_stranded_methylation(args):
314
361
  gc.collect()
315
362
  return dict_index, record, sample, processed_data
316
363
 
364
+
317
365
  def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
318
366
  """
319
367
  Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
320
-
368
+
321
369
  Returns:
322
370
  Updated dict_list with processed (nested) dictionaries.
323
371
  """
@@ -327,16 +375,17 @@ def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference
327
375
  for record in current_dict.keys():
328
376
  for sample in current_dict[record].keys():
329
377
  tasks.append((dict_index, record, sample, dict_list, max_reference_length))
330
-
378
+
331
379
  with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
332
380
  for dict_index, record, sample, processed_data in tqdm(
333
381
  executor.map(process_stranded_methylation, tasks),
334
382
  total=len(tasks),
335
- desc="Extracting stranded methylation states"
383
+ desc="Extracting stranded methylation states",
336
384
  ):
337
385
  dict_list[dict_index][record][sample] = processed_data
338
386
  return dict_list
339
387
 
388
+
340
389
  def delete_intermediate_h5ads_and_tmpdir(
341
390
  h5_dir: Union[str, Path, Iterable[str], None],
342
391
  tmp_dir: Optional[Union[str, Path]] = None,
@@ -360,25 +409,27 @@ def delete_intermediate_h5ads_and_tmpdir(
360
409
  verbose : bool
361
410
  Print progress / warnings.
362
411
  """
412
+
363
413
  # Helper: remove a single file path (Path-like or string)
364
414
  def _maybe_unlink(p: Path):
415
+ """Remove a file path if it exists and is a file."""
365
416
  if not p.exists():
366
417
  if verbose:
367
- print(f"[skip] not found: {p}")
418
+ logger.debug(f"[skip] not found: {p}")
368
419
  return
369
420
  if not p.is_file():
370
421
  if verbose:
371
- print(f"[skip] not a file: {p}")
422
+ logger.debug(f"[skip] not a file: {p}")
372
423
  return
373
424
  if dry_run:
374
- print(f"[dry-run] would remove file: {p}")
425
+ logger.debug(f"[dry-run] would remove file: {p}")
375
426
  return
376
427
  try:
377
428
  p.unlink()
378
429
  if verbose:
379
- print(f"Removed file: {p}")
430
+ logger.info(f"Removed file: {p}")
380
431
  except Exception as e:
381
- print(f"[error] failed to remove file {p}: {e}")
432
+ logger.warning(f"[error] failed to remove file {p}: {e}")
382
433
 
383
434
  # Handle h5_dir input (directory OR iterable of file paths)
384
435
  if h5_dir is not None:
@@ -393,7 +444,7 @@ def delete_intermediate_h5ads_and_tmpdir(
393
444
  else:
394
445
  if verbose:
395
446
  # optional: comment this out if too noisy
396
- print(f"[skip] not matching pattern: {p.name}")
447
+ logger.debug(f"[skip] not matching pattern: {p.name}")
397
448
  else:
398
449
  # treat as iterable of file paths
399
450
  for f in h5_dir:
@@ -403,30 +454,45 @@ def delete_intermediate_h5ads_and_tmpdir(
403
454
  _maybe_unlink(p)
404
455
  else:
405
456
  if verbose:
406
- print(f"[skip] not matching pattern or not a file: {p}")
457
+ logger.debug(f"[skip] not matching pattern or not a file: {p}")
407
458
 
408
459
  # Remove tmp_dir recursively (if provided)
409
460
  if tmp_dir is not None:
410
461
  td = Path(tmp_dir)
411
462
  if not td.exists():
412
463
  if verbose:
413
- print(f"[skip] tmp_dir not found: {td}")
464
+ logger.debug(f"[skip] tmp_dir not found: {td}")
414
465
  else:
415
466
  if not td.is_dir():
416
467
  if verbose:
417
- print(f"[skip] tmp_dir is not a directory: {td}")
468
+ logger.debug(f"[skip] tmp_dir is not a directory: {td}")
418
469
  else:
419
470
  if dry_run:
420
- print(f"[dry-run] would remove directory tree: {td}")
471
+ logger.debug(f"[dry-run] would remove directory tree: {td}")
421
472
  else:
422
473
  try:
423
474
  shutil.rmtree(td)
424
475
  if verbose:
425
- print(f"Removed directory tree: {td}")
476
+ logger.info(f"Removed directory tree: {td}")
426
477
  except Exception as e:
427
- print(f"[error] failed to remove tmp dir {td}: {e}")
428
-
429
- def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False, threads=None, double_barcoded_path = None):
478
+ logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
479
+
480
+
481
+ def modkit_extract_to_adata(
482
+ fasta,
483
+ bam_dir,
484
+ out_dir,
485
+ input_already_demuxed,
486
+ mapping_threshold,
487
+ experiment_name,
488
+ mods,
489
+ batch_size,
490
+ mod_tsv_dir,
491
+ delete_batch_hdfs=False,
492
+ threads=None,
493
+ double_barcoded_path=None,
494
+ samtools_backend: str | None = "auto",
495
+ ):
430
496
  """
431
497
  Takes modkit extract outputs and organizes it into an adata object
432
498
 
@@ -448,82 +514,121 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
448
514
  """
449
515
  ###################################################
450
516
  # Package imports
451
- from .. import readwrite
452
- from ..readwrite import safe_write_h5ad, make_dirs
453
- from .fasta_functions import get_native_references
454
- from .bam_functions import extract_base_identities
455
- from .ohe import ohe_batching
456
- import pandas as pd
457
- import anndata as ad
458
- import os
459
517
  import gc
460
518
  import math
519
+
520
+ import anndata as ad
461
521
  import numpy as np
522
+ import pandas as pd
462
523
  from Bio.Seq import Seq
463
524
  from tqdm import tqdm
464
- import h5py
525
+
526
+ from .. import readwrite
527
+ from ..readwrite import make_dirs
528
+ from .bam_functions import extract_base_identities
529
+ from .fasta_functions import get_native_references
530
+ from .ohe import ohe_batching
465
531
  ###################################################
466
532
 
467
533
  ################## Get input tsv and bam file names into a sorted list ################
468
534
  # Make output dirs
469
- h5_dir = out_dir / 'h5ads'
470
- tmp_dir = out_dir / 'tmp'
535
+ h5_dir = out_dir / "h5ads"
536
+ tmp_dir = out_dir / "tmp"
471
537
  make_dirs([h5_dir, tmp_dir])
472
538
 
473
- existing_h5s = h5_dir.iterdir()
474
- existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in str(h5)]
475
- final_hdf = f'{experiment_name}.h5ad.gz'
539
+ existing_h5s = h5_dir.iterdir()
540
+ existing_h5s = [h5 for h5 in existing_h5s if ".h5ad.gz" in str(h5)]
541
+ final_hdf = f"{experiment_name}.h5ad.gz"
476
542
  final_adata_path = h5_dir / final_hdf
477
543
  final_adata = None
478
-
544
+
479
545
  if final_adata_path.exists():
480
- print(f'{final_adata_path} already exists. Using existing adata')
546
+ logger.debug(f"{final_adata_path} already exists. Using existing adata")
481
547
  return final_adata, final_adata_path
482
-
548
+
483
549
  # List all files in the directory
484
550
  tsvs = sorted(
485
- p for p in mod_tsv_dir.iterdir()
486
- if p.is_file() and 'unclassified' not in p.name and 'extract.tsv' in p.name)
551
+ p
552
+ for p in mod_tsv_dir.iterdir()
553
+ if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
554
+ )
487
555
  bams = sorted(
488
- p for p in bam_dir.iterdir()
489
- if p.is_file() and p.suffix == '.bam' and 'unclassified' not in p.name and '.bai' not in p.name)
556
+ p
557
+ for p in bam_dir.iterdir()
558
+ if p.is_file()
559
+ and p.suffix == ".bam"
560
+ and "unclassified" not in p.name
561
+ and ".bai" not in p.name
562
+ )
563
+
564
+ tsv_path_list = [tsv for tsv in tsvs]
565
+ bam_path_list = [bam for bam in bams]
566
+ logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
567
+ logger.info(f"{len(bams)} sample bams found: {bams}")
568
+
569
+ # Map global sample index (bami / final_sample_index) -> sample name / barcode
570
+ sample_name_map = {}
571
+ barcode_map = {}
572
+
573
+ for idx, bam_path in enumerate(bam_path_list):
574
+ stem = bam_path.stem
575
+
576
+ # Try to peel off a "barcode..." suffix if present.
577
+ # This handles things like:
578
+ # "mySample_barcode01" -> sample="mySample", barcode="barcode01"
579
+ # "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
580
+ # "barcode01" -> sample="barcode01", barcode="barcode01"
581
+ m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
582
+ if m:
583
+ sample_name = m.group(1) or stem
584
+ barcode = m.group(2)
585
+ else:
586
+ # Fallback: treat the whole stem as both sample & barcode
587
+ sample_name = stem
588
+ barcode = stem
490
589
 
491
- tsv_path_list = [mod_tsv_dir / tsv for tsv in tsvs]
492
- bam_path_list = [bam_dir / bam for bam in bams]
493
- print(f'{len(tsvs)} sample tsv files found: {tsvs}')
494
- print(f'{len(bams)} sample bams found: {bams}')
590
+ # make sample name of the format of the bam file stem
591
+ sample_name = sample_name + f"_{barcode}"
592
+
593
+ # Clean the barcode name to be an integer
594
+ barcode = int(barcode.split("barcode")[1])
595
+
596
+ sample_name_map[idx] = sample_name
597
+ barcode_map[idx] = str(barcode)
495
598
  ##########################################################################################
496
599
 
497
600
  ######### Get Record names that have over a passed threshold of mapped reads #############
498
601
  # get all records that are above a certain mapping threshold in at least one sample bam
499
- records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
602
+ records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend)
500
603
 
501
604
  ##########################################################################################
502
605
 
503
606
  ########### Determine the maximum record length to analyze in the dataset ################
504
607
  # Get all references within the FASTA and indicate the length and identity of the record sequence
505
608
  max_reference_length = 0
506
- reference_dict = get_native_references(str(fasta)) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
609
+ reference_dict = get_native_references(
610
+ str(fasta)
611
+ ) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
507
612
  # Get the max record length in the dataset.
508
613
  for record in records_to_analyze:
509
614
  if reference_dict[record][0] > max_reference_length:
510
615
  max_reference_length = reference_dict[record][0]
511
- print(f'{readwrite.time_string()}: Max reference length in dataset: {max_reference_length}')
512
- batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
513
- print('{0}: Processing input tsvs in {1} batches of {2} tsvs '.format(readwrite.time_string(), batches, batch_size))
616
+ logger.info(f"Max reference length in dataset: {max_reference_length}")
617
+ batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
618
+ logger.info("Processing input tsvs in {0} batches of {1} tsvs ".format(batches, batch_size))
514
619
  ##########################################################################################
515
620
 
516
621
  ##########################################################################################
517
- # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
622
+ # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
518
623
  # Save the file paths in the bam_record_ohe_files dict.
519
624
  bam_record_ohe_files = {}
520
- bam_record_save = tmp_dir / 'tmp_file_dict.h5ad'
625
+ bam_record_save = tmp_dir / "tmp_file_dict.h5ad"
521
626
  fwd_mapped_reads = set()
522
627
  rev_mapped_reads = set()
523
628
  # If this step has already been performed, read in the tmp_dile_dict
524
629
  if bam_record_save.exists():
525
630
  bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
526
- print('Found existing OHE reads, using these')
631
+ logger.debug("Found existing OHE reads, using these")
527
632
  else:
528
633
  # Iterate over split bams
529
634
  for bami, bam in enumerate(bam_path_list):
@@ -533,18 +638,39 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
533
638
  positions = range(current_reference_length)
534
639
  ref_seq = reference_dict[record][1]
535
640
  # Extract the base identities of reads aligned to the record
536
- fwd_base_identities, rev_base_identities, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
641
+ (
642
+ fwd_base_identities,
643
+ rev_base_identities,
644
+ mismatch_counts_per_read,
645
+ mismatch_trend_per_read,
646
+ ) = extract_base_identities(
647
+ bam, record, positions, max_reference_length, ref_seq, samtools_backend
648
+ )
537
649
  # Store read names of fwd and rev mapped reads
538
650
  fwd_mapped_reads.update(fwd_base_identities.keys())
539
651
  rev_mapped_reads.update(rev_base_identities.keys())
540
652
  # One hot encode the sequence string of the reads
541
- fwd_ohe_files = ohe_batching(fwd_base_identities, tmp_dir, record, f"{bami}_fwd",batch_size=100000, threads=threads)
542
- rev_ohe_files = ohe_batching(rev_base_identities, tmp_dir, record, f"{bami}_rev",batch_size=100000, threads=threads)
543
- bam_record_ohe_files[f'{bami}_{record}'] = fwd_ohe_files + rev_ohe_files
653
+ fwd_ohe_files = ohe_batching(
654
+ fwd_base_identities,
655
+ tmp_dir,
656
+ record,
657
+ f"{bami}_fwd",
658
+ batch_size=100000,
659
+ threads=threads,
660
+ )
661
+ rev_ohe_files = ohe_batching(
662
+ rev_base_identities,
663
+ tmp_dir,
664
+ record,
665
+ f"{bami}_rev",
666
+ batch_size=100000,
667
+ threads=threads,
668
+ )
669
+ bam_record_ohe_files[f"{bami}_{record}"] = fwd_ohe_files + rev_ohe_files
544
670
  del fwd_base_identities, rev_base_identities
545
671
  # Save out the ohe file paths
546
672
  X = np.random.rand(1, 1)
547
- tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
673
+ tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
548
674
  tmp_ad.write_h5ad(bam_record_save)
549
675
  ##########################################################################################
550
676
 
@@ -554,39 +680,73 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
554
680
  for record in records_to_analyze:
555
681
  current_reference_length = reference_dict[record][0]
556
682
  delta_max_length = max_reference_length - current_reference_length
557
- sequence = reference_dict[record][1] + 'N'*delta_max_length
558
- complement = str(Seq(reference_dict[record][1]).complement()).upper() + 'N'*delta_max_length
683
+ sequence = reference_dict[record][1] + "N" * delta_max_length
684
+ complement = (
685
+ str(Seq(reference_dict[record][1]).complement()).upper() + "N" * delta_max_length
686
+ )
559
687
  record_seq_dict[record] = (sequence, complement)
560
688
  ##########################################################################################
561
689
 
562
690
  ###################################################
563
691
  # Begin iterating over batches
564
692
  for batch in range(batches):
565
- print('{0}: Processing tsvs for batch {1} '.format(readwrite.time_string(), batch))
693
+ logger.info("Processing tsvs for batch {0} ".format(batch))
566
694
  # For the final batch, just take the remaining tsv and bam files
567
695
  if batch == batches - 1:
568
696
  tsv_batch = tsv_path_list
569
697
  bam_batch = bam_path_list
570
- # For all other batches, take the next batch of tsvs and bams out of the file queue.
698
+ # For all other batches, take the next batch of tsvs and bams out of the file queue.
571
699
  else:
572
700
  tsv_batch = tsv_path_list[:batch_size]
573
701
  bam_batch = bam_path_list[:batch_size]
574
702
  tsv_path_list = tsv_path_list[batch_size:]
575
703
  bam_path_list = bam_path_list[batch_size:]
576
- print('{0}: tsvs in batch {1} '.format(readwrite.time_string(), tsv_batch))
704
+ logger.info("tsvs in batch {0} ".format(tsv_batch))
577
705
 
578
- batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5.name])
579
- ###################################################
706
+ batch_already_processed = sum([1 for h5 in existing_h5s if f"_{batch}_" in h5.name])
707
+ ###################################################
580
708
  if batch_already_processed:
581
- print(f'Batch {batch} has already been processed into h5ads. Skipping batch and using existing files')
709
+ logger.debug(
710
+ f"Batch {batch} has already been processed into h5ads. Skipping batch and using existing files"
711
+ )
582
712
  else:
583
713
  ###################################################
584
714
  ### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
585
715
  # # Initialize dictionaries and place them in a list
586
- dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top = {},{},{},{},{},{},{},{},{}
587
- dict_list = [dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top]
716
+ (
717
+ dict_total,
718
+ dict_a,
719
+ dict_a_bottom,
720
+ dict_a_top,
721
+ dict_c,
722
+ dict_c_bottom,
723
+ dict_c_top,
724
+ dict_combined_bottom,
725
+ dict_combined_top,
726
+ ) = {}, {}, {}, {}, {}, {}, {}, {}, {}
727
+ dict_list = [
728
+ dict_total,
729
+ dict_a,
730
+ dict_a_bottom,
731
+ dict_a_top,
732
+ dict_c,
733
+ dict_c_bottom,
734
+ dict_c_top,
735
+ dict_combined_bottom,
736
+ dict_combined_top,
737
+ ]
588
738
  # Give names to represent each dictionary in the list
589
- sample_types = ['total', 'm6A', 'm6A_bottom_strand', 'm6A_top_strand', '5mC', '5mC_bottom_strand', '5mC_top_strand', 'combined_bottom_strand', 'combined_top_strand']
739
+ sample_types = [
740
+ "total",
741
+ "m6A",
742
+ "m6A_bottom_strand",
743
+ "m6A_top_strand",
744
+ "5mC",
745
+ "5mC_bottom_strand",
746
+ "5mC_top_strand",
747
+ "combined_bottom_strand",
748
+ "combined_top_strand",
749
+ ]
590
750
  # Give indices of dictionaries to skip for analysis and final dictionary saving.
591
751
  dict_to_skip = [0, 1, 4]
592
752
  combined_dicts = [7, 8]
@@ -596,7 +756,14 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
596
756
  dict_to_skip = set(dict_to_skip)
597
757
 
598
758
  # # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
599
- dict_total = parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size=len(tsv_batch), threads=threads)
759
+ dict_total = parallel_load_tsvs(
760
+ tsv_batch,
761
+ records_to_analyze,
762
+ reference_dict,
763
+ batch,
764
+ batch_size=len(tsv_batch),
765
+ threads=threads,
766
+ )
600
767
 
601
768
  # # Step 2: Extract modification-specific data (per (record,sample)) in parallel
602
769
  # processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
@@ -621,56 +788,112 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
621
788
  # Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
622
789
  for record in dict_total.keys():
623
790
  for sample_index in dict_total[record].keys():
624
- if '6mA' in mods:
791
+ if "6mA" in mods:
625
792
  # Remove Adenine stranded dicts from the dicts to skip set
626
793
  dict_to_skip.difference_update(set(A_stranded_dicts))
627
794
 
628
- if record not in dict_a.keys() and record not in dict_a_bottom.keys() and record not in dict_a_top.keys():
795
+ if (
796
+ record not in dict_a.keys()
797
+ and record not in dict_a_bottom.keys()
798
+ and record not in dict_a_top.keys()
799
+ ):
629
800
  dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
630
801
 
631
802
  # get a dictionary of dataframes that only contain methylated adenine positions
632
- dict_a[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'A']
633
- print('{}: Successfully loaded a methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
803
+ dict_a[record][sample_index] = dict_total[record][sample_index][
804
+ dict_total[record][sample_index]["modified_primary_base"] == "A"
805
+ ]
806
+ logger.debug(
807
+ "Successfully loaded a methyl-adenine dictionary for {}".format(
808
+ str(sample_index)
809
+ )
810
+ )
811
+
634
812
  # Stratify the adenine dictionary into two strand specific dictionaries.
635
- dict_a_bottom[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '-']
636
- print('{}: Successfully loaded a minus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
637
- dict_a_top[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '+']
638
- print('{}: Successfully loaded a plus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
813
+ dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
814
+ dict_a[record][sample_index]["ref_strand"] == "-"
815
+ ]
816
+ logger.debug(
817
+ "Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
818
+ str(sample_index)
819
+ )
820
+ )
821
+ dict_a_top[record][sample_index] = dict_a[record][sample_index][
822
+ dict_a[record][sample_index]["ref_strand"] == "+"
823
+ ]
824
+ logger.debug(
825
+ "Successfully loaded a plus strand methyl-adenine dictionary for ".format(
826
+ str(sample_index)
827
+ )
828
+ )
639
829
 
640
830
  # Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
641
831
  dict_a[record][sample_index] = None
642
832
  gc.collect()
643
833
 
644
- if '5mC' in mods:
834
+ if "5mC" in mods:
645
835
  # Remove Cytosine stranded dicts from the dicts to skip set
646
836
  dict_to_skip.difference_update(set(C_stranded_dicts))
647
837
 
648
- if record not in dict_c.keys() and record not in dict_c_bottom.keys() and record not in dict_c_top.keys():
838
+ if (
839
+ record not in dict_c.keys()
840
+ and record not in dict_c_bottom.keys()
841
+ and record not in dict_c_top.keys()
842
+ ):
649
843
  dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
650
844
 
651
845
  # get a dictionary of dataframes that only contain methylated cytosine positions
652
- dict_c[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'C']
653
- print('{}: Successfully loaded a methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
846
+ dict_c[record][sample_index] = dict_total[record][sample_index][
847
+ dict_total[record][sample_index]["modified_primary_base"] == "C"
848
+ ]
849
+ logger.debug(
850
+ "Successfully loaded a methyl-cytosine dictionary for {}".format(
851
+ str(sample_index)
852
+ )
853
+ )
654
854
  # Stratify the cytosine dictionary into two strand specific dictionaries.
655
- dict_c_bottom[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '-']
656
- print('{}: Successfully loaded a minus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
657
- dict_c_top[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '+']
658
- print('{}: Successfully loaded a plus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
855
+ dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
856
+ dict_c[record][sample_index]["ref_strand"] == "-"
857
+ ]
858
+ logger.debug(
859
+ "Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
860
+ str(sample_index)
861
+ )
862
+ )
863
+ dict_c_top[record][sample_index] = dict_c[record][sample_index][
864
+ dict_c[record][sample_index]["ref_strand"] == "+"
865
+ ]
866
+ logger.debug(
867
+ "Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
868
+ str(sample_index)
869
+ )
870
+ )
659
871
  # Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
660
872
  dict_c[record][sample_index] = None
661
873
  gc.collect()
662
-
663
- if '6mA' in mods and '5mC' in mods:
874
+
875
+ if "6mA" in mods and "5mC" in mods:
664
876
  # Remove combined stranded dicts from the dicts to skip set
665
- dict_to_skip.difference_update(set(combined_dicts))
877
+ dict_to_skip.difference_update(set(combined_dicts))
666
878
  # Initialize the sample keys for the combined dictionaries
667
879
 
668
- if record not in dict_combined_bottom.keys() and record not in dict_combined_top.keys():
669
- dict_combined_bottom[record], dict_combined_top[record]= {}, {}
670
-
671
- print('{}: Successfully created a minus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
880
+ if (
881
+ record not in dict_combined_bottom.keys()
882
+ and record not in dict_combined_top.keys()
883
+ ):
884
+ dict_combined_bottom[record], dict_combined_top[record] = {}, {}
885
+
886
+ logger.debug(
887
+ "Successfully created a minus strand combined methylation dictionary for {}".format(
888
+ str(sample_index)
889
+ )
890
+ )
672
891
  dict_combined_bottom[record][sample_index] = []
673
- print('{}: Successfully created a plus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
892
+ logger.debug(
893
+ "Successfully created a plus strand combined methylation dictionary for {}".format(
894
+ str(sample_index)
895
+ )
896
+ )
674
897
  dict_combined_top[record][sample_index] = []
675
898
 
676
899
  # Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
@@ -681,14 +904,24 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
681
904
  for dict_index, dict_type in enumerate(dict_list):
682
905
  # Only iterate over stranded dictionaries
683
906
  if dict_index not in dict_to_skip:
684
- print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), sample_types[dict_index]))
907
+ logger.debug(
908
+ "Extracting methylation states for {} dictionary".format(
909
+ sample_types[dict_index]
910
+ )
911
+ )
685
912
  for record in dict_type.keys():
686
913
  # Get the dictionary for the modification type of interest from the reference mapping of interest
687
914
  mod_strand_record_sample_dict = dict_type[record]
688
- print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), record))
915
+ logger.debug(
916
+ "Extracting methylation states for {} dictionary".format(record)
917
+ )
689
918
  # For each sample in a stranded dictionary
690
919
  n_samples = len(mod_strand_record_sample_dict.keys())
691
- for sample in tqdm(mod_strand_record_sample_dict.keys(), desc=f'Extracting {sample_types[dict_index]} dictionary from record {record} for sample', total=n_samples):
920
+ for sample in tqdm(
921
+ mod_strand_record_sample_dict.keys(),
922
+ desc=f"Extracting {sample_types[dict_index]} dictionary from record {record} for sample",
923
+ total=n_samples,
924
+ ):
692
925
  # Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
693
926
  if dict_index == 7:
694
927
  # Load the minus strand dictionaries for each sample into temporary variables
@@ -699,16 +932,26 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
699
932
  for read in set(temp_a_dict) | set(temp_c_dict):
700
933
  # Add the arrays element-wise if the read is present in both dictionaries
701
934
  if read in temp_a_dict and read in temp_c_dict:
702
- mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
935
+ mod_strand_record_sample_dict[sample][read] = np.where(
936
+ np.isnan(temp_a_dict[read])
937
+ & np.isnan(temp_c_dict[read]),
938
+ np.nan,
939
+ np.nan_to_num(temp_a_dict[read])
940
+ + np.nan_to_num(temp_c_dict[read]),
941
+ )
703
942
  # If the read is present in only one dictionary, copy its value
704
943
  elif read in temp_a_dict:
705
- mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
944
+ mod_strand_record_sample_dict[sample][read] = temp_a_dict[
945
+ read
946
+ ]
706
947
  elif read in temp_c_dict:
707
- mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
948
+ mod_strand_record_sample_dict[sample][read] = temp_c_dict[
949
+ read
950
+ ]
708
951
  del temp_a_dict, temp_c_dict
709
- # Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
952
+ # Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
710
953
  elif dict_index == 8:
711
- # Load the plus strand dictionaries for each sample into temporary variables
954
+ # Load the plus strand dictionaries for each sample into temporary variables
712
955
  temp_a_dict = dict_list[3][record][sample].copy()
713
956
  temp_c_dict = dict_list[6][record][sample].copy()
714
957
  mod_strand_record_sample_dict[sample] = {}
@@ -716,105 +959,163 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
716
959
  for read in set(temp_a_dict) | set(temp_c_dict):
717
960
  # Add the arrays element-wise if the read is present in both dictionaries
718
961
  if read in temp_a_dict and read in temp_c_dict:
719
- mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
962
+ mod_strand_record_sample_dict[sample][read] = np.where(
963
+ np.isnan(temp_a_dict[read])
964
+ & np.isnan(temp_c_dict[read]),
965
+ np.nan,
966
+ np.nan_to_num(temp_a_dict[read])
967
+ + np.nan_to_num(temp_c_dict[read]),
968
+ )
720
969
  # If the read is present in only one dictionary, copy its value
721
970
  elif read in temp_a_dict:
722
- mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
971
+ mod_strand_record_sample_dict[sample][read] = temp_a_dict[
972
+ read
973
+ ]
723
974
  elif read in temp_c_dict:
724
- mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
975
+ mod_strand_record_sample_dict[sample][read] = temp_c_dict[
976
+ read
977
+ ]
725
978
  del temp_a_dict, temp_c_dict
726
979
  # For all other dictionaries
727
980
  else:
728
-
729
981
  # use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
730
982
  temp_df = mod_strand_record_sample_dict[sample]
731
983
  # reassign the dictionary pointer to a nested dictionary.
732
984
  mod_strand_record_sample_dict[sample] = {}
733
985
 
734
986
  # Get relevant columns as NumPy arrays
735
- read_ids = temp_df['read_id'].values
736
- positions = temp_df['ref_position'].values
737
- call_codes = temp_df['call_code'].values
738
- probabilities = temp_df['call_prob'].values
987
+ read_ids = temp_df["read_id"].values
988
+ positions = temp_df["ref_position"].values
989
+ call_codes = temp_df["call_code"].values
990
+ probabilities = temp_df["call_prob"].values
739
991
 
740
992
  # Define valid call code categories
741
- modified_codes = {'a', 'h', 'm'}
742
- canonical_codes = {'-'}
993
+ modified_codes = {"a", "h", "m"}
994
+ canonical_codes = {"-"}
743
995
 
744
996
  # Vectorized methylation calculation with NaN for other codes
745
- methylation_prob = np.full_like(probabilities, np.nan) # Default all to NaN
746
- methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
747
- methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
997
+ methylation_prob = np.full_like(
998
+ probabilities, np.nan
999
+ ) # Default all to NaN
1000
+ methylation_prob[np.isin(call_codes, list(modified_codes))] = (
1001
+ probabilities[np.isin(call_codes, list(modified_codes))]
1002
+ )
1003
+ methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
1004
+ 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
1005
+ )
748
1006
 
749
1007
  # Find unique reads
750
1008
  unique_reads = np.unique(read_ids)
751
1009
  # Preallocate storage for each read
752
1010
  for read in unique_reads:
753
- mod_strand_record_sample_dict[sample][read] = np.full(max_reference_length, np.nan)
1011
+ mod_strand_record_sample_dict[sample][read] = np.full(
1012
+ max_reference_length, np.nan
1013
+ )
754
1014
 
755
1015
  # Efficient NumPy indexing to assign values
756
1016
  for i in range(len(read_ids)):
757
1017
  read = read_ids[i]
758
1018
  pos = positions[i]
759
1019
  prob = methylation_prob[i]
760
-
1020
+
761
1021
  # Assign methylation probability
762
1022
  mod_strand_record_sample_dict[sample][read][pos] = prob
763
1023
 
764
-
765
1024
  # Save the sample files in the batch as gzipped hdf5 files
766
- print('{0}: Converting batch {1} dictionaries to anndata objects'.format(readwrite.time_string(), batch))
1025
+ logger.info("Converting batch {} dictionaries to anndata objects".format(batch))
767
1026
  for dict_index, dict_type in enumerate(dict_list):
768
1027
  if dict_index not in dict_to_skip:
769
1028
  # Initialize an hdf5 file for the current modified strand
770
1029
  adata = None
771
- print('{0}: Converting {1} dictionary to an anndata object'.format(readwrite.time_string(), sample_types[dict_index]))
1030
+ logger.info(
1031
+ "Converting {} dictionary to an anndata object".format(
1032
+ sample_types[dict_index]
1033
+ )
1034
+ )
772
1035
  for record in dict_type.keys():
773
1036
  # Get the dictionary for the modification type of interest from the reference mapping of interest
774
1037
  mod_strand_record_sample_dict = dict_type[record]
775
1038
  for sample in mod_strand_record_sample_dict.keys():
776
- print('{0}: Converting {1} dictionary for sample {2} to an anndata object'.format(readwrite.time_string(), sample_types[dict_index], sample))
1039
+ logger.info(
1040
+ "Converting {0} dictionary for sample {1} to an anndata object".format(
1041
+ sample_types[dict_index], sample
1042
+ )
1043
+ )
777
1044
  sample = int(sample)
778
1045
  final_sample_index = sample + (batch * batch_size)
779
- print('{0}: Final sample index for sample: {1}'.format(readwrite.time_string(), final_sample_index))
780
- print('{0}: Converting {1} dictionary for sample {2} to a dataframe'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
781
- temp_df = pd.DataFrame.from_dict(mod_strand_record_sample_dict[sample], orient='index')
782
- mod_strand_record_sample_dict[sample] = None # reassign pointer to facilitate memory usage
1046
+ logger.info(
1047
+ "Final sample index for sample: {}".format(final_sample_index)
1048
+ )
1049
+ logger.debug(
1050
+ "Converting {0} dictionary for sample {1} to a dataframe".format(
1051
+ sample_types[dict_index],
1052
+ final_sample_index,
1053
+ )
1054
+ )
1055
+ temp_df = pd.DataFrame.from_dict(
1056
+ mod_strand_record_sample_dict[sample], orient="index"
1057
+ )
1058
+ mod_strand_record_sample_dict[sample] = (
1059
+ None # reassign pointer to facilitate memory usage
1060
+ )
783
1061
  sorted_index = sorted(temp_df.index)
784
1062
  temp_df = temp_df.reindex(sorted_index)
785
1063
  X = temp_df.values
786
- dataset, strand = sample_types[dict_index].split('_')[:2]
787
-
788
- print('{0}: Loading {1} dataframe for sample {2} into a temp anndata object'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
1064
+ dataset, strand = sample_types[dict_index].split("_")[:2]
1065
+
1066
+ logger.info(
1067
+ "Loading {0} dataframe for sample {1} into a temp anndata object".format(
1068
+ sample_types[dict_index],
1069
+ final_sample_index,
1070
+ )
1071
+ )
789
1072
  temp_adata = ad.AnnData(X)
790
1073
  if temp_adata.shape[0] > 0:
791
- print('{0}: Adding read names and position ids to {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
1074
+ logger.info(
1075
+ "Adding read names and position ids to {0} anndata for sample {1}".format(
1076
+ sample_types[dict_index],
1077
+ final_sample_index,
1078
+ )
1079
+ )
792
1080
  temp_adata.obs_names = temp_df.index
793
1081
  temp_adata.obs_names = temp_adata.obs_names.astype(str)
794
1082
  temp_adata.var_names = temp_df.columns
795
1083
  temp_adata.var_names = temp_adata.var_names.astype(str)
796
- print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
797
- temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
798
- temp_adata.obs['Barcode'] = [str(final_sample_index)] * len(temp_adata)
799
- temp_adata.obs['Reference'] = [f'{record}'] * len(temp_adata)
800
- temp_adata.obs['Strand'] = [strand] * len(temp_adata)
801
- temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)
802
- temp_adata.obs['Reference_dataset_strand'] = [f'{record}_{dataset}_{strand}'] * len(temp_adata)
803
- temp_adata.obs['Reference_strand'] = [f'{record}_{strand}'] * len(temp_adata)
804
-
1084
+ logger.info(
1085
+ "Adding {0} anndata for sample {1}".format(
1086
+ sample_types[dict_index],
1087
+ final_sample_index,
1088
+ )
1089
+ )
1090
+ temp_adata.obs["Sample"] = [
1091
+ sample_name_map[final_sample_index]
1092
+ ] * len(temp_adata)
1093
+ temp_adata.obs["Barcode"] = [barcode_map[final_sample_index]] * len(
1094
+ temp_adata
1095
+ )
1096
+ temp_adata.obs["Reference"] = [f"{record}"] * len(temp_adata)
1097
+ temp_adata.obs["Strand"] = [strand] * len(temp_adata)
1098
+ temp_adata.obs["Dataset"] = [dataset] * len(temp_adata)
1099
+ temp_adata.obs["Reference_dataset_strand"] = [
1100
+ f"{record}_{dataset}_{strand}"
1101
+ ] * len(temp_adata)
1102
+ temp_adata.obs["Reference_strand"] = [f"{record}_{strand}"] * len(
1103
+ temp_adata
1104
+ )
1105
+
805
1106
  # Load in the one hot encoded reads from the current sample and record
806
1107
  one_hot_reads = {}
807
1108
  n_rows_OHE = 5
808
- ohe_files = bam_record_ohe_files[f'{final_sample_index}_{record}']
809
- print(f'Loading OHEs from {ohe_files}')
1109
+ ohe_files = bam_record_ohe_files[f"{final_sample_index}_{record}"]
1110
+ logger.info(f"Loading OHEs from {ohe_files}")
810
1111
  fwd_mapped_reads = set()
811
1112
  rev_mapped_reads = set()
812
1113
  for ohe_file in ohe_files:
813
1114
  tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
814
1115
  one_hot_reads.update(tmp_ohe_dict)
815
- if '_fwd_' in ohe_file:
1116
+ if "_fwd_" in ohe_file:
816
1117
  fwd_mapped_reads.update(tmp_ohe_dict.keys())
817
- elif '_rev_' in ohe_file:
1118
+ elif "_rev_" in ohe_file:
818
1119
  rev_mapped_reads.update(tmp_ohe_dict.keys())
819
1120
  del tmp_ohe_dict
820
1121
 
@@ -823,18 +1124,20 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
823
1124
  read_mapping_direction = []
824
1125
  for read_id in temp_adata.obs_names:
825
1126
  if read_id in fwd_mapped_reads:
826
- read_mapping_direction.append('fwd')
1127
+ read_mapping_direction.append("fwd")
827
1128
  elif read_id in rev_mapped_reads:
828
- read_mapping_direction.append('rev')
1129
+ read_mapping_direction.append("rev")
829
1130
  else:
830
- read_mapping_direction.append('unk')
1131
+ read_mapping_direction.append("unk")
831
1132
 
832
- temp_adata.obs['Read_mapping_direction'] = read_mapping_direction
1133
+ temp_adata.obs["Read_mapping_direction"] = read_mapping_direction
833
1134
 
834
1135
  del temp_df
835
-
1136
+
836
1137
  # Initialize NumPy arrays
837
- sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
1138
+ sequence_length = (
1139
+ one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
1140
+ )
838
1141
  df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
839
1142
  df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
840
1143
  df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
@@ -855,7 +1158,11 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
855
1158
  gc.collect()
856
1159
 
857
1160
  # Fill the arrays
858
- for j, read_name in tqdm(enumerate(sorted_index), desc='Loading dataframes of OHE reads', total=len(sorted_index)):
1161
+ for j, read_name in tqdm(
1162
+ enumerate(sorted_index),
1163
+ desc="Loading dataframes of OHE reads",
1164
+ total=len(sorted_index),
1165
+ ):
859
1166
  df_A[j, :] = dict_A[read_name]
860
1167
  df_C[j, :] = dict_C[read_name]
861
1168
  df_G[j, :] = dict_G[read_name]
@@ -867,43 +1174,78 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
867
1174
 
868
1175
  # Store the results in AnnData layers
869
1176
  ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
870
- for j, base in enumerate(['A', 'C', 'G', 'T', 'N']):
871
- temp_adata.layers[f'{base}_binary_encoding'] = ohe_df_map[j]
872
- ohe_df_map[j] = None # Reassign pointer for memory usage purposes
873
-
874
- # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
1177
+ for j, base in enumerate(["A", "C", "G", "T", "N"]):
1178
+ temp_adata.layers[f"{base}_binary_sequence_encoding"] = (
1179
+ ohe_df_map[j]
1180
+ )
1181
+ ohe_df_map[j] = (
1182
+ None # Reassign pointer for memory usage purposes
1183
+ )
1184
+
1185
+ # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
875
1186
  if adata:
876
1187
  if temp_adata.shape[0] > 0:
877
- print('{0}: Concatenating {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
878
- adata = ad.concat([adata, temp_adata], join='outer', index_unique=None)
1188
+ logger.info(
1189
+ "Concatenating {0} anndata object for sample {1}".format(
1190
+ sample_types[dict_index],
1191
+ final_sample_index,
1192
+ )
1193
+ )
1194
+ adata = ad.concat(
1195
+ [adata, temp_adata], join="outer", index_unique=None
1196
+ )
879
1197
  del temp_adata
880
1198
  else:
881
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
1199
+ logger.warning(
1200
+ f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
1201
+ )
882
1202
  else:
883
1203
  if temp_adata.shape[0] > 0:
884
- print('{0}: Initializing {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
1204
+ logger.info(
1205
+ "Initializing {0} anndata object for sample {1}".format(
1206
+ sample_types[dict_index],
1207
+ final_sample_index,
1208
+ )
1209
+ )
885
1210
  adata = temp_adata
886
1211
  else:
887
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
1212
+ logger.warning(
1213
+ f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
1214
+ )
888
1215
 
889
1216
  gc.collect()
890
1217
  else:
891
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample.")
1218
+ logger.warning(
1219
+ f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample."
1220
+ )
892
1221
 
893
1222
  try:
894
- print('{0}: Writing {1} anndata out as a hdf5 file'.format(readwrite.time_string(), sample_types[dict_index]))
895
- adata.write_h5ad(h5_dir / '{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz'.format(readwrite.date_string(), batch, sample_types[dict_index]), compression='gzip')
896
- except:
897
- print(f"Skipping writing anndata for sample")
898
-
899
- # Delete the batch dictionaries from memory
900
- del dict_list, adata
1223
+ logger.info(
1224
+ "Writing {0} anndata out as a hdf5 file".format(
1225
+ sample_types[dict_index]
1226
+ )
1227
+ )
1228
+ adata.write_h5ad(
1229
+ h5_dir
1230
+ / "{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz".format(
1231
+ readwrite.date_string(), batch, sample_types[dict_index]
1232
+ ),
1233
+ compression="gzip",
1234
+ )
1235
+ except Exception:
1236
+ logger.debug("Skipping writing anndata for sample")
1237
+
1238
+ try:
1239
+ # Delete the batch dictionaries from memory
1240
+ del dict_list, adata
1241
+ except Exception:
1242
+ pass
901
1243
  gc.collect()
902
1244
 
903
1245
  # Iterate over all of the batched hdf5 files and concatenate them.
904
- files = h5_dir.iterdir()
1246
+ files = h5_dir.iterdir()
905
1247
  # Filter file names that contain the search string in their filename and keep them in a list
906
- hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf.name and hdf != final_hdf]
1248
+ hdfs = [hdf for hdf in files if "hdf5.h5ad" in hdf.name and hdf != final_hdf]
907
1249
  combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf.name]
908
1250
  if len(combined_hdfs) > 0:
909
1251
  hdfs = combined_hdfs
@@ -911,55 +1253,62 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
911
1253
  pass
912
1254
  # Sort file list by names and print the list of file names
913
1255
  hdfs.sort()
914
- print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
915
- hdf_paths = [h5_dir / hd5 for hd5 in hdfs]
1256
+ logger.info("{0} sample files found: {1}".format(len(hdfs), hdfs))
1257
+ hdf_paths = [hd5 for hd5 in hdfs]
916
1258
  final_adata = None
917
1259
  for hdf_index, hdf in enumerate(hdf_paths):
918
- print('{0}: Reading in {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
1260
+ logger.info("Reading in {} hdf5 file".format(hdfs[hdf_index]))
919
1261
  temp_adata = ad.read_h5ad(hdf)
920
1262
  if final_adata:
921
- print('{0}: Concatenating final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
922
- final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
1263
+ logger.info(
1264
+ "Concatenating final adata object with {} hdf5 file".format(hdfs[hdf_index])
1265
+ )
1266
+ final_adata = ad.concat([final_adata, temp_adata], join="outer", index_unique=None)
923
1267
  else:
924
- print('{0}: Initializing final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
1268
+ logger.info("Initializing final adata object with {} hdf5 file".format(hdfs[hdf_index]))
925
1269
  final_adata = temp_adata
926
1270
  del temp_adata
927
1271
 
928
1272
  # Set obs columns to type 'category'
929
1273
  for col in final_adata.obs.columns:
930
- final_adata.obs[col] = final_adata.obs[col].astype('category')
1274
+ final_adata.obs[col] = final_adata.obs[col].astype("category")
931
1275
 
932
- ohe_bases = ['A', 'C', 'G', 'T'] # ignore N bases for consensus
933
- ohe_layers = [f"{ohe_base}_binary_encoding" for ohe_base in ohe_bases]
934
- final_adata.uns['References'] = {}
1276
+ ohe_bases = ["A", "C", "G", "T"] # ignore N bases for consensus
1277
+ ohe_layers = [f"{ohe_base}_binary_sequence_encoding" for ohe_base in ohe_bases]
1278
+ final_adata.uns["References"] = {}
935
1279
  for record in records_to_analyze:
936
1280
  # Add FASTA sequence to the object
937
1281
  sequence = record_seq_dict[record][0]
938
1282
  complement = record_seq_dict[record][1]
939
- final_adata.var[f'{record}_top_strand_FASTA_base'] = list(sequence)
940
- final_adata.var[f'{record}_bottom_strand_FASTA_base'] = list(complement)
941
- final_adata.uns[f'{record}_FASTA_sequence'] = sequence
942
- final_adata.uns['References'][f'{record}_FASTA_sequence'] = sequence
1283
+ final_adata.var[f"{record}_top_strand_FASTA_base"] = list(sequence)
1284
+ final_adata.var[f"{record}_bottom_strand_FASTA_base"] = list(complement)
1285
+ final_adata.uns[f"{record}_FASTA_sequence"] = sequence
1286
+ final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
943
1287
  # Add consensus sequence of samples mapped to the record to the object
944
- record_subset = final_adata[final_adata.obs['Reference'] == record]
945
- for strand in record_subset.obs['Strand'].cat.categories:
946
- strand_subset = record_subset[record_subset.obs['Strand'] == strand]
947
- for mapping_dir in strand_subset.obs['Read_mapping_direction'].cat.categories:
948
- mapping_dir_subset = strand_subset[strand_subset.obs['Read_mapping_direction'] == mapping_dir]
1288
+ record_subset = final_adata[final_adata.obs["Reference"] == record]
1289
+ for strand in record_subset.obs["Strand"].cat.categories:
1290
+ strand_subset = record_subset[record_subset.obs["Strand"] == strand]
1291
+ for mapping_dir in strand_subset.obs["Read_mapping_direction"].cat.categories:
1292
+ mapping_dir_subset = strand_subset[
1293
+ strand_subset.obs["Read_mapping_direction"] == mapping_dir
1294
+ ]
949
1295
  layer_map, layer_counts = {}, []
950
1296
  for i, layer in enumerate(ohe_layers):
951
- layer_map[i] = layer.split('_')[0]
1297
+ layer_map[i] = layer.split("_")[0]
952
1298
  layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
953
1299
  count_array = np.array(layer_counts)
954
1300
  nucleotide_indexes = np.argmax(count_array, axis=0)
955
1301
  consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
956
- final_adata.var[f'{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples'] = consensus_sequence_list
1302
+ final_adata.var[
1303
+ f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
1304
+ ] = consensus_sequence_list
957
1305
 
958
1306
  if input_already_demuxed:
959
1307
  final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
960
1308
  final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
961
1309
  else:
962
1310
  from .h5ad_functions import add_demux_type_annotation
1311
+
963
1312
  double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
964
1313
  add_demux_type_annotation(final_adata, double_barcoded_reads)
965
1314
 
@@ -967,4 +1316,4 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
967
1316
  if delete_batch_hdfs:
968
1317
  delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
969
1318
 
970
- return final_adata, final_adata_path
1319
+ return final_adata, final_adata_path