smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
from .positional import PositionalEncoding
|
|
5
|
+
from ..utils.grl import grad_reverse
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def forward(self, src, src_mask=None, is_causal=False, src_key_padding_mask=None):
|
|
13
|
+
self_attn_output, attn_weights = self.self_attn(
|
|
14
|
+
src, src, src,
|
|
15
|
+
attn_mask=src_mask,
|
|
16
|
+
key_padding_mask=src_key_padding_mask,
|
|
17
|
+
need_weights=True,
|
|
18
|
+
average_attn_weights=False, # preserve [B, num_heads, S, S]
|
|
19
|
+
is_causal=is_causal
|
|
20
|
+
)
|
|
21
|
+
src = src + self.dropout1(self_attn_output)
|
|
22
|
+
src = self.norm1(src)
|
|
23
|
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
24
|
+
src = src + self.dropout2(src2)
|
|
25
|
+
src = self.norm2(src)
|
|
26
|
+
|
|
27
|
+
# Save attention weights to module
|
|
28
|
+
self.attn_weights = attn_weights # Save to layer
|
|
29
|
+
return src
|
|
30
|
+
|
|
31
|
+
class BaseTransformer(BaseTorchModel):
|
|
32
|
+
def __init__(self,
|
|
33
|
+
input_dim=1,
|
|
34
|
+
model_dim=64,
|
|
35
|
+
num_heads=4,
|
|
36
|
+
num_layers=2,
|
|
37
|
+
dropout=0.2,
|
|
38
|
+
seq_len=None,
|
|
39
|
+
use_learnable_pos=False,
|
|
40
|
+
use_cls_token=True,
|
|
41
|
+
**kwargs):
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
# Input FC layer to map D_input to D_model
|
|
44
|
+
self.model_dim = model_dim
|
|
45
|
+
self.input_fc = nn.Linear(input_dim, model_dim)
|
|
46
|
+
self.ff_dim = model_dim * 4
|
|
47
|
+
self.dropout = dropout
|
|
48
|
+
self.use_cls_token = use_cls_token
|
|
49
|
+
|
|
50
|
+
self.attn_weights = []
|
|
51
|
+
self.attn_grads = []
|
|
52
|
+
|
|
53
|
+
if use_learnable_pos:
|
|
54
|
+
assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
|
|
55
|
+
self.pos_embed = nn.Parameter(torch.randn(seq_len + (1 if use_cls_token else 0), model_dim))
|
|
56
|
+
self.pos_encoder = None
|
|
57
|
+
else:
|
|
58
|
+
self.pos_encoder = PositionalEncoding(model_dim)
|
|
59
|
+
self.pos_embed = None
|
|
60
|
+
|
|
61
|
+
if self.use_cls_token:
|
|
62
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim)) # (1, 1, D)
|
|
63
|
+
|
|
64
|
+
# Specify the transformer encoder structure
|
|
65
|
+
encoder_layer = TransformerEncoderLayerWithAttn(d_model=model_dim, nhead=num_heads, batch_first=True, dim_feedforward=self.ff_dim, dropout=self.dropout)
|
|
66
|
+
# Stack the transformer encoder layers
|
|
67
|
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
68
|
+
|
|
69
|
+
# Register hooks
|
|
70
|
+
for layer in self.transformer.layers:
|
|
71
|
+
layer.self_attn.register_forward_hook(self._save_attn_weights)
|
|
72
|
+
layer.self_attn.register_full_backward_hook(self._save_attn_grads)
|
|
73
|
+
|
|
74
|
+
def _save_attn_weights(self, module, input, output):
|
|
75
|
+
self.attn_weights.append(output[1].detach())
|
|
76
|
+
|
|
77
|
+
def _save_attn_grads(self, module, grad_input, grad_output):
|
|
78
|
+
self.attn_grads.append(grad_output[0].detach())
|
|
79
|
+
|
|
80
|
+
def encode(self, x, mask=None):
|
|
81
|
+
if x.dim() == 2: # (B, S)
|
|
82
|
+
x = x.unsqueeze(-1)
|
|
83
|
+
elif x.dim() == 1: # (S,)
|
|
84
|
+
x = x.unsqueeze(0).unsqueeze(-1)
|
|
85
|
+
elif x.dim() == 3:
|
|
86
|
+
pass
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError(f"Unexpected input shape: {x.shape}")
|
|
89
|
+
|
|
90
|
+
x = self.input_fc(x) # (B, S, D)
|
|
91
|
+
|
|
92
|
+
B, S, D = x.shape
|
|
93
|
+
if self.use_cls_token:
|
|
94
|
+
cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
|
|
95
|
+
x = torch.cat([cls, x], dim=1) # (B, S+1, D)
|
|
96
|
+
|
|
97
|
+
if self.pos_embed is not None:
|
|
98
|
+
x = x + self.pos_embed.unsqueeze(0)[:, :x.shape[1], :]
|
|
99
|
+
elif self.pos_encoder is not None:
|
|
100
|
+
x = self.pos_encoder(x)
|
|
101
|
+
|
|
102
|
+
if mask is not None:
|
|
103
|
+
pad = torch.ones(B, 1, device=mask.device) if self.use_cls_token else 0
|
|
104
|
+
mask = torch.cat([pad, mask], dim=1) if self.use_cls_token else mask
|
|
105
|
+
x = x * mask.unsqueeze(-1)
|
|
106
|
+
|
|
107
|
+
encoded = self.transformer(x)
|
|
108
|
+
return encoded
|
|
109
|
+
|
|
110
|
+
def compute_attn_grad(self, reduction='mean'):
|
|
111
|
+
"""
|
|
112
|
+
Computes attention × gradient scores across layers.
|
|
113
|
+
Returns: [B, S] tensor of importance scores
|
|
114
|
+
"""
|
|
115
|
+
scores = []
|
|
116
|
+
for attn, grad in zip(self.attn_weights, self.attn_grads):
|
|
117
|
+
# attn: [B, H, S, S]
|
|
118
|
+
# grad: [B, S, D]
|
|
119
|
+
attn = attn.mean(dim=1) # [B, S, S]
|
|
120
|
+
grad_norm = grad.norm(dim=-1) # [B, S]
|
|
121
|
+
attn_grad_score = (attn * grad_norm.unsqueeze(1)).sum(dim=-1) # [B, S]
|
|
122
|
+
scores.append(attn_grad_score)
|
|
123
|
+
|
|
124
|
+
# Combine across layers
|
|
125
|
+
stacked = torch.stack(scores, dim=0) # [L, B, S]
|
|
126
|
+
if reduction == "mean":
|
|
127
|
+
return stacked.mean(dim=0) # [B, S]
|
|
128
|
+
elif reduction == "sum":
|
|
129
|
+
return stacked.sum(dim=0) # [B, S]
|
|
130
|
+
else:
|
|
131
|
+
return stacked # [L, B, S]
|
|
132
|
+
|
|
133
|
+
def compute_rollout(self):
|
|
134
|
+
"""
|
|
135
|
+
Computes attention rollout: [B, S, S] final attention influence map
|
|
136
|
+
"""
|
|
137
|
+
device = self.attn_weights[0].device
|
|
138
|
+
B, S = self.attn_weights[0].shape[0], self.attn_weights[0].shape[-1]
|
|
139
|
+
rollout = torch.eye(S, device=device).unsqueeze(0).repeat(B, 1, 1) # [B, S, S]
|
|
140
|
+
|
|
141
|
+
for attn in self.attn_weights:
|
|
142
|
+
attn_heads = attn.mean(dim=1) # [B, S, S]
|
|
143
|
+
attn_heads = attn_heads + torch.eye(S, device=device).unsqueeze(0) # add residual
|
|
144
|
+
attn_heads = attn_heads / attn_heads.sum(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
145
|
+
rollout = torch.bmm(attn_heads, rollout) # [B, S, S]
|
|
146
|
+
|
|
147
|
+
return rollout # [B, S, S]
|
|
148
|
+
|
|
149
|
+
def reset_attn_buffers(self):
|
|
150
|
+
self.attn_weights = []
|
|
151
|
+
self.attn_grads = []
|
|
152
|
+
|
|
153
|
+
def get_attn_layer(self, layer_idx=0, head_idx=None):
|
|
154
|
+
"""
|
|
155
|
+
Returns attention map from a specific layer (and optionally head).
|
|
156
|
+
"""
|
|
157
|
+
attn = self.attn_weights[layer_idx] # [B, H, S, S]
|
|
158
|
+
if head_idx is not None:
|
|
159
|
+
attn = attn[:, head_idx] # [B, S, S]
|
|
160
|
+
return attn
|
|
161
|
+
|
|
162
|
+
def apply_attn_interpretations_to_adata(self, dataloader, adata,
|
|
163
|
+
obsm_key_grad="attn_grad",
|
|
164
|
+
obsm_key_rollout="attn_rollout",
|
|
165
|
+
device="cpu"):
|
|
166
|
+
self.to(device)
|
|
167
|
+
self.eval()
|
|
168
|
+
grad_maps = []
|
|
169
|
+
rollout_maps = []
|
|
170
|
+
|
|
171
|
+
for batch in dataloader:
|
|
172
|
+
x = batch[0].to(device)
|
|
173
|
+
x.requires_grad_()
|
|
174
|
+
|
|
175
|
+
self.reset_attn_buffers()
|
|
176
|
+
logits = self(x)
|
|
177
|
+
|
|
178
|
+
if logits.shape[1] == 1:
|
|
179
|
+
target_score = logits.squeeze()
|
|
180
|
+
else:
|
|
181
|
+
target_score = logits.max(dim=1).values
|
|
182
|
+
|
|
183
|
+
target_score.sum().backward()
|
|
184
|
+
|
|
185
|
+
grad = self.compute_attn_grad() # [B, S+1]
|
|
186
|
+
if self.use_cls_token:
|
|
187
|
+
grad = grad[:, 1:] # ignore CLS token
|
|
188
|
+
grad_maps.append(grad.detach().cpu().numpy())
|
|
189
|
+
|
|
190
|
+
grad_concat = np.concatenate(grad_maps, axis=0)
|
|
191
|
+
adata.obsm[obsm_key_grad] = grad_concat
|
|
192
|
+
|
|
193
|
+
# add per-row normalized version
|
|
194
|
+
grad_normed = grad_concat / (np.max(grad_concat, axis=1, keepdims=True) + 1e-8)
|
|
195
|
+
adata.obsm[f"{obsm_key_grad}_normalized"] = grad_normed
|
|
196
|
+
|
|
197
|
+
class TransformerClassifier(BaseTransformer):
|
|
198
|
+
def __init__(self,
|
|
199
|
+
input_dim,
|
|
200
|
+
num_classes,
|
|
201
|
+
**kwargs):
|
|
202
|
+
super().__init__(input_dim, **kwargs)
|
|
203
|
+
# Classification head
|
|
204
|
+
output_size = 1 if num_classes == 2 else num_classes
|
|
205
|
+
self.cls_head = nn.Linear(self.model_dim, output_size)
|
|
206
|
+
|
|
207
|
+
def forward(self, x):
|
|
208
|
+
"""
|
|
209
|
+
x: (batch, seq_len, input_dim)
|
|
210
|
+
"""
|
|
211
|
+
self.reset_attn_buffers()
|
|
212
|
+
if x.dim() == 2: # shape (B, S)
|
|
213
|
+
x = x.unsqueeze(-1) # → (B, S, 1)
|
|
214
|
+
elif x.dim() == 1:
|
|
215
|
+
x = x.unsqueeze(0).unsqueeze(-1) # just in case (S,) → (1, S, 1)
|
|
216
|
+
else:
|
|
217
|
+
pass
|
|
218
|
+
encoded = self.encode(x) # -> (B, S, D_model)
|
|
219
|
+
if self.use_cls_token:
|
|
220
|
+
pooled = encoded[:, 0] # (B, D)
|
|
221
|
+
else:
|
|
222
|
+
pooled = encoded.mean(dim=1) # (B, D) out = self.cls_head(pooled) # -> (B, C)
|
|
223
|
+
|
|
224
|
+
out = self.cls_head(pooled) # (B, C)
|
|
225
|
+
return out
|
|
226
|
+
|
|
227
|
+
class DANNTransformerClassifier(TransformerClassifier):
|
|
228
|
+
def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
|
|
229
|
+
super().__init__(input_dim, model_dim, num_classes, **kwargs)
|
|
230
|
+
self.domain_classifier = nn.Sequential(
|
|
231
|
+
nn.Linear(model_dim, 128),
|
|
232
|
+
nn.ReLU(),
|
|
233
|
+
nn.Linear(128, n_domains)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def forward(self, x, alpha=1.0):
|
|
237
|
+
encoded = self.encode(x) # (B, S, D_model)
|
|
238
|
+
pooled = encoded.mean(dim=1) # (B, D_model)
|
|
239
|
+
|
|
240
|
+
class_logits = self.cls_head(pooled)
|
|
241
|
+
domain_logits = self.domain_classifier(grad_reverse(pooled, alpha))
|
|
242
|
+
|
|
243
|
+
return class_logits, domain_logits
|
|
244
|
+
|
|
245
|
+
class MaskedTransformerPretrainer(BaseTransformer):
|
|
246
|
+
def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
|
|
247
|
+
super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
|
|
248
|
+
self.decoder = nn.Linear(model_dim, input_dim)
|
|
249
|
+
|
|
250
|
+
def forward(self, x, mask):
|
|
251
|
+
"""
|
|
252
|
+
x: (batch, seq_len, input_dim)
|
|
253
|
+
mask: (batch, seq_len) optional
|
|
254
|
+
"""
|
|
255
|
+
if x.dim() == 2:
|
|
256
|
+
x = x.unsqueeze(-1)
|
|
257
|
+
encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
|
|
258
|
+
return self.decoder(encoded) # -> (B, D_input)
|
|
259
|
+
|
|
260
|
+
class DANNTransformer(BaseTransformer):
|
|
261
|
+
"""
|
|
262
|
+
"""
|
|
263
|
+
def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
|
|
264
|
+
super().__init__(
|
|
265
|
+
input_dim=1, # 1D scalar input per token
|
|
266
|
+
model_dim=model_dim,
|
|
267
|
+
num_heads=n_heads,
|
|
268
|
+
num_layers=n_layers,
|
|
269
|
+
seq_len=seq_len,
|
|
270
|
+
use_learnable_pos=True # enables learnable pos_embed in base
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Reconstruction head
|
|
274
|
+
self.recon_head = nn.Linear(model_dim, 1)
|
|
275
|
+
|
|
276
|
+
# Domain classification head
|
|
277
|
+
self.domain_classifier = nn.Sequential(
|
|
278
|
+
nn.Linear(model_dim, 128),
|
|
279
|
+
nn.ReLU(),
|
|
280
|
+
nn.Linear(128, n_domains)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def forward(self, x, alpha=1.0):
|
|
284
|
+
"""
|
|
285
|
+
x: Tensor of shape (B, S) or (B, S, 1)
|
|
286
|
+
alpha: GRL coefficient (float)
|
|
287
|
+
"""
|
|
288
|
+
if x.dim() == 2:
|
|
289
|
+
x = x.unsqueeze(-1) # (B, S, 1)
|
|
290
|
+
|
|
291
|
+
# Encode sequence
|
|
292
|
+
h = self.encode(x) # (B, S, D_model)
|
|
293
|
+
|
|
294
|
+
# Head 1: Reconstruction
|
|
295
|
+
recon = self.recon_head(h).squeeze(-1) # (B, S)
|
|
296
|
+
|
|
297
|
+
# Head 2: Domain classification via GRL
|
|
298
|
+
pooled = h.mean(dim=1) # (B, D_model)
|
|
299
|
+
rev = grad_reverse(pooled, alpha)
|
|
300
|
+
domain_logits = self.domain_classifier(rev) # (B, n_batches)
|
|
301
|
+
|
|
302
|
+
return recon, domain_logits
|
|
303
|
+
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from pytorch_lightning import Trainer
|
|
3
|
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
4
|
+
from ..data import AnnDataModule
|
|
5
|
+
from ..models import TorchClassifierWrapper
|
|
6
|
+
|
|
7
|
+
def train_lightning_model(
|
|
8
|
+
model,
|
|
9
|
+
datamodule,
|
|
10
|
+
max_epochs=30,
|
|
11
|
+
patience=5,
|
|
12
|
+
monitor_metric="val_loss",
|
|
13
|
+
checkpoint_path=None,
|
|
14
|
+
evaluate_test=True,
|
|
15
|
+
devices=1
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Takes a PyTorch Lightning Model and a Lightning DataLoader module to define a Lightning Trainer.
|
|
19
|
+
- The Lightning trainer fits the model to the training split of the datamodule.
|
|
20
|
+
- The Lightning trainer uses the validation split of the datamodule for monitoring training loss.
|
|
21
|
+
- Option of evaluating the trained model on a test set when evaluate_test is True.
|
|
22
|
+
- When using cuda, devices parameter can be: 1, [0,1], "all", "auto". Depending on what devices you want to use.
|
|
23
|
+
"""
|
|
24
|
+
# Device logic
|
|
25
|
+
if torch.cuda.is_available():
|
|
26
|
+
accelerator = "gpu"
|
|
27
|
+
elif torch.backends.mps.is_available():
|
|
28
|
+
accelerator = "mps"
|
|
29
|
+
devices = 1
|
|
30
|
+
else:
|
|
31
|
+
accelerator = "cpu"
|
|
32
|
+
devices = 1
|
|
33
|
+
|
|
34
|
+
# adds the train/val/test indices from the datamodule to the model class.
|
|
35
|
+
model.set_training_indices(datamodule)
|
|
36
|
+
|
|
37
|
+
# Callbacks
|
|
38
|
+
callbacks = [
|
|
39
|
+
EarlyStopping(monitor=monitor_metric, patience=patience, mode="min"),
|
|
40
|
+
]
|
|
41
|
+
if checkpoint_path:
|
|
42
|
+
callbacks.append(ModelCheckpoint(
|
|
43
|
+
dirpath=checkpoint_path,
|
|
44
|
+
filename="{epoch}-{val_loss:.4f}",
|
|
45
|
+
monitor=monitor_metric,
|
|
46
|
+
save_top_k=1,
|
|
47
|
+
mode="min",
|
|
48
|
+
))
|
|
49
|
+
|
|
50
|
+
# Trainer setup
|
|
51
|
+
trainer = Trainer(
|
|
52
|
+
max_epochs=max_epochs,
|
|
53
|
+
callbacks=callbacks,
|
|
54
|
+
accelerator=accelerator,
|
|
55
|
+
devices=devices,
|
|
56
|
+
log_every_n_steps=10,
|
|
57
|
+
enable_progress_bar=False
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Fit model with trainer
|
|
61
|
+
trainer.fit(model, datamodule=datamodule)
|
|
62
|
+
|
|
63
|
+
# Test model (if applicable)
|
|
64
|
+
if evaluate_test and hasattr(datamodule, "test_dataloader"):
|
|
65
|
+
trainer.test(model, datamodule=datamodule)
|
|
66
|
+
|
|
67
|
+
# Return best checkpoint path
|
|
68
|
+
best_ckpt = None
|
|
69
|
+
for cb in callbacks:
|
|
70
|
+
if isinstance(cb, ModelCheckpoint):
|
|
71
|
+
best_ckpt = cb.best_model_path
|
|
72
|
+
|
|
73
|
+
return trainer, best_ckpt
|
|
74
|
+
|
|
75
|
+
def run_sliding_window_lightning_training(
|
|
76
|
+
adata,
|
|
77
|
+
tensor_source,
|
|
78
|
+
tensor_key,
|
|
79
|
+
label_col,
|
|
80
|
+
model_class,
|
|
81
|
+
num_classes,
|
|
82
|
+
class_names,
|
|
83
|
+
class_weights,
|
|
84
|
+
focus_class,
|
|
85
|
+
window_size,
|
|
86
|
+
stride,
|
|
87
|
+
max_epochs=30,
|
|
88
|
+
patience=5,
|
|
89
|
+
enforce_eval_balance: bool=False,
|
|
90
|
+
target_eval_freq: float=0.3,
|
|
91
|
+
max_eval_positive: int=None
|
|
92
|
+
):
|
|
93
|
+
input_len = adata.shape[1]
|
|
94
|
+
results = {}
|
|
95
|
+
|
|
96
|
+
for start in range(0, input_len - window_size + 1, stride):
|
|
97
|
+
center_idx = start + window_size // 2
|
|
98
|
+
center_varname = adata.var_names[center_idx]
|
|
99
|
+
print(f"\nTraining window around {center_varname}")
|
|
100
|
+
|
|
101
|
+
# Build datamodule for this window
|
|
102
|
+
datamodule = AnnDataModule(
|
|
103
|
+
adata,
|
|
104
|
+
tensor_source=tensor_source,
|
|
105
|
+
tensor_key=tensor_key,
|
|
106
|
+
label_col=label_col,
|
|
107
|
+
batch_size=64,
|
|
108
|
+
window_start=start,
|
|
109
|
+
window_size=window_size
|
|
110
|
+
)
|
|
111
|
+
datamodule.setup()
|
|
112
|
+
|
|
113
|
+
# Build model for this window
|
|
114
|
+
model = model_class(window_size, num_classes)
|
|
115
|
+
wrapper = TorchClassifierWrapper(
|
|
116
|
+
model, label_col=label_col, num_classes=num_classes,
|
|
117
|
+
class_names=class_names,
|
|
118
|
+
class_weights=class_weights,
|
|
119
|
+
focus_class=focus_class, enforce_eval_balance=enforce_eval_balance,
|
|
120
|
+
target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Train model
|
|
124
|
+
trainer, ckpt = train_lightning_model(
|
|
125
|
+
wrapper, datamodule, max_epochs=max_epochs, patience=patience
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
results[center_varname] = {
|
|
129
|
+
"model": wrapper,
|
|
130
|
+
"trainer": trainer,
|
|
131
|
+
"checkpoint": ckpt,
|
|
132
|
+
"metrics": trainer.callback_metrics
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
return results
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from ..data import AnnDataModule
|
|
2
|
+
from ..models import SklearnModelWrapper
|
|
3
|
+
|
|
4
|
+
def train_sklearn_model(
|
|
5
|
+
model_wrapper,
|
|
6
|
+
datamodule,
|
|
7
|
+
evaluate_test=True,
|
|
8
|
+
evaluate_val=False
|
|
9
|
+
):
|
|
10
|
+
"""
|
|
11
|
+
Fits a SklearnModelWrapper on the train split from datamodule.
|
|
12
|
+
Evaluates on test and/or val set.
|
|
13
|
+
|
|
14
|
+
Parameters:
|
|
15
|
+
model_wrapper: SklearnModelWrapper instance
|
|
16
|
+
datamodule: AnnDataModule instance (with setup() method)
|
|
17
|
+
evaluate_test: whether to evaluate on test split
|
|
18
|
+
evaluate_val: whether to evaluate on validation split
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
metrics: dictionary containing evaluation metrics
|
|
22
|
+
"""
|
|
23
|
+
# Fit model
|
|
24
|
+
model_wrapper.fit_from_datamodule(datamodule)
|
|
25
|
+
|
|
26
|
+
# Evaluate
|
|
27
|
+
metrics = {}
|
|
28
|
+
|
|
29
|
+
if evaluate_val:
|
|
30
|
+
val_metrics = model_wrapper.evaluate_from_datamodule(datamodule, split="val")
|
|
31
|
+
metrics.update({f"{k}": v for k, v in val_metrics.items()})
|
|
32
|
+
|
|
33
|
+
if evaluate_test:
|
|
34
|
+
test_metrics = model_wrapper.evaluate_from_datamodule(datamodule, split="test")
|
|
35
|
+
metrics.update({f"{k}": v for k, v in test_metrics.items()})
|
|
36
|
+
|
|
37
|
+
# Plot evaluations
|
|
38
|
+
model_wrapper.plot_roc_pr_curves()
|
|
39
|
+
|
|
40
|
+
return metrics
|
|
41
|
+
|
|
42
|
+
def run_sliding_window_sklearn_training(
|
|
43
|
+
adata,
|
|
44
|
+
tensor_source,
|
|
45
|
+
tensor_key,
|
|
46
|
+
label_col,
|
|
47
|
+
model_class,
|
|
48
|
+
num_classes,
|
|
49
|
+
class_names,
|
|
50
|
+
focus_class,
|
|
51
|
+
window_size,
|
|
52
|
+
stride,
|
|
53
|
+
batch_size=64,
|
|
54
|
+
train_frac=0.6,
|
|
55
|
+
val_frac=0.1,
|
|
56
|
+
test_frac=0.3,
|
|
57
|
+
random_seed=42,
|
|
58
|
+
enforce_eval_balance=False,
|
|
59
|
+
target_eval_freq=0.3,
|
|
60
|
+
max_eval_positive=None,
|
|
61
|
+
**model_kwargs
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Sliding window training for sklearn models using AnnData.
|
|
65
|
+
|
|
66
|
+
Returns dict keyed by window center.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
input_len = adata.shape[1]
|
|
70
|
+
results = {}
|
|
71
|
+
|
|
72
|
+
for start in range(0, input_len - window_size + 1, stride):
|
|
73
|
+
center_idx = start + window_size // 2
|
|
74
|
+
center_varname = adata.var_names[center_idx]
|
|
75
|
+
print(f"\nTraining window around {center_varname}")
|
|
76
|
+
|
|
77
|
+
# Build datamodule for this window
|
|
78
|
+
datamodule = AnnDataModule(
|
|
79
|
+
adata,
|
|
80
|
+
tensor_source=tensor_source,
|
|
81
|
+
tensor_key=tensor_key,
|
|
82
|
+
label_col=label_col,
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
window_start=start,
|
|
85
|
+
window_size=window_size,
|
|
86
|
+
train_frac=train_frac,
|
|
87
|
+
val_frac=val_frac,
|
|
88
|
+
test_frac=test_frac,
|
|
89
|
+
random_seed=random_seed
|
|
90
|
+
)
|
|
91
|
+
datamodule.setup()
|
|
92
|
+
|
|
93
|
+
# Build model wrapper
|
|
94
|
+
sklearn_model = model_class(**model_kwargs)
|
|
95
|
+
wrapper = SklearnModelWrapper(
|
|
96
|
+
sklearn_model,
|
|
97
|
+
num_classes=num_classes,
|
|
98
|
+
label_col=label_col,
|
|
99
|
+
class_names=class_names,
|
|
100
|
+
focus_class=focus_class,
|
|
101
|
+
enforce_eval_balance=enforce_eval_balance,
|
|
102
|
+
target_eval_freq=target_eval_freq,
|
|
103
|
+
max_eval_positive=max_eval_positive
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Fit and evaluate
|
|
107
|
+
metrics = train_sklearn_model(wrapper, datamodule, evaluate_test=True, evaluate_val=False)
|
|
108
|
+
|
|
109
|
+
results[center_varname] = {
|
|
110
|
+
"model": wrapper,
|
|
111
|
+
"metrics": metrics
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
return results
|
smftools/plotting/__init__.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
from .autocorrelation_plotting import *
|
|
2
|
+
from .hmm_plotting import *
|
|
1
3
|
from .position_stats import plot_bar_relative_risk, plot_volcano_relative_risk, plot_positionwise_matrix, plot_positionwise_matrix_grid
|
|
2
|
-
from .general_plotting import combined_hmm_raw_clustermap
|
|
4
|
+
from .general_plotting import combined_hmm_raw_clustermap, combined_raw_clustermap, plot_hmm_layers_rolling_by_sample_ref
|
|
3
5
|
from .classifiers import plot_model_performance, plot_feature_importances_or_saliency, plot_model_curves_from_adata, plot_model_curves_from_adata_with_frequency_grid
|
|
6
|
+
from .qc_plotting import *
|
|
4
7
|
|
|
5
8
|
__all__ = [
|
|
6
9
|
"combined_hmm_raw_clustermap",
|