smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import anndata as ad
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import concurrent.futures
|
|
6
|
+
|
|
7
|
+
def one_hot_encode(sequence, device='auto'):
|
|
8
|
+
"""
|
|
9
|
+
One-hot encodes a DNA sequence.
|
|
10
|
+
|
|
11
|
+
Parameters:
|
|
12
|
+
sequence (str or list): DNA sequence (e.g., "ACGTN" or ['A', 'C', 'G', 'T', 'N']).
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
ndarray: Flattened one-hot encoded representation of the input sequence.
|
|
16
|
+
"""
|
|
17
|
+
mapping = np.array(['A', 'C', 'G', 'T', 'N'])
|
|
18
|
+
|
|
19
|
+
# Ensure input is a list of characters
|
|
20
|
+
if not isinstance(sequence, list):
|
|
21
|
+
sequence = list(sequence) # Convert string to list of characters
|
|
22
|
+
|
|
23
|
+
# Handle empty sequences
|
|
24
|
+
if len(sequence) == 0:
|
|
25
|
+
print("Warning: Empty sequence encountered in one_hot_encode()")
|
|
26
|
+
return np.zeros(len(mapping)) # Return empty encoding instead of failing
|
|
27
|
+
|
|
28
|
+
# Convert sequence to NumPy array
|
|
29
|
+
seq_array = np.array(sequence, dtype='<U1')
|
|
30
|
+
|
|
31
|
+
# Replace invalid bases with 'N'
|
|
32
|
+
seq_array = np.where(np.isin(seq_array, mapping), seq_array, 'N')
|
|
33
|
+
|
|
34
|
+
# Create one-hot encoding matrix
|
|
35
|
+
one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
|
|
36
|
+
|
|
37
|
+
# Flatten and return
|
|
38
|
+
return one_hot_matrix.flatten()
|
|
39
|
+
|
|
40
|
+
def one_hot_decode(ohe_array):
|
|
41
|
+
"""
|
|
42
|
+
Takes a flattened one hot encoded array and returns the sequence string from that array.
|
|
43
|
+
Parameters:
|
|
44
|
+
ohe_array (np.array): A one hot encoded array
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
sequence (str): Sequence string of the one hot encoded array
|
|
48
|
+
"""
|
|
49
|
+
# Define the mapping of one-hot encoded indices to DNA bases
|
|
50
|
+
mapping = ['A', 'C', 'G', 'T', 'N']
|
|
51
|
+
|
|
52
|
+
# Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
|
|
53
|
+
one_hot_matrix = ohe_array.reshape(-1, 5)
|
|
54
|
+
|
|
55
|
+
# Get the index of the maximum value (which will be 1) in each row
|
|
56
|
+
decoded_indices = np.argmax(one_hot_matrix, axis=1)
|
|
57
|
+
|
|
58
|
+
# Map the indices back to the corresponding bases
|
|
59
|
+
sequence_list = [mapping[i] for i in decoded_indices]
|
|
60
|
+
sequence = ''.join(sequence_list)
|
|
61
|
+
|
|
62
|
+
return sequence
|
|
63
|
+
|
|
64
|
+
def ohe_layers_decode(adata, obs_names):
|
|
65
|
+
"""
|
|
66
|
+
Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
|
|
67
|
+
Parameters:
|
|
68
|
+
adata (AnnData): An anndata object.
|
|
69
|
+
obs_names (list): A list of observation name strings to retrieve sequences for.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
sequences (list of str): List of strings of the one hot encoded array
|
|
73
|
+
"""
|
|
74
|
+
# Define the mapping of one-hot encoded indices to DNA bases
|
|
75
|
+
mapping = ['A', 'C', 'G', 'T', 'N']
|
|
76
|
+
|
|
77
|
+
ohe_layers = [f"{base}_binary_encoding" for base in mapping]
|
|
78
|
+
sequences = []
|
|
79
|
+
|
|
80
|
+
for obs_name in obs_names:
|
|
81
|
+
obs_subset = adata[obs_name]
|
|
82
|
+
ohe_list = []
|
|
83
|
+
for layer in ohe_layers:
|
|
84
|
+
ohe_list += list(obs_subset.layers[layer])
|
|
85
|
+
ohe_array = np.array(ohe_list)
|
|
86
|
+
sequence = one_hot_decode(ohe_array)
|
|
87
|
+
sequences.append(sequence)
|
|
88
|
+
|
|
89
|
+
return sequences
|
|
90
|
+
|
|
91
|
+
def _encode_sequence(args):
|
|
92
|
+
"""Parallel helper function for one-hot encoding."""
|
|
93
|
+
read_name, seq, device = args
|
|
94
|
+
try:
|
|
95
|
+
one_hot_matrix = one_hot_encode(seq, device)
|
|
96
|
+
return read_name, one_hot_matrix
|
|
97
|
+
except Exception:
|
|
98
|
+
return None # Skip invalid sequences
|
|
99
|
+
|
|
100
|
+
def _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
|
|
101
|
+
"""Encodes a batch and writes to disk immediately."""
|
|
102
|
+
batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
|
|
103
|
+
|
|
104
|
+
if batch:
|
|
105
|
+
save_name = os.path.join(tmp_dir, f'tmp_{prefix}_{record}_{batch_number}.h5ad')
|
|
106
|
+
tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
|
|
107
|
+
tmp_ad.write_h5ad(save_name)
|
|
108
|
+
return save_name
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000, progress_bar=None, device='auto', threads=None):
|
|
112
|
+
"""
|
|
113
|
+
Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
|
|
114
|
+
|
|
115
|
+
Parameters:
|
|
116
|
+
base_identities (dict): Dictionary mapping read names to sequences.
|
|
117
|
+
tmp_dir (str): Directory for storing temporary files.
|
|
118
|
+
record (str): Record name.
|
|
119
|
+
prefix (str): Prefix for file naming.
|
|
120
|
+
batch_size (int): Number of reads per batch.
|
|
121
|
+
progress_bar (tqdm instance, optional): Shared progress bar.
|
|
122
|
+
device (str): Device for encoding.
|
|
123
|
+
threads (int, optional): Number of parallel workers.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
list: List of valid H5AD file paths.
|
|
127
|
+
"""
|
|
128
|
+
threads = threads or os.cpu_count() # Default to max available CPU cores
|
|
129
|
+
batch_data = []
|
|
130
|
+
batch_number = 0
|
|
131
|
+
file_names = []
|
|
132
|
+
|
|
133
|
+
# Step 1: Prepare Data for Parallel Encoding
|
|
134
|
+
encoding_args = [(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None]
|
|
135
|
+
|
|
136
|
+
# Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
|
|
137
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
|
|
138
|
+
for result in executor.map(_encode_sequence, encoding_args):
|
|
139
|
+
if result:
|
|
140
|
+
batch_data.append(result)
|
|
141
|
+
|
|
142
|
+
if len(batch_data) >= batch_size:
|
|
143
|
+
# Step 3: Process and Write Batch Immediately
|
|
144
|
+
file_name = _encode_and_save_batch(batch_data.copy(), tmp_dir, prefix, record, batch_number)
|
|
145
|
+
if file_name:
|
|
146
|
+
file_names.append(file_name)
|
|
147
|
+
|
|
148
|
+
batch_data.clear()
|
|
149
|
+
batch_number += 1
|
|
150
|
+
|
|
151
|
+
if progress_bar:
|
|
152
|
+
progress_bar.update(1)
|
|
153
|
+
|
|
154
|
+
# Step 4: Process Remaining Batch
|
|
155
|
+
if batch_data:
|
|
156
|
+
file_name = _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number)
|
|
157
|
+
if file_name:
|
|
158
|
+
file_names.append(file_name)
|
|
159
|
+
|
|
160
|
+
return file_names
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from ..config import LoadExperimentConfig
|
|
2
|
+
from ..readwrite import make_dirs
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import subprocess
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pod5 as p5
|
|
9
|
+
|
|
10
|
+
from typing import Union, List
|
|
11
|
+
|
|
12
|
+
def basecall_pod5s(config_path):
|
|
13
|
+
"""
|
|
14
|
+
Basecall from pod5s given a config file.
|
|
15
|
+
|
|
16
|
+
Parameters:
|
|
17
|
+
config_path (str): File path to the basecall configuration file
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
None
|
|
21
|
+
"""
|
|
22
|
+
# Default params
|
|
23
|
+
bam_suffix = '.bam' # If different, change from here.
|
|
24
|
+
|
|
25
|
+
# Load experiment config parameters into global variables
|
|
26
|
+
experiment_config = LoadExperimentConfig(config_path)
|
|
27
|
+
var_dict = experiment_config.var_dict
|
|
28
|
+
|
|
29
|
+
# These below variables will point to default_value if they are empty in the experiment_config.csv or if the variable is fully omitted from the csv.
|
|
30
|
+
default_value = None
|
|
31
|
+
|
|
32
|
+
# General config variable init
|
|
33
|
+
input_data_path = Path(var_dict.get('input_data_path', default_value)) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
|
|
34
|
+
output_directory = Path(var_dict.get('output_directory', default_value)) # Path to the output directory to make for the analysis. Necessary.
|
|
35
|
+
model = var_dict.get('model', default_value) # needed for dorado basecaller
|
|
36
|
+
model_dir = Path(var_dict.get('model_dir', default_value)) # model directory
|
|
37
|
+
barcode_kit = var_dict.get('barcode_kit', default_value) # needed for dorado basecaller
|
|
38
|
+
barcode_both_ends = var_dict.get('barcode_both_ends', default_value) # dorado demultiplexing
|
|
39
|
+
trim = var_dict.get('trim', default_value) # dorado adapter and barcode removal
|
|
40
|
+
device = var_dict.get('device', 'auto')
|
|
41
|
+
|
|
42
|
+
# Modified basecalling specific variable init
|
|
43
|
+
filter_threshold = var_dict.get('filter_threshold', default_value)
|
|
44
|
+
m6A_threshold = var_dict.get('m6A_threshold', default_value)
|
|
45
|
+
m5C_threshold = var_dict.get('m5C_threshold', default_value)
|
|
46
|
+
hm5C_threshold = var_dict.get('hm5C_threshold', default_value)
|
|
47
|
+
thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
|
|
48
|
+
mod_list = var_dict.get('mod_list', default_value)
|
|
49
|
+
|
|
50
|
+
# Make initial output directory
|
|
51
|
+
make_dirs([output_directory])
|
|
52
|
+
|
|
53
|
+
# Get the input filetype
|
|
54
|
+
if input_data_path.is_file():
|
|
55
|
+
input_data_filetype = input_data_path.suffixes[0]
|
|
56
|
+
input_is_pod5 = input_data_filetype in ['.pod5','.p5']
|
|
57
|
+
input_is_fast5 = input_data_filetype in ['.fast5','.f5']
|
|
58
|
+
|
|
59
|
+
elif input_data_path.is_dir():
|
|
60
|
+
# Get the file names in the input data dir
|
|
61
|
+
input_files = input_data_path.iterdir()
|
|
62
|
+
input_is_pod5 = sum([True for file in input_files if '.pod5' in file or '.p5' in file])
|
|
63
|
+
input_is_fast5 = sum([True for file in input_files if '.fast5' in file or '.f5' in file])
|
|
64
|
+
|
|
65
|
+
# If the input files are not pod5 files, and they are fast5 files, convert the files to a pod5 file before proceeding.
|
|
66
|
+
if input_is_fast5 and not input_is_pod5:
|
|
67
|
+
# take the input directory of fast5 files and write out a single pod5 file into the output directory.
|
|
68
|
+
output_pod5 = output_directory / 'FAST5s_to_POD5.pod5'
|
|
69
|
+
print(f'Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}')
|
|
70
|
+
fast5_to_pod5(input_data_path, output_pod5)
|
|
71
|
+
# Reassign the pod5_dir variable to point to the new pod5 file.
|
|
72
|
+
input_data_path = output_pod5
|
|
73
|
+
|
|
74
|
+
model_basename = model.name
|
|
75
|
+
model_basename = model_basename.replace('.', '_')
|
|
76
|
+
|
|
77
|
+
if mod_list:
|
|
78
|
+
mod_string = "_".join(mod_list)
|
|
79
|
+
bam = output_directory / f"{model_basename}_{mod_string}_calls"
|
|
80
|
+
modcall(model, input_data_path, barcode_kit, mod_list, bam, bam_suffix, barcode_both_ends, trim, device)
|
|
81
|
+
else:
|
|
82
|
+
bam = output_directory / f"{model_basename}_canonical_basecalls"
|
|
83
|
+
canoncall(model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def fast5_to_pod5(
|
|
87
|
+
fast5_dir: Union[str, Path, List[Union[str, Path]]],
|
|
88
|
+
output_pod5: Union[str, Path] = "FAST5s_to_POD5.pod5"
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Convert Nanopore FAST5 files (single file, list of files, or directory)
|
|
92
|
+
into a single .pod5 output using the 'pod5 convert fast5' CLI tool.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
output_pod5 = str(output_pod5) # ensure string
|
|
96
|
+
|
|
97
|
+
# 1) If user gives a list of FAST5 files
|
|
98
|
+
if isinstance(fast5_dir, (list, tuple)):
|
|
99
|
+
fast5_paths = [str(Path(f)) for f in fast5_dir]
|
|
100
|
+
cmd = ["pod5", "convert", "fast5", *fast5_paths, "--output", output_pod5]
|
|
101
|
+
subprocess.run(cmd, check=True)
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
# Ensure Path object
|
|
105
|
+
p = Path(fast5_dir)
|
|
106
|
+
|
|
107
|
+
# 2) If user gives a single file
|
|
108
|
+
if p.is_file():
|
|
109
|
+
cmd = ["pod5", "convert", "fast5", str(p), "--output", output_pod5]
|
|
110
|
+
subprocess.run(cmd, check=True)
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
# 3) If user gives a directory → collect FAST5s
|
|
114
|
+
if p.is_dir():
|
|
115
|
+
fast5_paths = sorted(str(f) for f in p.glob("*.fast5"))
|
|
116
|
+
if not fast5_paths:
|
|
117
|
+
raise FileNotFoundError(f"No FAST5 files found in {p}")
|
|
118
|
+
|
|
119
|
+
cmd = ["pod5", "convert", "fast5", *fast5_paths, "--output", output_pod5]
|
|
120
|
+
subprocess.run(cmd, check=True)
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
raise FileNotFoundError(f"Input path invalid: {fast5_dir}")
|
|
124
|
+
|
|
125
|
+
def subsample_pod5(pod5_path, read_name_path, output_directory):
|
|
126
|
+
"""
|
|
127
|
+
Takes a POD5 file and a text file containing read names of interest and writes out a subsampled POD5 for just those reads.
|
|
128
|
+
This is a useful function when you have a list of read names that mapped to a region of interest that you want to reanalyze from the pod5 level.
|
|
129
|
+
|
|
130
|
+
Parameters:
|
|
131
|
+
pod5_path (str): File path to the POD5 file (or directory of multiple pod5 files) to subsample.
|
|
132
|
+
read_name_path (str | int): File path to a text file of read names. One read name per line. If an int value is passed, a random subset of that many reads will occur
|
|
133
|
+
output_directory (str): A file path to the directory to output the file.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
None
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
if os.path.isdir(pod5_path):
|
|
140
|
+
pod5_path_is_dir = True
|
|
141
|
+
input_pod5_base = 'input_pod5s.pod5'
|
|
142
|
+
files = os.listdir(pod5_path)
|
|
143
|
+
pod5_files = [os.path.join(pod5_path, file) for file in files if '.pod5' in file]
|
|
144
|
+
pod5_files.sort()
|
|
145
|
+
print(f'Found input pod5s: {pod5_files}')
|
|
146
|
+
|
|
147
|
+
elif os.path.exists(pod5_path):
|
|
148
|
+
pod5_path_is_dir = False
|
|
149
|
+
input_pod5_base = os.path.basename(pod5_path)
|
|
150
|
+
|
|
151
|
+
else:
|
|
152
|
+
print('Error: pod5_path passed does not exist')
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
if type(read_name_path) == str:
|
|
156
|
+
input_read_name_base = os.path.basename(read_name_path)
|
|
157
|
+
output_base = input_pod5_base.split('.pod5')[0] + '_' + input_read_name_base.split('.txt')[0] + '_subsampled.pod5'
|
|
158
|
+
|
|
159
|
+
# extract read names into a list of strings
|
|
160
|
+
with open(read_name_path, 'r') as file:
|
|
161
|
+
read_names = [line.strip() for line in file]
|
|
162
|
+
|
|
163
|
+
print(f'Looking for read_ids: {read_names}')
|
|
164
|
+
read_records = []
|
|
165
|
+
|
|
166
|
+
if pod5_path_is_dir:
|
|
167
|
+
for input_pod5 in pod5_files:
|
|
168
|
+
with p5.Reader(input_pod5) as reader:
|
|
169
|
+
try:
|
|
170
|
+
for read_record in reader.reads(selection=read_names, missing_ok=True):
|
|
171
|
+
read_records.append(read_record.to_read())
|
|
172
|
+
print(f'Found read in {input_pod5}: {read_record.read_id}')
|
|
173
|
+
except:
|
|
174
|
+
print('Skipping pod5, could not find reads')
|
|
175
|
+
else:
|
|
176
|
+
with p5.Reader(pod5_path) as reader:
|
|
177
|
+
try:
|
|
178
|
+
for read_record in reader.reads(selection=read_names):
|
|
179
|
+
read_records.append(read_record.to_read())
|
|
180
|
+
print(f'Found read in {input_pod5}: {read_record}')
|
|
181
|
+
except:
|
|
182
|
+
print('Could not find reads')
|
|
183
|
+
|
|
184
|
+
elif type(read_name_path) == int:
|
|
185
|
+
import random
|
|
186
|
+
output_base = input_pod5_base.split('.pod5')[0] + f'_{read_name_path}_randomly_subsampled.pod5'
|
|
187
|
+
all_read_records = []
|
|
188
|
+
|
|
189
|
+
if pod5_path_is_dir:
|
|
190
|
+
# Shuffle the list of input pod5 paths
|
|
191
|
+
random.shuffle(pod5_files)
|
|
192
|
+
for input_pod5 in pod5_files:
|
|
193
|
+
# iterate over the input pod5s
|
|
194
|
+
print(f'Opening pod5 file {input_pod5}')
|
|
195
|
+
with p5.Reader(pod5_path) as reader:
|
|
196
|
+
for read_record in reader.reads():
|
|
197
|
+
all_read_records.append(read_record.to_read())
|
|
198
|
+
# When enough reads are in all_read_records, stop accumulating reads.
|
|
199
|
+
if len(all_read_records) >= read_name_path:
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
if read_name_path <= len(all_read_records):
|
|
203
|
+
read_records = random.sample(all_read_records, read_name_path)
|
|
204
|
+
else:
|
|
205
|
+
print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
|
|
206
|
+
read_records = all_read_records
|
|
207
|
+
|
|
208
|
+
else:
|
|
209
|
+
with p5.Reader(pod5_path) as reader:
|
|
210
|
+
for read_record in reader.reads():
|
|
211
|
+
# get all read records from the input pod5
|
|
212
|
+
all_read_records.append(read_record.to_read())
|
|
213
|
+
if read_name_path <= len(all_read_records):
|
|
214
|
+
# if the subsampling amount is less than the record amount in the file, randomly subsample the reads
|
|
215
|
+
read_records = random.sample(all_read_records, read_name_path)
|
|
216
|
+
else:
|
|
217
|
+
print('Trying to sample more reads than are contained in the input pod5s, taking all reads')
|
|
218
|
+
read_records = all_read_records
|
|
219
|
+
|
|
220
|
+
output_pod5 = os.path.join(output_directory, output_base)
|
|
221
|
+
|
|
222
|
+
# Write the subsampled POD5
|
|
223
|
+
with p5.Writer(output_pod5) as writer:
|
|
224
|
+
writer.add_reads(read_records)
|
|
@@ -9,10 +9,13 @@ def run_multiqc(input_dir, output_dir):
|
|
|
9
9
|
Returns:
|
|
10
10
|
- None: The function executes MultiQC and prints the status.
|
|
11
11
|
"""
|
|
12
|
-
import
|
|
12
|
+
from ..readwrite import make_dirs
|
|
13
13
|
import subprocess
|
|
14
14
|
# Ensure the output directory exists
|
|
15
|
-
|
|
15
|
+
make_dirs(output_dir)
|
|
16
|
+
|
|
17
|
+
input_dir = str(input_dir)
|
|
18
|
+
output_dir = str(output_dir)
|
|
16
19
|
|
|
17
20
|
# Construct MultiQC command
|
|
18
21
|
command = ["multiqc", input_dir, "-o", output_dir]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from . import models
|
|
2
|
+
from . import data
|
|
3
|
+
from . import utils
|
|
4
|
+
from . import evaluation
|
|
5
|
+
from . import inference
|
|
6
|
+
from . import training
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"calculate_relative_risk_on_activity",
|
|
10
|
+
"evaluate_models_by_subgroup",
|
|
11
|
+
"prepare_melted_model_data",
|
|
12
|
+
]
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
|
|
3
|
+
import pytorch_lightning as pl
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from .preprocessing import random_fill_nans
|
|
7
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnnDataDataset(Dataset):
|
|
11
|
+
"""
|
|
12
|
+
Generic PyTorch Dataset from AnnData.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, adata, tensor_source="X", tensor_key=None, label_col=None, window_start=None, window_size=None):
|
|
15
|
+
self.adata = adata
|
|
16
|
+
self.tensor_source = tensor_source
|
|
17
|
+
self.tensor_key = tensor_key
|
|
18
|
+
self.label_col = label_col
|
|
19
|
+
self.window_start = window_start
|
|
20
|
+
self.window_size = window_size
|
|
21
|
+
|
|
22
|
+
if tensor_source == "X":
|
|
23
|
+
X = adata.X
|
|
24
|
+
elif tensor_source == "layers":
|
|
25
|
+
assert tensor_key in adata.layers
|
|
26
|
+
X = adata.layers[tensor_key]
|
|
27
|
+
elif tensor_source == "obsm":
|
|
28
|
+
assert tensor_key in adata.obsm
|
|
29
|
+
X = adata.obsm[tensor_key]
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Invalid tensor_source: {tensor_source}")
|
|
32
|
+
|
|
33
|
+
if self.window_start is not None and self.window_size is not None:
|
|
34
|
+
X = X[:, self.window_start : self.window_start + self.window_size]
|
|
35
|
+
|
|
36
|
+
X = random_fill_nans(X)
|
|
37
|
+
|
|
38
|
+
self.X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
39
|
+
|
|
40
|
+
if label_col is not None:
|
|
41
|
+
y = adata.obs[label_col]
|
|
42
|
+
if y.dtype.name == 'category':
|
|
43
|
+
y = y.cat.codes
|
|
44
|
+
self.y_tensor = torch.tensor(y.values, dtype=torch.long)
|
|
45
|
+
else:
|
|
46
|
+
self.y_tensor = None
|
|
47
|
+
|
|
48
|
+
def numpy(self, indices):
|
|
49
|
+
return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
|
|
50
|
+
|
|
51
|
+
def __len__(self):
|
|
52
|
+
return len(self.X_tensor)
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, idx):
|
|
55
|
+
x = self.X_tensor[idx]
|
|
56
|
+
if self.y_tensor is not None:
|
|
57
|
+
y = self.y_tensor[idx]
|
|
58
|
+
return x, y
|
|
59
|
+
else:
|
|
60
|
+
return (x,)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
|
|
64
|
+
random_seed=42, split_col="train_val_test_split",
|
|
65
|
+
load_existing_split=False, split_save_path=None):
|
|
66
|
+
"""
|
|
67
|
+
Perform split and record assignment into adata.obs[split_col].
|
|
68
|
+
"""
|
|
69
|
+
total_len = len(dataset)
|
|
70
|
+
|
|
71
|
+
if load_existing_split:
|
|
72
|
+
if split_col in adata.obs:
|
|
73
|
+
pass # use existing
|
|
74
|
+
elif split_save_path:
|
|
75
|
+
split_df = pd.read_csv(split_save_path, index_col=0)
|
|
76
|
+
adata.obs[split_col] = split_df.loc[adata.obs_names][split_col].values
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("No existing split column found and no file provided.")
|
|
79
|
+
else:
|
|
80
|
+
indices = np.arange(total_len)
|
|
81
|
+
np.random.seed(random_seed)
|
|
82
|
+
np.random.shuffle(indices)
|
|
83
|
+
|
|
84
|
+
n_train = int(train_frac * total_len)
|
|
85
|
+
n_val = int(val_frac * total_len)
|
|
86
|
+
n_test = total_len - n_train - n_val
|
|
87
|
+
|
|
88
|
+
split_array = np.full(total_len, "test", dtype=object)
|
|
89
|
+
split_array[indices[:n_train]] = "train"
|
|
90
|
+
split_array[indices[n_train:n_train + n_val]] = "val"
|
|
91
|
+
adata.obs[split_col] = split_array
|
|
92
|
+
|
|
93
|
+
if split_save_path:
|
|
94
|
+
adata.obs[[split_col]].to_csv(split_save_path)
|
|
95
|
+
|
|
96
|
+
split_labels = adata.obs[split_col].values
|
|
97
|
+
train_indices = np.where(split_labels == "train")[0]
|
|
98
|
+
val_indices = np.where(split_labels == "val")[0]
|
|
99
|
+
test_indices = np.where(split_labels == "test")[0]
|
|
100
|
+
|
|
101
|
+
train_set = Subset(dataset, train_indices)
|
|
102
|
+
val_set = Subset(dataset, val_indices)
|
|
103
|
+
test_set = Subset(dataset, test_indices)
|
|
104
|
+
|
|
105
|
+
return train_set, val_set, test_set
|
|
106
|
+
|
|
107
|
+
class AnnDataModule(pl.LightningDataModule):
|
|
108
|
+
"""
|
|
109
|
+
Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
|
|
110
|
+
"""
|
|
111
|
+
def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
|
|
112
|
+
batch_size=64, train_frac=0.6, val_frac=0.1, test_frac=0.3, random_seed=42,
|
|
113
|
+
inference_mode=False, split_col="train_val_test_split", split_save_path=None,
|
|
114
|
+
load_existing_split=False, window_start=None, window_size=None, num_workers=None, persistent_workers=False):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.adata = adata
|
|
117
|
+
self.tensor_source = tensor_source
|
|
118
|
+
self.tensor_key = tensor_key
|
|
119
|
+
self.label_col = label_col
|
|
120
|
+
self.batch_size = batch_size
|
|
121
|
+
self.train_frac = train_frac
|
|
122
|
+
self.val_frac = val_frac
|
|
123
|
+
self.test_frac = test_frac
|
|
124
|
+
self.random_seed = random_seed
|
|
125
|
+
self.inference_mode = inference_mode
|
|
126
|
+
self.split_col = split_col
|
|
127
|
+
self.split_save_path = split_save_path
|
|
128
|
+
self.load_existing_split = load_existing_split
|
|
129
|
+
self.var_names = adata.var_names.copy()
|
|
130
|
+
self.window_start = window_start
|
|
131
|
+
self.window_size = window_size
|
|
132
|
+
self.num_workers = num_workers
|
|
133
|
+
self.persistent_workers = persistent_workers
|
|
134
|
+
|
|
135
|
+
def setup(self, stage=None):
|
|
136
|
+
dataset = AnnDataDataset(self.adata, self.tensor_source, self.tensor_key,
|
|
137
|
+
None if self.inference_mode else self.label_col,
|
|
138
|
+
window_start=self.window_start, window_size=self.window_size)
|
|
139
|
+
|
|
140
|
+
if self.inference_mode:
|
|
141
|
+
self.infer_dataset = dataset
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
self.train_set, self.val_set, self.test_set = split_dataset(
|
|
145
|
+
self.adata, dataset, train_frac=self.train_frac, val_frac=self.val_frac,
|
|
146
|
+
test_frac=self.test_frac, random_seed=self.random_seed,
|
|
147
|
+
split_col=self.split_col, split_save_path=self.split_save_path,
|
|
148
|
+
load_existing_split=self.load_existing_split
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def train_dataloader(self):
|
|
152
|
+
if self.num_workers:
|
|
153
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
|
|
154
|
+
else:
|
|
155
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
|
|
156
|
+
|
|
157
|
+
def val_dataloader(self):
|
|
158
|
+
if self.num_workers:
|
|
159
|
+
return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
|
|
160
|
+
else:
|
|
161
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
162
|
+
|
|
163
|
+
def test_dataloader(self):
|
|
164
|
+
if self.num_workers:
|
|
165
|
+
return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
|
|
166
|
+
else:
|
|
167
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
168
|
+
|
|
169
|
+
def predict_dataloader(self):
|
|
170
|
+
if not self.inference_mode:
|
|
171
|
+
raise RuntimeError("Only valid in inference mode")
|
|
172
|
+
return DataLoader(self.infer_dataset, batch_size=self.batch_size)
|
|
173
|
+
|
|
174
|
+
def compute_class_weights(self):
|
|
175
|
+
train_indices = self.train_set.indices # get the indices of the training set
|
|
176
|
+
y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
|
|
177
|
+
y_train = y_all[train_indices].cpu().numpy() # get the labels for the training set and move to a numpy array
|
|
178
|
+
|
|
179
|
+
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
|
|
180
|
+
return torch.tensor(class_weights, dtype=torch.float32)
|
|
181
|
+
|
|
182
|
+
def inference_numpy(self):
|
|
183
|
+
"""
|
|
184
|
+
Return inference data as numpy for use in sklearn inference.
|
|
185
|
+
"""
|
|
186
|
+
if not self.inference_mode:
|
|
187
|
+
raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
|
|
188
|
+
X_np = self.infer_dataset.X_tensor.numpy()
|
|
189
|
+
return X_np
|
|
190
|
+
|
|
191
|
+
def to_numpy(self):
|
|
192
|
+
"""
|
|
193
|
+
Move the AnnDataModule tensors into numpy arrays
|
|
194
|
+
"""
|
|
195
|
+
if not self.inference_mode:
|
|
196
|
+
train_X, train_y = self.train_set.dataset.numpy(self.train_set.indices)
|
|
197
|
+
val_X, val_y = self.val_set.dataset.numpy(self.val_set.indices)
|
|
198
|
+
test_X, test_Y = self.test_set.dataset.numpy(self.test_set.indices)
|
|
199
|
+
return train_X, train_y, val_X, val_y, test_X, test_Y
|
|
200
|
+
else:
|
|
201
|
+
return self.inference_numpy()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def build_anndata_loader(
|
|
205
|
+
adata, tensor_source="X", tensor_key=None, label_col=None, train_frac=0.6, val_frac=0.1,
|
|
206
|
+
test_frac=0.3, random_seed=42, batch_size=64, lightning=True, inference_mode=False,
|
|
207
|
+
split_col="train_val_test_split", split_save_path=None, load_existing_split=False
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Unified pipeline for both Lightning and raw PyTorch.
|
|
211
|
+
The lightning loader works for both Lightning and the Sklearn wrapper.
|
|
212
|
+
Set lightning to False if you want to make data loaders for base PyTorch or base sklearn models
|
|
213
|
+
"""
|
|
214
|
+
if lightning:
|
|
215
|
+
return AnnDataModule(
|
|
216
|
+
adata, tensor_source=tensor_source, tensor_key=tensor_key, label_col=label_col,
|
|
217
|
+
batch_size=batch_size, train_frac=train_frac, val_frac=val_frac, test_frac=test_frac,
|
|
218
|
+
random_seed=random_seed, inference_mode=inference_mode,
|
|
219
|
+
split_col=split_col, split_save_path=split_save_path, load_existing_split=load_existing_split
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
var_names = adata.var_names.copy()
|
|
223
|
+
dataset = AnnDataDataset(adata, tensor_source, tensor_key, None if inference_mode else label_col)
|
|
224
|
+
if inference_mode:
|
|
225
|
+
return DataLoader(dataset, batch_size=batch_size)
|
|
226
|
+
else:
|
|
227
|
+
train_set, val_set, test_set = split_dataset(
|
|
228
|
+
adata, dataset, train_frac, val_frac, test_frac, random_seed,
|
|
229
|
+
split_col, split_save_path, load_existing_split
|
|
230
|
+
)
|
|
231
|
+
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
|
232
|
+
val_loader = DataLoader(val_set, batch_size=batch_size)
|
|
233
|
+
test_loader = DataLoader(test_set, batch_size=batch_size)
|
|
234
|
+
return train_loader, val_loader, test_loader
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
def flatten_sliding_window_results(results_dict):
|
|
4
|
+
"""
|
|
5
|
+
Flatten nested sliding window results into pandas DataFrame.
|
|
6
|
+
|
|
7
|
+
Expects structure:
|
|
8
|
+
results[model_name][window_size][window_center]['metrics'][metric_name]
|
|
9
|
+
"""
|
|
10
|
+
records = []
|
|
11
|
+
|
|
12
|
+
for model_name, model_results in results_dict.items():
|
|
13
|
+
for window_size, window_results in model_results.items():
|
|
14
|
+
for center_var, result in window_results.items():
|
|
15
|
+
metrics = result['metrics']
|
|
16
|
+
record = {
|
|
17
|
+
'model': model_name,
|
|
18
|
+
'window_size': window_size,
|
|
19
|
+
'center_var': center_var
|
|
20
|
+
}
|
|
21
|
+
# Add all metrics
|
|
22
|
+
record.update(metrics)
|
|
23
|
+
records.append(record)
|
|
24
|
+
|
|
25
|
+
df = pd.DataFrame.from_records(records)
|
|
26
|
+
|
|
27
|
+
# Convert center_var to numeric if possible (optional but helpful for plotting)
|
|
28
|
+
df['center_var'] = pd.to_numeric(df['center_var'], errors='coerce')
|
|
29
|
+
df = df.sort_values(['model', 'window_size', 'center_var'])
|
|
30
|
+
|
|
31
|
+
return df
|