smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,22 +1,31 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from .base import BaseTorchModel
|
|
4
|
-
from .positional import PositionalEncoding
|
|
5
|
-
from ..utils.grl import grad_reverse
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
6
3
|
import numpy as np
|
|
7
4
|
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
from ..utils.grl import grad_reverse
|
|
8
|
+
from .base import BaseTorchModel
|
|
9
|
+
from .positional import PositionalEncoding
|
|
10
|
+
|
|
11
|
+
torch = require("torch", extra="ml-base", purpose="Transformer models")
|
|
12
|
+
nn = torch.nn
|
|
13
|
+
|
|
14
|
+
|
|
8
15
|
class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
9
16
|
def __init__(self, *args, **kwargs):
|
|
10
17
|
super().__init__(*args, **kwargs)
|
|
11
18
|
|
|
12
19
|
def forward(self, src, src_mask=None, is_causal=False, src_key_padding_mask=None):
|
|
13
20
|
self_attn_output, attn_weights = self.self_attn(
|
|
14
|
-
src,
|
|
21
|
+
src,
|
|
22
|
+
src,
|
|
23
|
+
src,
|
|
15
24
|
attn_mask=src_mask,
|
|
16
25
|
key_padding_mask=src_key_padding_mask,
|
|
17
26
|
need_weights=True,
|
|
18
27
|
average_attn_weights=False, # preserve [B, num_heads, S, S]
|
|
19
|
-
is_causal=is_causal
|
|
28
|
+
is_causal=is_causal,
|
|
20
29
|
)
|
|
21
30
|
src = src + self.dropout1(self_attn_output)
|
|
22
31
|
src = self.norm1(src)
|
|
@@ -27,18 +36,21 @@ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
|
27
36
|
# Save attention weights to module
|
|
28
37
|
self.attn_weights = attn_weights # Save to layer
|
|
29
38
|
return src
|
|
30
|
-
|
|
39
|
+
|
|
40
|
+
|
|
31
41
|
class BaseTransformer(BaseTorchModel):
|
|
32
|
-
def __init__(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
input_dim=1,
|
|
45
|
+
model_dim=64,
|
|
46
|
+
num_heads=4,
|
|
47
|
+
num_layers=2,
|
|
48
|
+
dropout=0.2,
|
|
49
|
+
seq_len=None,
|
|
50
|
+
use_learnable_pos=False,
|
|
51
|
+
use_cls_token=True,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
42
54
|
super().__init__(**kwargs)
|
|
43
55
|
# Input FC layer to map D_input to D_model
|
|
44
56
|
self.model_dim = model_dim
|
|
@@ -52,7 +64,9 @@ class BaseTransformer(BaseTorchModel):
|
|
|
52
64
|
|
|
53
65
|
if use_learnable_pos:
|
|
54
66
|
assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
|
|
55
|
-
self.pos_embed = nn.Parameter(
|
|
67
|
+
self.pos_embed = nn.Parameter(
|
|
68
|
+
torch.randn(seq_len + (1 if use_cls_token else 0), model_dim)
|
|
69
|
+
)
|
|
56
70
|
self.pos_encoder = None
|
|
57
71
|
else:
|
|
58
72
|
self.pos_encoder = PositionalEncoding(model_dim)
|
|
@@ -62,7 +76,13 @@ class BaseTransformer(BaseTorchModel):
|
|
|
62
76
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim)) # (1, 1, D)
|
|
63
77
|
|
|
64
78
|
# Specify the transformer encoder structure
|
|
65
|
-
encoder_layer = TransformerEncoderLayerWithAttn(
|
|
79
|
+
encoder_layer = TransformerEncoderLayerWithAttn(
|
|
80
|
+
d_model=model_dim,
|
|
81
|
+
nhead=num_heads,
|
|
82
|
+
batch_first=True,
|
|
83
|
+
dim_feedforward=self.ff_dim,
|
|
84
|
+
dropout=self.dropout,
|
|
85
|
+
)
|
|
66
86
|
# Stack the transformer encoder layers
|
|
67
87
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
68
88
|
|
|
@@ -95,7 +115,7 @@ class BaseTransformer(BaseTorchModel):
|
|
|
95
115
|
x = torch.cat([cls, x], dim=1) # (B, S+1, D)
|
|
96
116
|
|
|
97
117
|
if self.pos_embed is not None:
|
|
98
|
-
x = x + self.pos_embed.unsqueeze(0)[:, :x.shape[1], :]
|
|
118
|
+
x = x + self.pos_embed.unsqueeze(0)[:, : x.shape[1], :]
|
|
99
119
|
elif self.pos_encoder is not None:
|
|
100
120
|
x = self.pos_encoder(x)
|
|
101
121
|
|
|
@@ -106,8 +126,8 @@ class BaseTransformer(BaseTorchModel):
|
|
|
106
126
|
|
|
107
127
|
encoded = self.transformer(x)
|
|
108
128
|
return encoded
|
|
109
|
-
|
|
110
|
-
def compute_attn_grad(self, reduction=
|
|
129
|
+
|
|
130
|
+
def compute_attn_grad(self, reduction="mean"):
|
|
111
131
|
"""
|
|
112
132
|
Computes attention × gradient scores across layers.
|
|
113
133
|
Returns: [B, S] tensor of importance scores
|
|
@@ -116,19 +136,19 @@ class BaseTransformer(BaseTorchModel):
|
|
|
116
136
|
for attn, grad in zip(self.attn_weights, self.attn_grads):
|
|
117
137
|
# attn: [B, H, S, S]
|
|
118
138
|
# grad: [B, S, D]
|
|
119
|
-
attn = attn.mean(dim=1)
|
|
120
|
-
grad_norm = grad.norm(dim=-1)
|
|
139
|
+
attn = attn.mean(dim=1) # [B, S, S]
|
|
140
|
+
grad_norm = grad.norm(dim=-1) # [B, S]
|
|
121
141
|
attn_grad_score = (attn * grad_norm.unsqueeze(1)).sum(dim=-1) # [B, S]
|
|
122
142
|
scores.append(attn_grad_score)
|
|
123
143
|
|
|
124
144
|
# Combine across layers
|
|
125
145
|
stacked = torch.stack(scores, dim=0) # [L, B, S]
|
|
126
146
|
if reduction == "mean":
|
|
127
|
-
return stacked.mean(dim=0)
|
|
147
|
+
return stacked.mean(dim=0) # [B, S]
|
|
128
148
|
elif reduction == "sum":
|
|
129
|
-
return stacked.sum(dim=0)
|
|
149
|
+
return stacked.sum(dim=0) # [B, S]
|
|
130
150
|
else:
|
|
131
|
-
return stacked
|
|
151
|
+
return stacked # [L, B, S]
|
|
132
152
|
|
|
133
153
|
def compute_rollout(self):
|
|
134
154
|
"""
|
|
@@ -143,9 +163,9 @@ class BaseTransformer(BaseTorchModel):
|
|
|
143
163
|
attn_heads = attn_heads + torch.eye(S, device=device).unsqueeze(0) # add residual
|
|
144
164
|
attn_heads = attn_heads / attn_heads.sum(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
145
165
|
rollout = torch.bmm(attn_heads, rollout) # [B, S, S]
|
|
146
|
-
|
|
166
|
+
|
|
147
167
|
return rollout # [B, S, S]
|
|
148
|
-
|
|
168
|
+
|
|
149
169
|
def reset_attn_buffers(self):
|
|
150
170
|
self.attn_weights = []
|
|
151
171
|
self.attn_grads = []
|
|
@@ -158,11 +178,15 @@ class BaseTransformer(BaseTorchModel):
|
|
|
158
178
|
if head_idx is not None:
|
|
159
179
|
attn = attn[:, head_idx] # [B, S, S]
|
|
160
180
|
return attn
|
|
161
|
-
|
|
162
|
-
def apply_attn_interpretations_to_adata(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
181
|
+
|
|
182
|
+
def apply_attn_interpretations_to_adata(
|
|
183
|
+
self,
|
|
184
|
+
dataloader,
|
|
185
|
+
adata,
|
|
186
|
+
obsm_key_grad="attn_grad",
|
|
187
|
+
obsm_key_rollout="attn_rollout",
|
|
188
|
+
device="cpu",
|
|
189
|
+
):
|
|
166
190
|
self.to(device)
|
|
167
191
|
self.eval()
|
|
168
192
|
grad_maps = []
|
|
@@ -193,12 +217,10 @@ class BaseTransformer(BaseTorchModel):
|
|
|
193
217
|
# add per-row normalized version
|
|
194
218
|
grad_normed = grad_concat / (np.max(grad_concat, axis=1, keepdims=True) + 1e-8)
|
|
195
219
|
adata.obsm[f"{obsm_key_grad}_normalized"] = grad_normed
|
|
196
|
-
|
|
220
|
+
|
|
221
|
+
|
|
197
222
|
class TransformerClassifier(BaseTransformer):
|
|
198
|
-
def __init__(self,
|
|
199
|
-
input_dim,
|
|
200
|
-
num_classes,
|
|
201
|
-
**kwargs):
|
|
223
|
+
def __init__(self, input_dim, num_classes, **kwargs):
|
|
202
224
|
super().__init__(input_dim, **kwargs)
|
|
203
225
|
# Classification head
|
|
204
226
|
output_size = 1 if num_classes == 2 else num_classes
|
|
@@ -215,7 +237,7 @@ class TransformerClassifier(BaseTransformer):
|
|
|
215
237
|
x = x.unsqueeze(0).unsqueeze(-1) # just in case (S,) → (1, S, 1)
|
|
216
238
|
else:
|
|
217
239
|
pass
|
|
218
|
-
encoded = self.encode(x)
|
|
240
|
+
encoded = self.encode(x) # -> (B, S, D_model)
|
|
219
241
|
if self.use_cls_token:
|
|
220
242
|
pooled = encoded[:, 0] # (B, D)
|
|
221
243
|
else:
|
|
@@ -223,14 +245,13 @@ class TransformerClassifier(BaseTransformer):
|
|
|
223
245
|
|
|
224
246
|
out = self.cls_head(pooled) # (B, C)
|
|
225
247
|
return out
|
|
226
|
-
|
|
248
|
+
|
|
249
|
+
|
|
227
250
|
class DANNTransformerClassifier(TransformerClassifier):
|
|
228
251
|
def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
|
|
229
252
|
super().__init__(input_dim, model_dim, num_classes, **kwargs)
|
|
230
253
|
self.domain_classifier = nn.Sequential(
|
|
231
|
-
nn.Linear(model_dim, 128),
|
|
232
|
-
nn.ReLU(),
|
|
233
|
-
nn.Linear(128, n_domains)
|
|
254
|
+
nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
|
|
234
255
|
)
|
|
235
256
|
|
|
236
257
|
def forward(self, x, alpha=1.0):
|
|
@@ -242,6 +263,7 @@ class DANNTransformerClassifier(TransformerClassifier):
|
|
|
242
263
|
|
|
243
264
|
return class_logits, domain_logits
|
|
244
265
|
|
|
266
|
+
|
|
245
267
|
class MaskedTransformerPretrainer(BaseTransformer):
|
|
246
268
|
def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
|
|
247
269
|
super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
|
|
@@ -254,12 +276,13 @@ class MaskedTransformerPretrainer(BaseTransformer):
|
|
|
254
276
|
"""
|
|
255
277
|
if x.dim() == 2:
|
|
256
278
|
x = x.unsqueeze(-1)
|
|
257
|
-
encoded = self.encode(x, mask=mask)
|
|
258
|
-
return self.decoder(encoded)
|
|
259
|
-
|
|
279
|
+
encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
|
|
280
|
+
return self.decoder(encoded) # -> (B, D_input)
|
|
281
|
+
|
|
282
|
+
|
|
260
283
|
class DANNTransformer(BaseTransformer):
|
|
261
|
-
"""
|
|
262
|
-
|
|
284
|
+
""" """
|
|
285
|
+
|
|
263
286
|
def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
|
|
264
287
|
super().__init__(
|
|
265
288
|
input_dim=1, # 1D scalar input per token
|
|
@@ -267,7 +290,7 @@ class DANNTransformer(BaseTransformer):
|
|
|
267
290
|
num_heads=n_heads,
|
|
268
291
|
num_layers=n_layers,
|
|
269
292
|
seq_len=seq_len,
|
|
270
|
-
use_learnable_pos=True # enables learnable pos_embed in base
|
|
293
|
+
use_learnable_pos=True, # enables learnable pos_embed in base
|
|
271
294
|
)
|
|
272
295
|
|
|
273
296
|
# Reconstruction head
|
|
@@ -275,9 +298,7 @@ class DANNTransformer(BaseTransformer):
|
|
|
275
298
|
|
|
276
299
|
# Domain classification head
|
|
277
300
|
self.domain_classifier = nn.Sequential(
|
|
278
|
-
nn.Linear(model_dim, 128),
|
|
279
|
-
nn.ReLU(),
|
|
280
|
-
nn.Linear(128, n_domains)
|
|
301
|
+
nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
|
|
281
302
|
)
|
|
282
303
|
|
|
283
304
|
def forward(self, x, alpha=1.0):
|
|
@@ -300,4 +321,3 @@ class DANNTransformer(BaseTransformer):
|
|
|
300
321
|
domain_logits = self.domain_classifier(rev) # (B, n_batches)
|
|
301
322
|
|
|
302
323
|
return recon, domain_logits
|
|
303
|
-
|
|
@@ -1,5 +1,10 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
5
|
+
torch = require("torch", extra="ml-base", purpose="model wrappers")
|
|
6
|
+
nn = torch.nn
|
|
7
|
+
|
|
3
8
|
|
|
4
9
|
class ScaledModel(nn.Module):
|
|
5
10
|
def __init__(self, model, mean, std):
|
|
@@ -17,4 +22,4 @@ class ScaledModel(nn.Module):
|
|
|
17
22
|
x = (x - mean[None, None, :]) / std[None, None, :]
|
|
18
23
|
else:
|
|
19
24
|
raise ValueError(f"Unsupported input shape {x.shape}")
|
|
20
|
-
return self.model(x)
|
|
25
|
+
return self.model(x)
|
|
@@ -1,2 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .train_lightning_model import run_sliding_window_lightning_training, train_lightning_model
|
|
4
|
+
from .train_sklearn_model import run_sliding_window_sklearn_training, train_sklearn_model
|
|
@@ -1,9 +1,21 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
4
5
|
from ..data import AnnDataModule
|
|
5
6
|
from ..models import TorchClassifierWrapper
|
|
6
7
|
|
|
8
|
+
torch = require("torch", extra="ml-base", purpose="Lightning training")
|
|
9
|
+
pytorch_lightning = require("pytorch_lightning", extra="ml-extended", purpose="Lightning training")
|
|
10
|
+
pl_callbacks = require(
|
|
11
|
+
"pytorch_lightning.callbacks", extra="ml-extended", purpose="Lightning training"
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
Trainer = pytorch_lightning.Trainer
|
|
15
|
+
EarlyStopping = pl_callbacks.EarlyStopping
|
|
16
|
+
ModelCheckpoint = pl_callbacks.ModelCheckpoint
|
|
17
|
+
|
|
18
|
+
|
|
7
19
|
def train_lightning_model(
|
|
8
20
|
model,
|
|
9
21
|
datamodule,
|
|
@@ -12,7 +24,7 @@ def train_lightning_model(
|
|
|
12
24
|
monitor_metric="val_loss",
|
|
13
25
|
checkpoint_path=None,
|
|
14
26
|
evaluate_test=True,
|
|
15
|
-
devices=1
|
|
27
|
+
devices=1,
|
|
16
28
|
):
|
|
17
29
|
"""
|
|
18
30
|
Takes a PyTorch Lightning Model and a Lightning DataLoader module to define a Lightning Trainer.
|
|
@@ -39,13 +51,15 @@ def train_lightning_model(
|
|
|
39
51
|
EarlyStopping(monitor=monitor_metric, patience=patience, mode="min"),
|
|
40
52
|
]
|
|
41
53
|
if checkpoint_path:
|
|
42
|
-
callbacks.append(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
callbacks.append(
|
|
55
|
+
ModelCheckpoint(
|
|
56
|
+
dirpath=checkpoint_path,
|
|
57
|
+
filename="{epoch}-{val_loss:.4f}",
|
|
58
|
+
monitor=monitor_metric,
|
|
59
|
+
save_top_k=1,
|
|
60
|
+
mode="min",
|
|
61
|
+
)
|
|
62
|
+
)
|
|
49
63
|
|
|
50
64
|
# Trainer setup
|
|
51
65
|
trainer = Trainer(
|
|
@@ -54,7 +68,7 @@ def train_lightning_model(
|
|
|
54
68
|
accelerator=accelerator,
|
|
55
69
|
devices=devices,
|
|
56
70
|
log_every_n_steps=10,
|
|
57
|
-
enable_progress_bar=False
|
|
71
|
+
enable_progress_bar=False,
|
|
58
72
|
)
|
|
59
73
|
|
|
60
74
|
# Fit model with trainer
|
|
@@ -63,7 +77,7 @@ def train_lightning_model(
|
|
|
63
77
|
# Test model (if applicable)
|
|
64
78
|
if evaluate_test and hasattr(datamodule, "test_dataloader"):
|
|
65
79
|
trainer.test(model, datamodule=datamodule)
|
|
66
|
-
|
|
80
|
+
|
|
67
81
|
# Return best checkpoint path
|
|
68
82
|
best_ckpt = None
|
|
69
83
|
for cb in callbacks:
|
|
@@ -72,6 +86,7 @@ def train_lightning_model(
|
|
|
72
86
|
|
|
73
87
|
return trainer, best_ckpt
|
|
74
88
|
|
|
89
|
+
|
|
75
90
|
def run_sliding_window_lightning_training(
|
|
76
91
|
adata,
|
|
77
92
|
tensor_source,
|
|
@@ -86,13 +101,13 @@ def run_sliding_window_lightning_training(
|
|
|
86
101
|
stride,
|
|
87
102
|
max_epochs=30,
|
|
88
103
|
patience=5,
|
|
89
|
-
enforce_eval_balance: bool=False,
|
|
90
|
-
target_eval_freq: float=0.3,
|
|
91
|
-
max_eval_positive: int=None
|
|
104
|
+
enforce_eval_balance: bool = False,
|
|
105
|
+
target_eval_freq: float = 0.3,
|
|
106
|
+
max_eval_positive: int = None,
|
|
92
107
|
):
|
|
93
108
|
input_len = adata.shape[1]
|
|
94
109
|
results = {}
|
|
95
|
-
|
|
110
|
+
|
|
96
111
|
for start in range(0, input_len - window_size + 1, stride):
|
|
97
112
|
center_idx = start + window_size // 2
|
|
98
113
|
center_varname = adata.var_names[center_idx]
|
|
@@ -106,18 +121,22 @@ def run_sliding_window_lightning_training(
|
|
|
106
121
|
label_col=label_col,
|
|
107
122
|
batch_size=64,
|
|
108
123
|
window_start=start,
|
|
109
|
-
window_size=window_size
|
|
124
|
+
window_size=window_size,
|
|
110
125
|
)
|
|
111
126
|
datamodule.setup()
|
|
112
127
|
|
|
113
128
|
# Build model for this window
|
|
114
129
|
model = model_class(window_size, num_classes)
|
|
115
130
|
wrapper = TorchClassifierWrapper(
|
|
116
|
-
model,
|
|
131
|
+
model,
|
|
132
|
+
label_col=label_col,
|
|
133
|
+
num_classes=num_classes,
|
|
117
134
|
class_names=class_names,
|
|
118
135
|
class_weights=class_weights,
|
|
119
|
-
focus_class=focus_class,
|
|
120
|
-
|
|
136
|
+
focus_class=focus_class,
|
|
137
|
+
enforce_eval_balance=enforce_eval_balance,
|
|
138
|
+
target_eval_freq=target_eval_freq,
|
|
139
|
+
max_eval_positive=max_eval_positive,
|
|
121
140
|
)
|
|
122
141
|
|
|
123
142
|
# Train model
|
|
@@ -129,7 +148,7 @@ def run_sliding_window_lightning_training(
|
|
|
129
148
|
"model": wrapper,
|
|
130
149
|
"trainer": trainer,
|
|
131
150
|
"checkpoint": ckpt,
|
|
132
|
-
"metrics": trainer.callback_metrics
|
|
151
|
+
"metrics": trainer.callback_metrics,
|
|
133
152
|
}
|
|
134
|
-
|
|
153
|
+
|
|
135
154
|
return results
|
|
@@ -1,16 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from ..data import AnnDataModule
|
|
2
4
|
from ..models import SklearnModelWrapper
|
|
3
5
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
datamodule,
|
|
7
|
-
evaluate_test=True,
|
|
8
|
-
evaluate_val=False
|
|
9
|
-
):
|
|
6
|
+
|
|
7
|
+
def train_sklearn_model(model_wrapper, datamodule, evaluate_test=True, evaluate_val=False):
|
|
10
8
|
"""
|
|
11
9
|
Fits a SklearnModelWrapper on the train split from datamodule.
|
|
12
10
|
Evaluates on test and/or val set.
|
|
13
|
-
|
|
11
|
+
|
|
14
12
|
Parameters:
|
|
15
13
|
model_wrapper: SklearnModelWrapper instance
|
|
16
14
|
datamodule: AnnDataModule instance (with setup() method)
|
|
@@ -39,6 +37,7 @@ def train_sklearn_model(
|
|
|
39
37
|
|
|
40
38
|
return metrics
|
|
41
39
|
|
|
40
|
+
|
|
42
41
|
def run_sliding_window_sklearn_training(
|
|
43
42
|
adata,
|
|
44
43
|
tensor_source,
|
|
@@ -58,7 +57,7 @@ def run_sliding_window_sklearn_training(
|
|
|
58
57
|
enforce_eval_balance=False,
|
|
59
58
|
target_eval_freq=0.3,
|
|
60
59
|
max_eval_positive=None,
|
|
61
|
-
**model_kwargs
|
|
60
|
+
**model_kwargs,
|
|
62
61
|
):
|
|
63
62
|
"""
|
|
64
63
|
Sliding window training for sklearn models using AnnData.
|
|
@@ -86,29 +85,26 @@ def run_sliding_window_sklearn_training(
|
|
|
86
85
|
train_frac=train_frac,
|
|
87
86
|
val_frac=val_frac,
|
|
88
87
|
test_frac=test_frac,
|
|
89
|
-
random_seed=random_seed
|
|
88
|
+
random_seed=random_seed,
|
|
90
89
|
)
|
|
91
90
|
datamodule.setup()
|
|
92
91
|
|
|
93
92
|
# Build model wrapper
|
|
94
93
|
sklearn_model = model_class(**model_kwargs)
|
|
95
94
|
wrapper = SklearnModelWrapper(
|
|
96
|
-
sklearn_model,
|
|
95
|
+
sklearn_model,
|
|
97
96
|
num_classes=num_classes,
|
|
98
97
|
label_col=label_col,
|
|
99
98
|
class_names=class_names,
|
|
100
99
|
focus_class=focus_class,
|
|
101
100
|
enforce_eval_balance=enforce_eval_balance,
|
|
102
101
|
target_eval_freq=target_eval_freq,
|
|
103
|
-
max_eval_positive=max_eval_positive
|
|
102
|
+
max_eval_positive=max_eval_positive,
|
|
104
103
|
)
|
|
105
104
|
|
|
106
105
|
# Fit and evaluate
|
|
107
106
|
metrics = train_sklearn_model(wrapper, datamodule, evaluate_test=True, evaluate_val=False)
|
|
108
107
|
|
|
109
|
-
results[center_varname] = {
|
|
110
|
-
"model": wrapper,
|
|
111
|
-
"metrics": metrics
|
|
112
|
-
}
|
|
108
|
+
results[center_varname] = {"model": wrapper, "metrics": metrics}
|
|
113
109
|
|
|
114
110
|
return results
|
|
@@ -1,10 +1,17 @@
|
|
|
1
|
-
import
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
5
|
+
torch = require("torch", extra="ml-base", purpose="device selection")
|
|
6
|
+
|
|
2
7
|
|
|
3
8
|
def detect_device():
|
|
4
9
|
device = (
|
|
5
|
-
torch.device(
|
|
6
|
-
|
|
7
|
-
torch.device(
|
|
10
|
+
torch.device("cuda")
|
|
11
|
+
if torch.cuda.is_available()
|
|
12
|
+
else torch.device("mps")
|
|
13
|
+
if torch.backends.mps.is_available()
|
|
14
|
+
else torch.device("cpu")
|
|
8
15
|
)
|
|
9
16
|
print(f"Detected device: {device}")
|
|
10
|
-
return device
|
|
17
|
+
return device
|
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
import
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
5
|
+
torch = require("torch", extra="ml-base", purpose="gradient reversal layers")
|
|
6
|
+
|
|
2
7
|
|
|
3
8
|
class GradReverse(torch.autograd.Function):
|
|
4
9
|
@staticmethod
|
|
@@ -10,5 +15,6 @@ class GradReverse(torch.autograd.Function):
|
|
|
10
15
|
def backward(ctx, grad_output):
|
|
11
16
|
return -ctx.alpha * grad_output, None
|
|
12
17
|
|
|
18
|
+
|
|
13
19
|
def grad_reverse(x, alpha=1.0):
|
|
14
|
-
return GradReverse.apply(x, alpha)
|
|
20
|
+
return GradReverse.apply(x, alpha)
|