smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,61 +1,65 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import time
|
|
3
|
-
import os
|
|
4
1
|
import gc
|
|
5
|
-
import pandas as pd
|
|
6
|
-
import anndata as ad
|
|
7
|
-
from tqdm import tqdm
|
|
8
2
|
import multiprocessing
|
|
9
|
-
|
|
3
|
+
import shutil
|
|
4
|
+
import time
|
|
10
5
|
import traceback
|
|
11
|
-
import
|
|
6
|
+
from multiprocessing import Manager, Pool, current_process
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Iterable, Optional, Union
|
|
9
|
+
|
|
10
|
+
import anndata as ad
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
12
13
|
import torch
|
|
13
14
|
|
|
14
|
-
import
|
|
15
|
-
from pathlib import Path
|
|
16
|
-
from typing import Union, Iterable, Optional
|
|
15
|
+
from smftools.logging_utils import get_logger
|
|
17
16
|
|
|
18
|
-
from ..readwrite import make_dirs
|
|
17
|
+
from ..readwrite import make_dirs
|
|
18
|
+
from .bam_functions import count_aligned_reads, extract_base_identities
|
|
19
19
|
from .binarize_converted_base_identities import binarize_converted_base_identities
|
|
20
20
|
from .fasta_functions import find_conversion_sites
|
|
21
|
-
from .bam_functions import count_aligned_reads, extract_base_identities
|
|
22
21
|
from .ohe import ohe_batching
|
|
23
22
|
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
24
25
|
if __name__ == "__main__":
|
|
25
26
|
multiprocessing.set_start_method("forkserver", force=True)
|
|
26
27
|
|
|
27
|
-
def converted_BAM_to_adata(converted_FASTA,
|
|
28
|
-
split_dir,
|
|
29
|
-
output_dir,
|
|
30
|
-
input_already_demuxed,
|
|
31
|
-
mapping_threshold,
|
|
32
|
-
experiment_name,
|
|
33
|
-
conversions,
|
|
34
|
-
bam_suffix,
|
|
35
|
-
device='cpu',
|
|
36
|
-
num_threads=8,
|
|
37
|
-
deaminase_footprinting=False,
|
|
38
|
-
delete_intermediates=True,
|
|
39
|
-
double_barcoded_path = None,
|
|
40
|
-
):
|
|
41
|
-
"""
|
|
42
|
-
Converts BAM files into an AnnData object by binarizing modified base identities.
|
|
43
28
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
29
|
+
def converted_BAM_to_adata(
|
|
30
|
+
converted_FASTA: str | Path,
|
|
31
|
+
split_dir: Path,
|
|
32
|
+
output_dir: Path,
|
|
33
|
+
input_already_demuxed: bool,
|
|
34
|
+
mapping_threshold: float,
|
|
35
|
+
experiment_name: str,
|
|
36
|
+
conversions: list[str],
|
|
37
|
+
bam_suffix: str,
|
|
38
|
+
device: str | torch.device = "cpu",
|
|
39
|
+
num_threads: int = 8,
|
|
40
|
+
deaminase_footprinting: bool = False,
|
|
41
|
+
delete_intermediates: bool = True,
|
|
42
|
+
double_barcoded_path: Path | None = None,
|
|
43
|
+
) -> tuple[ad.AnnData | None, Path]:
|
|
44
|
+
"""Convert BAM files into an AnnData object by binarizing modified base identities.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
converted_FASTA: Path to the converted FASTA reference.
|
|
48
|
+
split_dir: Directory containing converted BAM files.
|
|
49
|
+
output_dir: Output directory for intermediate and final files.
|
|
50
|
+
input_already_demuxed: Whether input reads were originally demultiplexed.
|
|
51
|
+
mapping_threshold: Minimum fraction of aligned reads required for inclusion.
|
|
52
|
+
experiment_name: Name for the output AnnData object.
|
|
53
|
+
conversions: List of modification types (e.g., ``["unconverted", "5mC", "6mA"]``).
|
|
54
|
+
bam_suffix: File suffix for BAM files.
|
|
55
|
+
device: Torch device or device string.
|
|
56
|
+
num_threads: Number of parallel processing threads.
|
|
57
|
+
deaminase_footprinting: Whether the footprinting used direct deamination chemistry.
|
|
58
|
+
delete_intermediates: Whether to remove intermediate files after processing.
|
|
59
|
+
double_barcoded_path: Path to dorado demux summary file of double-ended barcodes.
|
|
56
60
|
|
|
57
61
|
Returns:
|
|
58
|
-
|
|
62
|
+
tuple[anndata.AnnData | None, Path]: The AnnData object (if generated) and its path.
|
|
59
63
|
"""
|
|
60
64
|
if torch.cuda.is_available():
|
|
61
65
|
device = torch.device("cuda")
|
|
@@ -64,69 +68,88 @@ def converted_BAM_to_adata(converted_FASTA,
|
|
|
64
68
|
else:
|
|
65
69
|
device = torch.device("cpu")
|
|
66
70
|
|
|
67
|
-
|
|
71
|
+
logger.debug(f"Using device: {device}")
|
|
68
72
|
|
|
69
73
|
## Set Up Directories and File Paths
|
|
70
|
-
h5_dir = output_dir /
|
|
71
|
-
tmp_dir = output_dir /
|
|
74
|
+
h5_dir = output_dir / "h5ads"
|
|
75
|
+
tmp_dir = output_dir / "tmp"
|
|
72
76
|
final_adata = None
|
|
73
|
-
final_adata_path = h5_dir / f
|
|
77
|
+
final_adata_path = h5_dir / f"{experiment_name}.h5ad.gz"
|
|
74
78
|
|
|
75
79
|
if final_adata_path.exists():
|
|
76
|
-
|
|
80
|
+
logger.debug(f"{final_adata_path} already exists. Using existing AnnData object.")
|
|
77
81
|
return final_adata, final_adata_path
|
|
78
82
|
|
|
79
83
|
make_dirs([h5_dir, tmp_dir])
|
|
80
84
|
|
|
81
85
|
bam_files = sorted(
|
|
82
|
-
p
|
|
83
|
-
|
|
84
|
-
and p.suffix == ".bam"
|
|
85
|
-
and "unclassified" not in p.name
|
|
86
|
+
p
|
|
87
|
+
for p in split_dir.iterdir()
|
|
88
|
+
if p.is_file() and p.suffix == ".bam" and "unclassified" not in p.name
|
|
86
89
|
)
|
|
87
90
|
|
|
88
|
-
bam_path_list =
|
|
89
|
-
|
|
91
|
+
bam_path_list = bam_files
|
|
92
|
+
logger.info(f"Found {len(bam_files)} BAM files: {bam_files}")
|
|
90
93
|
|
|
91
94
|
## Process Conversion Sites
|
|
92
|
-
max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(
|
|
95
|
+
max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(
|
|
96
|
+
converted_FASTA, conversions, deaminase_footprinting
|
|
97
|
+
)
|
|
93
98
|
|
|
94
99
|
## Filter BAM Files by Mapping Threshold
|
|
95
|
-
records_to_analyze = filter_bams_by_mapping_threshold(
|
|
100
|
+
records_to_analyze = filter_bams_by_mapping_threshold(
|
|
101
|
+
bam_path_list, bam_files, mapping_threshold
|
|
102
|
+
)
|
|
96
103
|
|
|
97
104
|
## Process BAMs in Parallel
|
|
98
|
-
final_adata = process_bams_parallel(
|
|
105
|
+
final_adata = process_bams_parallel(
|
|
106
|
+
bam_path_list,
|
|
107
|
+
records_to_analyze,
|
|
108
|
+
record_FASTA_dict,
|
|
109
|
+
chromosome_FASTA_dict,
|
|
110
|
+
tmp_dir,
|
|
111
|
+
h5_dir,
|
|
112
|
+
num_threads,
|
|
113
|
+
max_reference_length,
|
|
114
|
+
device,
|
|
115
|
+
deaminase_footprinting,
|
|
116
|
+
)
|
|
99
117
|
|
|
100
|
-
final_adata.uns[
|
|
118
|
+
final_adata.uns["References"] = {}
|
|
101
119
|
for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
|
|
102
|
-
final_adata.var[f
|
|
103
|
-
final_adata.var[f
|
|
104
|
-
final_adata.uns[f
|
|
105
|
-
final_adata.uns[
|
|
120
|
+
final_adata.var[f"{chromosome}_top_strand_FASTA_base"] = list(seq)
|
|
121
|
+
final_adata.var[f"{chromosome}_bottom_strand_FASTA_base"] = list(comp)
|
|
122
|
+
final_adata.uns[f"{chromosome}_FASTA_sequence"] = seq
|
|
123
|
+
final_adata.uns["References"][f"{chromosome}_FASTA_sequence"] = seq
|
|
106
124
|
|
|
107
125
|
final_adata.obs_names_make_unique()
|
|
108
126
|
cols = final_adata.obs.columns
|
|
109
127
|
|
|
110
128
|
# Make obs cols categorical
|
|
111
129
|
for col in cols:
|
|
112
|
-
final_adata.obs[col] = final_adata.obs[col].astype(
|
|
130
|
+
final_adata.obs[col] = final_adata.obs[col].astype("category")
|
|
113
131
|
|
|
114
132
|
if input_already_demuxed:
|
|
115
133
|
final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
|
|
116
134
|
final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
|
|
117
135
|
else:
|
|
118
136
|
from .h5ad_functions import add_demux_type_annotation
|
|
137
|
+
|
|
119
138
|
double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
|
|
139
|
+
logger.info("Adding demux type to each read")
|
|
120
140
|
add_demux_type_annotation(final_adata, double_barcoded_reads)
|
|
121
141
|
|
|
122
142
|
## Delete intermediate h5ad files and temp directories
|
|
123
143
|
if delete_intermediates:
|
|
144
|
+
logger.info("Deleting intermediate h5ad files")
|
|
124
145
|
delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
|
|
125
|
-
|
|
146
|
+
|
|
126
147
|
return final_adata, final_adata_path
|
|
127
148
|
|
|
128
149
|
|
|
129
|
-
def process_conversion_sites(
|
|
150
|
+
def process_conversion_sites(
|
|
151
|
+
converted_FASTA, conversions=["unconverted", "5mC"], deaminase_footprinting=False
|
|
152
|
+
):
|
|
130
153
|
"""
|
|
131
154
|
Extracts conversion sites and determines the max reference length.
|
|
132
155
|
|
|
@@ -147,7 +170,9 @@ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC']
|
|
|
147
170
|
conversion_types = conversions[1:]
|
|
148
171
|
|
|
149
172
|
# Process the unconverted sequence once
|
|
150
|
-
modification_dict[unconverted] = find_conversion_sites(
|
|
173
|
+
modification_dict[unconverted] = find_conversion_sites(
|
|
174
|
+
converted_FASTA, unconverted, conversions, deaminase_footprinting
|
|
175
|
+
)
|
|
151
176
|
# Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
|
|
152
177
|
|
|
153
178
|
# Get **max sequence length** from unconverted records
|
|
@@ -166,15 +191,25 @@ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC']
|
|
|
166
191
|
record_FASTA_dict[record] = [
|
|
167
192
|
sequence + "N" * (max_reference_length - sequence_length),
|
|
168
193
|
complement + "N" * (max_reference_length - sequence_length),
|
|
169
|
-
chromosome,
|
|
194
|
+
chromosome,
|
|
195
|
+
record,
|
|
196
|
+
sequence_length,
|
|
197
|
+
max_reference_length - sequence_length,
|
|
198
|
+
unconverted,
|
|
199
|
+
"top",
|
|
170
200
|
]
|
|
171
201
|
|
|
172
202
|
if chromosome not in chromosome_FASTA_dict:
|
|
173
|
-
chromosome_FASTA_dict[chromosome] = [
|
|
203
|
+
chromosome_FASTA_dict[chromosome] = [
|
|
204
|
+
sequence + "N" * (max_reference_length - sequence_length),
|
|
205
|
+
complement + "N" * (max_reference_length - sequence_length),
|
|
206
|
+
]
|
|
174
207
|
|
|
175
208
|
# Process converted records
|
|
176
209
|
for conversion in conversion_types:
|
|
177
|
-
modification_dict[conversion] = find_conversion_sites(
|
|
210
|
+
modification_dict[conversion] = find_conversion_sites(
|
|
211
|
+
converted_FASTA, conversion, conversions, deaminase_footprinting
|
|
212
|
+
)
|
|
178
213
|
# Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
|
|
179
214
|
|
|
180
215
|
for record, values in modification_dict[conversion].items():
|
|
@@ -193,11 +228,15 @@ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC']
|
|
|
193
228
|
record_FASTA_dict[converted_name] = [
|
|
194
229
|
sequence + "N" * (max_reference_length - sequence_length),
|
|
195
230
|
complement + "N" * (max_reference_length - sequence_length),
|
|
196
|
-
chromosome,
|
|
197
|
-
|
|
231
|
+
chromosome,
|
|
232
|
+
unconverted_name,
|
|
233
|
+
sequence_length,
|
|
234
|
+
max_reference_length - sequence_length,
|
|
235
|
+
conversion,
|
|
236
|
+
strand,
|
|
198
237
|
]
|
|
199
238
|
|
|
200
|
-
|
|
239
|
+
logger.debug("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
|
|
201
240
|
return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
|
|
202
241
|
|
|
203
242
|
|
|
@@ -214,11 +253,21 @@ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold
|
|
|
214
253
|
if percent >= mapping_threshold:
|
|
215
254
|
records_to_analyze.add(record)
|
|
216
255
|
|
|
217
|
-
|
|
256
|
+
logger.info(f"Analyzing the following FASTA records: {records_to_analyze}")
|
|
218
257
|
return records_to_analyze
|
|
219
258
|
|
|
220
259
|
|
|
221
|
-
def process_single_bam(
|
|
260
|
+
def process_single_bam(
|
|
261
|
+
bam_index,
|
|
262
|
+
bam,
|
|
263
|
+
records_to_analyze,
|
|
264
|
+
record_FASTA_dict,
|
|
265
|
+
chromosome_FASTA_dict,
|
|
266
|
+
tmp_dir,
|
|
267
|
+
max_reference_length,
|
|
268
|
+
device,
|
|
269
|
+
deaminase_footprinting,
|
|
270
|
+
):
|
|
222
271
|
"""Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
|
|
223
272
|
adata_list = []
|
|
224
273
|
|
|
@@ -230,34 +279,58 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
|
|
|
230
279
|
sequence = chromosome_FASTA_dict[chromosome][0]
|
|
231
280
|
|
|
232
281
|
# Extract Base Identities
|
|
233
|
-
fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read =
|
|
282
|
+
fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = (
|
|
283
|
+
extract_base_identities(
|
|
284
|
+
bam, record, range(current_length), max_reference_length, sequence
|
|
285
|
+
)
|
|
286
|
+
)
|
|
234
287
|
mismatch_trend_series = pd.Series(mismatch_trend_per_read)
|
|
235
288
|
|
|
236
289
|
# Skip processing if both forward and reverse base identities are empty
|
|
237
290
|
if not fwd_bases and not rev_bases:
|
|
238
|
-
|
|
291
|
+
logger.debug(
|
|
292
|
+
f"[Worker {current_process().pid}] Skipping {sample} - No valid base identities for {record}."
|
|
293
|
+
)
|
|
239
294
|
continue
|
|
240
295
|
|
|
241
296
|
merged_bin = {}
|
|
242
297
|
|
|
243
298
|
# Binarize the Base Identities if they exist
|
|
244
299
|
if fwd_bases:
|
|
245
|
-
fwd_bin = binarize_converted_base_identities(
|
|
300
|
+
fwd_bin = binarize_converted_base_identities(
|
|
301
|
+
fwd_bases,
|
|
302
|
+
strand,
|
|
303
|
+
mod_type,
|
|
304
|
+
bam,
|
|
305
|
+
device,
|
|
306
|
+
deaminase_footprinting,
|
|
307
|
+
mismatch_trend_per_read,
|
|
308
|
+
)
|
|
246
309
|
merged_bin.update(fwd_bin)
|
|
247
310
|
|
|
248
311
|
if rev_bases:
|
|
249
|
-
rev_bin = binarize_converted_base_identities(
|
|
312
|
+
rev_bin = binarize_converted_base_identities(
|
|
313
|
+
rev_bases,
|
|
314
|
+
strand,
|
|
315
|
+
mod_type,
|
|
316
|
+
bam,
|
|
317
|
+
device,
|
|
318
|
+
deaminase_footprinting,
|
|
319
|
+
mismatch_trend_per_read,
|
|
320
|
+
)
|
|
250
321
|
merged_bin.update(rev_bin)
|
|
251
322
|
|
|
252
323
|
# Skip if merged_bin is empty (no valid binarized data)
|
|
253
324
|
if not merged_bin:
|
|
254
|
-
|
|
325
|
+
logger.debug(
|
|
326
|
+
f"[Worker {current_process().pid}] Skipping {sample} - No valid binarized data for {record}."
|
|
327
|
+
)
|
|
255
328
|
continue
|
|
256
329
|
|
|
257
330
|
# Convert to DataFrame
|
|
258
331
|
# for key in merged_bin:
|
|
259
332
|
# merged_bin[key] = merged_bin[key].cpu().numpy() # Move to CPU & convert to NumPy
|
|
260
|
-
bin_df = pd.DataFrame.from_dict(merged_bin, orient=
|
|
333
|
+
bin_df = pd.DataFrame.from_dict(merged_bin, orient="index")
|
|
261
334
|
sorted_index = sorted(bin_df.index)
|
|
262
335
|
bin_df = bin_df.reindex(sorted_index)
|
|
263
336
|
|
|
@@ -265,14 +338,18 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
|
|
|
265
338
|
one_hot_reads = {}
|
|
266
339
|
|
|
267
340
|
if fwd_bases:
|
|
268
|
-
fwd_ohe_files = ohe_batching(
|
|
341
|
+
fwd_ohe_files = ohe_batching(
|
|
342
|
+
fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000
|
|
343
|
+
)
|
|
269
344
|
for ohe_file in fwd_ohe_files:
|
|
270
345
|
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
271
346
|
one_hot_reads.update(tmp_ohe_dict)
|
|
272
347
|
del tmp_ohe_dict
|
|
273
348
|
|
|
274
349
|
if rev_bases:
|
|
275
|
-
rev_ohe_files = ohe_batching(
|
|
350
|
+
rev_ohe_files = ohe_batching(
|
|
351
|
+
rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000
|
|
352
|
+
)
|
|
276
353
|
for ohe_file in rev_ohe_files:
|
|
277
354
|
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
278
355
|
one_hot_reads.update(tmp_ohe_dict)
|
|
@@ -280,7 +357,9 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
|
|
|
280
357
|
|
|
281
358
|
# Skip if one_hot_reads is empty
|
|
282
359
|
if not one_hot_reads:
|
|
283
|
-
|
|
360
|
+
logger.debug(
|
|
361
|
+
f"[Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}."
|
|
362
|
+
)
|
|
284
363
|
continue
|
|
285
364
|
|
|
286
365
|
gc.collect()
|
|
@@ -291,11 +370,15 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
|
|
|
291
370
|
|
|
292
371
|
# Skip if no read names exist
|
|
293
372
|
if not read_names:
|
|
294
|
-
|
|
373
|
+
logger.debug(
|
|
374
|
+
f"[Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}."
|
|
375
|
+
)
|
|
295
376
|
continue
|
|
296
377
|
|
|
297
378
|
sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
|
|
298
|
-
df_A, df_C, df_G, df_T, df_N = [
|
|
379
|
+
df_A, df_C, df_G, df_T, df_N = [
|
|
380
|
+
np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)
|
|
381
|
+
]
|
|
299
382
|
|
|
300
383
|
# Populate One-Hot Arrays
|
|
301
384
|
for j, read_name in enumerate(sorted_index):
|
|
@@ -310,8 +393,8 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
|
|
|
310
393
|
adata.var_names = bin_df.columns.astype(str)
|
|
311
394
|
adata.obs["Sample"] = [sample] * len(adata)
|
|
312
395
|
try:
|
|
313
|
-
barcode = sample.split(
|
|
314
|
-
except:
|
|
396
|
+
barcode = sample.split("barcode")[1]
|
|
397
|
+
except Exception:
|
|
315
398
|
barcode = np.nan
|
|
316
399
|
adata.obs["Barcode"] = [int(barcode)] * len(adata)
|
|
317
400
|
adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
|
|
@@ -323,49 +406,76 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
|
|
|
323
406
|
adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
|
|
324
407
|
|
|
325
408
|
# Attach One-Hot Encodings to Layers
|
|
326
|
-
adata.layers["
|
|
327
|
-
adata.layers["
|
|
328
|
-
adata.layers["
|
|
329
|
-
adata.layers["
|
|
330
|
-
adata.layers["
|
|
409
|
+
adata.layers["A_binary_sequence_encoding"] = df_A
|
|
410
|
+
adata.layers["C_binary_sequence_encoding"] = df_C
|
|
411
|
+
adata.layers["G_binary_sequence_encoding"] = df_G
|
|
412
|
+
adata.layers["T_binary_sequence_encoding"] = df_T
|
|
413
|
+
adata.layers["N_binary_sequence_encoding"] = df_N
|
|
331
414
|
|
|
332
415
|
adata_list.append(adata)
|
|
333
416
|
|
|
334
417
|
return ad.concat(adata_list, join="outer") if adata_list else None
|
|
335
418
|
|
|
419
|
+
|
|
336
420
|
def timestamp():
|
|
337
421
|
"""Returns a formatted timestamp for logging."""
|
|
338
422
|
return time.strftime("[%Y-%m-%d %H:%M:%S]")
|
|
339
423
|
|
|
340
424
|
|
|
341
|
-
def worker_function(
|
|
425
|
+
def worker_function(
|
|
426
|
+
bam_index,
|
|
427
|
+
bam,
|
|
428
|
+
records_to_analyze,
|
|
429
|
+
shared_record_FASTA_dict,
|
|
430
|
+
chromosome_FASTA_dict,
|
|
431
|
+
tmp_dir,
|
|
432
|
+
h5_dir,
|
|
433
|
+
max_reference_length,
|
|
434
|
+
device,
|
|
435
|
+
deaminase_footprinting,
|
|
436
|
+
progress_queue,
|
|
437
|
+
):
|
|
342
438
|
"""Worker function that processes a single BAM and writes the output to an H5AD file."""
|
|
343
439
|
worker_id = current_process().pid # Get worker process ID
|
|
344
440
|
sample = bam.stem
|
|
345
441
|
|
|
346
442
|
try:
|
|
347
|
-
|
|
443
|
+
logger.info(f"[Worker {worker_id}] Processing BAM: {sample}")
|
|
348
444
|
|
|
349
445
|
h5ad_path = h5_dir / bam.with_suffix(".h5ad").name
|
|
350
446
|
if h5ad_path.exists():
|
|
351
|
-
|
|
447
|
+
logger.debug(f"[Worker {worker_id}] Skipping {sample}: Already processed.")
|
|
352
448
|
progress_queue.put(sample)
|
|
353
449
|
return
|
|
354
450
|
|
|
355
451
|
# Filter records specific to this BAM
|
|
356
|
-
bam_records_to_analyze = {
|
|
452
|
+
bam_records_to_analyze = {
|
|
453
|
+
record for record in records_to_analyze if record in shared_record_FASTA_dict
|
|
454
|
+
}
|
|
357
455
|
|
|
358
456
|
if not bam_records_to_analyze:
|
|
359
|
-
|
|
457
|
+
logger.debug(
|
|
458
|
+
f"[Worker {worker_id}] No valid records to analyze for {sample}. Skipping."
|
|
459
|
+
)
|
|
360
460
|
progress_queue.put(sample)
|
|
361
461
|
return
|
|
362
462
|
|
|
363
463
|
# Process BAM
|
|
364
|
-
adata = process_single_bam(
|
|
464
|
+
adata = process_single_bam(
|
|
465
|
+
bam_index,
|
|
466
|
+
bam,
|
|
467
|
+
bam_records_to_analyze,
|
|
468
|
+
shared_record_FASTA_dict,
|
|
469
|
+
chromosome_FASTA_dict,
|
|
470
|
+
tmp_dir,
|
|
471
|
+
max_reference_length,
|
|
472
|
+
device,
|
|
473
|
+
deaminase_footprinting,
|
|
474
|
+
)
|
|
365
475
|
|
|
366
476
|
if adata is not None:
|
|
367
477
|
adata.write_h5ad(str(h5ad_path))
|
|
368
|
-
|
|
478
|
+
logger.info(f"[Worker {worker_id}] Completed processing for BAM: {sample}")
|
|
369
479
|
|
|
370
480
|
# Free memory
|
|
371
481
|
del adata
|
|
@@ -373,22 +483,37 @@ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict
|
|
|
373
483
|
|
|
374
484
|
progress_queue.put(sample)
|
|
375
485
|
|
|
376
|
-
except Exception
|
|
377
|
-
|
|
486
|
+
except Exception:
|
|
487
|
+
logger.warning(
|
|
488
|
+
f"[Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}"
|
|
489
|
+
)
|
|
378
490
|
progress_queue.put(sample) # Still signal completion to prevent deadlock
|
|
379
491
|
|
|
380
|
-
|
|
492
|
+
|
|
493
|
+
def process_bams_parallel(
|
|
494
|
+
bam_path_list,
|
|
495
|
+
records_to_analyze,
|
|
496
|
+
record_FASTA_dict,
|
|
497
|
+
chromosome_FASTA_dict,
|
|
498
|
+
tmp_dir,
|
|
499
|
+
h5_dir,
|
|
500
|
+
num_threads,
|
|
501
|
+
max_reference_length,
|
|
502
|
+
device,
|
|
503
|
+
deaminase_footprinting,
|
|
504
|
+
):
|
|
381
505
|
"""Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
|
|
382
506
|
make_dirs(h5_dir) # Ensure h5_dir exists
|
|
383
507
|
|
|
384
|
-
|
|
508
|
+
logger.info(f"Starting parallel BAM processing with {num_threads} threads...")
|
|
385
509
|
|
|
386
510
|
# Ensure macOS uses forkserver to avoid spawning issues
|
|
387
511
|
try:
|
|
388
512
|
import multiprocessing
|
|
513
|
+
|
|
389
514
|
multiprocessing.set_start_method("forkserver", force=True)
|
|
390
515
|
except RuntimeError:
|
|
391
|
-
|
|
516
|
+
logger.warning(f"Multiprocessing context already set. Skipping set_start_method.")
|
|
392
517
|
|
|
393
518
|
with Manager() as manager:
|
|
394
519
|
progress_queue = manager.Queue()
|
|
@@ -396,11 +521,26 @@ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict,
|
|
|
396
521
|
|
|
397
522
|
with Pool(processes=num_threads) as pool:
|
|
398
523
|
results = [
|
|
399
|
-
pool.apply_async(
|
|
524
|
+
pool.apply_async(
|
|
525
|
+
worker_function,
|
|
526
|
+
(
|
|
527
|
+
i,
|
|
528
|
+
bam,
|
|
529
|
+
records_to_analyze,
|
|
530
|
+
shared_record_FASTA_dict,
|
|
531
|
+
chromosome_FASTA_dict,
|
|
532
|
+
tmp_dir,
|
|
533
|
+
h5_dir,
|
|
534
|
+
max_reference_length,
|
|
535
|
+
device,
|
|
536
|
+
deaminase_footprinting,
|
|
537
|
+
progress_queue,
|
|
538
|
+
),
|
|
539
|
+
)
|
|
400
540
|
for i, bam in enumerate(bam_path_list)
|
|
401
541
|
]
|
|
402
542
|
|
|
403
|
-
|
|
543
|
+
logger.info(f"Submitted {len(bam_path_list)} BAMs for processing.")
|
|
404
544
|
|
|
405
545
|
# Track completed BAMs
|
|
406
546
|
completed_bams = set()
|
|
@@ -409,24 +549,25 @@ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict,
|
|
|
409
549
|
processed_bam = progress_queue.get(timeout=2400) # Wait for a finished BAM
|
|
410
550
|
completed_bams.add(processed_bam)
|
|
411
551
|
except Exception as e:
|
|
412
|
-
|
|
552
|
+
logger.error(f"Timeout waiting for worker process. Possible crash? {e}")
|
|
413
553
|
|
|
414
554
|
pool.close()
|
|
415
555
|
pool.join() # Ensure all workers finish
|
|
416
556
|
|
|
417
557
|
# Final Concatenation Step
|
|
418
|
-
h5ad_files = [
|
|
558
|
+
h5ad_files = [f for f in h5_dir.iterdir() if f.suffix == ".h5ad"]
|
|
419
559
|
|
|
420
560
|
if not h5ad_files:
|
|
421
|
-
|
|
561
|
+
logger.debug(f"No valid H5AD files generated. Exiting.")
|
|
422
562
|
return None
|
|
423
563
|
|
|
424
|
-
|
|
564
|
+
logger.info(f"Concatenating {len(h5ad_files)} H5AD files into final output...")
|
|
425
565
|
final_adata = ad.concat([ad.read_h5ad(f) for f in h5ad_files], join="outer")
|
|
426
566
|
|
|
427
|
-
|
|
567
|
+
logger.info(f"Successfully generated final AnnData object.")
|
|
428
568
|
return final_adata
|
|
429
569
|
|
|
570
|
+
|
|
430
571
|
def delete_intermediate_h5ads_and_tmpdir(
|
|
431
572
|
h5_dir: Union[str, Path, Iterable[str], None],
|
|
432
573
|
tmp_dir: Optional[Union[str, Path]] = None,
|
|
@@ -450,25 +591,27 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
450
591
|
verbose : bool
|
|
451
592
|
Print progress / warnings.
|
|
452
593
|
"""
|
|
594
|
+
|
|
453
595
|
# Helper: remove a single file path (Path-like or string)
|
|
454
596
|
def _maybe_unlink(p: Path):
|
|
597
|
+
"""Remove a file path if it exists and is a file."""
|
|
455
598
|
if not p.exists():
|
|
456
599
|
if verbose:
|
|
457
|
-
|
|
600
|
+
logger.debug(f"[skip] not found: {p}")
|
|
458
601
|
return
|
|
459
602
|
if not p.is_file():
|
|
460
603
|
if verbose:
|
|
461
|
-
|
|
604
|
+
logger.debug(f"[skip] not a file: {p}")
|
|
462
605
|
return
|
|
463
606
|
if dry_run:
|
|
464
|
-
|
|
607
|
+
logger.debug(f"[dry-run] would remove file: {p}")
|
|
465
608
|
return
|
|
466
609
|
try:
|
|
467
610
|
p.unlink()
|
|
468
611
|
if verbose:
|
|
469
|
-
|
|
612
|
+
logger.info(f"Removed file: {p}")
|
|
470
613
|
except Exception as e:
|
|
471
|
-
|
|
614
|
+
logger.warning(f"[error] failed to remove file {p}: {e}")
|
|
472
615
|
|
|
473
616
|
# Handle h5_dir input (directory OR iterable of file paths)
|
|
474
617
|
if h5_dir is not None:
|
|
@@ -483,7 +626,7 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
483
626
|
else:
|
|
484
627
|
if verbose:
|
|
485
628
|
# optional: comment this out if too noisy
|
|
486
|
-
|
|
629
|
+
logger.debug(f"[skip] not matching pattern: {p.name}")
|
|
487
630
|
else:
|
|
488
631
|
# treat as iterable of file paths
|
|
489
632
|
for f in h5_dir:
|
|
@@ -493,25 +636,25 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
493
636
|
_maybe_unlink(p)
|
|
494
637
|
else:
|
|
495
638
|
if verbose:
|
|
496
|
-
|
|
639
|
+
logger.debug(f"[skip] not matching pattern or not a file: {p}")
|
|
497
640
|
|
|
498
641
|
# Remove tmp_dir recursively (if provided)
|
|
499
642
|
if tmp_dir is not None:
|
|
500
643
|
td = Path(tmp_dir)
|
|
501
644
|
if not td.exists():
|
|
502
645
|
if verbose:
|
|
503
|
-
|
|
646
|
+
logger.debug(f"[skip] tmp_dir not found: {td}")
|
|
504
647
|
else:
|
|
505
648
|
if not td.is_dir():
|
|
506
649
|
if verbose:
|
|
507
|
-
|
|
650
|
+
logger.debug(f"[skip] tmp_dir is not a directory: {td}")
|
|
508
651
|
else:
|
|
509
652
|
if dry_run:
|
|
510
|
-
|
|
653
|
+
logger.debug(f"[dry-run] would remove directory tree: {td}")
|
|
511
654
|
else:
|
|
512
655
|
try:
|
|
513
656
|
shutil.rmtree(td)
|
|
514
657
|
if verbose:
|
|
515
|
-
|
|
658
|
+
logger.info(f"Removed directory tree: {td}")
|
|
516
659
|
except Exception as e:
|
|
517
|
-
|
|
660
|
+
logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
|