smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,64 +1,81 @@
1
1
  import concurrent.futures
2
2
  import gc
3
- from .bam_functions import count_aligned_reads
3
+ import re
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Iterable, Optional, Union
7
+
8
+ import numpy as np
4
9
  import pandas as pd
5
10
  from tqdm import tqdm
6
- import numpy as np
7
- from pathlib import Path
8
- from typing import Union, Iterable, Optional
9
- import shutil
11
+
12
+ from smftools.logging_utils import get_logger
13
+
14
+ from .bam_functions import count_aligned_reads
15
+
16
+ logger = get_logger(__name__)
17
+
10
18
 
11
19
  def filter_bam_records(bam, mapping_threshold):
12
20
  """Processes a single BAM file, counts reads, and determines records to analyze."""
13
21
  aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
14
-
22
+
15
23
  total_reads = aligned_reads_count + unaligned_reads_count
16
24
  percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
17
- print(f'{percent_aligned:.2f}% of reads in {bam} aligned successfully')
25
+ logger.info(f"{percent_aligned:.2f}% of reads in {bam} aligned successfully")
18
26
 
19
27
  records = []
20
28
  for record, (count, percentage) in record_counts_dict.items():
21
- print(f'{count} reads mapped to reference {record}. This is {percentage*100:.2f}% of all mapped reads in {bam}')
29
+ logger.info(
30
+ f"{count} reads mapped to reference {record}. This is {percentage * 100:.2f}% of all mapped reads in {bam}"
31
+ )
22
32
  if percentage >= mapping_threshold:
23
33
  records.append(record)
24
-
34
+
25
35
  return set(records)
26
36
 
37
+
27
38
  def parallel_filter_bams(bam_path_list, mapping_threshold):
28
39
  """Parallel processing for multiple BAM files."""
29
40
  records_to_analyze = set()
30
41
 
31
42
  with concurrent.futures.ProcessPoolExecutor() as executor:
32
- results = executor.map(filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list))
43
+ results = executor.map(
44
+ filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list)
45
+ )
33
46
 
34
47
  # Aggregate results
35
48
  for result in results:
36
49
  records_to_analyze.update(result)
37
50
 
38
- print(f'Records to analyze: {records_to_analyze}')
51
+ logger.info(f"Records to analyze: {records_to_analyze}")
39
52
  return records_to_analyze
40
53
 
54
+
41
55
  def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
42
56
  """
43
57
  Loads and filters a single TSV file based on chromosome and position criteria.
44
58
  """
45
- temp_df = pd.read_csv(tsv, sep='\t', header=0)
59
+ temp_df = pd.read_csv(tsv, sep="\t", header=0)
46
60
  filtered_records = {}
47
61
 
48
62
  for record in records_to_analyze:
49
63
  if record not in reference_dict:
50
64
  continue
51
-
65
+
52
66
  ref_length = reference_dict[record][0]
53
- filtered_df = temp_df[(temp_df['chrom'] == record) &
54
- (temp_df['ref_position'] >= 0) &
55
- (temp_df['ref_position'] < ref_length)]
67
+ filtered_df = temp_df[
68
+ (temp_df["chrom"] == record)
69
+ & (temp_df["ref_position"] >= 0)
70
+ & (temp_df["ref_position"] < ref_length)
71
+ ]
56
72
 
57
73
  if not filtered_df.empty:
58
74
  filtered_records[record] = {sample_index: filtered_df}
59
75
 
60
76
  return filtered_records
61
77
 
78
+
62
79
  def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
63
80
  """
64
81
  Loads and filters TSV files in parallel.
@@ -78,52 +95,60 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
78
95
 
79
96
  with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
80
97
  futures = {
81
- executor.submit(process_tsv, tsv, records_to_analyze, reference_dict, sample_index): sample_index
98
+ executor.submit(
99
+ process_tsv, tsv, records_to_analyze, reference_dict, sample_index
100
+ ): sample_index
82
101
  for sample_index, tsv in enumerate(tsv_batch)
83
102
  }
84
103
 
85
- for future in tqdm(concurrent.futures.as_completed(futures), desc=f'Processing batch {batch}', total=batch_size):
104
+ for future in tqdm(
105
+ concurrent.futures.as_completed(futures),
106
+ desc=f"Processing batch {batch}",
107
+ total=batch_size,
108
+ ):
86
109
  result = future.result()
87
110
  for record, sample_data in result.items():
88
111
  dict_total[record].update(sample_data)
89
112
 
90
113
  return dict_total
91
114
 
115
+
92
116
  def update_dict_to_skip(dict_to_skip, detected_modifications):
93
117
  """
94
118
  Updates the dict_to_skip set based on the detected modifications.
95
-
119
+
96
120
  Parameters:
97
121
  dict_to_skip (set): The initial set of dictionary indices to skip.
98
122
  detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
99
-
123
+
100
124
  Returns:
101
125
  set: The updated dict_to_skip set.
102
126
  """
103
127
  # Define which indices correspond to modification-specific or strand-specific dictionaries
104
- A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
105
- C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
106
- combined_dicts = {7, 8} # Combined strand dictionaries
128
+ A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
129
+ C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
130
+ combined_dicts = {7, 8} # Combined strand dictionaries
107
131
 
108
132
  # If '6mA' is present, remove the A_stranded indices from the skip set
109
- if '6mA' in detected_modifications:
133
+ if "6mA" in detected_modifications:
110
134
  dict_to_skip -= A_stranded_dicts
111
135
  # If '5mC' is present, remove the C_stranded indices from the skip set
112
- if '5mC' in detected_modifications:
136
+ if "5mC" in detected_modifications:
113
137
  dict_to_skip -= C_stranded_dicts
114
138
  # If both modifications are present, remove the combined indices from the skip set
115
- if '6mA' in detected_modifications and '5mC' in detected_modifications:
139
+ if "6mA" in detected_modifications and "5mC" in detected_modifications:
116
140
  dict_to_skip -= combined_dicts
117
141
 
118
142
  return dict_to_skip
119
143
 
144
+
120
145
  def process_modifications_for_sample(args):
121
146
  """
122
147
  Processes a single (record, sample) pair to extract modification-specific data.
123
-
148
+
124
149
  Parameters:
125
150
  args: (record, sample_index, sample_df, mods, max_reference_length)
126
-
151
+
127
152
  Returns:
128
153
  (record, sample_index, result) where result is a dict with keys:
129
154
  'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
@@ -131,29 +156,30 @@ def process_modifications_for_sample(args):
131
156
  """
132
157
  record, sample_index, sample_df, mods, max_reference_length = args
133
158
  result = {}
134
- if '6mA' in mods:
135
- m6a_df = sample_df[sample_df['modified_primary_base'] == 'A']
136
- result['m6A'] = m6a_df
137
- result['m6A_minus'] = m6a_df[m6a_df['ref_strand'] == '-']
138
- result['m6A_plus'] = m6a_df[m6a_df['ref_strand'] == '+']
159
+ if "6mA" in mods:
160
+ m6a_df = sample_df[sample_df["modified_primary_base"] == "A"]
161
+ result["m6A"] = m6a_df
162
+ result["m6A_minus"] = m6a_df[m6a_df["ref_strand"] == "-"]
163
+ result["m6A_plus"] = m6a_df[m6a_df["ref_strand"] == "+"]
139
164
  m6a_df = None
140
165
  gc.collect()
141
- if '5mC' in mods:
142
- m5c_df = sample_df[sample_df['modified_primary_base'] == 'C']
143
- result['5mC'] = m5c_df
144
- result['5mC_minus'] = m5c_df[m5c_df['ref_strand'] == '-']
145
- result['5mC_plus'] = m5c_df[m5c_df['ref_strand'] == '+']
166
+ if "5mC" in mods:
167
+ m5c_df = sample_df[sample_df["modified_primary_base"] == "C"]
168
+ result["5mC"] = m5c_df
169
+ result["5mC_minus"] = m5c_df[m5c_df["ref_strand"] == "-"]
170
+ result["5mC_plus"] = m5c_df[m5c_df["ref_strand"] == "+"]
146
171
  m5c_df = None
147
172
  gc.collect()
148
- if '6mA' in mods and '5mC' in mods:
149
- result['combined_minus'] = []
150
- result['combined_plus'] = []
173
+ if "6mA" in mods and "5mC" in mods:
174
+ result["combined_minus"] = []
175
+ result["combined_plus"] = []
151
176
  return record, sample_index, result
152
177
 
178
+
153
179
  def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
154
180
  """
155
181
  Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
156
-
182
+
157
183
  Returns:
158
184
  processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
159
185
  """
@@ -164,18 +190,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
164
190
  processed_results = {}
165
191
  with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
166
192
  for record, sample_index, result in tqdm(
167
- executor.map(process_modifications_for_sample, tasks),
168
- total=len(tasks),
169
- desc="Processing modifications"):
193
+ executor.map(process_modifications_for_sample, tasks),
194
+ total=len(tasks),
195
+ desc="Processing modifications",
196
+ ):
170
197
  if record not in processed_results:
171
198
  processed_results[record] = {}
172
199
  processed_results[record][sample_index] = result
173
200
  return processed_results
174
201
 
202
+
175
203
  def merge_modification_results(processed_results, mods):
176
204
  """
177
205
  Merges individual sample results into global dictionaries.
178
-
206
+
179
207
  Returns:
180
208
  A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
181
209
  """
@@ -189,44 +217,52 @@ def merge_modification_results(processed_results, mods):
189
217
  combined_plus = {}
190
218
  for record, sample_results in processed_results.items():
191
219
  for sample_index, res in sample_results.items():
192
- if '6mA' in mods:
220
+ if "6mA" in mods:
193
221
  if record not in m6A_dict:
194
222
  m6A_dict[record], m6A_minus[record], m6A_plus[record] = {}, {}, {}
195
- m6A_dict[record][sample_index] = res.get('m6A', pd.DataFrame())
196
- m6A_minus[record][sample_index] = res.get('m6A_minus', pd.DataFrame())
197
- m6A_plus[record][sample_index] = res.get('m6A_plus', pd.DataFrame())
198
- if '5mC' in mods:
223
+ m6A_dict[record][sample_index] = res.get("m6A", pd.DataFrame())
224
+ m6A_minus[record][sample_index] = res.get("m6A_minus", pd.DataFrame())
225
+ m6A_plus[record][sample_index] = res.get("m6A_plus", pd.DataFrame())
226
+ if "5mC" in mods:
199
227
  if record not in c5m_dict:
200
228
  c5m_dict[record], c5m_minus[record], c5m_plus[record] = {}, {}, {}
201
- c5m_dict[record][sample_index] = res.get('5mC', pd.DataFrame())
202
- c5m_minus[record][sample_index] = res.get('5mC_minus', pd.DataFrame())
203
- c5m_plus[record][sample_index] = res.get('5mC_plus', pd.DataFrame())
204
- if '6mA' in mods and '5mC' in mods:
229
+ c5m_dict[record][sample_index] = res.get("5mC", pd.DataFrame())
230
+ c5m_minus[record][sample_index] = res.get("5mC_minus", pd.DataFrame())
231
+ c5m_plus[record][sample_index] = res.get("5mC_plus", pd.DataFrame())
232
+ if "6mA" in mods and "5mC" in mods:
205
233
  if record not in combined_minus:
206
234
  combined_minus[record], combined_plus[record] = {}, {}
207
- combined_minus[record][sample_index] = res.get('combined_minus', [])
208
- combined_plus[record][sample_index] = res.get('combined_plus', [])
209
- return (m6A_dict, m6A_minus, m6A_plus,
210
- c5m_dict, c5m_minus, c5m_plus,
211
- combined_minus, combined_plus)
235
+ combined_minus[record][sample_index] = res.get("combined_minus", [])
236
+ combined_plus[record][sample_index] = res.get("combined_plus", [])
237
+ return (
238
+ m6A_dict,
239
+ m6A_minus,
240
+ m6A_plus,
241
+ c5m_dict,
242
+ c5m_minus,
243
+ c5m_plus,
244
+ combined_minus,
245
+ combined_plus,
246
+ )
247
+
212
248
 
213
249
  def process_stranded_methylation(args):
214
250
  """
215
251
  Processes a single (dict_index, record, sample) task.
216
-
252
+
217
253
  For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
218
- For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
254
+ For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
219
255
  NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
220
-
256
+
221
257
  Parameters:
222
258
  args: (dict_index, record, sample, dict_list, max_reference_length)
223
-
259
+
224
260
  Returns:
225
261
  (dict_index, record, sample, processed_data)
226
262
  """
227
263
  dict_index, record, sample, dict_list, max_reference_length = args
228
264
  processed_data = {}
229
-
265
+
230
266
  # For combined bottom strand (index 7)
231
267
  if dict_index == 7:
232
268
  temp_a = dict_list[2][record].get(sample, {}).copy()
@@ -235,18 +271,18 @@ def process_stranded_methylation(args):
235
271
  for read in set(temp_a.keys()) | set(temp_c.keys()):
236
272
  if read in temp_a:
237
273
  # Convert using pd.to_numeric with errors='coerce'
238
- value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
274
+ value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
239
275
  else:
240
276
  value_a = None
241
277
  if read in temp_c:
242
- value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
278
+ value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
243
279
  else:
244
280
  value_c = None
245
281
  if value_a is not None and value_c is not None:
246
282
  processed_data[read] = np.where(
247
283
  np.isnan(value_a) & np.isnan(value_c),
248
284
  np.nan,
249
- np.nan_to_num(value_a) + np.nan_to_num(value_c)
285
+ np.nan_to_num(value_a) + np.nan_to_num(value_c),
250
286
  )
251
287
  elif value_a is not None:
252
288
  processed_data[read] = value_a
@@ -261,18 +297,18 @@ def process_stranded_methylation(args):
261
297
  processed_data = {}
262
298
  for read in set(temp_a.keys()) | set(temp_c.keys()):
263
299
  if read in temp_a:
264
- value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
300
+ value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
265
301
  else:
266
302
  value_a = None
267
303
  if read in temp_c:
268
- value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
304
+ value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
269
305
  else:
270
306
  value_c = None
271
307
  if value_a is not None and value_c is not None:
272
308
  processed_data[read] = np.where(
273
309
  np.isnan(value_a) & np.isnan(value_c),
274
310
  np.nan,
275
- np.nan_to_num(value_a) + np.nan_to_num(value_c)
311
+ np.nan_to_num(value_a) + np.nan_to_num(value_c),
276
312
  )
277
313
  elif value_a is not None:
278
314
  processed_data[read] = value_a
@@ -286,24 +322,28 @@ def process_stranded_methylation(args):
286
322
  temp_df = dict_list[dict_index][record][sample]
287
323
  processed_data = {}
288
324
  # Extract columns and convert probabilities to float (coercing errors)
289
- read_ids = temp_df['read_id'].values
290
- positions = temp_df['ref_position'].values
291
- call_codes = temp_df['call_code'].values
292
- probabilities = pd.to_numeric(temp_df['call_prob'].values, errors='coerce')
293
-
294
- modified_codes = {'a', 'h', 'm'}
295
- canonical_codes = {'-'}
296
-
325
+ read_ids = temp_df["read_id"].values
326
+ positions = temp_df["ref_position"].values
327
+ call_codes = temp_df["call_code"].values
328
+ probabilities = pd.to_numeric(temp_df["call_prob"].values, errors="coerce")
329
+
330
+ modified_codes = {"a", "h", "m"}
331
+ canonical_codes = {"-"}
332
+
297
333
  # Compute methylation probabilities (vectorized)
298
334
  methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
299
- methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
300
- methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
301
-
335
+ methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[
336
+ np.isin(call_codes, list(modified_codes))
337
+ ]
338
+ methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
339
+ 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
340
+ )
341
+
302
342
  # Preallocate storage for each unique read
303
343
  unique_reads = np.unique(read_ids)
304
344
  for read in unique_reads:
305
345
  processed_data[read] = np.full(max_reference_length, np.nan, dtype=float)
306
-
346
+
307
347
  # Assign values efficiently
308
348
  for i in range(len(read_ids)):
309
349
  read = read_ids[i]
@@ -314,10 +354,11 @@ def process_stranded_methylation(args):
314
354
  gc.collect()
315
355
  return dict_index, record, sample, processed_data
316
356
 
357
+
317
358
  def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
318
359
  """
319
360
  Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
320
-
361
+
321
362
  Returns:
322
363
  Updated dict_list with processed (nested) dictionaries.
323
364
  """
@@ -327,16 +368,17 @@ def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference
327
368
  for record in current_dict.keys():
328
369
  for sample in current_dict[record].keys():
329
370
  tasks.append((dict_index, record, sample, dict_list, max_reference_length))
330
-
371
+
331
372
  with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
332
373
  for dict_index, record, sample, processed_data in tqdm(
333
374
  executor.map(process_stranded_methylation, tasks),
334
375
  total=len(tasks),
335
- desc="Extracting stranded methylation states"
376
+ desc="Extracting stranded methylation states",
336
377
  ):
337
378
  dict_list[dict_index][record][sample] = processed_data
338
379
  return dict_list
339
380
 
381
+
340
382
  def delete_intermediate_h5ads_and_tmpdir(
341
383
  h5_dir: Union[str, Path, Iterable[str], None],
342
384
  tmp_dir: Optional[Union[str, Path]] = None,
@@ -360,25 +402,27 @@ def delete_intermediate_h5ads_and_tmpdir(
360
402
  verbose : bool
361
403
  Print progress / warnings.
362
404
  """
405
+
363
406
  # Helper: remove a single file path (Path-like or string)
364
407
  def _maybe_unlink(p: Path):
408
+ """Remove a file path if it exists and is a file."""
365
409
  if not p.exists():
366
410
  if verbose:
367
- print(f"[skip] not found: {p}")
411
+ logger.debug(f"[skip] not found: {p}")
368
412
  return
369
413
  if not p.is_file():
370
414
  if verbose:
371
- print(f"[skip] not a file: {p}")
415
+ logger.debug(f"[skip] not a file: {p}")
372
416
  return
373
417
  if dry_run:
374
- print(f"[dry-run] would remove file: {p}")
418
+ logger.debug(f"[dry-run] would remove file: {p}")
375
419
  return
376
420
  try:
377
421
  p.unlink()
378
422
  if verbose:
379
- print(f"Removed file: {p}")
423
+ logger.info(f"Removed file: {p}")
380
424
  except Exception as e:
381
- print(f"[error] failed to remove file {p}: {e}")
425
+ logger.warning(f"[error] failed to remove file {p}: {e}")
382
426
 
383
427
  # Handle h5_dir input (directory OR iterable of file paths)
384
428
  if h5_dir is not None:
@@ -393,7 +437,7 @@ def delete_intermediate_h5ads_and_tmpdir(
393
437
  else:
394
438
  if verbose:
395
439
  # optional: comment this out if too noisy
396
- print(f"[skip] not matching pattern: {p.name}")
440
+ logger.debug(f"[skip] not matching pattern: {p.name}")
397
441
  else:
398
442
  # treat as iterable of file paths
399
443
  for f in h5_dir:
@@ -403,30 +447,44 @@ def delete_intermediate_h5ads_and_tmpdir(
403
447
  _maybe_unlink(p)
404
448
  else:
405
449
  if verbose:
406
- print(f"[skip] not matching pattern or not a file: {p}")
450
+ logger.debug(f"[skip] not matching pattern or not a file: {p}")
407
451
 
408
452
  # Remove tmp_dir recursively (if provided)
409
453
  if tmp_dir is not None:
410
454
  td = Path(tmp_dir)
411
455
  if not td.exists():
412
456
  if verbose:
413
- print(f"[skip] tmp_dir not found: {td}")
457
+ logger.debug(f"[skip] tmp_dir not found: {td}")
414
458
  else:
415
459
  if not td.is_dir():
416
460
  if verbose:
417
- print(f"[skip] tmp_dir is not a directory: {td}")
461
+ logger.debug(f"[skip] tmp_dir is not a directory: {td}")
418
462
  else:
419
463
  if dry_run:
420
- print(f"[dry-run] would remove directory tree: {td}")
464
+ logger.debug(f"[dry-run] would remove directory tree: {td}")
421
465
  else:
422
466
  try:
423
467
  shutil.rmtree(td)
424
468
  if verbose:
425
- print(f"Removed directory tree: {td}")
469
+ logger.info(f"Removed directory tree: {td}")
426
470
  except Exception as e:
427
- print(f"[error] failed to remove tmp dir {td}: {e}")
428
-
429
- def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False, threads=None, double_barcoded_path = None):
471
+ logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
472
+
473
+
474
+ def modkit_extract_to_adata(
475
+ fasta,
476
+ bam_dir,
477
+ out_dir,
478
+ input_already_demuxed,
479
+ mapping_threshold,
480
+ experiment_name,
481
+ mods,
482
+ batch_size,
483
+ mod_tsv_dir,
484
+ delete_batch_hdfs=False,
485
+ threads=None,
486
+ double_barcoded_path=None,
487
+ ):
430
488
  """
431
489
  Takes modkit extract outputs and organizes it into an adata object
432
490
 
@@ -448,50 +506,87 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
448
506
  """
449
507
  ###################################################
450
508
  # Package imports
451
- from .. import readwrite
452
- from ..readwrite import safe_write_h5ad, make_dirs
453
- from .fasta_functions import get_native_references
454
- from .bam_functions import extract_base_identities
455
- from .ohe import ohe_batching
456
- import pandas as pd
457
- import anndata as ad
458
- import os
459
509
  import gc
460
510
  import math
511
+
512
+ import anndata as ad
461
513
  import numpy as np
514
+ import pandas as pd
462
515
  from Bio.Seq import Seq
463
516
  from tqdm import tqdm
464
- import h5py
517
+
518
+ from .. import readwrite
519
+ from ..readwrite import make_dirs
520
+ from .bam_functions import extract_base_identities
521
+ from .fasta_functions import get_native_references
522
+ from .ohe import ohe_batching
465
523
  ###################################################
466
524
 
467
525
  ################## Get input tsv and bam file names into a sorted list ################
468
526
  # Make output dirs
469
- h5_dir = out_dir / 'h5ads'
470
- tmp_dir = out_dir / 'tmp'
527
+ h5_dir = out_dir / "h5ads"
528
+ tmp_dir = out_dir / "tmp"
471
529
  make_dirs([h5_dir, tmp_dir])
472
530
 
473
- existing_h5s = h5_dir.iterdir()
474
- existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in str(h5)]
475
- final_hdf = f'{experiment_name}.h5ad.gz'
531
+ existing_h5s = h5_dir.iterdir()
532
+ existing_h5s = [h5 for h5 in existing_h5s if ".h5ad.gz" in str(h5)]
533
+ final_hdf = f"{experiment_name}.h5ad.gz"
476
534
  final_adata_path = h5_dir / final_hdf
477
535
  final_adata = None
478
-
536
+
479
537
  if final_adata_path.exists():
480
- print(f'{final_adata_path} already exists. Using existing adata')
538
+ logger.debug(f"{final_adata_path} already exists. Using existing adata")
481
539
  return final_adata, final_adata_path
482
-
540
+
483
541
  # List all files in the directory
484
542
  tsvs = sorted(
485
- p for p in mod_tsv_dir.iterdir()
486
- if p.is_file() and 'unclassified' not in p.name and 'extract.tsv' in p.name)
543
+ p
544
+ for p in mod_tsv_dir.iterdir()
545
+ if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
546
+ )
487
547
  bams = sorted(
488
- p for p in bam_dir.iterdir()
489
- if p.is_file() and p.suffix == '.bam' and 'unclassified' not in p.name and '.bai' not in p.name)
548
+ p
549
+ for p in bam_dir.iterdir()
550
+ if p.is_file()
551
+ and p.suffix == ".bam"
552
+ and "unclassified" not in p.name
553
+ and ".bai" not in p.name
554
+ )
555
+
556
+ tsv_path_list = [tsv for tsv in tsvs]
557
+ bam_path_list = [bam for bam in bams]
558
+ logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
559
+ logger.info(f"{len(bams)} sample bams found: {bams}")
560
+
561
+ # Map global sample index (bami / final_sample_index) -> sample name / barcode
562
+ sample_name_map = {}
563
+ barcode_map = {}
564
+
565
+ for idx, bam_path in enumerate(bam_path_list):
566
+ stem = bam_path.stem
567
+
568
+ # Try to peel off a "barcode..." suffix if present.
569
+ # This handles things like:
570
+ # "mySample_barcode01" -> sample="mySample", barcode="barcode01"
571
+ # "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
572
+ # "barcode01" -> sample="barcode01", barcode="barcode01"
573
+ m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
574
+ if m:
575
+ sample_name = m.group(1) or stem
576
+ barcode = m.group(2)
577
+ else:
578
+ # Fallback: treat the whole stem as both sample & barcode
579
+ sample_name = stem
580
+ barcode = stem
581
+
582
+ # make sample name of the format of the bam file stem
583
+ sample_name = sample_name + f"_{barcode}"
584
+
585
+ # Clean the barcode name to be an integer
586
+ barcode = int(barcode.split("barcode")[1])
490
587
 
491
- tsv_path_list = [mod_tsv_dir / tsv for tsv in tsvs]
492
- bam_path_list = [bam_dir / bam for bam in bams]
493
- print(f'{len(tsvs)} sample tsv files found: {tsvs}')
494
- print(f'{len(bams)} sample bams found: {bams}')
588
+ sample_name_map[idx] = sample_name
589
+ barcode_map[idx] = str(barcode)
495
590
  ##########################################################################################
496
591
 
497
592
  ######### Get Record names that have over a passed threshold of mapped reads #############
@@ -503,27 +598,29 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
503
598
  ########### Determine the maximum record length to analyze in the dataset ################
504
599
  # Get all references within the FASTA and indicate the length and identity of the record sequence
505
600
  max_reference_length = 0
506
- reference_dict = get_native_references(str(fasta)) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
601
+ reference_dict = get_native_references(
602
+ str(fasta)
603
+ ) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
507
604
  # Get the max record length in the dataset.
508
605
  for record in records_to_analyze:
509
606
  if reference_dict[record][0] > max_reference_length:
510
607
  max_reference_length = reference_dict[record][0]
511
- print(f'{readwrite.time_string()}: Max reference length in dataset: {max_reference_length}')
512
- batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
513
- print('{0}: Processing input tsvs in {1} batches of {2} tsvs '.format(readwrite.time_string(), batches, batch_size))
608
+ logger.info(f"Max reference length in dataset: {max_reference_length}")
609
+ batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
610
+ logger.info("Processing input tsvs in {0} batches of {1} tsvs ".format(batches, batch_size))
514
611
  ##########################################################################################
515
612
 
516
613
  ##########################################################################################
517
- # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
614
+ # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
518
615
  # Save the file paths in the bam_record_ohe_files dict.
519
616
  bam_record_ohe_files = {}
520
- bam_record_save = tmp_dir / 'tmp_file_dict.h5ad'
617
+ bam_record_save = tmp_dir / "tmp_file_dict.h5ad"
521
618
  fwd_mapped_reads = set()
522
619
  rev_mapped_reads = set()
523
620
  # If this step has already been performed, read in the tmp_dile_dict
524
621
  if bam_record_save.exists():
525
622
  bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
526
- print('Found existing OHE reads, using these')
623
+ logger.debug("Found existing OHE reads, using these")
527
624
  else:
528
625
  # Iterate over split bams
529
626
  for bami, bam in enumerate(bam_path_list):
@@ -533,18 +630,37 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
533
630
  positions = range(current_reference_length)
534
631
  ref_seq = reference_dict[record][1]
535
632
  # Extract the base identities of reads aligned to the record
536
- fwd_base_identities, rev_base_identities, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
633
+ (
634
+ fwd_base_identities,
635
+ rev_base_identities,
636
+ mismatch_counts_per_read,
637
+ mismatch_trend_per_read,
638
+ ) = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
537
639
  # Store read names of fwd and rev mapped reads
538
640
  fwd_mapped_reads.update(fwd_base_identities.keys())
539
641
  rev_mapped_reads.update(rev_base_identities.keys())
540
642
  # One hot encode the sequence string of the reads
541
- fwd_ohe_files = ohe_batching(fwd_base_identities, tmp_dir, record, f"{bami}_fwd",batch_size=100000, threads=threads)
542
- rev_ohe_files = ohe_batching(rev_base_identities, tmp_dir, record, f"{bami}_rev",batch_size=100000, threads=threads)
543
- bam_record_ohe_files[f'{bami}_{record}'] = fwd_ohe_files + rev_ohe_files
643
+ fwd_ohe_files = ohe_batching(
644
+ fwd_base_identities,
645
+ tmp_dir,
646
+ record,
647
+ f"{bami}_fwd",
648
+ batch_size=100000,
649
+ threads=threads,
650
+ )
651
+ rev_ohe_files = ohe_batching(
652
+ rev_base_identities,
653
+ tmp_dir,
654
+ record,
655
+ f"{bami}_rev",
656
+ batch_size=100000,
657
+ threads=threads,
658
+ )
659
+ bam_record_ohe_files[f"{bami}_{record}"] = fwd_ohe_files + rev_ohe_files
544
660
  del fwd_base_identities, rev_base_identities
545
661
  # Save out the ohe file paths
546
662
  X = np.random.rand(1, 1)
547
- tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
663
+ tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
548
664
  tmp_ad.write_h5ad(bam_record_save)
549
665
  ##########################################################################################
550
666
 
@@ -554,39 +670,73 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
554
670
  for record in records_to_analyze:
555
671
  current_reference_length = reference_dict[record][0]
556
672
  delta_max_length = max_reference_length - current_reference_length
557
- sequence = reference_dict[record][1] + 'N'*delta_max_length
558
- complement = str(Seq(reference_dict[record][1]).complement()).upper() + 'N'*delta_max_length
673
+ sequence = reference_dict[record][1] + "N" * delta_max_length
674
+ complement = (
675
+ str(Seq(reference_dict[record][1]).complement()).upper() + "N" * delta_max_length
676
+ )
559
677
  record_seq_dict[record] = (sequence, complement)
560
678
  ##########################################################################################
561
679
 
562
680
  ###################################################
563
681
  # Begin iterating over batches
564
682
  for batch in range(batches):
565
- print('{0}: Processing tsvs for batch {1} '.format(readwrite.time_string(), batch))
683
+ logger.info("Processing tsvs for batch {0} ".format(batch))
566
684
  # For the final batch, just take the remaining tsv and bam files
567
685
  if batch == batches - 1:
568
686
  tsv_batch = tsv_path_list
569
687
  bam_batch = bam_path_list
570
- # For all other batches, take the next batch of tsvs and bams out of the file queue.
688
+ # For all other batches, take the next batch of tsvs and bams out of the file queue.
571
689
  else:
572
690
  tsv_batch = tsv_path_list[:batch_size]
573
691
  bam_batch = bam_path_list[:batch_size]
574
692
  tsv_path_list = tsv_path_list[batch_size:]
575
693
  bam_path_list = bam_path_list[batch_size:]
576
- print('{0}: tsvs in batch {1} '.format(readwrite.time_string(), tsv_batch))
694
+ logger.info("tsvs in batch {0} ".format(tsv_batch))
577
695
 
578
- batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5.name])
579
- ###################################################
696
+ batch_already_processed = sum([1 for h5 in existing_h5s if f"_{batch}_" in h5.name])
697
+ ###################################################
580
698
  if batch_already_processed:
581
- print(f'Batch {batch} has already been processed into h5ads. Skipping batch and using existing files')
699
+ logger.debug(
700
+ f"Batch {batch} has already been processed into h5ads. Skipping batch and using existing files"
701
+ )
582
702
  else:
583
703
  ###################################################
584
704
  ### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
585
705
  # # Initialize dictionaries and place them in a list
586
- dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top = {},{},{},{},{},{},{},{},{}
587
- dict_list = [dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top]
706
+ (
707
+ dict_total,
708
+ dict_a,
709
+ dict_a_bottom,
710
+ dict_a_top,
711
+ dict_c,
712
+ dict_c_bottom,
713
+ dict_c_top,
714
+ dict_combined_bottom,
715
+ dict_combined_top,
716
+ ) = {}, {}, {}, {}, {}, {}, {}, {}, {}
717
+ dict_list = [
718
+ dict_total,
719
+ dict_a,
720
+ dict_a_bottom,
721
+ dict_a_top,
722
+ dict_c,
723
+ dict_c_bottom,
724
+ dict_c_top,
725
+ dict_combined_bottom,
726
+ dict_combined_top,
727
+ ]
588
728
  # Give names to represent each dictionary in the list
589
- sample_types = ['total', 'm6A', 'm6A_bottom_strand', 'm6A_top_strand', '5mC', '5mC_bottom_strand', '5mC_top_strand', 'combined_bottom_strand', 'combined_top_strand']
729
+ sample_types = [
730
+ "total",
731
+ "m6A",
732
+ "m6A_bottom_strand",
733
+ "m6A_top_strand",
734
+ "5mC",
735
+ "5mC_bottom_strand",
736
+ "5mC_top_strand",
737
+ "combined_bottom_strand",
738
+ "combined_top_strand",
739
+ ]
590
740
  # Give indices of dictionaries to skip for analysis and final dictionary saving.
591
741
  dict_to_skip = [0, 1, 4]
592
742
  combined_dicts = [7, 8]
@@ -596,7 +746,14 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
596
746
  dict_to_skip = set(dict_to_skip)
597
747
 
598
748
  # # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
599
- dict_total = parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size=len(tsv_batch), threads=threads)
749
+ dict_total = parallel_load_tsvs(
750
+ tsv_batch,
751
+ records_to_analyze,
752
+ reference_dict,
753
+ batch,
754
+ batch_size=len(tsv_batch),
755
+ threads=threads,
756
+ )
600
757
 
601
758
  # # Step 2: Extract modification-specific data (per (record,sample)) in parallel
602
759
  # processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
@@ -621,56 +778,112 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
621
778
  # Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
622
779
  for record in dict_total.keys():
623
780
  for sample_index in dict_total[record].keys():
624
- if '6mA' in mods:
781
+ if "6mA" in mods:
625
782
  # Remove Adenine stranded dicts from the dicts to skip set
626
783
  dict_to_skip.difference_update(set(A_stranded_dicts))
627
784
 
628
- if record not in dict_a.keys() and record not in dict_a_bottom.keys() and record not in dict_a_top.keys():
785
+ if (
786
+ record not in dict_a.keys()
787
+ and record not in dict_a_bottom.keys()
788
+ and record not in dict_a_top.keys()
789
+ ):
629
790
  dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
630
791
 
631
792
  # get a dictionary of dataframes that only contain methylated adenine positions
632
- dict_a[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'A']
633
- print('{}: Successfully loaded a methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
793
+ dict_a[record][sample_index] = dict_total[record][sample_index][
794
+ dict_total[record][sample_index]["modified_primary_base"] == "A"
795
+ ]
796
+ logger.debug(
797
+ "Successfully loaded a methyl-adenine dictionary for {}".format(
798
+ str(sample_index)
799
+ )
800
+ )
801
+
634
802
  # Stratify the adenine dictionary into two strand specific dictionaries.
635
- dict_a_bottom[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '-']
636
- print('{}: Successfully loaded a minus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
637
- dict_a_top[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '+']
638
- print('{}: Successfully loaded a plus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
803
+ dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
804
+ dict_a[record][sample_index]["ref_strand"] == "-"
805
+ ]
806
+ logger.debug(
807
+ "Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
808
+ str(sample_index)
809
+ )
810
+ )
811
+ dict_a_top[record][sample_index] = dict_a[record][sample_index][
812
+ dict_a[record][sample_index]["ref_strand"] == "+"
813
+ ]
814
+ logger.debug(
815
+ "Successfully loaded a plus strand methyl-adenine dictionary for ".format(
816
+ str(sample_index)
817
+ )
818
+ )
639
819
 
640
820
  # Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
641
821
  dict_a[record][sample_index] = None
642
822
  gc.collect()
643
823
 
644
- if '5mC' in mods:
824
+ if "5mC" in mods:
645
825
  # Remove Cytosine stranded dicts from the dicts to skip set
646
826
  dict_to_skip.difference_update(set(C_stranded_dicts))
647
827
 
648
- if record not in dict_c.keys() and record not in dict_c_bottom.keys() and record not in dict_c_top.keys():
828
+ if (
829
+ record not in dict_c.keys()
830
+ and record not in dict_c_bottom.keys()
831
+ and record not in dict_c_top.keys()
832
+ ):
649
833
  dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
650
834
 
651
835
  # get a dictionary of dataframes that only contain methylated cytosine positions
652
- dict_c[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'C']
653
- print('{}: Successfully loaded a methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
836
+ dict_c[record][sample_index] = dict_total[record][sample_index][
837
+ dict_total[record][sample_index]["modified_primary_base"] == "C"
838
+ ]
839
+ logger.debug(
840
+ "Successfully loaded a methyl-cytosine dictionary for {}".format(
841
+ str(sample_index)
842
+ )
843
+ )
654
844
  # Stratify the cytosine dictionary into two strand specific dictionaries.
655
- dict_c_bottom[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '-']
656
- print('{}: Successfully loaded a minus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
657
- dict_c_top[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '+']
658
- print('{}: Successfully loaded a plus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
845
+ dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
846
+ dict_c[record][sample_index]["ref_strand"] == "-"
847
+ ]
848
+ logger.debug(
849
+ "Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
850
+ str(sample_index)
851
+ )
852
+ )
853
+ dict_c_top[record][sample_index] = dict_c[record][sample_index][
854
+ dict_c[record][sample_index]["ref_strand"] == "+"
855
+ ]
856
+ logger.debug(
857
+ "Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
858
+ str(sample_index)
859
+ )
860
+ )
659
861
  # Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
660
862
  dict_c[record][sample_index] = None
661
863
  gc.collect()
662
-
663
- if '6mA' in mods and '5mC' in mods:
864
+
865
+ if "6mA" in mods and "5mC" in mods:
664
866
  # Remove combined stranded dicts from the dicts to skip set
665
- dict_to_skip.difference_update(set(combined_dicts))
867
+ dict_to_skip.difference_update(set(combined_dicts))
666
868
  # Initialize the sample keys for the combined dictionaries
667
869
 
668
- if record not in dict_combined_bottom.keys() and record not in dict_combined_top.keys():
669
- dict_combined_bottom[record], dict_combined_top[record]= {}, {}
670
-
671
- print('{}: Successfully created a minus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
870
+ if (
871
+ record not in dict_combined_bottom.keys()
872
+ and record not in dict_combined_top.keys()
873
+ ):
874
+ dict_combined_bottom[record], dict_combined_top[record] = {}, {}
875
+
876
+ logger.debug(
877
+ "Successfully created a minus strand combined methylation dictionary for {}".format(
878
+ str(sample_index)
879
+ )
880
+ )
672
881
  dict_combined_bottom[record][sample_index] = []
673
- print('{}: Successfully created a plus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
882
+ logger.debug(
883
+ "Successfully created a plus strand combined methylation dictionary for {}".format(
884
+ str(sample_index)
885
+ )
886
+ )
674
887
  dict_combined_top[record][sample_index] = []
675
888
 
676
889
  # Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
@@ -681,14 +894,24 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
681
894
  for dict_index, dict_type in enumerate(dict_list):
682
895
  # Only iterate over stranded dictionaries
683
896
  if dict_index not in dict_to_skip:
684
- print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), sample_types[dict_index]))
897
+ logger.debug(
898
+ "Extracting methylation states for {} dictionary".format(
899
+ sample_types[dict_index]
900
+ )
901
+ )
685
902
  for record in dict_type.keys():
686
903
  # Get the dictionary for the modification type of interest from the reference mapping of interest
687
904
  mod_strand_record_sample_dict = dict_type[record]
688
- print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), record))
905
+ logger.debug(
906
+ "Extracting methylation states for {} dictionary".format(record)
907
+ )
689
908
  # For each sample in a stranded dictionary
690
909
  n_samples = len(mod_strand_record_sample_dict.keys())
691
- for sample in tqdm(mod_strand_record_sample_dict.keys(), desc=f'Extracting {sample_types[dict_index]} dictionary from record {record} for sample', total=n_samples):
910
+ for sample in tqdm(
911
+ mod_strand_record_sample_dict.keys(),
912
+ desc=f"Extracting {sample_types[dict_index]} dictionary from record {record} for sample",
913
+ total=n_samples,
914
+ ):
692
915
  # Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
693
916
  if dict_index == 7:
694
917
  # Load the minus strand dictionaries for each sample into temporary variables
@@ -699,16 +922,26 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
699
922
  for read in set(temp_a_dict) | set(temp_c_dict):
700
923
  # Add the arrays element-wise if the read is present in both dictionaries
701
924
  if read in temp_a_dict and read in temp_c_dict:
702
- mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
925
+ mod_strand_record_sample_dict[sample][read] = np.where(
926
+ np.isnan(temp_a_dict[read])
927
+ & np.isnan(temp_c_dict[read]),
928
+ np.nan,
929
+ np.nan_to_num(temp_a_dict[read])
930
+ + np.nan_to_num(temp_c_dict[read]),
931
+ )
703
932
  # If the read is present in only one dictionary, copy its value
704
933
  elif read in temp_a_dict:
705
- mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
934
+ mod_strand_record_sample_dict[sample][read] = temp_a_dict[
935
+ read
936
+ ]
706
937
  elif read in temp_c_dict:
707
- mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
938
+ mod_strand_record_sample_dict[sample][read] = temp_c_dict[
939
+ read
940
+ ]
708
941
  del temp_a_dict, temp_c_dict
709
- # Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
942
+ # Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
710
943
  elif dict_index == 8:
711
- # Load the plus strand dictionaries for each sample into temporary variables
944
+ # Load the plus strand dictionaries for each sample into temporary variables
712
945
  temp_a_dict = dict_list[3][record][sample].copy()
713
946
  temp_c_dict = dict_list[6][record][sample].copy()
714
947
  mod_strand_record_sample_dict[sample] = {}
@@ -716,105 +949,163 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
716
949
  for read in set(temp_a_dict) | set(temp_c_dict):
717
950
  # Add the arrays element-wise if the read is present in both dictionaries
718
951
  if read in temp_a_dict and read in temp_c_dict:
719
- mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
952
+ mod_strand_record_sample_dict[sample][read] = np.where(
953
+ np.isnan(temp_a_dict[read])
954
+ & np.isnan(temp_c_dict[read]),
955
+ np.nan,
956
+ np.nan_to_num(temp_a_dict[read])
957
+ + np.nan_to_num(temp_c_dict[read]),
958
+ )
720
959
  # If the read is present in only one dictionary, copy its value
721
960
  elif read in temp_a_dict:
722
- mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
961
+ mod_strand_record_sample_dict[sample][read] = temp_a_dict[
962
+ read
963
+ ]
723
964
  elif read in temp_c_dict:
724
- mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
965
+ mod_strand_record_sample_dict[sample][read] = temp_c_dict[
966
+ read
967
+ ]
725
968
  del temp_a_dict, temp_c_dict
726
969
  # For all other dictionaries
727
970
  else:
728
-
729
971
  # use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
730
972
  temp_df = mod_strand_record_sample_dict[sample]
731
973
  # reassign the dictionary pointer to a nested dictionary.
732
974
  mod_strand_record_sample_dict[sample] = {}
733
975
 
734
976
  # Get relevant columns as NumPy arrays
735
- read_ids = temp_df['read_id'].values
736
- positions = temp_df['ref_position'].values
737
- call_codes = temp_df['call_code'].values
738
- probabilities = temp_df['call_prob'].values
977
+ read_ids = temp_df["read_id"].values
978
+ positions = temp_df["ref_position"].values
979
+ call_codes = temp_df["call_code"].values
980
+ probabilities = temp_df["call_prob"].values
739
981
 
740
982
  # Define valid call code categories
741
- modified_codes = {'a', 'h', 'm'}
742
- canonical_codes = {'-'}
983
+ modified_codes = {"a", "h", "m"}
984
+ canonical_codes = {"-"}
743
985
 
744
986
  # Vectorized methylation calculation with NaN for other codes
745
- methylation_prob = np.full_like(probabilities, np.nan) # Default all to NaN
746
- methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
747
- methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
987
+ methylation_prob = np.full_like(
988
+ probabilities, np.nan
989
+ ) # Default all to NaN
990
+ methylation_prob[np.isin(call_codes, list(modified_codes))] = (
991
+ probabilities[np.isin(call_codes, list(modified_codes))]
992
+ )
993
+ methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
994
+ 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
995
+ )
748
996
 
749
997
  # Find unique reads
750
998
  unique_reads = np.unique(read_ids)
751
999
  # Preallocate storage for each read
752
1000
  for read in unique_reads:
753
- mod_strand_record_sample_dict[sample][read] = np.full(max_reference_length, np.nan)
1001
+ mod_strand_record_sample_dict[sample][read] = np.full(
1002
+ max_reference_length, np.nan
1003
+ )
754
1004
 
755
1005
  # Efficient NumPy indexing to assign values
756
1006
  for i in range(len(read_ids)):
757
1007
  read = read_ids[i]
758
1008
  pos = positions[i]
759
1009
  prob = methylation_prob[i]
760
-
1010
+
761
1011
  # Assign methylation probability
762
1012
  mod_strand_record_sample_dict[sample][read][pos] = prob
763
1013
 
764
-
765
1014
  # Save the sample files in the batch as gzipped hdf5 files
766
- print('{0}: Converting batch {1} dictionaries to anndata objects'.format(readwrite.time_string(), batch))
1015
+ logger.info("Converting batch {} dictionaries to anndata objects".format(batch))
767
1016
  for dict_index, dict_type in enumerate(dict_list):
768
1017
  if dict_index not in dict_to_skip:
769
1018
  # Initialize an hdf5 file for the current modified strand
770
1019
  adata = None
771
- print('{0}: Converting {1} dictionary to an anndata object'.format(readwrite.time_string(), sample_types[dict_index]))
1020
+ logger.info(
1021
+ "Converting {} dictionary to an anndata object".format(
1022
+ sample_types[dict_index]
1023
+ )
1024
+ )
772
1025
  for record in dict_type.keys():
773
1026
  # Get the dictionary for the modification type of interest from the reference mapping of interest
774
1027
  mod_strand_record_sample_dict = dict_type[record]
775
1028
  for sample in mod_strand_record_sample_dict.keys():
776
- print('{0}: Converting {1} dictionary for sample {2} to an anndata object'.format(readwrite.time_string(), sample_types[dict_index], sample))
1029
+ logger.info(
1030
+ "Converting {0} dictionary for sample {1} to an anndata object".format(
1031
+ sample_types[dict_index], sample
1032
+ )
1033
+ )
777
1034
  sample = int(sample)
778
1035
  final_sample_index = sample + (batch * batch_size)
779
- print('{0}: Final sample index for sample: {1}'.format(readwrite.time_string(), final_sample_index))
780
- print('{0}: Converting {1} dictionary for sample {2} to a dataframe'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
781
- temp_df = pd.DataFrame.from_dict(mod_strand_record_sample_dict[sample], orient='index')
782
- mod_strand_record_sample_dict[sample] = None # reassign pointer to facilitate memory usage
1036
+ logger.info(
1037
+ "Final sample index for sample: {}".format(final_sample_index)
1038
+ )
1039
+ logger.debug(
1040
+ "Converting {0} dictionary for sample {1} to a dataframe".format(
1041
+ sample_types[dict_index],
1042
+ final_sample_index,
1043
+ )
1044
+ )
1045
+ temp_df = pd.DataFrame.from_dict(
1046
+ mod_strand_record_sample_dict[sample], orient="index"
1047
+ )
1048
+ mod_strand_record_sample_dict[sample] = (
1049
+ None # reassign pointer to facilitate memory usage
1050
+ )
783
1051
  sorted_index = sorted(temp_df.index)
784
1052
  temp_df = temp_df.reindex(sorted_index)
785
1053
  X = temp_df.values
786
- dataset, strand = sample_types[dict_index].split('_')[:2]
787
-
788
- print('{0}: Loading {1} dataframe for sample {2} into a temp anndata object'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
1054
+ dataset, strand = sample_types[dict_index].split("_")[:2]
1055
+
1056
+ logger.info(
1057
+ "Loading {0} dataframe for sample {1} into a temp anndata object".format(
1058
+ sample_types[dict_index],
1059
+ final_sample_index,
1060
+ )
1061
+ )
789
1062
  temp_adata = ad.AnnData(X)
790
1063
  if temp_adata.shape[0] > 0:
791
- print('{0}: Adding read names and position ids to {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
1064
+ logger.info(
1065
+ "Adding read names and position ids to {0} anndata for sample {1}".format(
1066
+ sample_types[dict_index],
1067
+ final_sample_index,
1068
+ )
1069
+ )
792
1070
  temp_adata.obs_names = temp_df.index
793
1071
  temp_adata.obs_names = temp_adata.obs_names.astype(str)
794
1072
  temp_adata.var_names = temp_df.columns
795
1073
  temp_adata.var_names = temp_adata.var_names.astype(str)
796
- print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
797
- temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
798
- temp_adata.obs['Barcode'] = [str(final_sample_index)] * len(temp_adata)
799
- temp_adata.obs['Reference'] = [f'{record}'] * len(temp_adata)
800
- temp_adata.obs['Strand'] = [strand] * len(temp_adata)
801
- temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)
802
- temp_adata.obs['Reference_dataset_strand'] = [f'{record}_{dataset}_{strand}'] * len(temp_adata)
803
- temp_adata.obs['Reference_strand'] = [f'{record}_{strand}'] * len(temp_adata)
804
-
1074
+ logger.info(
1075
+ "Adding {0} anndata for sample {1}".format(
1076
+ sample_types[dict_index],
1077
+ final_sample_index,
1078
+ )
1079
+ )
1080
+ temp_adata.obs["Sample"] = [
1081
+ sample_name_map[final_sample_index]
1082
+ ] * len(temp_adata)
1083
+ temp_adata.obs["Barcode"] = [barcode_map[final_sample_index]] * len(
1084
+ temp_adata
1085
+ )
1086
+ temp_adata.obs["Reference"] = [f"{record}"] * len(temp_adata)
1087
+ temp_adata.obs["Strand"] = [strand] * len(temp_adata)
1088
+ temp_adata.obs["Dataset"] = [dataset] * len(temp_adata)
1089
+ temp_adata.obs["Reference_dataset_strand"] = [
1090
+ f"{record}_{dataset}_{strand}"
1091
+ ] * len(temp_adata)
1092
+ temp_adata.obs["Reference_strand"] = [f"{record}_{strand}"] * len(
1093
+ temp_adata
1094
+ )
1095
+
805
1096
  # Load in the one hot encoded reads from the current sample and record
806
1097
  one_hot_reads = {}
807
1098
  n_rows_OHE = 5
808
- ohe_files = bam_record_ohe_files[f'{final_sample_index}_{record}']
809
- print(f'Loading OHEs from {ohe_files}')
1099
+ ohe_files = bam_record_ohe_files[f"{final_sample_index}_{record}"]
1100
+ logger.info(f"Loading OHEs from {ohe_files}")
810
1101
  fwd_mapped_reads = set()
811
1102
  rev_mapped_reads = set()
812
1103
  for ohe_file in ohe_files:
813
1104
  tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
814
1105
  one_hot_reads.update(tmp_ohe_dict)
815
- if '_fwd_' in ohe_file:
1106
+ if "_fwd_" in ohe_file:
816
1107
  fwd_mapped_reads.update(tmp_ohe_dict.keys())
817
- elif '_rev_' in ohe_file:
1108
+ elif "_rev_" in ohe_file:
818
1109
  rev_mapped_reads.update(tmp_ohe_dict.keys())
819
1110
  del tmp_ohe_dict
820
1111
 
@@ -823,18 +1114,20 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
823
1114
  read_mapping_direction = []
824
1115
  for read_id in temp_adata.obs_names:
825
1116
  if read_id in fwd_mapped_reads:
826
- read_mapping_direction.append('fwd')
1117
+ read_mapping_direction.append("fwd")
827
1118
  elif read_id in rev_mapped_reads:
828
- read_mapping_direction.append('rev')
1119
+ read_mapping_direction.append("rev")
829
1120
  else:
830
- read_mapping_direction.append('unk')
1121
+ read_mapping_direction.append("unk")
831
1122
 
832
- temp_adata.obs['Read_mapping_direction'] = read_mapping_direction
1123
+ temp_adata.obs["Read_mapping_direction"] = read_mapping_direction
833
1124
 
834
1125
  del temp_df
835
-
1126
+
836
1127
  # Initialize NumPy arrays
837
- sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
1128
+ sequence_length = (
1129
+ one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
1130
+ )
838
1131
  df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
839
1132
  df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
840
1133
  df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
@@ -855,7 +1148,11 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
855
1148
  gc.collect()
856
1149
 
857
1150
  # Fill the arrays
858
- for j, read_name in tqdm(enumerate(sorted_index), desc='Loading dataframes of OHE reads', total=len(sorted_index)):
1151
+ for j, read_name in tqdm(
1152
+ enumerate(sorted_index),
1153
+ desc="Loading dataframes of OHE reads",
1154
+ total=len(sorted_index),
1155
+ ):
859
1156
  df_A[j, :] = dict_A[read_name]
860
1157
  df_C[j, :] = dict_C[read_name]
861
1158
  df_G[j, :] = dict_G[read_name]
@@ -867,43 +1164,78 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
867
1164
 
868
1165
  # Store the results in AnnData layers
869
1166
  ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
870
- for j, base in enumerate(['A', 'C', 'G', 'T', 'N']):
871
- temp_adata.layers[f'{base}_binary_encoding'] = ohe_df_map[j]
872
- ohe_df_map[j] = None # Reassign pointer for memory usage purposes
873
-
874
- # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
1167
+ for j, base in enumerate(["A", "C", "G", "T", "N"]):
1168
+ temp_adata.layers[f"{base}_binary_sequence_encoding"] = (
1169
+ ohe_df_map[j]
1170
+ )
1171
+ ohe_df_map[j] = (
1172
+ None # Reassign pointer for memory usage purposes
1173
+ )
1174
+
1175
+ # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
875
1176
  if adata:
876
1177
  if temp_adata.shape[0] > 0:
877
- print('{0}: Concatenating {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
878
- adata = ad.concat([adata, temp_adata], join='outer', index_unique=None)
1178
+ logger.info(
1179
+ "Concatenating {0} anndata object for sample {1}".format(
1180
+ sample_types[dict_index],
1181
+ final_sample_index,
1182
+ )
1183
+ )
1184
+ adata = ad.concat(
1185
+ [adata, temp_adata], join="outer", index_unique=None
1186
+ )
879
1187
  del temp_adata
880
1188
  else:
881
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
1189
+ logger.warning(
1190
+ f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
1191
+ )
882
1192
  else:
883
1193
  if temp_adata.shape[0] > 0:
884
- print('{0}: Initializing {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
1194
+ logger.info(
1195
+ "Initializing {0} anndata object for sample {1}".format(
1196
+ sample_types[dict_index],
1197
+ final_sample_index,
1198
+ )
1199
+ )
885
1200
  adata = temp_adata
886
1201
  else:
887
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
1202
+ logger.warning(
1203
+ f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
1204
+ )
888
1205
 
889
1206
  gc.collect()
890
1207
  else:
891
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample.")
1208
+ logger.warning(
1209
+ f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample."
1210
+ )
892
1211
 
893
1212
  try:
894
- print('{0}: Writing {1} anndata out as a hdf5 file'.format(readwrite.time_string(), sample_types[dict_index]))
895
- adata.write_h5ad(h5_dir / '{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz'.format(readwrite.date_string(), batch, sample_types[dict_index]), compression='gzip')
896
- except:
897
- print(f"Skipping writing anndata for sample")
898
-
899
- # Delete the batch dictionaries from memory
900
- del dict_list, adata
1213
+ logger.info(
1214
+ "Writing {0} anndata out as a hdf5 file".format(
1215
+ sample_types[dict_index]
1216
+ )
1217
+ )
1218
+ adata.write_h5ad(
1219
+ h5_dir
1220
+ / "{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz".format(
1221
+ readwrite.date_string(), batch, sample_types[dict_index]
1222
+ ),
1223
+ compression="gzip",
1224
+ )
1225
+ except Exception:
1226
+ logger.debug("Skipping writing anndata for sample")
1227
+
1228
+ try:
1229
+ # Delete the batch dictionaries from memory
1230
+ del dict_list, adata
1231
+ except Exception:
1232
+ pass
901
1233
  gc.collect()
902
1234
 
903
1235
  # Iterate over all of the batched hdf5 files and concatenate them.
904
- files = h5_dir.iterdir()
1236
+ files = h5_dir.iterdir()
905
1237
  # Filter file names that contain the search string in their filename and keep them in a list
906
- hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf.name and hdf != final_hdf]
1238
+ hdfs = [hdf for hdf in files if "hdf5.h5ad" in hdf.name and hdf != final_hdf]
907
1239
  combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf.name]
908
1240
  if len(combined_hdfs) > 0:
909
1241
  hdfs = combined_hdfs
@@ -911,55 +1243,62 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
911
1243
  pass
912
1244
  # Sort file list by names and print the list of file names
913
1245
  hdfs.sort()
914
- print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
915
- hdf_paths = [h5_dir / hd5 for hd5 in hdfs]
1246
+ logger.info("{0} sample files found: {1}".format(len(hdfs), hdfs))
1247
+ hdf_paths = [hd5 for hd5 in hdfs]
916
1248
  final_adata = None
917
1249
  for hdf_index, hdf in enumerate(hdf_paths):
918
- print('{0}: Reading in {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
1250
+ logger.info("Reading in {} hdf5 file".format(hdfs[hdf_index]))
919
1251
  temp_adata = ad.read_h5ad(hdf)
920
1252
  if final_adata:
921
- print('{0}: Concatenating final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
922
- final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
1253
+ logger.info(
1254
+ "Concatenating final adata object with {} hdf5 file".format(hdfs[hdf_index])
1255
+ )
1256
+ final_adata = ad.concat([final_adata, temp_adata], join="outer", index_unique=None)
923
1257
  else:
924
- print('{0}: Initializing final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
1258
+ logger.info("Initializing final adata object with {} hdf5 file".format(hdfs[hdf_index]))
925
1259
  final_adata = temp_adata
926
1260
  del temp_adata
927
1261
 
928
1262
  # Set obs columns to type 'category'
929
1263
  for col in final_adata.obs.columns:
930
- final_adata.obs[col] = final_adata.obs[col].astype('category')
1264
+ final_adata.obs[col] = final_adata.obs[col].astype("category")
931
1265
 
932
- ohe_bases = ['A', 'C', 'G', 'T'] # ignore N bases for consensus
933
- ohe_layers = [f"{ohe_base}_binary_encoding" for ohe_base in ohe_bases]
934
- final_adata.uns['References'] = {}
1266
+ ohe_bases = ["A", "C", "G", "T"] # ignore N bases for consensus
1267
+ ohe_layers = [f"{ohe_base}_binary_sequence_encoding" for ohe_base in ohe_bases]
1268
+ final_adata.uns["References"] = {}
935
1269
  for record in records_to_analyze:
936
1270
  # Add FASTA sequence to the object
937
1271
  sequence = record_seq_dict[record][0]
938
1272
  complement = record_seq_dict[record][1]
939
- final_adata.var[f'{record}_top_strand_FASTA_base'] = list(sequence)
940
- final_adata.var[f'{record}_bottom_strand_FASTA_base'] = list(complement)
941
- final_adata.uns[f'{record}_FASTA_sequence'] = sequence
942
- final_adata.uns['References'][f'{record}_FASTA_sequence'] = sequence
1273
+ final_adata.var[f"{record}_top_strand_FASTA_base"] = list(sequence)
1274
+ final_adata.var[f"{record}_bottom_strand_FASTA_base"] = list(complement)
1275
+ final_adata.uns[f"{record}_FASTA_sequence"] = sequence
1276
+ final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
943
1277
  # Add consensus sequence of samples mapped to the record to the object
944
- record_subset = final_adata[final_adata.obs['Reference'] == record]
945
- for strand in record_subset.obs['Strand'].cat.categories:
946
- strand_subset = record_subset[record_subset.obs['Strand'] == strand]
947
- for mapping_dir in strand_subset.obs['Read_mapping_direction'].cat.categories:
948
- mapping_dir_subset = strand_subset[strand_subset.obs['Read_mapping_direction'] == mapping_dir]
1278
+ record_subset = final_adata[final_adata.obs["Reference"] == record]
1279
+ for strand in record_subset.obs["Strand"].cat.categories:
1280
+ strand_subset = record_subset[record_subset.obs["Strand"] == strand]
1281
+ for mapping_dir in strand_subset.obs["Read_mapping_direction"].cat.categories:
1282
+ mapping_dir_subset = strand_subset[
1283
+ strand_subset.obs["Read_mapping_direction"] == mapping_dir
1284
+ ]
949
1285
  layer_map, layer_counts = {}, []
950
1286
  for i, layer in enumerate(ohe_layers):
951
- layer_map[i] = layer.split('_')[0]
1287
+ layer_map[i] = layer.split("_")[0]
952
1288
  layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
953
1289
  count_array = np.array(layer_counts)
954
1290
  nucleotide_indexes = np.argmax(count_array, axis=0)
955
1291
  consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
956
- final_adata.var[f'{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples'] = consensus_sequence_list
1292
+ final_adata.var[
1293
+ f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
1294
+ ] = consensus_sequence_list
957
1295
 
958
1296
  if input_already_demuxed:
959
1297
  final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
960
1298
  final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
961
1299
  else:
962
1300
  from .h5ad_functions import add_demux_type_annotation
1301
+
963
1302
  double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
964
1303
  add_demux_type_annotation(final_adata, double_barcoded_reads)
965
1304
 
@@ -967,4 +1306,4 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
967
1306
  if delete_batch_hdfs:
968
1307
  delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
969
1308
 
970
- return final_adata, final_adata_path
1309
+ return final_adata, final_adata_path