smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
from .positional import PositionalEncoding
|
|
5
|
+
from ..utils.grl import grad_reverse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseTransformer(BaseTorchModel):
|
|
9
|
+
def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, seq_len=None, use_learnable_pos=False, **kwargs):
|
|
10
|
+
super().__init__(**kwargs)
|
|
11
|
+
# Input FC layer to map D_input to D_model
|
|
12
|
+
self.input_fc = nn.Linear(input_dim, model_dim)
|
|
13
|
+
|
|
14
|
+
if use_learnable_pos:
|
|
15
|
+
assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
|
|
16
|
+
self.pos_embed = nn.Parameter(torch.randn(seq_len, model_dim)) # (S, D)
|
|
17
|
+
self.pos_encoder = None
|
|
18
|
+
else:
|
|
19
|
+
self.pos_encoder = PositionalEncoding(model_dim)
|
|
20
|
+
self.pos_embed = None
|
|
21
|
+
|
|
22
|
+
# Specify the transformer encoder structure
|
|
23
|
+
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, batch_first=False)
|
|
24
|
+
# Stack the transformer encoder layers
|
|
25
|
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
26
|
+
|
|
27
|
+
def encode(self, x, mask=None):
|
|
28
|
+
"""
|
|
29
|
+
x: (B, S, D_input)
|
|
30
|
+
mask: (B, S) optional
|
|
31
|
+
"""
|
|
32
|
+
x = self.input_fc(x) # (B, S, D_model)
|
|
33
|
+
if self.pos_embed is not None:
|
|
34
|
+
x = x + self.pos_embed.unsqueeze(0).to(x.device) # (B, S, D_model)
|
|
35
|
+
elif self.pos_encoder is not None:
|
|
36
|
+
x = self.pos_encoder(x) # (B, S, D_model)
|
|
37
|
+
if mask is not None:
|
|
38
|
+
x = x * mask.unsqueeze(-1) # (B, S, D_model)
|
|
39
|
+
x = x.permute(1, 0, 2) # (S, B, D_model)
|
|
40
|
+
encoded = self.transformer(x) # (S, B, D_model)
|
|
41
|
+
return encoded.permute(1, 0, 2) # (B, S, D_model)
|
|
42
|
+
|
|
43
|
+
class TransformerClassifier(BaseTransformer):
|
|
44
|
+
def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2, **kwargs):
|
|
45
|
+
super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
|
|
46
|
+
# Classification head
|
|
47
|
+
self.cls_head = nn.Linear(model_dim, num_classes)
|
|
48
|
+
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
"""
|
|
51
|
+
x: (batch, seq_len, input_dim)
|
|
52
|
+
"""
|
|
53
|
+
encoded = self.encode(x) # -> (B, S, D_model)
|
|
54
|
+
pooled = encoded.mean(dim=1) # -> (B, D_model)
|
|
55
|
+
return self.cls_head(pooled) # -> (B, C)
|
|
56
|
+
|
|
57
|
+
class DANNTransformerClassifier(TransformerClassifier):
|
|
58
|
+
def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
|
|
59
|
+
super().__init__(input_dim, model_dim, num_classes, **kwargs)
|
|
60
|
+
self.domain_classifier = nn.Sequential(
|
|
61
|
+
nn.Linear(model_dim, 128),
|
|
62
|
+
nn.ReLU(),
|
|
63
|
+
nn.Linear(128, n_domains)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def forward(self, x, alpha=1.0):
|
|
67
|
+
encoded = self.encode(x) # (B, S, D_model)
|
|
68
|
+
pooled = encoded.mean(dim=1) # (B, D_model)
|
|
69
|
+
|
|
70
|
+
class_logits = self.cls_head(pooled)
|
|
71
|
+
domain_logits = self.domain_classifier(grad_reverse(pooled, alpha))
|
|
72
|
+
|
|
73
|
+
return class_logits, domain_logits
|
|
74
|
+
|
|
75
|
+
class MaskedTransformerPretrainer(BaseTransformer):
|
|
76
|
+
def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
|
|
77
|
+
super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
|
|
78
|
+
self.decoder = nn.Linear(model_dim, input_dim)
|
|
79
|
+
|
|
80
|
+
def forward(self, x, mask):
|
|
81
|
+
"""
|
|
82
|
+
x: (batch, seq_len, input_dim)
|
|
83
|
+
mask: (batch, seq_len) optional
|
|
84
|
+
"""
|
|
85
|
+
if x.dim() == 2:
|
|
86
|
+
x = x.unsqueeze(-1)
|
|
87
|
+
encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
|
|
88
|
+
return self.decoder(encoded) # -> (B, D_input)
|
|
89
|
+
|
|
90
|
+
class DANNTransformer(BaseTransformer):
|
|
91
|
+
"""
|
|
92
|
+
"""
|
|
93
|
+
def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
|
|
94
|
+
super().__init__(
|
|
95
|
+
input_dim=1, # 1D scalar input per token
|
|
96
|
+
model_dim=model_dim,
|
|
97
|
+
num_heads=n_heads,
|
|
98
|
+
num_layers=n_layers,
|
|
99
|
+
seq_len=seq_len,
|
|
100
|
+
use_learnable_pos=True # enables learnable pos_embed in base
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Reconstruction head
|
|
104
|
+
self.recon_head = nn.Linear(model_dim, 1)
|
|
105
|
+
|
|
106
|
+
# Domain classification head
|
|
107
|
+
self.domain_classifier = nn.Sequential(
|
|
108
|
+
nn.Linear(model_dim, 128),
|
|
109
|
+
nn.ReLU(),
|
|
110
|
+
nn.Linear(128, n_domains)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def forward(self, x, alpha=1.0):
|
|
114
|
+
"""
|
|
115
|
+
x: Tensor of shape (B, S) or (B, S, 1)
|
|
116
|
+
alpha: GRL coefficient (float)
|
|
117
|
+
"""
|
|
118
|
+
if x.dim() == 2:
|
|
119
|
+
x = x.unsqueeze(-1) # (B, S, 1)
|
|
120
|
+
|
|
121
|
+
# Encode sequence
|
|
122
|
+
h = self.encode(x) # (B, S, D_model)
|
|
123
|
+
|
|
124
|
+
# Head 1: Reconstruction
|
|
125
|
+
recon = self.recon_head(h).squeeze(-1) # (B, S)
|
|
126
|
+
|
|
127
|
+
# Head 2: Domain classification via GRL
|
|
128
|
+
pooled = h.mean(dim=1) # (B, D_model)
|
|
129
|
+
rev = grad_reverse(pooled, alpha)
|
|
130
|
+
domain_logits = self.domain_classifier(rev) # (B, n_batches)
|
|
131
|
+
|
|
132
|
+
return recon, domain_logits
|
|
133
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
class ScaledModel(nn.Module):
|
|
5
|
+
def __init__(self, model, mean, std):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.model = model
|
|
8
|
+
self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32))
|
|
9
|
+
self.register_buffer("std", torch.tensor(std, dtype=torch.float32))
|
|
10
|
+
|
|
11
|
+
def forward(self, x):
|
|
12
|
+
mean = self.mean.to(x.device)
|
|
13
|
+
std = torch.clamp(self.std.to(x.device), min=1e-8)
|
|
14
|
+
if x.dim() == 2:
|
|
15
|
+
x = (x - mean) / std
|
|
16
|
+
elif x.dim() == 3:
|
|
17
|
+
x = (x - mean[None, None, :]) / std[None, None, :]
|
|
18
|
+
else:
|
|
19
|
+
raise ValueError(f"Unsupported input shape {x.shape}")
|
|
20
|
+
return self.model(x)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
def refine_nucleosome_calls(adata, layer_name, nan_mask_layer, hexamer_size=120, octamer_size=147, max_wiggle=40, device="cpu"):
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
nucleosome_layer = adata.layers[layer_name]
|
|
5
|
+
nan_mask = adata.layers[nan_mask_layer] # should be binary mask: 1 = nan region, 0 = valid data
|
|
6
|
+
|
|
7
|
+
hexamer_layer = np.zeros_like(nucleosome_layer)
|
|
8
|
+
octamer_layer = np.zeros_like(nucleosome_layer)
|
|
9
|
+
|
|
10
|
+
for read_idx, row in enumerate(nucleosome_layer):
|
|
11
|
+
in_patch = False
|
|
12
|
+
start_idx = None
|
|
13
|
+
|
|
14
|
+
for pos in range(len(row)):
|
|
15
|
+
if row[pos] == 1 and not in_patch:
|
|
16
|
+
in_patch = True
|
|
17
|
+
start_idx = pos
|
|
18
|
+
if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
|
|
19
|
+
in_patch = False
|
|
20
|
+
end_idx = pos if row[pos] == 0 else pos + 1
|
|
21
|
+
|
|
22
|
+
# Expand boundaries into NaNs
|
|
23
|
+
left_expand = 0
|
|
24
|
+
right_expand = 0
|
|
25
|
+
|
|
26
|
+
# Left
|
|
27
|
+
for i in range(1, max_wiggle + 1):
|
|
28
|
+
if start_idx - i >= 0 and nan_mask[read_idx, start_idx - i] == 1:
|
|
29
|
+
left_expand += 1
|
|
30
|
+
else:
|
|
31
|
+
break
|
|
32
|
+
# Right
|
|
33
|
+
for i in range(1, max_wiggle + 1):
|
|
34
|
+
if end_idx + i < nucleosome_layer.shape[1] and nan_mask[read_idx, end_idx + i] == 1:
|
|
35
|
+
right_expand += 1
|
|
36
|
+
else:
|
|
37
|
+
break
|
|
38
|
+
|
|
39
|
+
expanded_start = start_idx - left_expand
|
|
40
|
+
expanded_end = end_idx + right_expand
|
|
41
|
+
|
|
42
|
+
available_size = expanded_end - expanded_start
|
|
43
|
+
|
|
44
|
+
# Octamer placement
|
|
45
|
+
if available_size >= octamer_size:
|
|
46
|
+
center = (expanded_start + expanded_end) // 2
|
|
47
|
+
half_oct = octamer_size // 2
|
|
48
|
+
octamer_layer[read_idx, center - half_oct: center - half_oct + octamer_size] = 1
|
|
49
|
+
|
|
50
|
+
# Hexamer placement
|
|
51
|
+
elif available_size >= hexamer_size:
|
|
52
|
+
center = (expanded_start + expanded_end) // 2
|
|
53
|
+
half_hex = hexamer_size // 2
|
|
54
|
+
hexamer_layer[read_idx, center - half_hex: center - half_hex + hexamer_size] = 1
|
|
55
|
+
|
|
56
|
+
adata.layers[f"{layer_name}_hexamers"] = hexamer_layer
|
|
57
|
+
adata.layers[f"{layer_name}_octamers"] = octamer_layer
|
|
58
|
+
|
|
59
|
+
print(f"✅ Added layers: {layer_name}_hexamers and {layer_name}_octamers")
|
|
60
|
+
return adata
|
|
61
|
+
|
|
62
|
+
def infer_nucleosomes_in_large_bound(adata, large_bound_layer, combined_nuc_layer, nan_mask_layer, nuc_size=147, linker_size=50, exclusion_buffer=30, device="cpu"):
|
|
63
|
+
import numpy as np
|
|
64
|
+
|
|
65
|
+
large_bound = adata.layers[large_bound_layer]
|
|
66
|
+
existing_nucs = adata.layers[combined_nuc_layer]
|
|
67
|
+
nan_mask = adata.layers[nan_mask_layer]
|
|
68
|
+
|
|
69
|
+
inferred_layer = np.zeros_like(large_bound)
|
|
70
|
+
|
|
71
|
+
for read_idx, row in enumerate(large_bound):
|
|
72
|
+
in_patch = False
|
|
73
|
+
start_idx = None
|
|
74
|
+
|
|
75
|
+
for pos in range(len(row)):
|
|
76
|
+
if row[pos] == 1 and not in_patch:
|
|
77
|
+
in_patch = True
|
|
78
|
+
start_idx = pos
|
|
79
|
+
if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
|
|
80
|
+
in_patch = False
|
|
81
|
+
end_idx = pos if row[pos] == 0 else pos + 1
|
|
82
|
+
|
|
83
|
+
# Adjust boundaries into flanking NaN regions without getting too close to existing nucleosomes
|
|
84
|
+
left_expand = start_idx
|
|
85
|
+
while left_expand > 0 and nan_mask[read_idx, left_expand - 1] == 1 and np.sum(existing_nucs[read_idx, max(0, left_expand - exclusion_buffer):left_expand]) == 0:
|
|
86
|
+
left_expand -= 1
|
|
87
|
+
|
|
88
|
+
right_expand = end_idx
|
|
89
|
+
while right_expand < row.shape[0] and nan_mask[read_idx, right_expand] == 1 and np.sum(existing_nucs[read_idx, right_expand:min(row.shape[0], right_expand + exclusion_buffer)]) == 0:
|
|
90
|
+
right_expand += 1
|
|
91
|
+
|
|
92
|
+
# Phase nucleosomes with linker spacing only
|
|
93
|
+
region = (left_expand, right_expand)
|
|
94
|
+
pos_cursor = region[0]
|
|
95
|
+
while pos_cursor + nuc_size <= region[1]:
|
|
96
|
+
if np.all((existing_nucs[read_idx, pos_cursor - exclusion_buffer:pos_cursor + nuc_size + exclusion_buffer] == 0)):
|
|
97
|
+
inferred_layer[read_idx, pos_cursor:pos_cursor + nuc_size] = 1
|
|
98
|
+
pos_cursor += nuc_size + linker_size
|
|
99
|
+
else:
|
|
100
|
+
pos_cursor += 1
|
|
101
|
+
|
|
102
|
+
adata.layers[f"{large_bound_layer}_phased_nucleosomes"] = inferred_layer
|
|
103
|
+
print(f"✅ Added layer: {large_bound_layer}_phased_nucleosomes")
|
|
104
|
+
return adata
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# ------------------------- Utilities -------------------------
|
|
2
|
+
def random_fill_nans(X):
|
|
3
|
+
import numpy as np
|
|
4
|
+
nan_mask = np.isnan(X)
|
|
5
|
+
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
|
+
return X
|
|
7
|
+
|
|
8
|
+
def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
9
|
+
"""
|
|
10
|
+
Perform Bayesian-style methylation vs activity analysis independently within each group.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
adata (AnnData): Annotated data matrix.
|
|
14
|
+
sites (list of str): List of site keys (e.g., ['GpC_site', 'CpG_site']).
|
|
15
|
+
alpha (float): FDR threshold for significance.
|
|
16
|
+
groupby (str or list of str): Column(s) in adata.obs to group by.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
results_dict (dict): Dictionary with structure:
|
|
20
|
+
results_dict[ref][group_label] = (results_df, sig_df)
|
|
21
|
+
"""
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
from scipy.stats import fisher_exact
|
|
25
|
+
from statsmodels.stats.multitest import multipletests
|
|
26
|
+
|
|
27
|
+
def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
|
|
28
|
+
p_adj = multipletests(p_values, method='fdr_bh')[1] if p_values else []
|
|
29
|
+
|
|
30
|
+
genomic_positions = np.array(site_subset.var_names)[positions_list]
|
|
31
|
+
is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
|
|
32
|
+
is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
|
|
33
|
+
|
|
34
|
+
results_df = pd.DataFrame({
|
|
35
|
+
'Feature_Index': positions_list,
|
|
36
|
+
'Genomic_Position': genomic_positions.astype(int),
|
|
37
|
+
'Relative_Risk': relative_risks,
|
|
38
|
+
'Adjusted_P_Value': p_adj,
|
|
39
|
+
'GpC_Site': is_gpc_site,
|
|
40
|
+
'CpG_Site': is_cpg_site
|
|
41
|
+
})
|
|
42
|
+
|
|
43
|
+
results_df['log2_Relative_Risk'] = np.log2(results_df['Relative_Risk'].replace(0, 1e-300))
|
|
44
|
+
results_df['-log10_Adj_P'] = -np.log10(results_df['Adjusted_P_Value'].replace(0, 1e-300))
|
|
45
|
+
sig_df = results_df[results_df['Adjusted_P_Value'] < alpha]
|
|
46
|
+
return results_df, sig_df
|
|
47
|
+
|
|
48
|
+
results_dict = {}
|
|
49
|
+
|
|
50
|
+
for ref in adata.obs['Reference_strand'].unique():
|
|
51
|
+
ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
|
|
52
|
+
if ref_subset.shape[0] == 0:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
# Normalize groupby to list
|
|
56
|
+
if groupby is not None:
|
|
57
|
+
if isinstance(groupby, str):
|
|
58
|
+
groupby = [groupby]
|
|
59
|
+
def format_group_label(row):
|
|
60
|
+
return ",".join([f"{col}={row[col]}" for col in groupby])
|
|
61
|
+
|
|
62
|
+
combined_label = '__'.join(groupby)
|
|
63
|
+
ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
|
|
64
|
+
groups = ref_subset.obs[combined_label].unique()
|
|
65
|
+
else:
|
|
66
|
+
combined_label = None
|
|
67
|
+
groups = ['all']
|
|
68
|
+
|
|
69
|
+
results_dict[ref] = {}
|
|
70
|
+
|
|
71
|
+
for group in groups:
|
|
72
|
+
if group == 'all':
|
|
73
|
+
group_subset = ref_subset
|
|
74
|
+
else:
|
|
75
|
+
group_subset = ref_subset[ref_subset.obs[combined_label] == group]
|
|
76
|
+
|
|
77
|
+
if group_subset.shape[0] == 0:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# Build site mask
|
|
81
|
+
site_mask = np.zeros(group_subset.shape[1], dtype=bool)
|
|
82
|
+
for site in sites:
|
|
83
|
+
site_mask |= group_subset.var[f"{ref}_{site}"]
|
|
84
|
+
site_subset = group_subset[:, site_mask].copy()
|
|
85
|
+
|
|
86
|
+
# Matrix and labels
|
|
87
|
+
X = random_fill_nans(site_subset.X.copy())
|
|
88
|
+
y = site_subset.obs['activity_status'].map({'Active': 1, 'Silent': 0}).values
|
|
89
|
+
P_active = np.mean(y)
|
|
90
|
+
|
|
91
|
+
# Analysis
|
|
92
|
+
positions_list, relative_risks, p_values = [], [], []
|
|
93
|
+
for pos in range(X.shape[1]):
|
|
94
|
+
methylation_state = (X[:, pos] > 0).astype(int)
|
|
95
|
+
table = pd.crosstab(methylation_state, y)
|
|
96
|
+
if table.shape != (2, 2):
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
P_methylated = np.mean(methylation_state)
|
|
100
|
+
P_methylated_given_active = np.mean(methylation_state[y == 1])
|
|
101
|
+
P_methylated_given_inactive = np.mean(methylation_state[y == 0])
|
|
102
|
+
|
|
103
|
+
if P_methylated_given_inactive == 0 or P_methylated in [0, 1]:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
|
|
107
|
+
P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (1 - P_methylated)
|
|
108
|
+
RR = P_active_given_methylated / P_active_given_unmethylated
|
|
109
|
+
|
|
110
|
+
_, p_value = fisher_exact(table)
|
|
111
|
+
positions_list.append(pos)
|
|
112
|
+
relative_risks.append(RR)
|
|
113
|
+
p_values.append(p_value)
|
|
114
|
+
|
|
115
|
+
results_df, sig_df = compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values)
|
|
116
|
+
results_dict[ref][group] = (results_df, sig_df)
|
|
117
|
+
|
|
118
|
+
return results_dict
|
|
119
|
+
|
|
120
|
+
def compute_positionwise_statistic(
|
|
121
|
+
adata,
|
|
122
|
+
layer,
|
|
123
|
+
method="pearson",
|
|
124
|
+
groupby=["Reference_strand"],
|
|
125
|
+
output_key="positionwise_result",
|
|
126
|
+
site_config=None,
|
|
127
|
+
encoding="signed",
|
|
128
|
+
max_threads=None
|
|
129
|
+
):
|
|
130
|
+
"""
|
|
131
|
+
Computes a position-by-position matrix (correlation, RR, or Chi-squared) from an adata layer.
|
|
132
|
+
|
|
133
|
+
Parameters:
|
|
134
|
+
adata (AnnData): Annotated data matrix.
|
|
135
|
+
layer (str): Name of the adata layer to use.
|
|
136
|
+
method (str): 'pearson', 'binary_covariance', 'relative_risk', or 'chi_squared'.
|
|
137
|
+
groupby (str or list): Column(s) in adata.obs to group by.
|
|
138
|
+
output_key (str): Key in adata.uns to store results.
|
|
139
|
+
site_config (dict): Optional {ref: [site_types]} to restrict sites per reference.
|
|
140
|
+
encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
|
|
141
|
+
max_threads (int): Number of parallel threads to use (joblib).
|
|
142
|
+
"""
|
|
143
|
+
import numpy as np
|
|
144
|
+
import pandas as pd
|
|
145
|
+
from scipy.stats import chi2_contingency
|
|
146
|
+
from joblib import Parallel, delayed
|
|
147
|
+
from tqdm import tqdm
|
|
148
|
+
|
|
149
|
+
if isinstance(groupby, str):
|
|
150
|
+
groupby = [groupby]
|
|
151
|
+
|
|
152
|
+
adata.uns[output_key] = {}
|
|
153
|
+
adata.uns[output_key + "_n"] = {}
|
|
154
|
+
|
|
155
|
+
label_col = "__".join(groupby)
|
|
156
|
+
adata.obs[label_col] = adata.obs[groupby].astype(str).agg("_".join, axis=1)
|
|
157
|
+
|
|
158
|
+
for group in adata.obs[label_col].unique():
|
|
159
|
+
subset = adata[adata.obs[label_col] == group].copy()
|
|
160
|
+
if subset.shape[0] == 0:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
ref = subset.obs["Reference_strand"].unique()[0] if "Reference_strand" in groupby else None
|
|
164
|
+
|
|
165
|
+
if site_config and ref in site_config:
|
|
166
|
+
site_mask = np.zeros(subset.shape[1], dtype=bool)
|
|
167
|
+
for site in site_config[ref]:
|
|
168
|
+
site_mask |= subset.var[f"{ref}_{site}"]
|
|
169
|
+
subset = subset[:, site_mask].copy()
|
|
170
|
+
|
|
171
|
+
X = subset.layers[layer].copy()
|
|
172
|
+
|
|
173
|
+
if encoding == "signed":
|
|
174
|
+
X_bin = np.where(X == 1, 1, np.where(X == -1, 0, np.nan))
|
|
175
|
+
else:
|
|
176
|
+
X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
|
|
177
|
+
|
|
178
|
+
n_pos = subset.shape[1]
|
|
179
|
+
mat = np.zeros((n_pos, n_pos))
|
|
180
|
+
|
|
181
|
+
if method == "pearson":
|
|
182
|
+
with np.errstate(invalid='ignore'):
|
|
183
|
+
mat = np.corrcoef(np.nan_to_num(X_bin).T)
|
|
184
|
+
|
|
185
|
+
elif method == "binary_covariance":
|
|
186
|
+
binary = (X_bin == 1).astype(float)
|
|
187
|
+
valid = (X_bin == 1) | (X_bin == 0) # Only consider true binary (ignore NaN)
|
|
188
|
+
valid = valid.astype(float)
|
|
189
|
+
|
|
190
|
+
numerator = np.dot(binary.T, binary)
|
|
191
|
+
denominator = np.dot(valid.T, valid)
|
|
192
|
+
|
|
193
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
194
|
+
mat = np.true_divide(numerator, denominator)
|
|
195
|
+
mat[~np.isfinite(mat)] = 0
|
|
196
|
+
|
|
197
|
+
elif method in ["relative_risk", "chi_squared"]:
|
|
198
|
+
def compute_row(i):
|
|
199
|
+
row = np.zeros(n_pos)
|
|
200
|
+
xi = X_bin[:, i]
|
|
201
|
+
for j in range(n_pos):
|
|
202
|
+
xj = X_bin[:, j]
|
|
203
|
+
mask = ~np.isnan(xi) & ~np.isnan(xj)
|
|
204
|
+
if np.sum(mask) < 10:
|
|
205
|
+
row[j] = np.nan
|
|
206
|
+
continue
|
|
207
|
+
if method == "relative_risk":
|
|
208
|
+
a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
|
|
209
|
+
b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
|
|
210
|
+
c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
|
|
211
|
+
d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
|
|
212
|
+
if (a + b) > 0 and (c + d) > 0 and c > 0:
|
|
213
|
+
p1 = a / (a + b)
|
|
214
|
+
p2 = c / (c + d)
|
|
215
|
+
row[j] = p1 / p2 if p2 > 0 else np.nan
|
|
216
|
+
else:
|
|
217
|
+
row[j] = np.nan
|
|
218
|
+
elif method == "chi_squared":
|
|
219
|
+
table = pd.crosstab(xi[mask], xj[mask])
|
|
220
|
+
if table.shape != (2, 2):
|
|
221
|
+
row[j] = np.nan
|
|
222
|
+
else:
|
|
223
|
+
chi2, _, _, _ = chi2_contingency(table, correction=False)
|
|
224
|
+
row[j] = chi2
|
|
225
|
+
return row
|
|
226
|
+
|
|
227
|
+
mat = np.array(
|
|
228
|
+
Parallel(n_jobs=max_threads)(
|
|
229
|
+
delayed(compute_row)(i) for i in tqdm(range(n_pos), desc=f"{method}: {group}")
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
else:
|
|
234
|
+
raise ValueError(f"Unsupported method: {method}")
|
|
235
|
+
|
|
236
|
+
var_names = subset.var_names.astype(int)
|
|
237
|
+
mat_df = pd.DataFrame(mat, index=var_names, columns=var_names)
|
|
238
|
+
adata.uns[output_key][group] = mat_df
|
|
239
|
+
adata.uns[output_key + "_n"][group] = subset.shape[0]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# ------------------------- Utilities -------------------------
|
|
2
|
+
def random_fill_nans(X):
|
|
3
|
+
import numpy as np
|
|
4
|
+
nan_mask = np.isnan(X)
|
|
5
|
+
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
|
+
return X
|
|
7
|
+
|
|
8
|
+
def calculate_row_entropy(
|
|
9
|
+
adata,
|
|
10
|
+
layer,
|
|
11
|
+
output_key="entropy",
|
|
12
|
+
site_config=None,
|
|
13
|
+
ref_col="Reference_strand",
|
|
14
|
+
encoding="signed",
|
|
15
|
+
max_threads=None):
|
|
16
|
+
"""
|
|
17
|
+
Adds an obs column to the adata that calculates entropy within each read from a given layer
|
|
18
|
+
when looking at each site type passed in the site_config list.
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
adata (AnnData): The annotated data matrix.
|
|
22
|
+
layer (str): Name of the layer to use for entropy calculation.
|
|
23
|
+
method (str): Unused currently. Placeholder for potential future methods.
|
|
24
|
+
output_key (str): Base name for the entropy column in adata.obs.
|
|
25
|
+
site_config (dict): {ref: [site_types]} for masking relevant sites.
|
|
26
|
+
ref_col (str): Column in adata.obs denoting reference strands.
|
|
27
|
+
encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
|
|
28
|
+
max_threads (int): Number of threads for parallel processing.
|
|
29
|
+
"""
|
|
30
|
+
import numpy as np
|
|
31
|
+
import pandas as pd
|
|
32
|
+
from scipy.stats import entropy
|
|
33
|
+
from joblib import Parallel, delayed
|
|
34
|
+
from tqdm import tqdm
|
|
35
|
+
|
|
36
|
+
entropy_values = []
|
|
37
|
+
row_indices = []
|
|
38
|
+
|
|
39
|
+
for ref in adata.obs[ref_col].cat.categories:
|
|
40
|
+
subset = adata[adata.obs[ref_col] == ref].copy()
|
|
41
|
+
if subset.shape[0] == 0:
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
if site_config and ref in site_config:
|
|
45
|
+
site_mask = np.zeros(subset.shape[1], dtype=bool)
|
|
46
|
+
for site in site_config[ref]:
|
|
47
|
+
site_mask |= subset.var[f"{ref}_{site}"]
|
|
48
|
+
subset = subset[:, site_mask].copy()
|
|
49
|
+
|
|
50
|
+
X = subset.layers[layer].copy()
|
|
51
|
+
|
|
52
|
+
if encoding == "signed":
|
|
53
|
+
X_bin = np.where(X == 1, 1, np.where(X == -1, 0, np.nan))
|
|
54
|
+
else:
|
|
55
|
+
X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
|
|
56
|
+
|
|
57
|
+
def compute_entropy(row):
|
|
58
|
+
counts = pd.Series(row).value_counts(dropna=True).sort_index()
|
|
59
|
+
probs = counts / counts.sum()
|
|
60
|
+
return entropy(probs, base=2)
|
|
61
|
+
|
|
62
|
+
entropies = Parallel(n_jobs=max_threads)(
|
|
63
|
+
delayed(compute_entropy)(X_bin[i, :]) for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
entropy_values.extend(entropies)
|
|
67
|
+
row_indices.extend(subset.obs_names.tolist())
|
|
68
|
+
|
|
69
|
+
entropy_key = f"{output_key}_entropy"
|
|
70
|
+
adata.obs.loc[row_indices, entropy_key] = entropy_values
|
smftools/tools/subset_adata.py
CHANGED
|
@@ -1,32 +1,28 @@
|
|
|
1
1
|
# subset_adata
|
|
2
2
|
|
|
3
|
-
def subset_adata(adata,
|
|
3
|
+
def subset_adata(adata, columns, cat_type='obs'):
|
|
4
4
|
"""
|
|
5
|
-
|
|
5
|
+
Adds subset metadata to an AnnData object based on categorical values in specified .obs or .var columns.
|
|
6
6
|
|
|
7
7
|
Parameters:
|
|
8
|
-
adata (AnnData): The AnnData object to subset.
|
|
9
|
-
|
|
8
|
+
adata (AnnData): The AnnData object to add subset metadata to.
|
|
9
|
+
columns (list of str): List of .obs or .var column names to subset by. The order matters.
|
|
10
|
+
cat_type (str): obs or var. Default is obs
|
|
10
11
|
|
|
11
12
|
Returns:
|
|
12
|
-
|
|
13
|
+
None
|
|
13
14
|
"""
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import anndata as ad
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
return subsets
|
|
28
|
-
|
|
29
|
-
# Start the recursive subset process
|
|
30
|
-
subsets_dict = subset_recursive(adata, obs_columns)
|
|
31
|
-
|
|
32
|
-
return subsets_dict
|
|
18
|
+
subgroup_name = '_'.join(columns)
|
|
19
|
+
if 'obs' in cat_type:
|
|
20
|
+
df = adata.obs[columns]
|
|
21
|
+
adata.obs[subgroup_name] = df.apply(lambda row: '_'.join(row.astype(str)), axis=1)
|
|
22
|
+
adata.obs[subgroup_name] = adata.obs[subgroup_name].astype('category')
|
|
23
|
+
elif 'var' in cat_type:
|
|
24
|
+
df = adata.var[columns]
|
|
25
|
+
adata.var[subgroup_name] = df.apply(lambda row: '_'.join(row.astype(str)), axis=1)
|
|
26
|
+
adata.var[subgroup_name] = adata.var[subgroup_name].astype('category')
|
|
27
|
+
|
|
28
|
+
return None
|