smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,52 +1,76 @@
1
- # ohe_batching
1
+ import os
2
+ import anndata as ad
3
+ import numpy as np
4
+ import concurrent.futures
5
+ from .one_hot_encode import one_hot_encode
2
6
 
3
- def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000):
7
+ def encode_sequence(args):
8
+ """Parallel helper function for one-hot encoding."""
9
+ read_name, seq, device = args
10
+ try:
11
+ one_hot_matrix = one_hot_encode(seq, device)
12
+ return read_name, one_hot_matrix
13
+ except Exception:
14
+ return None # Skip invalid sequences
15
+
16
+ def encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
17
+ """Encodes a batch and writes to disk immediately."""
18
+ batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
19
+
20
+ if batch:
21
+ save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad')
22
+ tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
23
+ tmp_ad.write_h5ad(save_name)
24
+ return save_name
25
+ return None
26
+
27
+ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000, progress_bar=None, device='auto', threads=None):
4
28
  """
5
- Processes base identities to one-hot encoded matrices and writes to a h5ad file in batches.
29
+ Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
6
30
 
7
31
  Parameters:
8
- base_identities (dict): A dictionary of read names and sequences.
9
- tmp_dir (str): Path to directory where the files will be saved.
10
- record (str): Name of the record.
11
- prefix (str): Prefix to add to the output file name
12
- batch_size (int): Number of reads to process in each batch.
32
+ base_identities (dict): Dictionary mapping read names to sequences.
33
+ tmp_dir (str): Directory for storing temporary files.
34
+ record (str): Record name.
35
+ prefix (str): Prefix for file naming.
36
+ batch_size (int): Number of reads per batch.
37
+ progress_bar (tqdm instance, optional): Shared progress bar.
38
+ device (str): Device for encoding.
39
+ threads (int, optional): Number of parallel workers.
13
40
 
14
41
  Returns:
15
- ohe_file (list): list of output file names
42
+ list: List of valid H5AD file paths.
16
43
  """
17
- import os
18
- import anndata as ad
19
- import numpy as np
20
- from tqdm import tqdm
21
- from .one_hot_encode import one_hot_encode
22
-
23
- batch = {}
24
- count = 0
44
+ threads = threads or os.cpu_count() # Default to max available CPU cores
45
+ batch_data = []
25
46
  batch_number = 0
26
- total_reads = len(base_identities)
27
47
  file_names = []
28
-
29
- for read_name, seq in tqdm(base_identities.items(), desc="Encoding and writing one hot encoded reads", total=total_reads):
30
- one_hot_matrix = one_hot_encode(seq)
31
- batch[read_name] = one_hot_matrix
32
- count += 1
33
- # If the batch size is reached, write out the batch and reset
34
- if count >= batch_size:
35
- save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad.gz')
36
- X = np.random.rand(1, 1)
37
- tmp_ad = ad.AnnData(X=X, uns=batch)
38
- tmp_ad.write_h5ad(save_name, compression='gzip')
39
- file_names.append(save_name)
40
- batch.clear()
41
- count = 0
42
- batch_number += 1
43
-
44
- # Write out any remaining reads in the final batch
45
- if batch:
46
- save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad.gz')
47
- X = np.random.rand(1, 1)
48
- tmp_ad = ad.AnnData(X=X, uns=batch)
49
- tmp_ad.write_h5ad(save_name, compression='gzip')
50
- file_names.append(save_name)
48
+
49
+ # Step 1: Prepare Data for Parallel Encoding
50
+ encoding_args = [(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None]
51
+
52
+ # Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
53
+ with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
54
+ for result in executor.map(encode_sequence, encoding_args):
55
+ if result:
56
+ batch_data.append(result)
57
+
58
+ if len(batch_data) >= batch_size:
59
+ # Step 3: Process and Write Batch Immediately
60
+ file_name = encode_and_save_batch(batch_data.copy(), tmp_dir, prefix, record, batch_number)
61
+ if file_name:
62
+ file_names.append(file_name)
63
+
64
+ batch_data.clear()
65
+ batch_number += 1
66
+
67
+ if progress_bar:
68
+ progress_bar.update(1)
69
+
70
+ # Step 4: Process Remaining Batch
71
+ if batch_data:
72
+ file_name = encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number)
73
+ if file_name:
74
+ file_names.append(file_name)
51
75
 
52
76
  return file_names
@@ -0,0 +1,32 @@
1
+ # ohe_layers_decode
2
+
3
+ def ohe_layers_decode(adata, obs_names):
4
+ """
5
+ Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
6
+ Parameters:
7
+ adata (AnnData): An anndata object.
8
+ obs_names (list): A list of observation name strings to retrieve sequences for.
9
+
10
+ Returns:
11
+ sequences (list of str): List of strings of the one hot encoded array
12
+ """
13
+ import anndata as ad
14
+ import numpy as np
15
+ from .ohe_decode import ohe_decode
16
+
17
+ # Define the mapping of one-hot encoded indices to DNA bases
18
+ mapping = ['A', 'C', 'G', 'T', 'N']
19
+
20
+ ohe_layers = [f"{base}_binary_encoding" for base in mapping]
21
+ sequences = []
22
+
23
+ for obs_name in obs_names:
24
+ obs_subset = adata[obs_name]
25
+ ohe_list = []
26
+ for layer in ohe_layers:
27
+ ohe_list += list(obs_subset.layers[layer])
28
+ ohe_array = np.array(ohe_list)
29
+ sequence = ohe_decode(ohe_array)
30
+ sequences.append(sequence)
31
+
32
+ return sequences
@@ -0,0 +1,27 @@
1
+ # one_hot_decode
2
+
3
+ # String encodings
4
+ def one_hot_decode(ohe_array):
5
+ """
6
+ Takes a flattened one hot encoded array and returns the sequence string from that array.
7
+ Parameters:
8
+ ohe_array (np.array): A one hot encoded array
9
+
10
+ Returns:
11
+ sequence (str): Sequence string of the one hot encoded array
12
+ """
13
+ import numpy as np
14
+ # Define the mapping of one-hot encoded indices to DNA bases
15
+ mapping = ['A', 'C', 'G', 'T', 'N']
16
+
17
+ # Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
18
+ one_hot_matrix = ohe_array.reshape(-1, 5)
19
+
20
+ # Get the index of the maximum value (which will be 1) in each row
21
+ decoded_indices = np.argmax(one_hot_matrix, axis=1)
22
+
23
+ # Map the indices back to the corresponding bases
24
+ sequence_list = [mapping[i] for i in decoded_indices]
25
+ sequence = ''.join(sequence_list)
26
+
27
+ return sequence
@@ -1,21 +1,57 @@
1
1
  # one_hot_encode
2
2
 
3
- # String encodings
4
- def one_hot_encode(sequence):
3
+ def one_hot_encode(sequence, device='auto'):
5
4
  """
6
- One hot encodes a sequence list.
5
+ One-hot encodes a DNA sequence.
6
+
7
7
  Parameters:
8
- sequence (list): A list of DNA base sequences.
8
+ sequence (str or list): DNA sequence (e.g., "ACGTN" or ['A', 'C', 'G', 'T', 'N']).
9
9
 
10
10
  Returns:
11
- flattened (ndarray): A numpy ndarray holding a flattened one hot encoding of the input sequence string.
11
+ ndarray: Flattened one-hot encoded representation of the input sequence.
12
12
  """
13
13
  import numpy as np
14
14
 
15
- seq_array = np.array(sequence, dtype='<U1') # String dtype
16
15
  mapping = np.array(['A', 'C', 'G', 'T', 'N'])
17
- seq_array[~np.isin(seq_array, mapping)] = 'N'
16
+
17
+ # Ensure input is a list of characters
18
+ if not isinstance(sequence, list):
19
+ sequence = list(sequence) # Convert string to list of characters
20
+
21
+ # Handle empty sequences
22
+ if len(sequence) == 0:
23
+ print("Warning: Empty sequence encountered in one_hot_encode()")
24
+ return np.zeros(len(mapping)) # Return empty encoding instead of failing
25
+
26
+ # Convert sequence to NumPy array
27
+ seq_array = np.array(sequence, dtype='<U1')
28
+
29
+ # Replace invalid bases with 'N'
30
+ seq_array = np.where(np.isin(seq_array, mapping), seq_array, 'N')
31
+
32
+ # Create one-hot encoding matrix
18
33
  one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
19
- flattened = one_hot_matrix.flatten()
20
34
 
21
- return flattened
35
+ # Flatten and return
36
+ return one_hot_matrix.flatten()
37
+
38
+ # import torch
39
+ # bases = torch.tensor([ord('A'), ord('C'), ord('G'), ord('T'), ord('N')], dtype=torch.int8, device=device)
40
+
41
+ # # Convert input to tensor of character ASCII codes
42
+ # seq_tensor = torch.tensor([ord(c) for c in sequence], dtype=torch.int8, device=device)
43
+
44
+ # # Handle empty sequence
45
+ # if seq_tensor.numel() == 0:
46
+ # print("Warning: Empty sequence encountered in one_hot_encode_torch()")
47
+ # return torch.zeros(len(bases), device=device)
48
+
49
+ # # Replace invalid bases with 'N'
50
+ # is_valid = (seq_tensor[:, None] == bases) # Compare each base with mapping
51
+ # seq_tensor = torch.where(is_valid.any(dim=1), seq_tensor, ord('N'))
52
+
53
+ # # Create one-hot encoding matrix
54
+ # one_hot_matrix = (seq_tensor[:, None] == bases).int()
55
+
56
+ # # Flatten and return
57
+ # return one_hot_matrix.flatten()
@@ -18,6 +18,7 @@ def plot_read_length_and_coverage_histograms(bed_file, plotting_directory):
18
18
 
19
19
  bed_basename = os.path.basename(bed_file).split('.bed')[0]
20
20
  # Load the BED file into a DataFrame
21
+ print(f"Loading BED to plot read length and coverage histograms: {bed_file}")
21
22
  df = pd.read_csv(bed_file, sep='\t', header=None, names=['chromosome', 'start', 'end', 'length', 'read_name'])
22
23
 
23
24
  # Group by chromosome
@@ -0,0 +1,28 @@
1
+ def run_multiqc(input_dir, output_dir):
2
+ """
3
+ Runs MultiQC on a given directory and saves the report to the specified output directory.
4
+
5
+ Parameters:
6
+ - input_dir (str): Path to the directory containing QC reports (e.g., FastQC, Samtools, bcftools outputs).
7
+ - output_dir (str): Path to the directory where MultiQC reports should be saved.
8
+
9
+ Returns:
10
+ - None: The function executes MultiQC and prints the status.
11
+ """
12
+ import os
13
+ import subprocess
14
+ # Ensure the output directory exists
15
+ os.makedirs(output_dir, exist_ok=True)
16
+
17
+ # Construct MultiQC command
18
+ command = ["multiqc", input_dir, "-o", output_dir]
19
+
20
+ print(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
21
+
22
+ # Run MultiQC
23
+ try:
24
+ subprocess.run(command, check=True)
25
+ print(f"MultiQC report generated successfully in: {output_dir}")
26
+ except subprocess.CalledProcessError as e:
27
+ print(f"Error running MultiQC: {e}")
28
+
@@ -1,6 +1,6 @@
1
1
  ## split_and_index_BAM
2
2
 
3
- def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_directory, fasta):
3
+ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_directory):
4
4
  """
5
5
  A wrapper function for splitting BAMS and indexing them.
6
6
  Parameters:
@@ -8,7 +8,6 @@ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_direct
8
8
  split_dir (str): A string representing the file path to the directory to split the BAMs into.
9
9
  bam_suffix (str): A suffix to add to the bam file.
10
10
  output_directory (str): A file path to the directory to output all the analyses.
11
- fasta (str): File path to the reference genome to align to.
12
11
 
13
12
  Returns:
14
13
  None
@@ -19,8 +18,6 @@ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_direct
19
18
  import subprocess
20
19
  import glob
21
20
  from .separate_bam_by_bc import separate_bam_by_bc
22
- from .aligned_BAM_to_bed import aligned_BAM_to_bed
23
- from .extract_readnames_from_BAM import extract_readnames_from_BAM
24
21
  from .make_dirs import make_dirs
25
22
 
26
23
  plotting_dir = os.path.join(output_directory, 'demultiplexed_bed_histograms')
@@ -35,7 +32,5 @@ def split_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, output_direct
35
32
  bam_files = [bam for bam in bam_files if '.bai' not in bam]
36
33
  for input_file in bam_files:
37
34
  subprocess.run(["samtools", "index", input_file])
38
- # Make a bed file of coordinates for the BAM
39
- aligned_BAM_to_bed(input_file, plotting_dir, bed_dir, fasta)
40
- # Make a text file of reads for the BAM
41
- extract_readnames_from_BAM(input_file)
35
+
36
+ return bam_files
@@ -14,11 +14,12 @@ def load_adata(config_path):
14
14
  None
15
15
  """
16
16
  # Lazy importing of packages
17
- from .helpers import LoadExperimentConfig, make_dirs, concatenate_fastqs_to_bam
17
+ from .helpers import LoadExperimentConfig, make_dirs, concatenate_fastqs_to_bam, extract_read_features_from_bam
18
18
  from .fast5_to_pod5 import fast5_to_pod5
19
19
  from .subsample_fasta_from_bed import subsample_fasta_from_bed
20
20
  import os
21
21
  import numpy as np
22
+ import anndata as ad
22
23
  from pathlib import Path
23
24
 
24
25
  # Default params
@@ -42,8 +43,13 @@ def load_adata(config_path):
42
43
  fasta_regions_of_interest = var_dict.get("fasta_regions_of_interest", default_value) # Path to a bed file listing coordinate regions of interest within the FASTA to include. Optional.
43
44
  mapping_threshold = var_dict.get('mapping_threshold', default_value) # Minimum proportion of mapped reads that need to fall within a region to include in the final AnnData.
44
45
  experiment_name = var_dict.get('experiment_name', default_value) # A key term to add to the AnnData file name.
46
+ model_dir = var_dict.get('model_dir', default_value) # needed for dorado basecaller
45
47
  model = var_dict.get('model', default_value) # needed for dorado basecaller
46
48
  barcode_kit = var_dict.get('barcode_kit', default_value) # needed for dorado basecaller
49
+ barcode_both_ends = var_dict.get('barcode_both_ends', default_value) # dorado demultiplexing
50
+ trim = var_dict.get('trim', default_value) # dorado adapter and barcode removal
51
+ input_already_demuxed = var_dict.get('input_already_demuxed', default_value) # If the input files are already demultiplexed.
52
+ threads = var_dict.get('threads', default_value) # number of cpu threads available for multiprocessing
47
53
  # Conversion specific variable init
48
54
  conversion_types = var_dict.get('conversion_types', default_value)
49
55
  # Direct methylation specific variable init
@@ -54,6 +60,10 @@ def load_adata(config_path):
54
60
  thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
55
61
  mod_list = var_dict.get('mod_list', default_value)
56
62
  batch_size = var_dict.get('batch_size', default_value)
63
+ device = var_dict.get('device', 'auto')
64
+ make_bigwigs = var_dict.get('make_bigwigs', default_value)
65
+ skip_unclassified = var_dict.get('skip_unclassified', True)
66
+ delete_batch_hdfs = var_dict.get('delete_batch_hdfs', True)
57
67
 
58
68
  # Make initial output directory
59
69
  make_dirs([output_directory])
@@ -119,9 +129,54 @@ def load_adata(config_path):
119
129
 
120
130
  if smf_modality == 'conversion':
121
131
  from .conversion_smf import conversion_smf
122
- conversion_smf(fasta, output_directory, conversions, strands, model, input_data_path, split_path, barcode_kit, mapping_threshold, experiment_name, bam_suffix, basecall)
132
+ final_adata, final_adata_path, sorted_output, bam_files = conversion_smf(fasta, output_directory, conversions, strands, model_dir, model, input_data_path, split_path
133
+ , barcode_kit, mapping_threshold, experiment_name, bam_suffix, basecall, barcode_both_ends, trim, device, make_bigwigs, threads, input_already_demuxed)
123
134
  elif smf_modality == 'direct':
124
135
  from .direct_smf import direct_smf
125
- direct_smf(fasta, output_directory, mod_list, model, thresholds, input_data_path, split_path, barcode_kit, mapping_threshold, experiment_name, bam_suffix, batch_size, basecall)
136
+ # need to add input_already_demuxed workflow here.
137
+ final_adata, final_adata_path, sorted_output, bam_files = direct_smf(fasta, output_directory, mod_list,model_dir, model, thresholds, input_data_path, split_path
138
+ , barcode_kit, mapping_threshold, experiment_name, bam_suffix, batch_size, basecall, barcode_both_ends, trim, device, make_bigwigs, skip_unclassified, delete_batch_hdfs, threads)
126
139
  else:
127
140
  print("Error")
141
+
142
+ # Read in the final adata object and append final metadata
143
+ #print(f'Reading in adata from {final_adata_path} to add final metadata')
144
+ # final_adata = ad.read_h5ad(final_adata_path)
145
+
146
+ # Adding read query length metadata to adata object.
147
+ read_metrics = {}
148
+ for bam_file in bam_files:
149
+ bam_read_metrics = extract_read_features_from_bam(bam_file)
150
+ read_metrics.update(bam_read_metrics)
151
+ #read_metrics = extract_read_features_from_bam(sorted_output)
152
+
153
+ query_read_length_values = []
154
+ query_read_quality_values = []
155
+ reference_lengths = []
156
+ # Iterate over each row of the AnnData object
157
+ for obs_name in final_adata.obs_names:
158
+ # Fetch the value from the dictionary using the obs_name as the key
159
+ value = read_metrics.get(obs_name, np.nan) # Use np.nan if the key is not found
160
+ if type(value) is list:
161
+ query_read_length_values.append(value[0])
162
+ query_read_quality_values.append(value[1])
163
+ reference_lengths.append(value[2])
164
+ else:
165
+ query_read_length_values.append(value)
166
+ query_read_quality_values.append(value)
167
+ reference_lengths.append(value)
168
+
169
+ # Add the new column to adata.obs
170
+ final_adata.obs['query_read_length'] = query_read_length_values
171
+ final_adata.obs['query_read_quality'] = query_read_quality_values
172
+ final_adata.obs['query_length_to_reference_length_ratio'] = np.array(query_read_length_values) / np.array(reference_lengths)
173
+
174
+ final_adata.obs['Raw_methylation_signal'] = np.nansum(final_adata.X, axis=1)
175
+ final_adata.obs['Raw_per_base_methylation_average'] = final_adata.obs['Raw_methylation_signal'] / final_adata.obs['query_read_length']
176
+
177
+ print('Saving final adata')
178
+ if ".gz" in final_adata_path:
179
+ final_adata.write_h5ad(f"{final_adata_path}", compression='gzip')
180
+ else:
181
+ final_adata.write_h5ad(f"{final_adata_path}.gz", compression='gzip')
182
+ print('Final adata saved')
@@ -0,0 +1,15 @@
1
+ from .position_stats import plot_bar_relative_risk, plot_volcano_relative_risk, plot_positionwise_matrix, plot_positionwise_matrix_grid
2
+ from .general_plotting import combined_hmm_raw_clustermap
3
+ from .classifiers import plot_model_performance, plot_feature_importances_or_saliency, plot_model_curves_from_adata, plot_model_curves_from_adata_with_frequency_grid
4
+
5
+ __all__ = [
6
+ "combined_hmm_raw_clustermap",
7
+ "plot_bar_relative_risk",
8
+ "plot_positionwise_matrix",
9
+ "plot_positionwise_matrix_grid",
10
+ "plot_volcano_relative_risk",
11
+ "plot_feature_importances_or_saliency",
12
+ "plot_model_performance",
13
+ "plot_model_curves_from_adata",
14
+ "plot_model_curves_from_adata_with_frequency_grid"
15
+ ]