smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
## calculate_read_length_stats
|
|
2
2
|
|
|
3
3
|
# Read length QC
|
|
4
|
-
def calculate_read_length_stats(adata, reference_column, sample_names_col
|
|
4
|
+
def calculate_read_length_stats(adata, reference_column='', sample_names_col=''):
|
|
5
5
|
"""
|
|
6
6
|
Append first valid position in a read and last valid position in the read. From this determine and append the read length.
|
|
7
7
|
|
|
@@ -9,9 +9,6 @@ def calculate_read_length_stats(adata, reference_column, sample_names_col, outpu
|
|
|
9
9
|
adata (AnnData): An adata object
|
|
10
10
|
reference_column (str): String representing the name of the Reference column to use
|
|
11
11
|
sample_names_col (str): String representing the name of the sample name column to use
|
|
12
|
-
output_directory (str): String representing the output directory to make and write out the histograms.
|
|
13
|
-
show_read_length_histogram (bool): Whether to display the histograms.
|
|
14
|
-
save_read_length_histogram (bool): Whether to save the histograms.
|
|
15
12
|
|
|
16
13
|
Returns:
|
|
17
14
|
upper_bound (int): last valid position in the dataset
|
|
@@ -20,11 +17,8 @@ def calculate_read_length_stats(adata, reference_column, sample_names_col, outpu
|
|
|
20
17
|
import numpy as np
|
|
21
18
|
import anndata as ad
|
|
22
19
|
import pandas as pd
|
|
23
|
-
import matplotlib.pyplot as plt
|
|
24
|
-
from .. import readwrite
|
|
25
|
-
from .make_dirs import make_dirs
|
|
26
20
|
|
|
27
|
-
|
|
21
|
+
print('Calculating read length statistics')
|
|
28
22
|
|
|
29
23
|
references = set(adata.obs[reference_column])
|
|
30
24
|
sample_names = set(adata.obs[sample_names_col])
|
|
@@ -44,43 +38,42 @@ def calculate_read_length_stats(adata, reference_column, sample_names_col, outpu
|
|
|
44
38
|
upper_bound = int(np.nanmax(adata.obs['last_valid_position']))
|
|
45
39
|
lower_bound = int(np.nanmin(adata.obs['first_valid_position']))
|
|
46
40
|
|
|
47
|
-
|
|
41
|
+
return upper_bound, lower_bound
|
|
48
42
|
|
|
49
|
-
|
|
50
|
-
|
|
43
|
+
# # Add an unstructured element to the anndata object which points to a dictionary of read lengths keyed by reference and sample name. Points to a tuple containing (mean, median, stdev) of the read lengths of the sample for the given reference strand
|
|
44
|
+
# ## Plot histogram of read length data and save the median and stdev of the read lengths for each sample.
|
|
45
|
+
# adata.uns['read_length_dict'] = {}
|
|
51
46
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
return upper_bound, lower_bound
|
|
47
|
+
# for reference in references:
|
|
48
|
+
# temp_reference_adata = adata[adata.obs[reference_column] == reference].copy()
|
|
49
|
+
# split_reference = reference.split('_')[0][1:]
|
|
50
|
+
# for sample in sample_names:
|
|
51
|
+
# temp_sample_adata = temp_reference_adata[temp_reference_adata.obs[sample_names_col] == sample].copy()
|
|
52
|
+
# temp_data = temp_sample_adata.obs['read_length']
|
|
53
|
+
# max_length = np.max(temp_data)
|
|
54
|
+
# mean = np.mean(temp_data)
|
|
55
|
+
# median = np.median(temp_data)
|
|
56
|
+
# stdev = np.std(temp_data)
|
|
57
|
+
# adata.uns['read_length_dict'][f'{reference}_{sample}'] = [mean, median, stdev]
|
|
58
|
+
# if not np.isnan(max_length):
|
|
59
|
+
# n_bins = int(max_length // 100)
|
|
60
|
+
# else:
|
|
61
|
+
# n_bins = 1
|
|
62
|
+
# if show_read_length_histogram or save_read_length_histogram:
|
|
63
|
+
# plt.figure(figsize=(10, 6))
|
|
64
|
+
# plt.text(median + 0.5, max(plt.hist(temp_data, bins=n_bins)[0]) / 2, f'Median: {median:.2f}', color='red')
|
|
65
|
+
# plt.hist(temp_data, bins=n_bins, alpha=0.7, color='blue', edgecolor='black')
|
|
66
|
+
# plt.xlabel('Read Length')
|
|
67
|
+
# plt.ylabel('Count')
|
|
68
|
+
# title = f'Read length distribution of {temp_sample_adata.shape[0]} total reads from {sample} sample on {split_reference} allele'
|
|
69
|
+
# plt.title(title)
|
|
70
|
+
# # Add a vertical line at the median
|
|
71
|
+
# plt.axvline(median, color='red', linestyle='dashed', linewidth=1)
|
|
72
|
+
# # Annotate the median
|
|
73
|
+
# plt.xlim(lower_bound - 100, upper_bound + 100)
|
|
74
|
+
# if save_read_length_histogram:
|
|
75
|
+
# save_name = output_directory + f'/{readwrite.date_string()} {title}'
|
|
76
|
+
# plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
|
|
77
|
+
# plt.close()
|
|
78
|
+
# else:
|
|
79
|
+
# plt.show()
|
|
@@ -1,38 +1,46 @@
|
|
|
1
|
-
## clean_NaN
|
|
2
|
-
|
|
3
1
|
def clean_NaN(adata, layer=None):
|
|
4
2
|
"""
|
|
5
3
|
Append layers to adata that contain NaN cleaning strategies.
|
|
6
4
|
|
|
7
5
|
Parameters:
|
|
8
|
-
adata (AnnData): an
|
|
9
|
-
layer (str):
|
|
6
|
+
adata (AnnData): an anndata object
|
|
7
|
+
layer (str, optional): Name of the layer to fill NaN values in. If None, uses adata.X.
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
Modifies:
|
|
10
|
+
- Adds new layers to `adata.layers` with different NaN-filling strategies.
|
|
13
11
|
"""
|
|
14
12
|
import numpy as np
|
|
15
|
-
import anndata as ad
|
|
16
13
|
import pandas as pd
|
|
17
|
-
|
|
14
|
+
import anndata as ad
|
|
15
|
+
from ..readwrite import adata_to_df
|
|
18
16
|
|
|
19
|
-
#
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
adata.layers['fill_nans_closest'] = df.values
|
|
17
|
+
# Ensure the specified layer exists
|
|
18
|
+
if layer and layer not in adata.layers:
|
|
19
|
+
raise ValueError(f"Layer '{layer}' not found in adata.layers.")
|
|
23
20
|
|
|
24
|
-
#
|
|
25
|
-
old_value, new_value = [0, -1]
|
|
21
|
+
# Convert to DataFrame
|
|
26
22
|
df = adata_to_df(adata, layer=layer)
|
|
27
|
-
df = df.replace(old_value, new_value)
|
|
28
|
-
old_value, new_value = [np.nan, 0]
|
|
29
|
-
df = df.replace(old_value, new_value)
|
|
30
|
-
adata.layers['nan0_0minus1'] = df.values
|
|
31
23
|
|
|
32
|
-
#
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
24
|
+
# Fill NaN with closest SMF value (forward then backward fill)
|
|
25
|
+
print('Making layer: fill_nans_closest')
|
|
26
|
+
adata.layers['fill_nans_closest'] = df.ffill(axis=1).bfill(axis=1).values
|
|
27
|
+
|
|
28
|
+
# Replace NaN with 0, and 0 with -1
|
|
29
|
+
print('Making layer: nan0_0minus1')
|
|
30
|
+
df_nan0_0minus1 = df.replace(0, -1).fillna(0)
|
|
31
|
+
adata.layers['nan0_0minus1'] = df_nan0_0minus1.values
|
|
32
|
+
|
|
33
|
+
# Replace NaN with 1, and 1 with 2
|
|
34
|
+
print('Making layer: nan1_12')
|
|
35
|
+
df_nan1_12 = df.replace(1, 2).fillna(1)
|
|
36
|
+
adata.layers['nan1_12'] = df_nan1_12.values
|
|
37
|
+
|
|
38
|
+
# Replace NaN with -1
|
|
39
|
+
print('Making layer: nan_minus_1')
|
|
40
|
+
df_nan_minus_1 = df.fillna(-1)
|
|
41
|
+
adata.layers['nan_minus_1'] = df_nan_minus_1.values
|
|
42
|
+
|
|
43
|
+
# Replace NaN with -1
|
|
44
|
+
print('Making layer: nan_half')
|
|
45
|
+
df_nan_half = df.fillna(0.5)
|
|
46
|
+
adata.layers['nan_half'] = df_nan_half.values
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
## filter_adata_by_nan_proportion
|
|
2
|
+
|
|
3
|
+
def filter_adata_by_nan_proportion(adata, threshold, axis='obs'):
|
|
4
|
+
"""
|
|
5
|
+
Filters an anndata object on a nan proportion threshold in a given matrix axis.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData):
|
|
9
|
+
threshold (float): The max np.nan content to allow in the given axis.
|
|
10
|
+
axis (str): Whether to filter the adata based on obs or var np.nan content
|
|
11
|
+
Returns:
|
|
12
|
+
filtered_adata
|
|
13
|
+
"""
|
|
14
|
+
import numpy as np
|
|
15
|
+
import anndata as ad
|
|
16
|
+
|
|
17
|
+
if axis == 'obs':
|
|
18
|
+
# Calculate the proportion of NaN values in each read
|
|
19
|
+
nan_proportion = np.isnan(adata.X).mean(axis=1)
|
|
20
|
+
# Filter reads to keep reads with less than a certain NaN proportion
|
|
21
|
+
filtered_indices = np.where(nan_proportion <= threshold)[0]
|
|
22
|
+
filtered_adata = adata[filtered_indices, :].copy()
|
|
23
|
+
elif axis == 'var':
|
|
24
|
+
# Calculate the proportion of NaN values at a given position
|
|
25
|
+
nan_proportion = np.isnan(adata.X).mean(axis=0)
|
|
26
|
+
# Filter positions to keep positions with less than a certain NaN proportion
|
|
27
|
+
filtered_indices = np.where(nan_proportion <= threshold)[0]
|
|
28
|
+
filtered_adata = adata[:, filtered_indices].copy()
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError("Axis must be either 'obs' or 'var'")
|
|
31
|
+
return filtered_adata
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
## filter_converted_reads_on_methylation
|
|
2
2
|
|
|
3
3
|
## Conversion SMF Specific
|
|
4
|
-
|
|
5
|
-
def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, min_SMF_threshold=0.025):
|
|
4
|
+
def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, min_SMF_threshold=0.025, max_SMF_threshold=0.975):
|
|
6
5
|
"""
|
|
7
6
|
Filter adata object using minimum thresholds for valid SMF site fraction in read, as well as minimum methylation content in read.
|
|
8
7
|
|
|
@@ -11,7 +10,7 @@ def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, m
|
|
|
11
10
|
valid_SMF_site_threshold (float): A minimum proportion of valid SMF sites that must be present in the read. Default is 0.8
|
|
12
11
|
min_SMF_threshold (float): A minimum read methylation level. Default is 0.025
|
|
13
12
|
Returns:
|
|
14
|
-
|
|
13
|
+
Anndata
|
|
15
14
|
"""
|
|
16
15
|
import numpy as np
|
|
17
16
|
import anndata as ad
|
|
@@ -20,10 +19,26 @@ def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, m
|
|
|
20
19
|
if valid_SMF_site_threshold:
|
|
21
20
|
# Keep reads that have over a given valid GpC site content
|
|
22
21
|
adata = adata[adata.obs['fraction_valid_GpC_site_in_range'] > valid_SMF_site_threshold].copy()
|
|
22
|
+
|
|
23
23
|
if min_SMF_threshold:
|
|
24
24
|
# Keep reads with SMF methylation over background methylation.
|
|
25
|
+
below_background = (~adata.obs['GpC_above_other_C']).sum()
|
|
26
|
+
print(f'Removing {below_background} reads that have GpC conversion below background conversion rate')
|
|
25
27
|
adata = adata[adata.obs['GpC_above_other_C'] == True].copy()
|
|
26
28
|
# Keep reads over a defined methylation threshold
|
|
29
|
+
s0 = adata.shape[0]
|
|
27
30
|
adata = adata[adata.obs['GpC_site_row_methylation_means'] > min_SMF_threshold].copy()
|
|
28
|
-
|
|
29
|
-
|
|
31
|
+
s1 = adata.shape[0]
|
|
32
|
+
below_threshold = s0 - s1
|
|
33
|
+
print(f'Removing {below_threshold} reads that have GpC conversion below a minimum threshold conversion rate')
|
|
34
|
+
|
|
35
|
+
if max_SMF_threshold:
|
|
36
|
+
# Keep reads below a defined methylation threshold
|
|
37
|
+
s0 = adata.shape[0]
|
|
38
|
+
adata = adata[adata.obs['GpC_site_row_methylation_means'] < max_SMF_threshold].copy()
|
|
39
|
+
s1 = adata.shape[0]
|
|
40
|
+
above_threshold = s0 - s1
|
|
41
|
+
print(f'Removing {above_threshold} reads that have GpC conversion above a maximum threshold conversion rate')
|
|
42
|
+
|
|
43
|
+
return adata
|
|
44
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
## filter_reads_on_length
|
|
2
2
|
|
|
3
|
-
def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2700):
|
|
3
|
+
def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2700, max_read_length=3200):
|
|
4
4
|
"""
|
|
5
5
|
Filters the adata object to keep a defined coordinate window, as well as reads that are over a minimum threshold in length.
|
|
6
6
|
|
|
@@ -8,15 +8,15 @@ def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2
|
|
|
8
8
|
adata (AnnData): An adata object.
|
|
9
9
|
filter_on_coordinates (bool | list): If False, skips filtering. Otherwise, provide a list containing integers representing the lower and upper bound coordinates to filter on. Default is False.
|
|
10
10
|
min_read_length (int): The minimum read length to keep in the filtered dataset. Default is 2700.
|
|
11
|
+
max_read_length (int): The maximum query read length to keep in the filtered dataset. Default is 3200.
|
|
11
12
|
|
|
12
13
|
Returns:
|
|
13
|
-
adata
|
|
14
|
-
Input: Adata object. a list of lower and upper bound (set to False or None if not wanted), and a minimum read length integer.
|
|
15
|
-
|
|
14
|
+
adata
|
|
16
15
|
"""
|
|
17
16
|
import numpy as np
|
|
18
17
|
import anndata as ad
|
|
19
18
|
import pandas as pd
|
|
19
|
+
|
|
20
20
|
if filter_on_coordinates:
|
|
21
21
|
lower_bound, upper_bound = filter_on_coordinates
|
|
22
22
|
# Extract the position information from the adata object as an np array
|
|
@@ -36,6 +36,16 @@ def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2
|
|
|
36
36
|
|
|
37
37
|
if min_read_length:
|
|
38
38
|
print(f'Subsetting adata to keep reads longer than {min_read_length}')
|
|
39
|
+
s0 = adata.shape[0]
|
|
39
40
|
adata = adata[adata.obs['read_length'] > min_read_length].copy()
|
|
41
|
+
s1 = adata.shape[0]
|
|
42
|
+
print(f'Removed {s0-s1} reads less than {min_read_length} basepairs in length')
|
|
43
|
+
|
|
44
|
+
if max_read_length:
|
|
45
|
+
print(f'Subsetting adata to keep reads shorter than {max_read_length}')
|
|
46
|
+
s0 = adata.shape[0]
|
|
47
|
+
adata = adata[adata.obs['read_length'] < max_read_length].copy()
|
|
48
|
+
s1 = adata.shape[0]
|
|
49
|
+
print(f'Removed {s0-s1} reads greater than {max_read_length} basepairs in length')
|
|
40
50
|
|
|
41
51
|
return adata
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
class UnionFind:
|
|
5
|
+
def __init__(self, size):
|
|
6
|
+
self.parent = torch.arange(size)
|
|
7
|
+
|
|
8
|
+
def find(self, x):
|
|
9
|
+
while self.parent[x] != x:
|
|
10
|
+
self.parent[x] = self.parent[self.parent[x]]
|
|
11
|
+
x = self.parent[x]
|
|
12
|
+
return x
|
|
13
|
+
|
|
14
|
+
def union(self, x, y):
|
|
15
|
+
root_x = self.find(x)
|
|
16
|
+
root_y = self.find(y)
|
|
17
|
+
if root_x != root_y:
|
|
18
|
+
self.parent[root_y] = root_x
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def flag_duplicate_reads(adata, var_filters_sets, distance_threshold=0.05, obs_reference_col='Reference_strand'):
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
|
|
26
|
+
all_hamming_dists = []
|
|
27
|
+
merged_results = []
|
|
28
|
+
|
|
29
|
+
references = adata.obs[obs_reference_col].cat.categories
|
|
30
|
+
|
|
31
|
+
for ref in references:
|
|
32
|
+
print(f'🔹 Processing reference: {ref}')
|
|
33
|
+
|
|
34
|
+
ref_mask = adata.obs[obs_reference_col] == ref
|
|
35
|
+
adata_subset = adata[ref_mask].copy()
|
|
36
|
+
N = adata_subset.shape[0]
|
|
37
|
+
|
|
38
|
+
combined_mask = torch.zeros(len(adata.var), dtype=torch.bool)
|
|
39
|
+
for var_set in var_filters_sets:
|
|
40
|
+
if any(ref in v for v in var_set):
|
|
41
|
+
set_mask = torch.ones(len(adata.var), dtype=torch.bool)
|
|
42
|
+
for key in var_set:
|
|
43
|
+
set_mask &= torch.from_numpy(adata.var[key].values)
|
|
44
|
+
combined_mask |= set_mask
|
|
45
|
+
|
|
46
|
+
selected_cols = adata.var.index[combined_mask.numpy()].to_list()
|
|
47
|
+
col_indices = [adata.var.index.get_loc(col) for col in selected_cols]
|
|
48
|
+
|
|
49
|
+
print(f"Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
|
|
50
|
+
|
|
51
|
+
X = adata_subset.X
|
|
52
|
+
if not isinstance(X, np.ndarray):
|
|
53
|
+
X = X.toarray()
|
|
54
|
+
X_subset = X[:, col_indices]
|
|
55
|
+
X_tensor = torch.from_numpy(X_subset.astype(np.float32))
|
|
56
|
+
|
|
57
|
+
fwd_hamming_to_next = torch.full((N,), float('nan'))
|
|
58
|
+
rev_hamming_to_prev = torch.full((N,), float('nan'))
|
|
59
|
+
|
|
60
|
+
def cluster_pass(X_tensor, reverse=False, window_size=50, record_distances=False):
|
|
61
|
+
N_local = X_tensor.shape[0]
|
|
62
|
+
X_sortable = X_tensor.nan_to_num(-1)
|
|
63
|
+
sort_keys = X_sortable.tolist()
|
|
64
|
+
sorted_idx = sorted(range(N_local), key=lambda i: sort_keys[i], reverse=reverse)
|
|
65
|
+
sorted_X = X_tensor[sorted_idx]
|
|
66
|
+
|
|
67
|
+
cluster_pairs = []
|
|
68
|
+
|
|
69
|
+
for i in tqdm(range(len(sorted_X)), desc=f"Pass {'rev' if reverse else 'fwd'} ({ref})"):
|
|
70
|
+
row_i = sorted_X[i]
|
|
71
|
+
j_range = range(i + 1, min(i + 1 + window_size, len(sorted_X)))
|
|
72
|
+
|
|
73
|
+
if len(j_range) > 0:
|
|
74
|
+
row_i_exp = row_i.unsqueeze(0)
|
|
75
|
+
block_rows = sorted_X[j_range]
|
|
76
|
+
valid_mask = (~torch.isnan(row_i_exp)) & (~torch.isnan(block_rows))
|
|
77
|
+
valid_counts = valid_mask.sum(dim=1)
|
|
78
|
+
diffs = (row_i_exp != block_rows) & valid_mask
|
|
79
|
+
hamming_dists = diffs.sum(dim=1) / valid_counts.clamp(min=1)
|
|
80
|
+
all_hamming_dists.extend(hamming_dists.cpu().numpy().tolist())
|
|
81
|
+
|
|
82
|
+
matches = (hamming_dists < distance_threshold) & (valid_counts > 0)
|
|
83
|
+
for offset_idx, m in zip(j_range, matches):
|
|
84
|
+
if m:
|
|
85
|
+
cluster_pairs.append((sorted_idx[i], sorted_idx[offset_idx]))
|
|
86
|
+
|
|
87
|
+
if record_distances and i + 1 < len(sorted_X):
|
|
88
|
+
next_idx = sorted_idx[i + 1]
|
|
89
|
+
valid_mask_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[i + 1]))
|
|
90
|
+
if valid_mask_pair.sum() > 0:
|
|
91
|
+
d = (row_i[valid_mask_pair] != sorted_X[i + 1][valid_mask_pair]).sum()
|
|
92
|
+
norm_d = d.item() / valid_mask_pair.sum().item()
|
|
93
|
+
if reverse:
|
|
94
|
+
rev_hamming_to_prev[next_idx] = norm_d
|
|
95
|
+
else:
|
|
96
|
+
fwd_hamming_to_next[sorted_idx[i]] = norm_d
|
|
97
|
+
|
|
98
|
+
return cluster_pairs
|
|
99
|
+
|
|
100
|
+
pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
|
|
101
|
+
involved_in_fwd = set([p[0] for p in pairs_fwd] + [p[1] for p in pairs_fwd])
|
|
102
|
+
mask_for_rev = torch.ones(N, dtype=torch.bool)
|
|
103
|
+
mask_for_rev[list(involved_in_fwd)] = False
|
|
104
|
+
pairs_rev = cluster_pass(X_tensor[mask_for_rev], reverse=True, record_distances=True)
|
|
105
|
+
|
|
106
|
+
all_pairs = pairs_fwd + [(list(mask_for_rev.nonzero(as_tuple=True)[0])[i], list(mask_for_rev.nonzero(as_tuple=True)[0])[j]) for i, j in pairs_rev]
|
|
107
|
+
|
|
108
|
+
uf = UnionFind(N)
|
|
109
|
+
for i, j in all_pairs:
|
|
110
|
+
uf.union(i, j)
|
|
111
|
+
|
|
112
|
+
merged_cluster = torch.zeros(N, dtype=torch.long)
|
|
113
|
+
for i in range(N):
|
|
114
|
+
merged_cluster[i] = uf.find(i)
|
|
115
|
+
|
|
116
|
+
cluster_sizes = torch.zeros_like(merged_cluster)
|
|
117
|
+
for cid in merged_cluster.unique():
|
|
118
|
+
members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
|
|
119
|
+
cluster_sizes[members] = len(members)
|
|
120
|
+
|
|
121
|
+
is_duplicate = torch.zeros(N, dtype=torch.bool)
|
|
122
|
+
for cid in merged_cluster.unique():
|
|
123
|
+
members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
|
|
124
|
+
if len(members) > 1:
|
|
125
|
+
is_duplicate[members[1:]] = True
|
|
126
|
+
|
|
127
|
+
adata_subset.obs['is_duplicate'] = is_duplicate.numpy()
|
|
128
|
+
adata_subset.obs['merged_cluster_id'] = merged_cluster.numpy()
|
|
129
|
+
adata_subset.obs['cluster_size'] = cluster_sizes.numpy()
|
|
130
|
+
adata_subset.obs['fwd_hamming_to_next'] = fwd_hamming_to_next.numpy()
|
|
131
|
+
adata_subset.obs['rev_hamming_to_prev'] = rev_hamming_to_prev.numpy()
|
|
132
|
+
|
|
133
|
+
merged_results.append(adata_subset.obs)
|
|
134
|
+
|
|
135
|
+
merged_obs = pd.concat(merged_results)
|
|
136
|
+
adata.obs = adata.obs.join(merged_obs[['is_duplicate', 'merged_cluster_id', 'cluster_size', 'fwd_hamming_to_next', 'rev_hamming_to_prev']])
|
|
137
|
+
|
|
138
|
+
adata_unique = adata[~adata.obs['is_duplicate']].copy()
|
|
139
|
+
|
|
140
|
+
plt.figure(figsize=(5, 4))
|
|
141
|
+
plt.hist(all_hamming_dists, bins=50, alpha=0.75)
|
|
142
|
+
plt.axvline(distance_threshold, color="red", linestyle="--", label=f"threshold = {distance_threshold}")
|
|
143
|
+
plt.xlabel("Hamming Distance")
|
|
144
|
+
plt.ylabel("Frequency")
|
|
145
|
+
plt.title("Histogram of Pairwise Hamming Distances")
|
|
146
|
+
plt.legend()
|
|
147
|
+
plt.show()
|
|
148
|
+
|
|
149
|
+
return adata_unique, adata
|
|
@@ -1,23 +1,30 @@
|
|
|
1
1
|
## invert_adata
|
|
2
2
|
|
|
3
3
|
# Optional inversion of the adata
|
|
4
|
+
|
|
4
5
|
def invert_adata(adata):
|
|
5
6
|
"""
|
|
6
|
-
Inverts the
|
|
7
|
+
Inverts the AnnData object along the column (variable) axis.
|
|
7
8
|
|
|
8
9
|
Parameters:
|
|
9
|
-
adata (AnnData): An
|
|
10
|
+
adata (AnnData): An AnnData object.
|
|
10
11
|
|
|
11
12
|
Returns:
|
|
12
|
-
|
|
13
|
+
AnnData: A new AnnData object with inverted column ordering.
|
|
13
14
|
"""
|
|
14
15
|
import numpy as np
|
|
15
16
|
import anndata as ad
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
adata
|
|
21
|
-
|
|
22
|
-
#
|
|
23
|
-
|
|
17
|
+
|
|
18
|
+
print("🔄 Inverting AnnData along the column axis...")
|
|
19
|
+
|
|
20
|
+
# Reverse the order of columns (variables)
|
|
21
|
+
inverted_adata = adata[:, ::-1].copy()
|
|
22
|
+
|
|
23
|
+
# Reassign var_names with new order
|
|
24
|
+
inverted_adata.var_names = adata.var_names
|
|
25
|
+
|
|
26
|
+
# Optional: Store original coordinates for reference
|
|
27
|
+
inverted_adata.var["Original_var_names"] = adata.var_names[::-1]
|
|
28
|
+
|
|
29
|
+
print("✅ Inversion complete!")
|
|
30
|
+
return inverted_adata
|
|
@@ -1,24 +1,38 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
def load_sample_sheet(adata, sample_sheet_path, mapping_key_column):
|
|
1
|
+
def load_sample_sheet(adata, sample_sheet_path, mapping_key_column='obs_names', as_category=True):
|
|
4
2
|
"""
|
|
5
|
-
Loads a sample sheet
|
|
3
|
+
Loads a sample sheet CSV and maps metadata into the AnnData object as categorical columns.
|
|
6
4
|
|
|
7
5
|
Parameters:
|
|
8
|
-
adata (AnnData): The
|
|
9
|
-
sample_sheet_path (str):
|
|
10
|
-
mapping_key_column (str):
|
|
6
|
+
adata (AnnData): The AnnData object to append sample information to.
|
|
7
|
+
sample_sheet_path (str): Path to the CSV file.
|
|
8
|
+
mapping_key_column (str): Column name in the CSV to map against adata.obs_names or an existing obs column.
|
|
9
|
+
as_category (bool): If True, added columns will be cast as pandas Categorical.
|
|
11
10
|
|
|
12
11
|
Returns:
|
|
13
|
-
|
|
12
|
+
AnnData: Updated AnnData object.
|
|
14
13
|
"""
|
|
15
14
|
import pandas as pd
|
|
16
|
-
|
|
15
|
+
|
|
16
|
+
print('🔹 Loading sample sheet...')
|
|
17
17
|
df = pd.read_csv(sample_sheet_path)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
18
|
+
df[mapping_key_column] = df[mapping_key_column].astype(str)
|
|
19
|
+
|
|
20
|
+
# If matching against obs_names directly
|
|
21
|
+
if mapping_key_column == 'obs_names':
|
|
22
|
+
key_series = adata.obs_names.astype(str)
|
|
23
|
+
else:
|
|
24
|
+
key_series = adata.obs[mapping_key_column].astype(str)
|
|
25
|
+
|
|
26
|
+
value_columns = [col for col in df.columns if col != mapping_key_column]
|
|
27
|
+
|
|
28
|
+
print(f'🔹 Appending metadata columns: {value_columns}')
|
|
29
|
+
df = df.set_index(mapping_key_column)
|
|
30
|
+
|
|
31
|
+
for col in value_columns:
|
|
32
|
+
mapped = key_series.map(df[col])
|
|
33
|
+
if as_category:
|
|
34
|
+
mapped = mapped.astype('category')
|
|
35
|
+
adata.obs[col] = mapped
|
|
36
|
+
|
|
37
|
+
print('✅ Sample sheet metadata successfully added as categories.' if as_category else '✅ Metadata added.')
|
|
38
|
+
return adata
|