smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,52 +1,76 @@
|
|
|
1
|
-
|
|
1
|
+
import os
|
|
2
|
+
import anndata as ad
|
|
3
|
+
import numpy as np
|
|
4
|
+
import concurrent.futures
|
|
5
|
+
from .one_hot_encode import one_hot_encode
|
|
2
6
|
|
|
3
|
-
def
|
|
7
|
+
def encode_sequence(args):
|
|
8
|
+
"""Parallel helper function for one-hot encoding."""
|
|
9
|
+
read_name, seq, device = args
|
|
10
|
+
try:
|
|
11
|
+
one_hot_matrix = one_hot_encode(seq, device)
|
|
12
|
+
return read_name, one_hot_matrix
|
|
13
|
+
except Exception:
|
|
14
|
+
return None # Skip invalid sequences
|
|
15
|
+
|
|
16
|
+
def encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
|
|
17
|
+
"""Encodes a batch and writes to disk immediately."""
|
|
18
|
+
batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
|
|
19
|
+
|
|
20
|
+
if batch:
|
|
21
|
+
save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad')
|
|
22
|
+
tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
|
|
23
|
+
tmp_ad.write_h5ad(save_name)
|
|
24
|
+
return save_name
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000, progress_bar=None, device='auto', threads=None):
|
|
4
28
|
"""
|
|
5
|
-
|
|
29
|
+
Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
|
|
6
30
|
|
|
7
31
|
Parameters:
|
|
8
|
-
base_identities (dict):
|
|
9
|
-
tmp_dir (str):
|
|
10
|
-
record (str):
|
|
11
|
-
prefix (str): Prefix
|
|
12
|
-
batch_size (int): Number of reads
|
|
32
|
+
base_identities (dict): Dictionary mapping read names to sequences.
|
|
33
|
+
tmp_dir (str): Directory for storing temporary files.
|
|
34
|
+
record (str): Record name.
|
|
35
|
+
prefix (str): Prefix for file naming.
|
|
36
|
+
batch_size (int): Number of reads per batch.
|
|
37
|
+
progress_bar (tqdm instance, optional): Shared progress bar.
|
|
38
|
+
device (str): Device for encoding.
|
|
39
|
+
threads (int, optional): Number of parallel workers.
|
|
13
40
|
|
|
14
41
|
Returns:
|
|
15
|
-
|
|
42
|
+
list: List of valid H5AD file paths.
|
|
16
43
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
import numpy as np
|
|
20
|
-
from tqdm import tqdm
|
|
21
|
-
from .one_hot_encode import one_hot_encode
|
|
22
|
-
|
|
23
|
-
batch = {}
|
|
24
|
-
count = 0
|
|
44
|
+
threads = threads or os.cpu_count() # Default to max available CPU cores
|
|
45
|
+
batch_data = []
|
|
25
46
|
batch_number = 0
|
|
26
|
-
total_reads = len(base_identities)
|
|
27
47
|
file_names = []
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
|
|
49
|
+
# Step 1: Prepare Data for Parallel Encoding
|
|
50
|
+
encoding_args = [(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None]
|
|
51
|
+
|
|
52
|
+
# Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
|
|
53
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
|
|
54
|
+
for result in executor.map(encode_sequence, encoding_args):
|
|
55
|
+
if result:
|
|
56
|
+
batch_data.append(result)
|
|
57
|
+
|
|
58
|
+
if len(batch_data) >= batch_size:
|
|
59
|
+
# Step 3: Process and Write Batch Immediately
|
|
60
|
+
file_name = encode_and_save_batch(batch_data.copy(), tmp_dir, prefix, record, batch_number)
|
|
61
|
+
if file_name:
|
|
62
|
+
file_names.append(file_name)
|
|
63
|
+
|
|
64
|
+
batch_data.clear()
|
|
65
|
+
batch_number += 1
|
|
66
|
+
|
|
67
|
+
if progress_bar:
|
|
68
|
+
progress_bar.update(1)
|
|
69
|
+
|
|
70
|
+
# Step 4: Process Remaining Batch
|
|
71
|
+
if batch_data:
|
|
72
|
+
file_name = encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number)
|
|
73
|
+
if file_name:
|
|
74
|
+
file_names.append(file_name)
|
|
51
75
|
|
|
52
76
|
return file_names
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# ohe_layers_decode
|
|
2
|
+
|
|
3
|
+
def ohe_layers_decode(adata, obs_names):
|
|
4
|
+
"""
|
|
5
|
+
Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
|
|
6
|
+
Parameters:
|
|
7
|
+
adata (AnnData): An anndata object.
|
|
8
|
+
obs_names (list): A list of observation name strings to retrieve sequences for.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
sequences (list of str): List of strings of the one hot encoded array
|
|
12
|
+
"""
|
|
13
|
+
import anndata as ad
|
|
14
|
+
import numpy as np
|
|
15
|
+
from .ohe_decode import ohe_decode
|
|
16
|
+
|
|
17
|
+
# Define the mapping of one-hot encoded indices to DNA bases
|
|
18
|
+
mapping = ['A', 'C', 'G', 'T', 'N']
|
|
19
|
+
|
|
20
|
+
ohe_layers = [f"{base}_binary_encoding" for base in mapping]
|
|
21
|
+
sequences = []
|
|
22
|
+
|
|
23
|
+
for obs_name in obs_names:
|
|
24
|
+
obs_subset = adata[obs_name]
|
|
25
|
+
ohe_list = []
|
|
26
|
+
for layer in ohe_layers:
|
|
27
|
+
ohe_list += list(obs_subset.layers[layer])
|
|
28
|
+
ohe_array = np.array(ohe_list)
|
|
29
|
+
sequence = ohe_decode(ohe_array)
|
|
30
|
+
sequences.append(sequence)
|
|
31
|
+
|
|
32
|
+
return sequences
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# one_hot_decode
|
|
2
|
+
|
|
3
|
+
# String encodings
|
|
4
|
+
def one_hot_decode(ohe_array):
|
|
5
|
+
"""
|
|
6
|
+
Takes a flattened one hot encoded array and returns the sequence string from that array.
|
|
7
|
+
Parameters:
|
|
8
|
+
ohe_array (np.array): A one hot encoded array
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
sequence (str): Sequence string of the one hot encoded array
|
|
12
|
+
"""
|
|
13
|
+
import numpy as np
|
|
14
|
+
# Define the mapping of one-hot encoded indices to DNA bases
|
|
15
|
+
mapping = ['A', 'C', 'G', 'T', 'N']
|
|
16
|
+
|
|
17
|
+
# Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
|
|
18
|
+
one_hot_matrix = ohe_array.reshape(-1, 5)
|
|
19
|
+
|
|
20
|
+
# Get the index of the maximum value (which will be 1) in each row
|
|
21
|
+
decoded_indices = np.argmax(one_hot_matrix, axis=1)
|
|
22
|
+
|
|
23
|
+
# Map the indices back to the corresponding bases
|
|
24
|
+
sequence_list = [mapping[i] for i in decoded_indices]
|
|
25
|
+
sequence = ''.join(sequence_list)
|
|
26
|
+
|
|
27
|
+
return sequence
|
|
@@ -1,21 +1,57 @@
|
|
|
1
1
|
# one_hot_encode
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
def one_hot_encode(sequence):
|
|
3
|
+
def one_hot_encode(sequence, device='auto'):
|
|
5
4
|
"""
|
|
6
|
-
One
|
|
5
|
+
One-hot encodes a DNA sequence.
|
|
6
|
+
|
|
7
7
|
Parameters:
|
|
8
|
-
sequence (list): A
|
|
8
|
+
sequence (str or list): DNA sequence (e.g., "ACGTN" or ['A', 'C', 'G', 'T', 'N']).
|
|
9
9
|
|
|
10
10
|
Returns:
|
|
11
|
-
|
|
11
|
+
ndarray: Flattened one-hot encoded representation of the input sequence.
|
|
12
12
|
"""
|
|
13
13
|
import numpy as np
|
|
14
14
|
|
|
15
|
-
seq_array = np.array(sequence, dtype='<U1') # String dtype
|
|
16
15
|
mapping = np.array(['A', 'C', 'G', 'T', 'N'])
|
|
17
|
-
|
|
16
|
+
|
|
17
|
+
# Ensure input is a list of characters
|
|
18
|
+
if not isinstance(sequence, list):
|
|
19
|
+
sequence = list(sequence) # Convert string to list of characters
|
|
20
|
+
|
|
21
|
+
# Handle empty sequences
|
|
22
|
+
if len(sequence) == 0:
|
|
23
|
+
print("Warning: Empty sequence encountered in one_hot_encode()")
|
|
24
|
+
return np.zeros(len(mapping)) # Return empty encoding instead of failing
|
|
25
|
+
|
|
26
|
+
# Convert sequence to NumPy array
|
|
27
|
+
seq_array = np.array(sequence, dtype='<U1')
|
|
28
|
+
|
|
29
|
+
# Replace invalid bases with 'N'
|
|
30
|
+
seq_array = np.where(np.isin(seq_array, mapping), seq_array, 'N')
|
|
31
|
+
|
|
32
|
+
# Create one-hot encoding matrix
|
|
18
33
|
one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
|
|
19
|
-
flattened = one_hot_matrix.flatten()
|
|
20
34
|
|
|
21
|
-
return
|
|
35
|
+
# Flatten and return
|
|
36
|
+
return one_hot_matrix.flatten()
|
|
37
|
+
|
|
38
|
+
# import torch
|
|
39
|
+
# bases = torch.tensor([ord('A'), ord('C'), ord('G'), ord('T'), ord('N')], dtype=torch.int8, device=device)
|
|
40
|
+
|
|
41
|
+
# # Convert input to tensor of character ASCII codes
|
|
42
|
+
# seq_tensor = torch.tensor([ord(c) for c in sequence], dtype=torch.int8, device=device)
|
|
43
|
+
|
|
44
|
+
# # Handle empty sequence
|
|
45
|
+
# if seq_tensor.numel() == 0:
|
|
46
|
+
# print("Warning: Empty sequence encountered in one_hot_encode_torch()")
|
|
47
|
+
# return torch.zeros(len(bases), device=device)
|
|
48
|
+
|
|
49
|
+
# # Replace invalid bases with 'N'
|
|
50
|
+
# is_valid = (seq_tensor[:, None] == bases) # Compare each base with mapping
|
|
51
|
+
# seq_tensor = torch.where(is_valid.any(dim=1), seq_tensor, ord('N'))
|
|
52
|
+
|
|
53
|
+
# # Create one-hot encoding matrix
|
|
54
|
+
# one_hot_matrix = (seq_tensor[:, None] == bases).int()
|
|
55
|
+
|
|
56
|
+
# # Flatten and return
|
|
57
|
+
# return one_hot_matrix.flatten()
|
|
@@ -18,6 +18,7 @@ def plot_read_length_and_coverage_histograms(bed_file, plotting_directory):
|
|
|
18
18
|
|
|
19
19
|
bed_basename = os.path.basename(bed_file).split('.bed')[0]
|
|
20
20
|
# Load the BED file into a DataFrame
|
|
21
|
+
print(f"Loading BED to plot read length and coverage histograms: {bed_file}")
|
|
21
22
|
df = pd.read_csv(bed_file, sep='\t', header=None, names=['chromosome', 'start', 'end', 'length', 'read_name'])
|
|
22
23
|
|
|
23
24
|
# Group by chromosome
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
def run_multiqc(input_dir, output_dir):
|
|
2
|
+
"""
|
|
3
|
+
Runs MultiQC on a given directory and saves the report to the specified output directory.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
- input_dir (str): Path to the directory containing QC reports (e.g., FastQC, Samtools, bcftools outputs).
|
|
7
|
+
- output_dir (str): Path to the directory where MultiQC reports should be saved.
|
|
8
|
+
|
|
9
|
+
Returns:
|
|
10
|
+
- None: The function executes MultiQC and prints the status.
|
|
11
|
+
"""
|
|
12
|
+
import os
|
|
13
|
+
import subprocess
|
|
14
|
+
# Ensure the output directory exists
|
|
15
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
16
|
+
|
|
17
|
+
# Construct MultiQC command
|
|
18
|
+
command = ["multiqc", input_dir, "-o", output_dir]
|
|
19
|
+
|
|
20
|
+
print(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
|
|
21
|
+
|
|
22
|
+
# Run MultiQC
|
|
23
|
+
try:
|
|
24
|
+
subprocess.run(command, check=True)
|
|
25
|
+
print(f"MultiQC report generated successfully in: {output_dir}")
|
|
26
|
+
except subprocess.CalledProcessError as e:
|
|
27
|
+
print(f"Error running MultiQC: {e}")
|
|
28
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
## split_and_index_BAM
|
|
2
2
|
|
|
3
|
-
def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_directory
|
|
3
|
+
def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_directory):
|
|
4
4
|
"""
|
|
5
5
|
A wrapper function for splitting BAMS and indexing them.
|
|
6
6
|
Parameters:
|
|
@@ -8,7 +8,6 @@ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_direct
|
|
|
8
8
|
split_dir (str): A string representing the file path to the directory to split the BAMs into.
|
|
9
9
|
bam_suffix (str): A suffix to add to the bam file.
|
|
10
10
|
output_directory (str): A file path to the directory to output all the analyses.
|
|
11
|
-
fasta (str): File path to the reference genome to align to.
|
|
12
11
|
|
|
13
12
|
Returns:
|
|
14
13
|
None
|
|
@@ -19,8 +18,6 @@ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_direct
|
|
|
19
18
|
import subprocess
|
|
20
19
|
import glob
|
|
21
20
|
from .separate_bam_by_bc import separate_bam_by_bc
|
|
22
|
-
from .aligned_BAM_to_bed import aligned_BAM_to_bed
|
|
23
|
-
from .extract_readnames_from_BAM import extract_readnames_from_BAM
|
|
24
21
|
from .make_dirs import make_dirs
|
|
25
22
|
|
|
26
23
|
plotting_dir = os.path.join(output_directory, 'demultiplexed_bed_histograms')
|
|
@@ -35,7 +32,5 @@ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_direct
|
|
|
35
32
|
bam_files = [bam for bam in bam_files if '.bai' not in bam]
|
|
36
33
|
for input_file in bam_files:
|
|
37
34
|
subprocess.run(["samtools", "index", input_file])
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
# Make a text file of reads for the BAM
|
|
41
|
-
extract_readnames_from_BAM(input_file)
|
|
35
|
+
|
|
36
|
+
return bam_files
|
|
@@ -14,11 +14,12 @@ def load_adata(config_path):
|
|
|
14
14
|
None
|
|
15
15
|
"""
|
|
16
16
|
# Lazy importing of packages
|
|
17
|
-
from .helpers import LoadExperimentConfig, make_dirs, concatenate_fastqs_to_bam
|
|
17
|
+
from .helpers import LoadExperimentConfig, make_dirs, concatenate_fastqs_to_bam, extract_read_features_from_bam
|
|
18
18
|
from .fast5_to_pod5 import fast5_to_pod5
|
|
19
19
|
from .subsample_fasta_from_bed import subsample_fasta_from_bed
|
|
20
20
|
import os
|
|
21
21
|
import numpy as np
|
|
22
|
+
import anndata as ad
|
|
22
23
|
from pathlib import Path
|
|
23
24
|
|
|
24
25
|
# Default params
|
|
@@ -42,8 +43,13 @@ def load_adata(config_path):
|
|
|
42
43
|
fasta_regions_of_interest = var_dict.get("fasta_regions_of_interest", default_value) # Path to a bed file listing coordinate regions of interest within the FASTA to include. Optional.
|
|
43
44
|
mapping_threshold = var_dict.get('mapping_threshold', default_value) # Minimum proportion of mapped reads that need to fall within a region to include in the final AnnData.
|
|
44
45
|
experiment_name = var_dict.get('experiment_name', default_value) # A key term to add to the AnnData file name.
|
|
46
|
+
model_dir = var_dict.get('model_dir', default_value) # needed for dorado basecaller
|
|
45
47
|
model = var_dict.get('model', default_value) # needed for dorado basecaller
|
|
46
48
|
barcode_kit = var_dict.get('barcode_kit', default_value) # needed for dorado basecaller
|
|
49
|
+
barcode_both_ends = var_dict.get('barcode_both_ends', default_value) # dorado demultiplexing
|
|
50
|
+
trim = var_dict.get('trim', default_value) # dorado adapter and barcode removal
|
|
51
|
+
input_already_demuxed = var_dict.get('input_already_demuxed', default_value) # If the input files are already demultiplexed.
|
|
52
|
+
threads = var_dict.get('threads', default_value) # number of cpu threads available for multiprocessing
|
|
47
53
|
# Conversion specific variable init
|
|
48
54
|
conversion_types = var_dict.get('conversion_types', default_value)
|
|
49
55
|
# Direct methylation specific variable init
|
|
@@ -54,6 +60,10 @@ def load_adata(config_path):
|
|
|
54
60
|
thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
|
|
55
61
|
mod_list = var_dict.get('mod_list', default_value)
|
|
56
62
|
batch_size = var_dict.get('batch_size', default_value)
|
|
63
|
+
device = var_dict.get('device', 'auto')
|
|
64
|
+
make_bigwigs = var_dict.get('make_bigwigs', default_value)
|
|
65
|
+
skip_unclassified = var_dict.get('skip_unclassified', True)
|
|
66
|
+
delete_batch_hdfs = var_dict.get('delete_batch_hdfs', True)
|
|
57
67
|
|
|
58
68
|
# Make initial output directory
|
|
59
69
|
make_dirs([output_directory])
|
|
@@ -119,9 +129,54 @@ def load_adata(config_path):
|
|
|
119
129
|
|
|
120
130
|
if smf_modality == 'conversion':
|
|
121
131
|
from .conversion_smf import conversion_smf
|
|
122
|
-
conversion_smf(fasta, output_directory, conversions, strands, model, input_data_path, split_path
|
|
132
|
+
final_adata, final_adata_path, sorted_output, bam_files = conversion_smf(fasta, output_directory, conversions, strands, model_dir, model, input_data_path, split_path
|
|
133
|
+
, barcode_kit, mapping_threshold, experiment_name, bam_suffix, basecall, barcode_both_ends, trim, device, make_bigwigs, threads, input_already_demuxed)
|
|
123
134
|
elif smf_modality == 'direct':
|
|
124
135
|
from .direct_smf import direct_smf
|
|
125
|
-
|
|
136
|
+
# need to add input_already_demuxed workflow here.
|
|
137
|
+
final_adata, final_adata_path, sorted_output, bam_files = direct_smf(fasta, output_directory, mod_list,model_dir, model, thresholds, input_data_path, split_path
|
|
138
|
+
, barcode_kit, mapping_threshold, experiment_name, bam_suffix, batch_size, basecall, barcode_both_ends, trim, device, make_bigwigs, skip_unclassified, delete_batch_hdfs, threads)
|
|
126
139
|
else:
|
|
127
140
|
print("Error")
|
|
141
|
+
|
|
142
|
+
# Read in the final adata object and append final metadata
|
|
143
|
+
#print(f'Reading in adata from {final_adata_path} to add final metadata')
|
|
144
|
+
# final_adata = ad.read_h5ad(final_adata_path)
|
|
145
|
+
|
|
146
|
+
# Adding read query length metadata to adata object.
|
|
147
|
+
read_metrics = {}
|
|
148
|
+
for bam_file in bam_files:
|
|
149
|
+
bam_read_metrics = extract_read_features_from_bam(bam_file)
|
|
150
|
+
read_metrics.update(bam_read_metrics)
|
|
151
|
+
#read_metrics = extract_read_features_from_bam(sorted_output)
|
|
152
|
+
|
|
153
|
+
query_read_length_values = []
|
|
154
|
+
query_read_quality_values = []
|
|
155
|
+
reference_lengths = []
|
|
156
|
+
# Iterate over each row of the AnnData object
|
|
157
|
+
for obs_name in final_adata.obs_names:
|
|
158
|
+
# Fetch the value from the dictionary using the obs_name as the key
|
|
159
|
+
value = read_metrics.get(obs_name, np.nan) # Use np.nan if the key is not found
|
|
160
|
+
if type(value) is list:
|
|
161
|
+
query_read_length_values.append(value[0])
|
|
162
|
+
query_read_quality_values.append(value[1])
|
|
163
|
+
reference_lengths.append(value[2])
|
|
164
|
+
else:
|
|
165
|
+
query_read_length_values.append(value)
|
|
166
|
+
query_read_quality_values.append(value)
|
|
167
|
+
reference_lengths.append(value)
|
|
168
|
+
|
|
169
|
+
# Add the new column to adata.obs
|
|
170
|
+
final_adata.obs['query_read_length'] = query_read_length_values
|
|
171
|
+
final_adata.obs['query_read_quality'] = query_read_quality_values
|
|
172
|
+
final_adata.obs['query_length_to_reference_length_ratio'] = np.array(query_read_length_values) / np.array(reference_lengths)
|
|
173
|
+
|
|
174
|
+
final_adata.obs['Raw_methylation_signal'] = np.nansum(final_adata.X, axis=1)
|
|
175
|
+
final_adata.obs['Raw_per_base_methylation_average'] = final_adata.obs['Raw_methylation_signal'] / final_adata.obs['query_read_length']
|
|
176
|
+
|
|
177
|
+
print('Saving final adata')
|
|
178
|
+
if ".gz" in final_adata_path:
|
|
179
|
+
final_adata.write_h5ad(f"{final_adata_path}", compression='gzip')
|
|
180
|
+
else:
|
|
181
|
+
final_adata.write_h5ad(f"{final_adata_path}.gz", compression='gzip')
|
|
182
|
+
print('Final adata saved')
|
smftools/plotting/__init__.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .position_stats import plot_bar_relative_risk, plot_volcano_relative_risk, plot_positionwise_matrix, plot_positionwise_matrix_grid
|
|
2
|
+
from .general_plotting import combined_hmm_raw_clustermap
|
|
3
|
+
from .classifiers import plot_model_performance, plot_feature_importances_or_saliency, plot_model_curves_from_adata, plot_model_curves_from_adata_with_frequency_grid
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"combined_hmm_raw_clustermap",
|
|
7
|
+
"plot_bar_relative_risk",
|
|
8
|
+
"plot_positionwise_matrix",
|
|
9
|
+
"plot_positionwise_matrix_grid",
|
|
10
|
+
"plot_volcano_relative_risk",
|
|
11
|
+
"plot_feature_importances_or_saliency",
|
|
12
|
+
"plot_model_performance",
|
|
13
|
+
"plot_model_curves_from_adata",
|
|
14
|
+
"plot_model_curves_from_adata_with_frequency_grid"
|
|
15
|
+
]
|