smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,342 @@
1
1
  ## modkit_extract_to_adata
2
2
 
3
- def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False):
3
+ import concurrent.futures
4
+ import gc
5
+ from .count_aligned_reads import count_aligned_reads
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+ def filter_bam_records(bam, mapping_threshold):
11
+ """Processes a single BAM file, counts reads, and determines records to analyze."""
12
+ aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
13
+
14
+ total_reads = aligned_reads_count + unaligned_reads_count
15
+ percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
16
+ print(f'{percent_aligned:.2f}% of reads in {bam} aligned successfully')
17
+
18
+ records = []
19
+ for record, (count, percentage) in record_counts_dict.items():
20
+ print(f'{count} reads mapped to reference {record}. This is {percentage*100:.2f}% of all mapped reads in {bam}')
21
+ if percentage >= mapping_threshold:
22
+ records.append(record)
23
+
24
+ return set(records)
25
+
26
+ def parallel_filter_bams(bam_path_list, mapping_threshold):
27
+ """Parallel processing for multiple BAM files."""
28
+ records_to_analyze = set()
29
+
30
+ with concurrent.futures.ProcessPoolExecutor() as executor:
31
+ results = executor.map(filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list))
32
+
33
+ # Aggregate results
34
+ for result in results:
35
+ records_to_analyze.update(result)
36
+
37
+ print(f'Records to analyze: {records_to_analyze}')
38
+ return records_to_analyze
39
+
40
+ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
41
+ """
42
+ Loads and filters a single TSV file based on chromosome and position criteria.
43
+ """
44
+ temp_df = pd.read_csv(tsv, sep='\t', header=0)
45
+ filtered_records = {}
46
+
47
+ for record in records_to_analyze:
48
+ if record not in reference_dict:
49
+ continue
50
+
51
+ ref_length = reference_dict[record][0]
52
+ filtered_df = temp_df[(temp_df['chrom'] == record) &
53
+ (temp_df['ref_position'] >= 0) &
54
+ (temp_df['ref_position'] < ref_length)]
55
+
56
+ if not filtered_df.empty:
57
+ filtered_records[record] = {sample_index: filtered_df}
58
+
59
+ return filtered_records
60
+
61
+ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
62
+ """
63
+ Loads and filters TSV files in parallel.
64
+
65
+ Parameters:
66
+ tsv_batch (list): List of TSV file paths.
67
+ records_to_analyze (list): Chromosome records to analyze.
68
+ reference_dict (dict): Dictionary containing reference lengths.
69
+ batch (int): Current batch number.
70
+ batch_size (int): Total files in the batch.
71
+ threads (int): Number of parallel workers.
72
+
73
+ Returns:
74
+ dict: Processed `dict_total` dictionary.
75
+ """
76
+ dict_total = {record: {} for record in records_to_analyze}
77
+
78
+ with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
79
+ futures = {
80
+ executor.submit(process_tsv, tsv, records_to_analyze, reference_dict, sample_index): sample_index
81
+ for sample_index, tsv in enumerate(tsv_batch)
82
+ }
83
+
84
+ for future in tqdm(concurrent.futures.as_completed(futures), desc=f'Processing batch {batch}', total=batch_size):
85
+ result = future.result()
86
+ for record, sample_data in result.items():
87
+ dict_total[record].update(sample_data)
88
+
89
+ return dict_total
90
+
91
+ def update_dict_to_skip(dict_to_skip, detected_modifications):
92
+ """
93
+ Updates the dict_to_skip set based on the detected modifications.
94
+
95
+ Parameters:
96
+ dict_to_skip (set): The initial set of dictionary indices to skip.
97
+ detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
98
+
99
+ Returns:
100
+ set: The updated dict_to_skip set.
101
+ """
102
+ # Define which indices correspond to modification-specific or strand-specific dictionaries
103
+ A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
104
+ C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
105
+ combined_dicts = {7, 8} # Combined strand dictionaries
106
+
107
+ # If '6mA' is present, remove the A_stranded indices from the skip set
108
+ if '6mA' in detected_modifications:
109
+ dict_to_skip -= A_stranded_dicts
110
+ # If '5mC' is present, remove the C_stranded indices from the skip set
111
+ if '5mC' in detected_modifications:
112
+ dict_to_skip -= C_stranded_dicts
113
+ # If both modifications are present, remove the combined indices from the skip set
114
+ if '6mA' in detected_modifications and '5mC' in detected_modifications:
115
+ dict_to_skip -= combined_dicts
116
+
117
+ return dict_to_skip
118
+
119
+ def process_modifications_for_sample(args):
120
+ """
121
+ Processes a single (record, sample) pair to extract modification-specific data.
122
+
123
+ Parameters:
124
+ args: (record, sample_index, sample_df, mods, max_reference_length)
125
+
126
+ Returns:
127
+ (record, sample_index, result) where result is a dict with keys:
128
+ 'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
129
+ optionally 'combined_minus' and 'combined_plus' (initialized as empty lists).
130
+ """
131
+ record, sample_index, sample_df, mods, max_reference_length = args
132
+ result = {}
133
+ if '6mA' in mods:
134
+ m6a_df = sample_df[sample_df['modified_primary_base'] == 'A']
135
+ result['m6A'] = m6a_df
136
+ result['m6A_minus'] = m6a_df[m6a_df['ref_strand'] == '-']
137
+ result['m6A_plus'] = m6a_df[m6a_df['ref_strand'] == '+']
138
+ m6a_df = None
139
+ gc.collect()
140
+ if '5mC' in mods:
141
+ m5c_df = sample_df[sample_df['modified_primary_base'] == 'C']
142
+ result['5mC'] = m5c_df
143
+ result['5mC_minus'] = m5c_df[m5c_df['ref_strand'] == '-']
144
+ result['5mC_plus'] = m5c_df[m5c_df['ref_strand'] == '+']
145
+ m5c_df = None
146
+ gc.collect()
147
+ if '6mA' in mods and '5mC' in mods:
148
+ result['combined_minus'] = []
149
+ result['combined_plus'] = []
150
+ return record, sample_index, result
151
+
152
+ def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
153
+ """
154
+ Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
155
+
156
+ Returns:
157
+ processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
158
+ """
159
+ tasks = []
160
+ for record, sample_dict in dict_total.items():
161
+ for sample_index, sample_df in sample_dict.items():
162
+ tasks.append((record, sample_index, sample_df, mods, max_reference_length))
163
+ processed_results = {}
164
+ with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
165
+ for record, sample_index, result in tqdm(
166
+ executor.map(process_modifications_for_sample, tasks),
167
+ total=len(tasks),
168
+ desc="Processing modifications"):
169
+ if record not in processed_results:
170
+ processed_results[record] = {}
171
+ processed_results[record][sample_index] = result
172
+ return processed_results
173
+
174
+ def merge_modification_results(processed_results, mods):
175
+ """
176
+ Merges individual sample results into global dictionaries.
177
+
178
+ Returns:
179
+ A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
180
+ """
181
+ m6A_dict = {}
182
+ m6A_minus = {}
183
+ m6A_plus = {}
184
+ c5m_dict = {}
185
+ c5m_minus = {}
186
+ c5m_plus = {}
187
+ combined_minus = {}
188
+ combined_plus = {}
189
+ for record, sample_results in processed_results.items():
190
+ for sample_index, res in sample_results.items():
191
+ if '6mA' in mods:
192
+ if record not in m6A_dict:
193
+ m6A_dict[record], m6A_minus[record], m6A_plus[record] = {}, {}, {}
194
+ m6A_dict[record][sample_index] = res.get('m6A', pd.DataFrame())
195
+ m6A_minus[record][sample_index] = res.get('m6A_minus', pd.DataFrame())
196
+ m6A_plus[record][sample_index] = res.get('m6A_plus', pd.DataFrame())
197
+ if '5mC' in mods:
198
+ if record not in c5m_dict:
199
+ c5m_dict[record], c5m_minus[record], c5m_plus[record] = {}, {}, {}
200
+ c5m_dict[record][sample_index] = res.get('5mC', pd.DataFrame())
201
+ c5m_minus[record][sample_index] = res.get('5mC_minus', pd.DataFrame())
202
+ c5m_plus[record][sample_index] = res.get('5mC_plus', pd.DataFrame())
203
+ if '6mA' in mods and '5mC' in mods:
204
+ if record not in combined_minus:
205
+ combined_minus[record], combined_plus[record] = {}, {}
206
+ combined_minus[record][sample_index] = res.get('combined_minus', [])
207
+ combined_plus[record][sample_index] = res.get('combined_plus', [])
208
+ return (m6A_dict, m6A_minus, m6A_plus,
209
+ c5m_dict, c5m_minus, c5m_plus,
210
+ combined_minus, combined_plus)
211
+
212
+ def process_stranded_methylation(args):
213
+ """
214
+ Processes a single (dict_index, record, sample) task.
215
+
216
+ For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
217
+ For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
218
+ NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
219
+
220
+ Parameters:
221
+ args: (dict_index, record, sample, dict_list, max_reference_length)
222
+
223
+ Returns:
224
+ (dict_index, record, sample, processed_data)
225
+ """
226
+ dict_index, record, sample, dict_list, max_reference_length = args
227
+ processed_data = {}
228
+
229
+ # For combined bottom strand (index 7)
230
+ if dict_index == 7:
231
+ temp_a = dict_list[2][record].get(sample, {}).copy()
232
+ temp_c = dict_list[5][record].get(sample, {}).copy()
233
+ processed_data = {}
234
+ for read in set(temp_a.keys()) | set(temp_c.keys()):
235
+ if read in temp_a:
236
+ # Convert using pd.to_numeric with errors='coerce'
237
+ value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
238
+ else:
239
+ value_a = None
240
+ if read in temp_c:
241
+ value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
242
+ else:
243
+ value_c = None
244
+ if value_a is not None and value_c is not None:
245
+ processed_data[read] = np.where(
246
+ np.isnan(value_a) & np.isnan(value_c),
247
+ np.nan,
248
+ np.nan_to_num(value_a) + np.nan_to_num(value_c)
249
+ )
250
+ elif value_a is not None:
251
+ processed_data[read] = value_a
252
+ elif value_c is not None:
253
+ processed_data[read] = value_c
254
+ del temp_a, temp_c
255
+
256
+ # For combined top strand (index 8)
257
+ elif dict_index == 8:
258
+ temp_a = dict_list[3][record].get(sample, {}).copy()
259
+ temp_c = dict_list[6][record].get(sample, {}).copy()
260
+ processed_data = {}
261
+ for read in set(temp_a.keys()) | set(temp_c.keys()):
262
+ if read in temp_a:
263
+ value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
264
+ else:
265
+ value_a = None
266
+ if read in temp_c:
267
+ value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
268
+ else:
269
+ value_c = None
270
+ if value_a is not None and value_c is not None:
271
+ processed_data[read] = np.where(
272
+ np.isnan(value_a) & np.isnan(value_c),
273
+ np.nan,
274
+ np.nan_to_num(value_a) + np.nan_to_num(value_c)
275
+ )
276
+ elif value_a is not None:
277
+ processed_data[read] = value_a
278
+ elif value_c is not None:
279
+ processed_data[read] = value_c
280
+ del temp_a, temp_c
281
+
282
+ # For all other dictionaries
283
+ else:
284
+ # current_data is a DataFrame
285
+ temp_df = dict_list[dict_index][record][sample]
286
+ processed_data = {}
287
+ # Extract columns and convert probabilities to float (coercing errors)
288
+ read_ids = temp_df['read_id'].values
289
+ positions = temp_df['ref_position'].values
290
+ call_codes = temp_df['call_code'].values
291
+ probabilities = pd.to_numeric(temp_df['call_prob'].values, errors='coerce')
292
+
293
+ modified_codes = {'a', 'h', 'm'}
294
+ canonical_codes = {'-'}
295
+
296
+ # Compute methylation probabilities (vectorized)
297
+ methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
298
+ methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
299
+ methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
300
+
301
+ # Preallocate storage for each unique read
302
+ unique_reads = np.unique(read_ids)
303
+ for read in unique_reads:
304
+ processed_data[read] = np.full(max_reference_length, np.nan, dtype=float)
305
+
306
+ # Assign values efficiently
307
+ for i in range(len(read_ids)):
308
+ read = read_ids[i]
309
+ pos = positions[i]
310
+ prob = methylation_prob[i]
311
+ processed_data[read][pos] = prob
312
+
313
+ gc.collect()
314
+ return dict_index, record, sample, processed_data
315
+
316
+ def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
317
+ """
318
+ Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
319
+
320
+ Returns:
321
+ Updated dict_list with processed (nested) dictionaries.
322
+ """
323
+ tasks = []
324
+ for dict_index, current_dict in enumerate(dict_list):
325
+ if dict_index not in dict_to_skip:
326
+ for record in current_dict.keys():
327
+ for sample in current_dict[record].keys():
328
+ tasks.append((dict_index, record, sample, dict_list, max_reference_length))
329
+
330
+ with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
331
+ for dict_index, record, sample, processed_data in tqdm(
332
+ executor.map(process_stranded_methylation, tasks),
333
+ total=len(tasks),
334
+ desc="Extracting stranded methylation states"
335
+ ):
336
+ dict_list[dict_index][record][sample] = processed_data
337
+ return dict_list
338
+
339
+ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False, threads=None):
4
340
  """
5
341
  Takes modkit extract outputs and organizes it into an adata object
6
342
 
@@ -15,15 +351,13 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
15
351
  delete_batch_hdfs (bool): Whether to delete the batch hdfs after writing out the final concatenated hdf. Default is False
16
352
 
17
353
  Returns:
18
- None
354
+ final_adata_path (str): Path to the final adata
19
355
  """
20
356
  ###################################################
21
357
  # Package imports
22
358
  from .. import readwrite
23
359
  from .get_native_references import get_native_references
24
- from .count_aligned_reads import count_aligned_reads
25
360
  from .extract_base_identities import extract_base_identities
26
- from .one_hot_encode import one_hot_encode
27
361
  from .ohe_batching import ohe_batching
28
362
  import pandas as pd
29
363
  import anndata as ad
@@ -43,13 +377,27 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
43
377
  bam_files = os.listdir(bam_dir)
44
378
  # get current working directory
45
379
  parent_dir = os.path.dirname(mod_tsv_dir)
380
+
46
381
  # Make output dirs
47
382
  h5_dir = os.path.join(parent_dir, 'h5ads')
48
383
  tmp_dir = os.path.join(parent_dir, 'tmp')
49
384
  make_dirs([h5_dir, tmp_dir])
385
+ existing_h5s = os.listdir(h5_dir)
386
+ existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in h5]
387
+ final_hdf = f'{experiment_name}_final_experiment_hdf5.h5ad'
388
+ final_adata_path = os.path.join(h5_dir, final_hdf)
389
+
390
+ if os.path.exists(f"{final_adata_path}.gz"):
391
+ print(f'{final_adata_path}.gz already exists. Using existing adata')
392
+ return f"{final_adata_path}.gz"
393
+
394
+ elif os.path.exists(f"{final_adata_path}"):
395
+ print(f'{final_adata_path} already exists. Using existing adata')
396
+ return final_adata_path
397
+
50
398
  # Filter file names that contain the search string in their filename and keep them in a list
51
- tsvs = [tsv for tsv in tsv_files if 'extract.tsv' in tsv]
52
- bams = [bam for bam in bam_files if '.bam' in bam and '.bai' not in bam]
399
+ tsvs = [tsv for tsv in tsv_files if 'extract.tsv' in tsv and 'unclassified' not in tsv]
400
+ bams = [bam for bam in bam_files if '.bam' in bam and '.bai' not in bam and 'unclassified' not in bam]
53
401
  # Sort file list by names and print the list of file names
54
402
  tsvs.sort()
55
403
  tsv_path_list = [os.path.join(mod_tsv_dir, tsv) for tsv in tsvs]
@@ -61,18 +409,8 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
61
409
 
62
410
  ######### Get Record names that have over a passed threshold of mapped reads #############
63
411
  # get all records that are above a certain mapping threshold in at least one sample bam
64
- records_to_analyze = []
65
- for bami, bam in enumerate(bam_path_list):
66
- aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
67
- percent_aligned = aligned_reads_count*100 / (aligned_reads_count+unaligned_reads_count)
68
- print(f'{percent_aligned} percent of reads in {bams[bami]} aligned successfully')
69
- # Iterate over references and decide which to use in the analysis based on the mapping_threshold
70
- for record in record_counts_dict:
71
- print('{0} reads mapped to reference record {1}. This is {2} percent of all mapped reads in {3}'.format(record_counts_dict[record][0], record, record_counts_dict[record][1]*100, bams[bami]))
72
- if record_counts_dict[record][1] >= mapping_threshold:
73
- records_to_analyze.append(record)
74
- records_to_analyze = set(records_to_analyze)
75
- print(f'Records to analyze: {records_to_analyze}')
412
+ records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
413
+
76
414
  ##########################################################################################
77
415
 
78
416
  ########### Determine the maximum record length to analyze in the dataset ################
@@ -92,7 +430,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
92
430
  # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
93
431
  # Save the file paths in the bam_record_ohe_files dict.
94
432
  bam_record_ohe_files = {}
95
- bam_record_save = os.path.join(tmp_dir, 'tmp_file_dict.h5ad.gz')
433
+ bam_record_save = os.path.join(tmp_dir, 'tmp_file_dict.h5ad')
96
434
  fwd_mapped_reads = set()
97
435
  rev_mapped_reads = set()
98
436
  # If this step has already been performed, read in the tmp_dile_dict
@@ -112,14 +450,14 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
112
450
  fwd_mapped_reads.update(fwd_base_identities.keys())
113
451
  rev_mapped_reads.update(rev_base_identities.keys())
114
452
  # One hot encode the sequence string of the reads
115
- fwd_ohe_files = ohe_batching(fwd_base_identities, tmp_dir, record, f"{bami}_fwd",batch_size=100000)
116
- rev_ohe_files = ohe_batching(rev_base_identities, tmp_dir, record, f"{bami}_rev",batch_size=100000)
453
+ fwd_ohe_files = ohe_batching(fwd_base_identities, tmp_dir, record, f"{bami}_fwd",batch_size=100000, threads=threads)
454
+ rev_ohe_files = ohe_batching(rev_base_identities, tmp_dir, record, f"{bami}_rev",batch_size=100000, threads=threads)
117
455
  bam_record_ohe_files[f'{bami}_{record}'] = fwd_ohe_files + rev_ohe_files
118
456
  del fwd_base_identities, rev_base_identities
119
457
  # Save out the ohe file paths
120
458
  X = np.random.rand(1, 1)
121
459
  tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
122
- tmp_ad.write_h5ad(bam_record_save, compression='gzip')
460
+ tmp_ad.write_h5ad(bam_record_save)
123
461
  ##########################################################################################
124
462
 
125
463
  ##########################################################################################
@@ -134,385 +472,413 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
134
472
  ##########################################################################################
135
473
 
136
474
  ###################################################
137
- existing_h5s = os.listdir(h5_dir)
138
- existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in h5]
139
- final_hdf = f'{experiment_name}_final_experiment_hdf5.h5ad.gz'
140
- final_hdf_already_exists = final_hdf in existing_h5s
141
-
142
- if final_hdf_already_exists:
143
- print(f'{final_hdf} has already been made. Skipping processing.')
144
- else:
145
- # Begin iterating over batches
146
- for batch in range(batches):
147
- print('{0}: Processing tsvs for batch {1} '.format(readwrite.time_string(), batch))
148
- # For the final batch, just take the remaining tsv and bam files
149
- if batch == batches - 1:
150
- tsv_batch = tsv_path_list
151
- bam_batch = bam_path_list
152
- # For all other batches, take the next batch of tsvs and bams out of the file queue.
153
- else:
154
- tsv_batch = tsv_path_list[:batch_size]
155
- bam_batch = bam_path_list[:batch_size]
156
- tsv_path_list = tsv_path_list[batch_size:]
157
- bam_path_list = bam_path_list[batch_size:]
158
- print('{0}: tsvs in batch {1} '.format(readwrite.time_string(), tsv_batch))
159
-
160
- batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5])
161
- ###################################################
162
- if batch_already_processed:
163
- print(f'Batch {batch} has already been processed into h5ads. Skipping batch and using existing files')
164
- else:
165
- ###################################################
166
- ### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
167
- # Initialize dictionaries and place them in a list
168
- dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top = {},{},{},{},{},{},{},{},{}
169
- dict_list = [dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top]
170
- # Give names to represent each dictionary in the list
171
- sample_types = ['total', 'm6A', 'm6A_bottom_strand', 'm6A_top_strand', '5mC', '5mC_bottom_strand', '5mC_top_strand', 'combined_bottom_strand', 'combined_top_strand']
172
- # Give indices of dictionaries to skip for analysis and final dictionary saving.
173
- dict_to_skip = [0, 1, 4]
174
- combined_dicts = [7, 8]
175
- A_stranded_dicts = [2, 3]
176
- C_stranded_dicts = [5, 6]
177
- dict_to_skip = dict_to_skip + combined_dicts + A_stranded_dicts + C_stranded_dicts
178
- dict_to_skip = set(dict_to_skip)
179
-
180
- # Load the dict_total dictionary with all of the tsv files as dataframes.
181
- for sample_index, tsv in tqdm(enumerate(tsv_batch), desc=f'Loading TSVs into dataframes and filtering on chromosome/position for batch {batch}', total=batch_size):
182
- #print('{0}: Loading sample tsv {1} into dataframe'.format(readwrite.time_string(), tsv))
183
- temp_df = pd.read_csv(tsv, sep='\t', header=0)
184
- for record in records_to_analyze:
185
- if record not in dict_total.keys():
186
- dict_total[record] = {}
187
- # Only keep the reads aligned to the chromosomes of interest
188
- #print('{0}: Filtering sample dataframe to keep chromosome of interest'.format(readwrite.time_string()))
189
- dict_total[record][sample_index] = temp_df[temp_df['chrom'] == record]
190
- # Only keep the read positions that fall within the region of interest
191
- #print('{0}: Filtering sample dataframe to keep positions falling within region of interest'.format(readwrite.time_string()))
192
- current_reference_length = reference_dict[record][0]
193
- dict_total[record][sample_index] = dict_total[record][sample_index][(current_reference_length > dict_total[record][sample_index]['ref_position']) & (dict_total[record][sample_index]['ref_position']>= 0)]
194
-
195
- # Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
196
- for record in dict_total.keys():
197
- for sample_index in dict_total[record].keys():
198
- if '6mA' in mods:
199
- # Remove Adenine stranded dicts from the dicts to skip set
200
- dict_to_skip.difference_update(set(A_stranded_dicts))
201
-
202
- if record not in dict_a.keys() and record not in dict_a_bottom.keys() and record not in dict_a_top.keys():
203
- dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
204
-
205
- # get a dictionary of dataframes that only contain methylated adenine positions
206
- dict_a[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'A']
207
- print('{}: Successfully loaded a methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
208
- # Stratify the adenine dictionary into two strand specific dictionaries.
209
- dict_a_bottom[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '-']
210
- print('{}: Successfully loaded a minus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
211
- dict_a_top[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '+']
212
- print('{}: Successfully loaded a plus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
213
-
214
- # Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
215
- dict_a[record][sample_index] = None
216
- gc.collect()
217
-
218
- if '5mC' in mods:
219
- # Remove Cytosine stranded dicts from the dicts to skip set
220
- dict_to_skip.difference_update(set(C_stranded_dicts))
221
-
222
- if record not in dict_c.keys() and record not in dict_c_bottom.keys() and record not in dict_c_top.keys():
223
- dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
224
-
225
- # get a dictionary of dataframes that only contain methylated cytosine positions
226
- dict_c[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'C']
227
- print('{}: Successfully loaded a methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
228
- # Stratify the cytosine dictionary into two strand specific dictionaries.
229
- dict_c_bottom[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '-']
230
- print('{}: Successfully loaded a minus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
231
- dict_c_top[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '+']
232
- print('{}: Successfully loaded a plus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
233
- # In the strand specific dictionaries, only keep positions that are informative for GpC SMF
234
-
235
- # Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
236
- dict_c[record][sample_index] = None
237
- gc.collect()
238
-
239
- if '6mA' in mods and '5mC' in mods:
240
- # Remove combined stranded dicts from the dicts to skip set
241
- dict_to_skip.difference_update(set(combined_dicts))
242
- # Initialize the sample keys for the combined dictionaries
243
-
244
- if record not in dict_combined_bottom.keys() and record not in dict_combined_top.keys():
245
- dict_combined_bottom[record], dict_combined_top[record]= {}, {}
246
-
247
- print('{}: Successfully created a minus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
248
- dict_combined_bottom[record][sample_index] = []
249
- print('{}: Successfully created a plus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
250
- dict_combined_top[record][sample_index] = []
251
-
252
- # Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
253
- dict_total[record][sample_index] = None
475
+ # Begin iterating over batches
476
+ for batch in range(batches):
477
+ print('{0}: Processing tsvs for batch {1} '.format(readwrite.time_string(), batch))
478
+ # For the final batch, just take the remaining tsv and bam files
479
+ if batch == batches - 1:
480
+ tsv_batch = tsv_path_list
481
+ bam_batch = bam_path_list
482
+ # For all other batches, take the next batch of tsvs and bams out of the file queue.
483
+ else:
484
+ tsv_batch = tsv_path_list[:batch_size]
485
+ bam_batch = bam_path_list[:batch_size]
486
+ tsv_path_list = tsv_path_list[batch_size:]
487
+ bam_path_list = bam_path_list[batch_size:]
488
+ print('{0}: tsvs in batch {1} '.format(readwrite.time_string(), tsv_batch))
489
+
490
+ batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5])
491
+ ###################################################
492
+ if batch_already_processed:
493
+ print(f'Batch {batch} has already been processed into h5ads. Skipping batch and using existing files')
494
+ else:
495
+ ###################################################
496
+ ### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
497
+ # # Initialize dictionaries and place them in a list
498
+ dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top = {},{},{},{},{},{},{},{},{}
499
+ dict_list = [dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top]
500
+ # Give names to represent each dictionary in the list
501
+ sample_types = ['total', 'm6A', 'm6A_bottom_strand', 'm6A_top_strand', '5mC', '5mC_bottom_strand', '5mC_top_strand', 'combined_bottom_strand', 'combined_top_strand']
502
+ # Give indices of dictionaries to skip for analysis and final dictionary saving.
503
+ dict_to_skip = [0, 1, 4]
504
+ combined_dicts = [7, 8]
505
+ A_stranded_dicts = [2, 3]
506
+ C_stranded_dicts = [5, 6]
507
+ dict_to_skip = dict_to_skip + combined_dicts + A_stranded_dicts + C_stranded_dicts
508
+ dict_to_skip = set(dict_to_skip)
509
+
510
+ # # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
511
+ dict_total = parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size=len(tsv_batch), threads=threads)
512
+
513
+ # # Step 2: Extract modification-specific data (per (record,sample)) in parallel
514
+ # processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
515
+ # (m6A_dict, m6A_minus_strand, m6A_plus_strand,
516
+ # c5m_dict, c5m_minus_strand, c5m_plus_strand,
517
+ # combined_minus_strand, combined_plus_strand) = merge_modification_results(processed_mod_results, mods)
518
+
519
+ # # Create dict_list with the desired ordering:
520
+ # # 0: dict_total, 1: m6A, 2: m6A_minus, 3: m6A_plus, 4: 5mC, 5: 5mC_minus, 6: 5mC_plus, 7: combined_minus, 8: combined_plus
521
+ # dict_list = [dict_total, m6A_dict, m6A_minus_strand, m6A_plus_strand,
522
+ # c5m_dict, c5m_minus_strand, c5m_plus_strand,
523
+ # combined_minus_strand, combined_plus_strand]
524
+
525
+ # # Initialize dict_to_skip (default skip all mod-specific indices)
526
+ # dict_to_skip = set([0, 1, 4, 7, 8, 2, 3, 5, 6])
527
+ # # Update dict_to_skip based on modifications present in mods
528
+ # dict_to_skip = update_dict_to_skip(dict_to_skip, mods)
529
+
530
+ # # Step 3: Process stranded methylation data in parallel
531
+ # dict_list = parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=threads or 4)
532
+
533
+ # Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
534
+ for record in dict_total.keys():
535
+ for sample_index in dict_total[record].keys():
536
+ if '6mA' in mods:
537
+ # Remove Adenine stranded dicts from the dicts to skip set
538
+ dict_to_skip.difference_update(set(A_stranded_dicts))
539
+
540
+ if record not in dict_a.keys() and record not in dict_a_bottom.keys() and record not in dict_a_top.keys():
541
+ dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
542
+
543
+ # get a dictionary of dataframes that only contain methylated adenine positions
544
+ dict_a[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'A']
545
+ print('{}: Successfully loaded a methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
546
+ # Stratify the adenine dictionary into two strand specific dictionaries.
547
+ dict_a_bottom[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '-']
548
+ print('{}: Successfully loaded a minus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
549
+ dict_a_top[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '+']
550
+ print('{}: Successfully loaded a plus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
551
+
552
+ # Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
553
+ dict_a[record][sample_index] = None
254
554
  gc.collect()
255
555
 
256
- # Iterate over the stranded modification dictionaries and replace the dataframes with a dictionary of read names pointing to a list of values from the dataframe
257
- for dict_index, dict_type in enumerate(dict_list):
258
- # Only iterate over stranded dictionaries
259
- if dict_index not in dict_to_skip:
260
- print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), sample_types[dict_index]))
261
- for record in dict_type.keys():
262
- # Get the dictionary for the modification type of interest from the reference mapping of interest
263
- mod_strand_record_sample_dict = dict_type[record]
264
- print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), record))
265
- # For each sample in a stranded dictionary
266
- n_samples = len(mod_strand_record_sample_dict.keys())
267
- for sample in tqdm(mod_strand_record_sample_dict.keys(), desc=f'Extracting {sample_types[dict_index]} dictionary from record {record} for sample', total=n_samples):
268
- # Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
269
- if dict_index == 7:
270
- # Load the minus strand dictionaries for each sample into temporary variables
271
- temp_a_dict = dict_list[2][record][sample].copy()
272
- temp_c_dict = dict_list[5][record][sample].copy()
273
- mod_strand_record_sample_dict[sample] = {}
274
- # Iterate over the reads present in the merge of both dictionaries
275
- for read in set(temp_a_dict) | set(temp_c_dict):
276
- # Add the arrays element-wise if the read is present in both dictionaries
277
- if read in temp_a_dict and read in temp_c_dict:
278
- mod_strand_record_sample_dict[sample][read] = np.nansum([temp_a_dict[read], temp_c_dict[read]], axis=0)
279
- # If the read is present in only one dictionary, copy its value
280
- elif read in temp_a_dict:
281
- mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
282
- elif read in temp_c_dict:
283
- mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
284
- del temp_a_dict, temp_c_dict
285
- # Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
286
- elif dict_index == 8:
287
- # Load the plus strand dictionaries for each sample into temporary variables
288
- temp_a_dict = dict_list[3][record][sample].copy()
289
- temp_c_dict = dict_list[6][record][sample].copy()
290
- mod_strand_record_sample_dict[sample] = {}
291
- # Iterate over the reads present in the merge of both dictionaries
292
- for read in set(temp_a_dict) | set(temp_c_dict):
293
- # Add the arrays element-wise if the read is present in both dictionaries
294
- if read in temp_a_dict and read in temp_c_dict:
295
- mod_strand_record_sample_dict[sample][read] = np.nansum([temp_a_dict[read], temp_c_dict[read]], axis=0)
296
- # If the read is present in only one dictionary, copy its value
297
- elif read in temp_a_dict:
298
- mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
299
- elif read in temp_c_dict:
300
- mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
301
- del temp_a_dict, temp_c_dict
302
- # For all other dictionaries
303
- else:
304
- # use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
305
- temp_df = mod_strand_record_sample_dict[sample]
306
- # reassign the dictionary pointer to a nested dictionary.
307
- mod_strand_record_sample_dict[sample] = {}
308
- # # Iterate through rows in the temp DataFrame
309
- for index, row in temp_df.iterrows():
310
- read = row['read_id'] # read name
311
- position = row['ref_position'] # 1-indexed positional coordinate
312
- probability = row['call_prob'] # Get the probability of the given call
313
- # if the call_code is modified change methylated value to the probability of methylation
314
- if (row['call_code'] in ['a', 'h', 'm']):
315
- methylated = probability
316
- # If the call code is canonical, change the methylated value to 1 - the probability of canonical
317
- elif (row['call_code'] in ['-']):
318
- methylated = 1 - probability
319
-
320
- # If the current read is not in the dictionary yet, initalize the dictionary with a nan filled numpy array of proper size.
321
- if read not in mod_strand_record_sample_dict[sample]:
322
- mod_strand_record_sample_dict[sample][read] = np.full(max_reference_length, np.nan)
323
-
324
- # add the positional methylation state to the numpy array
325
- mod_strand_record_sample_dict[sample][read][position-1] = methylated
326
-
327
- # Save the sample files in the batch as gzipped hdf5 files
328
- os.chdir(h5_dir)
329
- print('{0}: Converting batch {1} dictionaries to anndata objects'.format(readwrite.time_string(), batch))
330
- for dict_index, dict_type in enumerate(dict_list):
331
- if dict_index not in dict_to_skip:
332
- # Initialize an hdf5 file for the current modified strand
333
- adata = None
334
- print('{0}: Converting {1} dictionary to an anndata object'.format(readwrite.time_string(), sample_types[dict_index]))
335
- for record in dict_type.keys():
336
- # Get the dictionary for the modification type of interest from the reference mapping of interest
337
- mod_strand_record_sample_dict = dict_type[record]
338
- for sample in mod_strand_record_sample_dict.keys():
339
- print('{0}: Converting {1} dictionary for sample {2} to an anndata object'.format(readwrite.time_string(), sample_types[dict_index], sample))
340
- sample = int(sample)
341
- final_sample_index = sample + (batch * batch_size)
342
- print('{0}: Final sample index for sample: {1}'.format(readwrite.time_string(), final_sample_index))
343
- print('{0}: Converting {1} dictionary for sample {2} to a dataframe'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
344
- temp_df = pd.DataFrame.from_dict(mod_strand_record_sample_dict[sample], orient='index')
345
- mod_strand_record_sample_dict[sample] = None # reassign pointer to facilitate memory usage
346
- sorted_index = sorted(temp_df.index)
347
- temp_df = temp_df.reindex(sorted_index)
348
- X = temp_df.values
349
-
350
- print('{0}: Loading {1} dataframe for sample {2} into a temp anndata object'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
351
- temp_adata = ad.AnnData(X, dtype=X.dtype)
352
- if temp_adata.shape[0] > 0:
353
- print('{0}: Adding read names and position ids to {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
354
- temp_adata.obs_names = temp_df.index
355
- temp_adata.obs_names = temp_adata.obs_names.astype(str)
356
- temp_adata.var_names = temp_df.columns
357
- temp_adata.var_names = temp_adata.var_names.astype(str)
358
- print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
359
- temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
360
- dataset, strand = sample_types[dict_index].split('_')[:2]
361
- temp_adata.obs['Strand'] = [strand] * len(temp_adata)
362
- temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)
363
- temp_adata.obs['Reference'] = [f'{record}_{dataset}_{strand}'] * len(temp_adata)
364
- temp_adata.obs['Reference_chromosome'] = [f'{record}'] * len(temp_adata)
365
-
366
- # Load in the one hot encoded reads from the current sample and record
367
- one_hot_reads = {}
368
- n_rows_OHE = 5
369
- ohe_files = bam_record_ohe_files[f'{final_sample_index}_{record}']
370
- print(f'Loading OHEs from {ohe_files}')
371
- fwd_mapped_reads = set()
372
- rev_mapped_reads = set()
373
- for ohe_file in ohe_files:
374
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
375
- one_hot_reads.update(tmp_ohe_dict)
376
- if '_fwd_' in ohe_file:
377
- fwd_mapped_reads.update(tmp_ohe_dict.keys())
378
- elif '_rev_' in ohe_file:
379
- rev_mapped_reads.update(tmp_ohe_dict.keys())
380
- del tmp_ohe_dict
381
-
382
- read_names = list(one_hot_reads.keys())
383
-
384
- read_mapping_direction = []
385
- for read_id in temp_adata.obs_names:
386
- if read_id in fwd_mapped_reads:
387
- read_mapping_direction.append('fwd')
388
- elif read_id in rev_mapped_reads:
389
- read_mapping_direction.append('rev')
390
- else:
391
- read_mapping_direction.append('unk')
392
-
393
- temp_adata.obs['Read_mapping_direction'] = read_mapping_direction
394
-
395
- del temp_df
556
+ if '5mC' in mods:
557
+ # Remove Cytosine stranded dicts from the dicts to skip set
558
+ dict_to_skip.difference_update(set(C_stranded_dicts))
559
+
560
+ if record not in dict_c.keys() and record not in dict_c_bottom.keys() and record not in dict_c_top.keys():
561
+ dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
562
+
563
+ # get a dictionary of dataframes that only contain methylated cytosine positions
564
+ dict_c[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'C']
565
+ print('{}: Successfully loaded a methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
566
+ # Stratify the cytosine dictionary into two strand specific dictionaries.
567
+ dict_c_bottom[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '-']
568
+ print('{}: Successfully loaded a minus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
569
+ dict_c_top[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '+']
570
+ print('{}: Successfully loaded a plus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
571
+ # Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
572
+ dict_c[record][sample_index] = None
573
+ gc.collect()
574
+
575
+ if '6mA' in mods and '5mC' in mods:
576
+ # Remove combined stranded dicts from the dicts to skip set
577
+ dict_to_skip.difference_update(set(combined_dicts))
578
+ # Initialize the sample keys for the combined dictionaries
579
+
580
+ if record not in dict_combined_bottom.keys() and record not in dict_combined_top.keys():
581
+ dict_combined_bottom[record], dict_combined_top[record]= {}, {}
582
+
583
+ print('{}: Successfully created a minus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
584
+ dict_combined_bottom[record][sample_index] = []
585
+ print('{}: Successfully created a plus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
586
+ dict_combined_top[record][sample_index] = []
587
+
588
+ # Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
589
+ dict_total[record][sample_index] = None
590
+ gc.collect()
591
+
592
+ # Iterate over the stranded modification dictionaries and replace the dataframes with a dictionary of read names pointing to a list of values from the dataframe
593
+ for dict_index, dict_type in enumerate(dict_list):
594
+ # Only iterate over stranded dictionaries
595
+ if dict_index not in dict_to_skip:
596
+ print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), sample_types[dict_index]))
597
+ for record in dict_type.keys():
598
+ # Get the dictionary for the modification type of interest from the reference mapping of interest
599
+ mod_strand_record_sample_dict = dict_type[record]
600
+ print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), record))
601
+ # For each sample in a stranded dictionary
602
+ n_samples = len(mod_strand_record_sample_dict.keys())
603
+ for sample in tqdm(mod_strand_record_sample_dict.keys(), desc=f'Extracting {sample_types[dict_index]} dictionary from record {record} for sample', total=n_samples):
604
+ # Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
605
+ if dict_index == 7:
606
+ # Load the minus strand dictionaries for each sample into temporary variables
607
+ temp_a_dict = dict_list[2][record][sample].copy()
608
+ temp_c_dict = dict_list[5][record][sample].copy()
609
+ mod_strand_record_sample_dict[sample] = {}
610
+ # Iterate over the reads present in the merge of both dictionaries
611
+ for read in set(temp_a_dict) | set(temp_c_dict):
612
+ # Add the arrays element-wise if the read is present in both dictionaries
613
+ if read in temp_a_dict and read in temp_c_dict:
614
+ mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
615
+ # If the read is present in only one dictionary, copy its value
616
+ elif read in temp_a_dict:
617
+ mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
618
+ elif read in temp_c_dict:
619
+ mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
620
+ del temp_a_dict, temp_c_dict
621
+ # Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
622
+ elif dict_index == 8:
623
+ # Load the plus strand dictionaries for each sample into temporary variables
624
+ temp_a_dict = dict_list[3][record][sample].copy()
625
+ temp_c_dict = dict_list[6][record][sample].copy()
626
+ mod_strand_record_sample_dict[sample] = {}
627
+ # Iterate over the reads present in the merge of both dictionaries
628
+ for read in set(temp_a_dict) | set(temp_c_dict):
629
+ # Add the arrays element-wise if the read is present in both dictionaries
630
+ if read in temp_a_dict and read in temp_c_dict:
631
+ mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
632
+ # If the read is present in only one dictionary, copy its value
633
+ elif read in temp_a_dict:
634
+ mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
635
+ elif read in temp_c_dict:
636
+ mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
637
+ del temp_a_dict, temp_c_dict
638
+ # For all other dictionaries
639
+ else:
640
+
641
+ # use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
642
+ temp_df = mod_strand_record_sample_dict[sample]
643
+ # reassign the dictionary pointer to a nested dictionary.
644
+ mod_strand_record_sample_dict[sample] = {}
645
+
646
+ # Get relevant columns as NumPy arrays
647
+ read_ids = temp_df['read_id'].values
648
+ positions = temp_df['ref_position'].values
649
+ call_codes = temp_df['call_code'].values
650
+ probabilities = temp_df['call_prob'].values
651
+
652
+ # Define valid call code categories
653
+ modified_codes = {'a', 'h', 'm'}
654
+ canonical_codes = {'-'}
655
+
656
+ # Vectorized methylation calculation with NaN for other codes
657
+ methylation_prob = np.full_like(probabilities, np.nan) # Default all to NaN
658
+ methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
659
+ methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
660
+
661
+ # Find unique reads
662
+ unique_reads = np.unique(read_ids)
663
+ # Preallocate storage for each read
664
+ for read in unique_reads:
665
+ mod_strand_record_sample_dict[sample][read] = np.full(max_reference_length, np.nan)
666
+
667
+ # Efficient NumPy indexing to assign values
668
+ for i in range(len(read_ids)):
669
+ read = read_ids[i]
670
+ pos = positions[i]
671
+ prob = methylation_prob[i]
396
672
 
397
- dict_A, dict_C, dict_G, dict_T, dict_N = {}, {}, {}, {}, {}
398
- sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
399
- df_A = pd.DataFrame(0, index=sorted_index, columns=range(sequence_length))
400
- df_C = pd.DataFrame(0, index=sorted_index, columns=range(sequence_length))
401
- df_G = pd.DataFrame(0, index=sorted_index, columns=range(sequence_length))
402
- df_T = pd.DataFrame(0, index=sorted_index, columns=range(sequence_length))
403
- df_N = pd.DataFrame(0, index=sorted_index, columns=range(sequence_length))
404
-
405
- for read_name, one_hot_array in one_hot_reads.items():
406
- one_hot_array = one_hot_array.reshape(n_rows_OHE, -1)
407
- dict_A[read_name] = one_hot_array[0, :]
408
- dict_C[read_name] = one_hot_array[1, :]
409
- dict_G[read_name] = one_hot_array[2, :]
410
- dict_T[read_name] = one_hot_array[3, :]
411
- dict_N[read_name] = one_hot_array[4, :]
412
-
413
- del one_hot_reads
414
- gc.collect()
415
-
416
- for j, read_name in tqdm(enumerate(sorted_index), desc='Loading dataframes of OHE reads', total=len(sorted_index)):
417
- df_A.iloc[j] = dict_A[read_name]
418
- df_C.iloc[j] = dict_C[read_name]
419
- df_G.iloc[j] = dict_G[read_name]
420
- df_T.iloc[j] = dict_T[read_name]
421
- df_N.iloc[j] = dict_N[read_name]
422
-
423
- del dict_A, dict_C, dict_G, dict_T, dict_N
424
- gc.collect()
425
-
426
- ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
427
-
428
- for j, base in enumerate(['A', 'C', 'G', 'T', 'N']):
429
- temp_adata.layers[f'{base}_binary_encoding'] = ohe_df_map[j].values
430
- ohe_df_map[j] = None # Reassign pointer for memory usage purposes
431
-
432
- # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
433
- if adata:
434
- if temp_adata.shape[0] > 0:
435
- print('{0}: Concatenating {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
436
- adata = ad.concat([adata, temp_adata], join='outer', index_unique=None)
437
- del temp_adata
438
- else:
439
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
673
+ # Assign methylation probability
674
+ mod_strand_record_sample_dict[sample][read][pos] = prob
675
+
676
+
677
+ # Save the sample files in the batch as gzipped hdf5 files
678
+ os.chdir(h5_dir)
679
+ print('{0}: Converting batch {1} dictionaries to anndata objects'.format(readwrite.time_string(), batch))
680
+ for dict_index, dict_type in enumerate(dict_list):
681
+ if dict_index not in dict_to_skip:
682
+ # Initialize an hdf5 file for the current modified strand
683
+ adata = None
684
+ print('{0}: Converting {1} dictionary to an anndata object'.format(readwrite.time_string(), sample_types[dict_index]))
685
+ for record in dict_type.keys():
686
+ # Get the dictionary for the modification type of interest from the reference mapping of interest
687
+ mod_strand_record_sample_dict = dict_type[record]
688
+ for sample in mod_strand_record_sample_dict.keys():
689
+ print('{0}: Converting {1} dictionary for sample {2} to an anndata object'.format(readwrite.time_string(), sample_types[dict_index], sample))
690
+ sample = int(sample)
691
+ final_sample_index = sample + (batch * batch_size)
692
+ print('{0}: Final sample index for sample: {1}'.format(readwrite.time_string(), final_sample_index))
693
+ print('{0}: Converting {1} dictionary for sample {2} to a dataframe'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
694
+ temp_df = pd.DataFrame.from_dict(mod_strand_record_sample_dict[sample], orient='index')
695
+ mod_strand_record_sample_dict[sample] = None # reassign pointer to facilitate memory usage
696
+ sorted_index = sorted(temp_df.index)
697
+ temp_df = temp_df.reindex(sorted_index)
698
+ X = temp_df.values
699
+ dataset, strand = sample_types[dict_index].split('_')[:2]
700
+
701
+ print('{0}: Loading {1} dataframe for sample {2} into a temp anndata object'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
702
+ temp_adata = ad.AnnData(X)
703
+ if temp_adata.shape[0] > 0:
704
+ print('{0}: Adding read names and position ids to {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
705
+ temp_adata.obs_names = temp_df.index
706
+ temp_adata.obs_names = temp_adata.obs_names.astype(str)
707
+ temp_adata.var_names = temp_df.columns
708
+ temp_adata.var_names = temp_adata.var_names.astype(str)
709
+ print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
710
+ temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
711
+ temp_adata.obs['Reference'] = [f'{record}'] * len(temp_adata)
712
+ temp_adata.obs['Strand'] = [strand] * len(temp_adata)
713
+ temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)
714
+ temp_adata.obs['Reference_dataset_strand'] = [f'{record}_{dataset}_{strand}'] * len(temp_adata)
715
+ temp_adata.obs['Reference_strand'] = [f'{record}_{strand}'] * len(temp_adata)
716
+
717
+ # Load in the one hot encoded reads from the current sample and record
718
+ one_hot_reads = {}
719
+ n_rows_OHE = 5
720
+ ohe_files = bam_record_ohe_files[f'{final_sample_index}_{record}']
721
+ print(f'Loading OHEs from {ohe_files}')
722
+ fwd_mapped_reads = set()
723
+ rev_mapped_reads = set()
724
+ for ohe_file in ohe_files:
725
+ tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
726
+ one_hot_reads.update(tmp_ohe_dict)
727
+ if '_fwd_' in ohe_file:
728
+ fwd_mapped_reads.update(tmp_ohe_dict.keys())
729
+ elif '_rev_' in ohe_file:
730
+ rev_mapped_reads.update(tmp_ohe_dict.keys())
731
+ del tmp_ohe_dict
732
+
733
+ read_names = list(one_hot_reads.keys())
734
+
735
+ read_mapping_direction = []
736
+ for read_id in temp_adata.obs_names:
737
+ if read_id in fwd_mapped_reads:
738
+ read_mapping_direction.append('fwd')
739
+ elif read_id in rev_mapped_reads:
740
+ read_mapping_direction.append('rev')
440
741
  else:
441
- if temp_adata.shape[0] > 0:
442
- print('{0}: Initializing {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
443
- adata = temp_adata
444
- else:
445
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
446
-
447
- gc.collect()
742
+ read_mapping_direction.append('unk')
743
+
744
+ temp_adata.obs['Read_mapping_direction'] = read_mapping_direction
745
+
746
+ del temp_df
747
+
748
+ # Initialize NumPy arrays
749
+ sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
750
+ df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
751
+ df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
752
+ df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
753
+ df_T = np.zeros((len(sorted_index), sequence_length), dtype=int)
754
+ df_N = np.zeros((len(sorted_index), sequence_length), dtype=int)
755
+
756
+ # Process one-hot data into dictionaries
757
+ dict_A, dict_C, dict_G, dict_T, dict_N = {}, {}, {}, {}, {}
758
+ for read_name, one_hot_array in one_hot_reads.items():
759
+ one_hot_array = one_hot_array.reshape(n_rows_OHE, -1)
760
+ dict_A[read_name] = one_hot_array[0, :]
761
+ dict_C[read_name] = one_hot_array[1, :]
762
+ dict_G[read_name] = one_hot_array[2, :]
763
+ dict_T[read_name] = one_hot_array[3, :]
764
+ dict_N[read_name] = one_hot_array[4, :]
765
+
766
+ del one_hot_reads
767
+ gc.collect()
768
+
769
+ # Fill the arrays
770
+ for j, read_name in tqdm(enumerate(sorted_index), desc='Loading dataframes of OHE reads', total=len(sorted_index)):
771
+ df_A[j, :] = dict_A[read_name]
772
+ df_C[j, :] = dict_C[read_name]
773
+ df_G[j, :] = dict_G[read_name]
774
+ df_T[j, :] = dict_T[read_name]
775
+ df_N[j, :] = dict_N[read_name]
776
+
777
+ del dict_A, dict_C, dict_G, dict_T, dict_N
778
+ gc.collect()
779
+
780
+ # Store the results in AnnData layers
781
+ ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
782
+ for j, base in enumerate(['A', 'C', 'G', 'T', 'N']):
783
+ temp_adata.layers[f'{base}_binary_encoding'] = ohe_df_map[j]
784
+ ohe_df_map[j] = None # Reassign pointer for memory usage purposes
785
+
786
+ # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
787
+ if adata:
788
+ if temp_adata.shape[0] > 0:
789
+ print('{0}: Concatenating {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
790
+ adata = ad.concat([adata, temp_adata], join='outer', index_unique=None)
791
+ del temp_adata
792
+ else:
793
+ print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
448
794
  else:
449
- print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample.")
795
+ if temp_adata.shape[0] > 0:
796
+ print('{0}: Initializing {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
797
+ adata = temp_adata
798
+ else:
799
+ print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
800
+
801
+ gc.collect()
802
+ else:
803
+ print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample.")
450
804
 
451
- print('{0}: Writing {1} anndata out as a gzipped hdf5 file'.format(readwrite.time_string(), sample_types[dict_index]))
805
+ try:
806
+ print('{0}: Writing {1} anndata out as a hdf5 file'.format(readwrite.time_string(), sample_types[dict_index]))
452
807
  adata.write_h5ad('{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz'.format(readwrite.date_string(), batch, sample_types[dict_index]), compression='gzip')
808
+ except:
809
+ print(f"Skipping writing anndata for sample")
453
810
 
454
- # Delete the batch dictionaries from memory
455
- del dict_list, adata
456
- gc.collect()
457
-
458
- # Iterate over all of the batched hdf5 files and concatenate them.
459
- os.chdir(h5_dir)
460
- files = os.listdir(h5_dir)
461
- # Filter file names that contain the search string in their filename and keep them in a list
462
- hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
463
- # Sort file list by names and print the list of file names
464
- hdfs.sort()
465
- print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
466
- hdf_paths = [os.path.join(h5_dir, hd5) for hd5 in hdfs]
467
- final_adata = None
468
- for hdf_index, hdf in enumerate(hdf_paths):
469
- print('{0}: Reading in {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
470
- temp_adata = ad.read_h5ad(hdf)
471
- if final_adata:
472
- print('{0}: Concatenating final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdf[hdf_index]))
473
- final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
474
- else:
475
- print('{0}: Initializing final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdf[hdf_index]))
476
- final_adata = temp_adata
477
- del temp_adata
478
-
479
- # Set obs columns to type 'category'
480
- for col in final_adata.obs.columns:
481
- final_adata.obs[col] = final_adata.obs[col].astype('category')
482
-
483
- for record in records_to_analyze:
484
- # Add FASTA sequence to the object
485
- sequence = record_seq_dict[record][0]
486
- complement = record_seq_dict[record][1]
487
- final_adata.var[f'{record}_top_strand_FASTA_base_at_coordinate'] = list(sequence)
488
- final_adata.var[f'{record}_bottom_strand_FASTA_base_at_coordinate'] = list(complement)
489
- final_adata.uns[f'{record}_FASTA_sequence'] = sequence
490
- # Add consensus sequence of samples mapped to the record to the object
491
- record_subset = final_adata[final_adata.obs['Reference_chromosome'] == record].copy()
492
- for strand in record_subset.obs['Strand'].cat.categories:
493
- strand_subset = record_subset[record_subset.obs['Strand'] == strand].copy()
494
- for mapping_dir in strand_subset.obs['Read_mapping_direction'].cat.categories:
495
- mapping_dir_subset = strand_subset[strand_subset.obs['Read_mapping_direction'] == mapping_dir].copy()
496
- layer_map, layer_counts = {}, []
497
- for i, layer in enumerate(mapping_dir_subset.layers):
498
- layer_map[i] = layer.split('_')[0]
499
- layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
500
- count_array = np.array(layer_counts)
501
- nucleotide_indexes = np.argmax(count_array, axis=0)
502
- consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
503
- final_adata.var[f'{record}_{strand}_strand_{mapping_dir}_mapping_dir_consensus_from_all_samples'] = consensus_sequence_list
504
-
505
- final_adata.write_h5ad(os.path.join(h5_dir, final_hdf), compression='gzip')
506
-
507
- # Delete the individual h5ad files and only keep the final concatenated file
508
- if delete_batch_hdfs:
509
- files = os.listdir(h5_dir)
510
- hdfs_to_delete = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
511
- hdf_paths_to_delete = [os.path.join(h5_dir, hdf) for hdf in hdfs_to_delete]
512
- # Iterate over the files and delete them
513
- for hdf in hdf_paths_to_delete:
514
- try:
515
- os.remove(hdf)
516
- print(f"Deleted file: {hdf}")
517
- except OSError as e:
518
- print(f"Error deleting file {hdf}: {e}")
811
+ # Delete the batch dictionaries from memory
812
+ del dict_list, adata
813
+ gc.collect()
814
+
815
+ # Iterate over all of the batched hdf5 files and concatenate them.
816
+ os.chdir(h5_dir)
817
+ files = os.listdir(h5_dir)
818
+ # Filter file names that contain the search string in their filename and keep them in a list
819
+ hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
820
+ combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf]
821
+ if len(combined_hdfs) > 0:
822
+ hdfs = combined_hdfs
823
+ else:
824
+ pass
825
+ # Sort file list by names and print the list of file names
826
+ hdfs.sort()
827
+ print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
828
+ hdf_paths = [os.path.join(h5_dir, hd5) for hd5 in hdfs]
829
+ final_adata = None
830
+ for hdf_index, hdf in enumerate(hdf_paths):
831
+ print('{0}: Reading in {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
832
+ temp_adata = ad.read_h5ad(hdf)
833
+ if final_adata:
834
+ print('{0}: Concatenating final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
835
+ final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
836
+ else:
837
+ print('{0}: Initializing final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
838
+ final_adata = temp_adata
839
+ del temp_adata
840
+
841
+ # Set obs columns to type 'category'
842
+ for col in final_adata.obs.columns:
843
+ final_adata.obs[col] = final_adata.obs[col].astype('category')
844
+
845
+ ohe_bases = ['A', 'C', 'G', 'T'] # ignore N bases for consensus
846
+ ohe_layers = [f"{ohe_base}_binary_encoding" for ohe_base in ohe_bases]
847
+ for record in records_to_analyze:
848
+ # Add FASTA sequence to the object
849
+ sequence = record_seq_dict[record][0]
850
+ complement = record_seq_dict[record][1]
851
+ final_adata.var[f'{record}_top_strand_FASTA_base'] = list(sequence)
852
+ final_adata.var[f'{record}_bottom_strand_FASTA_base'] = list(complement)
853
+ final_adata.uns[f'{record}_FASTA_sequence'] = sequence
854
+ # Add consensus sequence of samples mapped to the record to the object
855
+ record_subset = final_adata[final_adata.obs['Reference'] == record]
856
+ for strand in record_subset.obs['Strand'].cat.categories:
857
+ strand_subset = record_subset[record_subset.obs['Strand'] == strand]
858
+ for mapping_dir in strand_subset.obs['Read_mapping_direction'].cat.categories:
859
+ mapping_dir_subset = strand_subset[strand_subset.obs['Read_mapping_direction'] == mapping_dir]
860
+ layer_map, layer_counts = {}, []
861
+ for i, layer in enumerate(ohe_layers):
862
+ layer_map[i] = layer.split('_')[0]
863
+ layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
864
+ count_array = np.array(layer_counts)
865
+ nucleotide_indexes = np.argmax(count_array, axis=0)
866
+ consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
867
+ final_adata.var[f'{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples'] = consensus_sequence_list
868
+
869
+ #final_adata.write_h5ad(final_adata_path)
870
+
871
+ # Delete the individual h5ad files and only keep the final concatenated file
872
+ if delete_batch_hdfs:
873
+ files = os.listdir(h5_dir)
874
+ hdfs_to_delete = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
875
+ hdf_paths_to_delete = [os.path.join(h5_dir, hdf) for hdf in hdfs_to_delete]
876
+ # Iterate over the files and delete them
877
+ for hdf in hdf_paths_to_delete:
878
+ try:
879
+ os.remove(hdf)
880
+ print(f"Deleted file: {hdf}")
881
+ except OSError as e:
882
+ print(f"Error deleting file {hdf}: {e}")
883
+
884
+ return final_adata, final_adata_path