smftools 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +29 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/informatics/__init__.py +16 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +132 -0
- smftools/informatics/direct_smf.py +137 -0
- smftools/informatics/fast5_to_pod5.py +21 -0
- smftools/informatics/helpers/LoadExperimentConfig.py +75 -0
- smftools/informatics/helpers/__init__.py +74 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +59 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +74 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +79 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +55 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +44 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +50 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +884 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +53 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +36 -0
- smftools/informatics/load_adata.py +182 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +33 -0
- smftools/preprocessing/append_C_context.py +82 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +94 -0
- smftools/preprocessing/calculate_coverage.py +42 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/clean_NaN.py +46 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +44 -0
- smftools/preprocessing/filter_reads_on_length.py +51 -0
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +30 -0
- smftools/preprocessing/load_sample_sheet.py +38 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +198 -0
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/evaluation/__init__.py +0 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +28 -0
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/METADATA +5 -2
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
def clean_NaN(adata, layer=None):
|
|
2
|
+
"""
|
|
3
|
+
Append layers to adata that contain NaN cleaning strategies.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
adata (AnnData): an anndata object
|
|
7
|
+
layer (str, optional): Name of the layer to fill NaN values in. If None, uses adata.X.
|
|
8
|
+
|
|
9
|
+
Modifies:
|
|
10
|
+
- Adds new layers to `adata.layers` with different NaN-filling strategies.
|
|
11
|
+
"""
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import anndata as ad
|
|
15
|
+
from ..readwrite import adata_to_df
|
|
16
|
+
|
|
17
|
+
# Ensure the specified layer exists
|
|
18
|
+
if layer and layer not in adata.layers:
|
|
19
|
+
raise ValueError(f"Layer '{layer}' not found in adata.layers.")
|
|
20
|
+
|
|
21
|
+
# Convert to DataFrame
|
|
22
|
+
df = adata_to_df(adata, layer=layer)
|
|
23
|
+
|
|
24
|
+
# Fill NaN with closest SMF value (forward then backward fill)
|
|
25
|
+
print('Making layer: fill_nans_closest')
|
|
26
|
+
adata.layers['fill_nans_closest'] = df.ffill(axis=1).bfill(axis=1).values
|
|
27
|
+
|
|
28
|
+
# Replace NaN with 0, and 0 with -1
|
|
29
|
+
print('Making layer: nan0_0minus1')
|
|
30
|
+
df_nan0_0minus1 = df.replace(0, -1).fillna(0)
|
|
31
|
+
adata.layers['nan0_0minus1'] = df_nan0_0minus1.values
|
|
32
|
+
|
|
33
|
+
# Replace NaN with 1, and 1 with 2
|
|
34
|
+
print('Making layer: nan1_12')
|
|
35
|
+
df_nan1_12 = df.replace(1, 2).fillna(1)
|
|
36
|
+
adata.layers['nan1_12'] = df_nan1_12.values
|
|
37
|
+
|
|
38
|
+
# Replace NaN with -1
|
|
39
|
+
print('Making layer: nan_minus_1')
|
|
40
|
+
df_nan_minus_1 = df.fillna(-1)
|
|
41
|
+
adata.layers['nan_minus_1'] = df_nan_minus_1.values
|
|
42
|
+
|
|
43
|
+
# Replace NaN with -1
|
|
44
|
+
print('Making layer: nan_half')
|
|
45
|
+
df_nan_half = df.fillna(0.5)
|
|
46
|
+
adata.layers['nan_half'] = df_nan_half.values
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
## filter_adata_by_nan_proportion
|
|
2
|
+
|
|
3
|
+
def filter_adata_by_nan_proportion(adata, threshold, axis='obs'):
|
|
4
|
+
"""
|
|
5
|
+
Filters an anndata object on a nan proportion threshold in a given matrix axis.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData):
|
|
9
|
+
threshold (float): The max np.nan content to allow in the given axis.
|
|
10
|
+
axis (str): Whether to filter the adata based on obs or var np.nan content
|
|
11
|
+
Returns:
|
|
12
|
+
filtered_adata
|
|
13
|
+
"""
|
|
14
|
+
import numpy as np
|
|
15
|
+
import anndata as ad
|
|
16
|
+
|
|
17
|
+
if axis == 'obs':
|
|
18
|
+
# Calculate the proportion of NaN values in each read
|
|
19
|
+
nan_proportion = np.isnan(adata.X).mean(axis=1)
|
|
20
|
+
# Filter reads to keep reads with less than a certain NaN proportion
|
|
21
|
+
filtered_indices = np.where(nan_proportion <= threshold)[0]
|
|
22
|
+
filtered_adata = adata[filtered_indices, :].copy()
|
|
23
|
+
elif axis == 'var':
|
|
24
|
+
# Calculate the proportion of NaN values at a given position
|
|
25
|
+
nan_proportion = np.isnan(adata.X).mean(axis=0)
|
|
26
|
+
# Filter positions to keep positions with less than a certain NaN proportion
|
|
27
|
+
filtered_indices = np.where(nan_proportion <= threshold)[0]
|
|
28
|
+
filtered_adata = adata[:, filtered_indices].copy()
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError("Axis must be either 'obs' or 'var'")
|
|
31
|
+
return filtered_adata
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
## filter_converted_reads_on_methylation
|
|
2
|
+
|
|
3
|
+
## Conversion SMF Specific
|
|
4
|
+
def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, min_SMF_threshold=0.025, max_SMF_threshold=0.975):
|
|
5
|
+
"""
|
|
6
|
+
Filter adata object using minimum thresholds for valid SMF site fraction in read, as well as minimum methylation content in read.
|
|
7
|
+
|
|
8
|
+
Parameters:
|
|
9
|
+
adata (AnnData): An adata object.
|
|
10
|
+
valid_SMF_site_threshold (float): A minimum proportion of valid SMF sites that must be present in the read. Default is 0.8
|
|
11
|
+
min_SMF_threshold (float): A minimum read methylation level. Default is 0.025
|
|
12
|
+
Returns:
|
|
13
|
+
Anndata
|
|
14
|
+
"""
|
|
15
|
+
import numpy as np
|
|
16
|
+
import anndata as ad
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
if valid_SMF_site_threshold:
|
|
20
|
+
# Keep reads that have over a given valid GpC site content
|
|
21
|
+
adata = adata[adata.obs['fraction_valid_GpC_site_in_range'] > valid_SMF_site_threshold].copy()
|
|
22
|
+
|
|
23
|
+
if min_SMF_threshold:
|
|
24
|
+
# Keep reads with SMF methylation over background methylation.
|
|
25
|
+
below_background = (~adata.obs['GpC_above_other_C']).sum()
|
|
26
|
+
print(f'Removing {below_background} reads that have GpC conversion below background conversion rate')
|
|
27
|
+
adata = adata[adata.obs['GpC_above_other_C'] == True].copy()
|
|
28
|
+
# Keep reads over a defined methylation threshold
|
|
29
|
+
s0 = adata.shape[0]
|
|
30
|
+
adata = adata[adata.obs['GpC_site_row_methylation_means'] > min_SMF_threshold].copy()
|
|
31
|
+
s1 = adata.shape[0]
|
|
32
|
+
below_threshold = s0 - s1
|
|
33
|
+
print(f'Removing {below_threshold} reads that have GpC conversion below a minimum threshold conversion rate')
|
|
34
|
+
|
|
35
|
+
if max_SMF_threshold:
|
|
36
|
+
# Keep reads below a defined methylation threshold
|
|
37
|
+
s0 = adata.shape[0]
|
|
38
|
+
adata = adata[adata.obs['GpC_site_row_methylation_means'] < max_SMF_threshold].copy()
|
|
39
|
+
s1 = adata.shape[0]
|
|
40
|
+
above_threshold = s0 - s1
|
|
41
|
+
print(f'Removing {above_threshold} reads that have GpC conversion above a maximum threshold conversion rate')
|
|
42
|
+
|
|
43
|
+
return adata
|
|
44
|
+
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
## filter_reads_on_length
|
|
2
|
+
|
|
3
|
+
def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2700, max_read_length=3200):
|
|
4
|
+
"""
|
|
5
|
+
Filters the adata object to keep a defined coordinate window, as well as reads that are over a minimum threshold in length.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData): An adata object.
|
|
9
|
+
filter_on_coordinates (bool | list): If False, skips filtering. Otherwise, provide a list containing integers representing the lower and upper bound coordinates to filter on. Default is False.
|
|
10
|
+
min_read_length (int): The minimum read length to keep in the filtered dataset. Default is 2700.
|
|
11
|
+
max_read_length (int): The maximum query read length to keep in the filtered dataset. Default is 3200.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
adata
|
|
15
|
+
"""
|
|
16
|
+
import numpy as np
|
|
17
|
+
import anndata as ad
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
if filter_on_coordinates:
|
|
21
|
+
lower_bound, upper_bound = filter_on_coordinates
|
|
22
|
+
# Extract the position information from the adata object as an np array
|
|
23
|
+
var_names_arr = adata.var_names.astype(int).to_numpy()
|
|
24
|
+
# Find the upper bound coordinate that is closest to the specified value
|
|
25
|
+
closest_end_index = np.argmin(np.abs(var_names_arr - upper_bound))
|
|
26
|
+
upper_bound = int(adata.var_names[closest_end_index])
|
|
27
|
+
# Find the lower bound coordinate that is closest to the specified value
|
|
28
|
+
closest_start_index = np.argmin(np.abs(var_names_arr - lower_bound))
|
|
29
|
+
lower_bound = int(adata.var_names[closest_start_index])
|
|
30
|
+
# Get a list of positional indexes that encompass the lower and upper bounds of the dataset
|
|
31
|
+
position_list = list(range(lower_bound, upper_bound + 1))
|
|
32
|
+
position_list = [str(pos) for pos in position_list]
|
|
33
|
+
position_set = set(position_list)
|
|
34
|
+
print(f'Subsetting adata to keep data between coordinates {lower_bound} and {upper_bound}')
|
|
35
|
+
adata = adata[:, adata.var_names.isin(position_set)].copy()
|
|
36
|
+
|
|
37
|
+
if min_read_length:
|
|
38
|
+
print(f'Subsetting adata to keep reads longer than {min_read_length}')
|
|
39
|
+
s0 = adata.shape[0]
|
|
40
|
+
adata = adata[adata.obs['read_length'] > min_read_length].copy()
|
|
41
|
+
s1 = adata.shape[0]
|
|
42
|
+
print(f'Removed {s0-s1} reads less than {min_read_length} basepairs in length')
|
|
43
|
+
|
|
44
|
+
if max_read_length:
|
|
45
|
+
print(f'Subsetting adata to keep reads shorter than {max_read_length}')
|
|
46
|
+
s0 = adata.shape[0]
|
|
47
|
+
adata = adata[adata.obs['read_length'] < max_read_length].copy()
|
|
48
|
+
s1 = adata.shape[0]
|
|
49
|
+
print(f'Removed {s0-s1} reads greater than {max_read_length} basepairs in length')
|
|
50
|
+
|
|
51
|
+
return adata
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
class UnionFind:
|
|
5
|
+
def __init__(self, size):
|
|
6
|
+
self.parent = torch.arange(size)
|
|
7
|
+
|
|
8
|
+
def find(self, x):
|
|
9
|
+
while self.parent[x] != x:
|
|
10
|
+
self.parent[x] = self.parent[self.parent[x]]
|
|
11
|
+
x = self.parent[x]
|
|
12
|
+
return x
|
|
13
|
+
|
|
14
|
+
def union(self, x, y):
|
|
15
|
+
root_x = self.find(x)
|
|
16
|
+
root_y = self.find(y)
|
|
17
|
+
if root_x != root_y:
|
|
18
|
+
self.parent[root_y] = root_x
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def flag_duplicate_reads(adata, var_filters_sets, distance_threshold=0.05, obs_reference_col='Reference_strand'):
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
|
|
26
|
+
all_hamming_dists = []
|
|
27
|
+
merged_results = []
|
|
28
|
+
|
|
29
|
+
references = adata.obs[obs_reference_col].cat.categories
|
|
30
|
+
|
|
31
|
+
for ref in references:
|
|
32
|
+
print(f'🔹 Processing reference: {ref}')
|
|
33
|
+
|
|
34
|
+
ref_mask = adata.obs[obs_reference_col] == ref
|
|
35
|
+
adata_subset = adata[ref_mask].copy()
|
|
36
|
+
N = adata_subset.shape[0]
|
|
37
|
+
|
|
38
|
+
combined_mask = torch.zeros(len(adata.var), dtype=torch.bool)
|
|
39
|
+
for var_set in var_filters_sets:
|
|
40
|
+
if any(ref in v for v in var_set):
|
|
41
|
+
set_mask = torch.ones(len(adata.var), dtype=torch.bool)
|
|
42
|
+
for key in var_set:
|
|
43
|
+
set_mask &= torch.from_numpy(adata.var[key].values)
|
|
44
|
+
combined_mask |= set_mask
|
|
45
|
+
|
|
46
|
+
selected_cols = adata.var.index[combined_mask.numpy()].to_list()
|
|
47
|
+
col_indices = [adata.var.index.get_loc(col) for col in selected_cols]
|
|
48
|
+
|
|
49
|
+
print(f"Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
|
|
50
|
+
|
|
51
|
+
X = adata_subset.X
|
|
52
|
+
if not isinstance(X, np.ndarray):
|
|
53
|
+
X = X.toarray()
|
|
54
|
+
X_subset = X[:, col_indices]
|
|
55
|
+
X_tensor = torch.from_numpy(X_subset.astype(np.float32))
|
|
56
|
+
|
|
57
|
+
fwd_hamming_to_next = torch.full((N,), float('nan'))
|
|
58
|
+
rev_hamming_to_prev = torch.full((N,), float('nan'))
|
|
59
|
+
|
|
60
|
+
def cluster_pass(X_tensor, reverse=False, window_size=50, record_distances=False):
|
|
61
|
+
N_local = X_tensor.shape[0]
|
|
62
|
+
X_sortable = X_tensor.nan_to_num(-1)
|
|
63
|
+
sort_keys = X_sortable.tolist()
|
|
64
|
+
sorted_idx = sorted(range(N_local), key=lambda i: sort_keys[i], reverse=reverse)
|
|
65
|
+
sorted_X = X_tensor[sorted_idx]
|
|
66
|
+
|
|
67
|
+
cluster_pairs = []
|
|
68
|
+
|
|
69
|
+
for i in tqdm(range(len(sorted_X)), desc=f"Pass {'rev' if reverse else 'fwd'} ({ref})"):
|
|
70
|
+
row_i = sorted_X[i]
|
|
71
|
+
j_range = range(i + 1, min(i + 1 + window_size, len(sorted_X)))
|
|
72
|
+
|
|
73
|
+
if len(j_range) > 0:
|
|
74
|
+
row_i_exp = row_i.unsqueeze(0)
|
|
75
|
+
block_rows = sorted_X[j_range]
|
|
76
|
+
valid_mask = (~torch.isnan(row_i_exp)) & (~torch.isnan(block_rows))
|
|
77
|
+
valid_counts = valid_mask.sum(dim=1)
|
|
78
|
+
diffs = (row_i_exp != block_rows) & valid_mask
|
|
79
|
+
hamming_dists = diffs.sum(dim=1) / valid_counts.clamp(min=1)
|
|
80
|
+
all_hamming_dists.extend(hamming_dists.cpu().numpy().tolist())
|
|
81
|
+
|
|
82
|
+
matches = (hamming_dists < distance_threshold) & (valid_counts > 0)
|
|
83
|
+
for offset_idx, m in zip(j_range, matches):
|
|
84
|
+
if m:
|
|
85
|
+
cluster_pairs.append((sorted_idx[i], sorted_idx[offset_idx]))
|
|
86
|
+
|
|
87
|
+
if record_distances and i + 1 < len(sorted_X):
|
|
88
|
+
next_idx = sorted_idx[i + 1]
|
|
89
|
+
valid_mask_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[i + 1]))
|
|
90
|
+
if valid_mask_pair.sum() > 0:
|
|
91
|
+
d = (row_i[valid_mask_pair] != sorted_X[i + 1][valid_mask_pair]).sum()
|
|
92
|
+
norm_d = d.item() / valid_mask_pair.sum().item()
|
|
93
|
+
if reverse:
|
|
94
|
+
rev_hamming_to_prev[next_idx] = norm_d
|
|
95
|
+
else:
|
|
96
|
+
fwd_hamming_to_next[sorted_idx[i]] = norm_d
|
|
97
|
+
|
|
98
|
+
return cluster_pairs
|
|
99
|
+
|
|
100
|
+
pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
|
|
101
|
+
involved_in_fwd = set([p[0] for p in pairs_fwd] + [p[1] for p in pairs_fwd])
|
|
102
|
+
mask_for_rev = torch.ones(N, dtype=torch.bool)
|
|
103
|
+
mask_for_rev[list(involved_in_fwd)] = False
|
|
104
|
+
pairs_rev = cluster_pass(X_tensor[mask_for_rev], reverse=True, record_distances=True)
|
|
105
|
+
|
|
106
|
+
all_pairs = pairs_fwd + [(list(mask_for_rev.nonzero(as_tuple=True)[0])[i], list(mask_for_rev.nonzero(as_tuple=True)[0])[j]) for i, j in pairs_rev]
|
|
107
|
+
|
|
108
|
+
uf = UnionFind(N)
|
|
109
|
+
for i, j in all_pairs:
|
|
110
|
+
uf.union(i, j)
|
|
111
|
+
|
|
112
|
+
merged_cluster = torch.zeros(N, dtype=torch.long)
|
|
113
|
+
for i in range(N):
|
|
114
|
+
merged_cluster[i] = uf.find(i)
|
|
115
|
+
|
|
116
|
+
cluster_sizes = torch.zeros_like(merged_cluster)
|
|
117
|
+
for cid in merged_cluster.unique():
|
|
118
|
+
members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
|
|
119
|
+
cluster_sizes[members] = len(members)
|
|
120
|
+
|
|
121
|
+
is_duplicate = torch.zeros(N, dtype=torch.bool)
|
|
122
|
+
for cid in merged_cluster.unique():
|
|
123
|
+
members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
|
|
124
|
+
if len(members) > 1:
|
|
125
|
+
is_duplicate[members[1:]] = True
|
|
126
|
+
|
|
127
|
+
adata_subset.obs['is_duplicate'] = is_duplicate.numpy()
|
|
128
|
+
adata_subset.obs['merged_cluster_id'] = merged_cluster.numpy()
|
|
129
|
+
adata_subset.obs['cluster_size'] = cluster_sizes.numpy()
|
|
130
|
+
adata_subset.obs['fwd_hamming_to_next'] = fwd_hamming_to_next.numpy()
|
|
131
|
+
adata_subset.obs['rev_hamming_to_prev'] = rev_hamming_to_prev.numpy()
|
|
132
|
+
|
|
133
|
+
merged_results.append(adata_subset.obs)
|
|
134
|
+
|
|
135
|
+
merged_obs = pd.concat(merged_results)
|
|
136
|
+
adata.obs = adata.obs.join(merged_obs[['is_duplicate', 'merged_cluster_id', 'cluster_size', 'fwd_hamming_to_next', 'rev_hamming_to_prev']])
|
|
137
|
+
|
|
138
|
+
adata_unique = adata[~adata.obs['is_duplicate']].copy()
|
|
139
|
+
|
|
140
|
+
plt.figure(figsize=(5, 4))
|
|
141
|
+
plt.hist(all_hamming_dists, bins=50, alpha=0.75)
|
|
142
|
+
plt.axvline(distance_threshold, color="red", linestyle="--", label=f"threshold = {distance_threshold}")
|
|
143
|
+
plt.xlabel("Hamming Distance")
|
|
144
|
+
plt.ylabel("Frequency")
|
|
145
|
+
plt.title("Histogram of Pairwise Hamming Distances")
|
|
146
|
+
plt.legend()
|
|
147
|
+
plt.show()
|
|
148
|
+
|
|
149
|
+
return adata_unique, adata
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
## invert_adata
|
|
2
|
+
|
|
3
|
+
# Optional inversion of the adata
|
|
4
|
+
|
|
5
|
+
def invert_adata(adata):
|
|
6
|
+
"""
|
|
7
|
+
Inverts the AnnData object along the column (variable) axis.
|
|
8
|
+
|
|
9
|
+
Parameters:
|
|
10
|
+
adata (AnnData): An AnnData object.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
AnnData: A new AnnData object with inverted column ordering.
|
|
14
|
+
"""
|
|
15
|
+
import numpy as np
|
|
16
|
+
import anndata as ad
|
|
17
|
+
|
|
18
|
+
print("🔄 Inverting AnnData along the column axis...")
|
|
19
|
+
|
|
20
|
+
# Reverse the order of columns (variables)
|
|
21
|
+
inverted_adata = adata[:, ::-1].copy()
|
|
22
|
+
|
|
23
|
+
# Reassign var_names with new order
|
|
24
|
+
inverted_adata.var_names = adata.var_names
|
|
25
|
+
|
|
26
|
+
# Optional: Store original coordinates for reference
|
|
27
|
+
inverted_adata.var["Original_var_names"] = adata.var_names[::-1]
|
|
28
|
+
|
|
29
|
+
print("✅ Inversion complete!")
|
|
30
|
+
return inverted_adata
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
def load_sample_sheet(adata, sample_sheet_path, mapping_key_column='obs_names', as_category=True):
|
|
2
|
+
"""
|
|
3
|
+
Loads a sample sheet CSV and maps metadata into the AnnData object as categorical columns.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
adata (AnnData): The AnnData object to append sample information to.
|
|
7
|
+
sample_sheet_path (str): Path to the CSV file.
|
|
8
|
+
mapping_key_column (str): Column name in the CSV to map against adata.obs_names or an existing obs column.
|
|
9
|
+
as_category (bool): If True, added columns will be cast as pandas Categorical.
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
AnnData: Updated AnnData object.
|
|
13
|
+
"""
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
print('🔹 Loading sample sheet...')
|
|
17
|
+
df = pd.read_csv(sample_sheet_path)
|
|
18
|
+
df[mapping_key_column] = df[mapping_key_column].astype(str)
|
|
19
|
+
|
|
20
|
+
# If matching against obs_names directly
|
|
21
|
+
if mapping_key_column == 'obs_names':
|
|
22
|
+
key_series = adata.obs_names.astype(str)
|
|
23
|
+
else:
|
|
24
|
+
key_series = adata.obs[mapping_key_column].astype(str)
|
|
25
|
+
|
|
26
|
+
value_columns = [col for col in df.columns if col != mapping_key_column]
|
|
27
|
+
|
|
28
|
+
print(f'🔹 Appending metadata columns: {value_columns}')
|
|
29
|
+
df = df.set_index(mapping_key_column)
|
|
30
|
+
|
|
31
|
+
for col in value_columns:
|
|
32
|
+
mapped = key_series.map(df[col])
|
|
33
|
+
if as_category:
|
|
34
|
+
mapped = mapped.astype('category')
|
|
35
|
+
adata.obs[col] = mapped
|
|
36
|
+
|
|
37
|
+
print('✅ Sample sheet metadata successfully added as categories.' if as_category else '✅ Metadata added.')
|
|
38
|
+
return adata
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
## make_dirs
|
|
2
|
+
|
|
3
|
+
# General
|
|
4
|
+
def make_dirs(directories):
|
|
5
|
+
"""
|
|
6
|
+
Takes a list of file paths and makes new directories if the directory does not already exist.
|
|
7
|
+
|
|
8
|
+
Parameters:
|
|
9
|
+
directories (list): A list of directories to make
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
None
|
|
13
|
+
"""
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
for directory in directories:
|
|
17
|
+
if not os.path.isdir(directory):
|
|
18
|
+
os.mkdir(directory)
|
|
19
|
+
print(f"Directory '{directory}' created successfully.")
|
|
20
|
+
else:
|
|
21
|
+
print(f"Directory '{directory}' already exists.")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
## min_non_diagonal
|
|
2
|
+
|
|
3
|
+
def min_non_diagonal(matrix):
|
|
4
|
+
"""
|
|
5
|
+
Takes a matrix and returns the smallest value from each row with the diagonal masked.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
matrix (ndarray): A 2D ndarray.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
min_values (list): A list of minimum values from each row of the matrix
|
|
12
|
+
"""
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
n = matrix.shape[0]
|
|
16
|
+
min_values = []
|
|
17
|
+
for i in range(n):
|
|
18
|
+
# Mask to exclude the diagonal element
|
|
19
|
+
row_mask = np.ones(n, dtype=bool)
|
|
20
|
+
row_mask[i] = False
|
|
21
|
+
# Extract the row excluding the diagonal element
|
|
22
|
+
row = matrix[i, row_mask]
|
|
23
|
+
# Find the minimum value in the row
|
|
24
|
+
min_values.append(np.min(row))
|
|
25
|
+
return min_values
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# recipes
|
|
2
|
+
|
|
3
|
+
def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=True):
|
|
4
|
+
"""
|
|
5
|
+
The first part of the preprocessing workflow applied to the smf.inform.pod_to_adata() output derived from Kissiov_and_McKenna_2025.
|
|
6
|
+
|
|
7
|
+
Performs the following tasks:
|
|
8
|
+
1) Loads a sample CSV to append metadata mappings to the adata object.
|
|
9
|
+
2) Appends a boolean indicating whether each position in var_names is within a given reference.
|
|
10
|
+
3) Appends the cytosine context to each position from each reference.
|
|
11
|
+
4) Calculate read level methylation statistics.
|
|
12
|
+
5) Calculates read length statistics (start position, end position, read length).
|
|
13
|
+
6) Optionally inverts the adata to flip the position coordinate orientation.
|
|
14
|
+
7) Adds new layers containing NaN replaced variants of adata.X (fill_closest, nan0_0minus1, nan1_12).
|
|
15
|
+
8) Returns a dictionary to pass the variable namespace to the parent scope.
|
|
16
|
+
|
|
17
|
+
Parameters:
|
|
18
|
+
adata (AnnData): The AnnData object to use as input.
|
|
19
|
+
sample_sheet_path (str): String representing the path to the sample sheet csv containing the sample metadata.
|
|
20
|
+
output_directory (str): String representing the path to the output directory for plots.
|
|
21
|
+
mapping_key_column (str): The column name to use as the mapping keys for applying the sample sheet metadata.
|
|
22
|
+
reference_column (str): The name of the reference column to use.
|
|
23
|
+
sample_names_col (str): The name of the sample name column to use.
|
|
24
|
+
invert (bool): Whether to invert the positional coordinates of the adata object.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
variables (dict): A dictionary of variables to append to the parent scope.
|
|
28
|
+
"""
|
|
29
|
+
import anndata as ad
|
|
30
|
+
import pandas as pd
|
|
31
|
+
import numpy as np
|
|
32
|
+
from .load_sample_sheet import load_sample_sheet
|
|
33
|
+
from .calculate_coverage import calculate_coverage
|
|
34
|
+
from .append_C_context import append_C_context
|
|
35
|
+
from .calculate_converted_read_methylation_stats import calculate_converted_read_methylation_stats
|
|
36
|
+
from .invert_adata import invert_adata
|
|
37
|
+
from .calculate_read_length_stats import calculate_read_length_stats
|
|
38
|
+
from .clean_NaN import clean_NaN
|
|
39
|
+
|
|
40
|
+
# Clean up some of the Reference metadata and save variable names that point to sets of values in the column.
|
|
41
|
+
adata.obs[reference_column] = adata.obs[reference_column].astype('category')
|
|
42
|
+
references = adata.obs[reference_column].cat.categories
|
|
43
|
+
split_references = [(reference, reference.split('_')[0][1:]) for reference in references]
|
|
44
|
+
reference_mapping = {k: v for k, v in split_references}
|
|
45
|
+
adata.obs[f'{reference_column}_short'] = adata.obs[reference_column].map(reference_mapping)
|
|
46
|
+
short_references = set(adata.obs[f'{reference_column}_short'])
|
|
47
|
+
binary_layers = list(adata.layers.keys())
|
|
48
|
+
|
|
49
|
+
# load sample sheet metadata
|
|
50
|
+
load_sample_sheet(adata, sample_sheet_path, mapping_key_column)
|
|
51
|
+
|
|
52
|
+
# hold sample names set
|
|
53
|
+
adata.obs[sample_names_col] = adata.obs[sample_names_col].astype('category')
|
|
54
|
+
sample_names = adata.obs[sample_names_col].cat.categories
|
|
55
|
+
|
|
56
|
+
# Add position level metadata
|
|
57
|
+
calculate_coverage(adata, obs_column=reference_column)
|
|
58
|
+
adata.var['SNP_position'] = (adata.var[f'N_{reference_column}_with_position'] > 0) & (adata.var[f'N_{reference_column}_with_position'] < len(references)).astype(bool)
|
|
59
|
+
|
|
60
|
+
# Append cytosine context to the reference positions based on the conversion strand.
|
|
61
|
+
append_C_context(adata, obs_column=reference_column, use_consensus=False)
|
|
62
|
+
|
|
63
|
+
# Calculate read level methylation statistics. Assess if GpC methylation level is above other_C methylation level as a QC.
|
|
64
|
+
calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col)
|
|
65
|
+
|
|
66
|
+
# Calculate read length statistics
|
|
67
|
+
upper_bound, lower_bound = calculate_read_length_stats(adata, reference_column, sample_names_col)
|
|
68
|
+
|
|
69
|
+
# Invert the adata object (ie flip the strand orientation for visualization)
|
|
70
|
+
if invert:
|
|
71
|
+
adata = invert_adata(adata)
|
|
72
|
+
else:
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
# NaN replacement strategies stored in additional layers. Having layer=None uses adata.X
|
|
76
|
+
clean_NaN(adata, layer=None)
|
|
77
|
+
|
|
78
|
+
variables = {
|
|
79
|
+
"short_references": short_references,
|
|
80
|
+
"binary_layers": binary_layers,
|
|
81
|
+
"sample_names": sample_names,
|
|
82
|
+
"upper_bound": upper_bound,
|
|
83
|
+
"lower_bound": lower_bound,
|
|
84
|
+
"references": references
|
|
85
|
+
}
|
|
86
|
+
return variables
|
|
87
|
+
|
|
88
|
+
def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, distance_thresholds={}, reference_column = 'Reference', sample_names_col='Sample_names'):
|
|
89
|
+
"""
|
|
90
|
+
The second part of the preprocessing workflow applied to the adata that has already been preprocessed by recipe_1_Kissiov_and_McKenna_2025.
|
|
91
|
+
|
|
92
|
+
Performs the following tasks:
|
|
93
|
+
1) Marks putative PCR duplicates using pairwise hamming distance metrics.
|
|
94
|
+
2) Performs a complexity analysis of the library based on the PCR duplicate detection rate.
|
|
95
|
+
3) Removes PCR duplicates from the adata.
|
|
96
|
+
4) Returns two adata object: one for the filtered adata and one for the duplicate adata.
|
|
97
|
+
|
|
98
|
+
Parameters:
|
|
99
|
+
adata (AnnData): The AnnData object to use as input.
|
|
100
|
+
output_directory (str): String representing the path to the output directory for plots.
|
|
101
|
+
binary_layers (list): A list of layers to used for the binary encoding of read sequences. Used for duplicate detection.
|
|
102
|
+
distance_thresholds (dict): A dictionary keyed by obs_column categories that points to a float corresponding to the distance threshold to apply. Default is an empty dict.
|
|
103
|
+
reference_column (str): The name of the reference column to use.
|
|
104
|
+
sample_names_col (str): The name of the sample name column to use.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
filtered_adata (AnnData): An AnnData object containing the filtered reads
|
|
108
|
+
duplicates (AnnData): An AnnData object containing the duplicate reads
|
|
109
|
+
"""
|
|
110
|
+
import anndata as ad
|
|
111
|
+
import pandas as pd
|
|
112
|
+
import numpy as np
|
|
113
|
+
from .mark_duplicates import mark_duplicates
|
|
114
|
+
from .calculate_complexity import calculate_complexity
|
|
115
|
+
from .remove_duplicates import remove_duplicates
|
|
116
|
+
|
|
117
|
+
# Add here a way to remove reads below a given read quality (based on nan content). Need to also add a way to pull from BAM files the read quality from each read
|
|
118
|
+
|
|
119
|
+
# Duplicate detection using pairwise hamming distance across reads
|
|
120
|
+
mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col, distance_thresholds=distance_thresholds, method='N_masked_distances')
|
|
121
|
+
|
|
122
|
+
# Complexity analysis using the marked duplicates and the lander-watermann algorithm
|
|
123
|
+
calculate_complexity(adata, output_directory, obs_column=reference_column, sample_col=sample_names_col, plot=True, save_plot=False)
|
|
124
|
+
|
|
125
|
+
# Remove duplicate reads and store the duplicate reads in a new AnnData object named duplicates.
|
|
126
|
+
filtered_adata, duplicates = remove_duplicates(adata)
|
|
127
|
+
return filtered_adata, duplicates
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
|
|
2
|
+
"""
|
|
3
|
+
Subsamples an AnnData object so that each unique combination of categories
|
|
4
|
+
in the given `obs_columns` has at most `max_samples` observations.
|
|
5
|
+
If `obs_columns` is None or empty, the function randomly subsamples the entire dataset.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData): The AnnData object to subsample.
|
|
9
|
+
obs_columns (list of str, optional): List of observation column names to group by.
|
|
10
|
+
max_samples (int): The maximum number of observations per category combination.
|
|
11
|
+
random_seed (int): Random seed for reproducibility.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
AnnData: A new AnnData object with subsampled observations.
|
|
15
|
+
"""
|
|
16
|
+
import anndata as ad
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
np.random.seed(random_seed) # Ensure reproducibility
|
|
20
|
+
|
|
21
|
+
if not obs_columns: # If no obs columns are given, sample globally
|
|
22
|
+
if adata.shape[0] > max_samples:
|
|
23
|
+
sampled_indices = np.random.choice(adata.obs.index, max_samples, replace=False)
|
|
24
|
+
else:
|
|
25
|
+
sampled_indices = adata.obs.index # Keep all if fewer than max_samples
|
|
26
|
+
|
|
27
|
+
return adata[sampled_indices].copy()
|
|
28
|
+
|
|
29
|
+
sampled_indices = []
|
|
30
|
+
|
|
31
|
+
# Get unique category combinations from all specified obs columns
|
|
32
|
+
unique_combinations = adata.obs[obs_columns].drop_duplicates()
|
|
33
|
+
|
|
34
|
+
for _, row in unique_combinations.iterrows():
|
|
35
|
+
# Build filter condition dynamically for multiple columns
|
|
36
|
+
condition = (adata.obs[obs_columns] == row.values).all(axis=1)
|
|
37
|
+
|
|
38
|
+
# Get indices for the current category combination
|
|
39
|
+
subset_indices = adata.obs[condition].index.to_numpy()
|
|
40
|
+
|
|
41
|
+
# Subsample or take all
|
|
42
|
+
if len(subset_indices) > max_samples:
|
|
43
|
+
sampled = np.random.choice(subset_indices, max_samples, replace=False)
|
|
44
|
+
else:
|
|
45
|
+
sampled = subset_indices # Keep all if fewer than max_samples
|
|
46
|
+
|
|
47
|
+
sampled_indices.extend(sampled)
|
|
48
|
+
|
|
49
|
+
# ⚠ Handle backed mode detection
|
|
50
|
+
if adata.isbacked:
|
|
51
|
+
print("⚠ Detected backed mode. Subset will be loaded fully into memory.")
|
|
52
|
+
subset = adata[sampled_indices]
|
|
53
|
+
subset = subset.to_memory()
|
|
54
|
+
else:
|
|
55
|
+
subset = adata[sampled_indices]
|
|
56
|
+
|
|
57
|
+
# Create a new AnnData object with only the selected indices
|
|
58
|
+
return subset[sampled_indices].copy()
|