smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +7 -1
- smftools/cli/hmm_adata.py +902 -244
- smftools/cli/load_adata.py +318 -198
- smftools/cli/preprocess_adata.py +285 -171
- smftools/cli/spatial_adata.py +137 -53
- smftools/cli_entry.py +94 -178
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +22 -17
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +505 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2125 -1426
- smftools/hmm/__init__.py +2 -3
- smftools/hmm/archived/call_hmm_peaks.py +16 -1
- smftools/hmm/call_hmm_peaks.py +173 -193
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +379 -156
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +195 -29
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +347 -168
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +145 -85
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +8 -8
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +103 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +688 -271
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.4.dist-info/RECORD +0 -176
- /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import numpy as np
|
|
1
2
|
import torch
|
|
2
3
|
import torch.nn as nn
|
|
3
|
-
|
|
4
|
-
from .positional import PositionalEncoding
|
|
4
|
+
|
|
5
5
|
from ..utils.grl import grad_reverse
|
|
6
|
-
|
|
6
|
+
from .base import BaseTorchModel
|
|
7
|
+
from .positional import PositionalEncoding
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
9
11
|
def __init__(self, *args, **kwargs):
|
|
@@ -11,12 +13,14 @@ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
|
11
13
|
|
|
12
14
|
def forward(self, src, src_mask=None, is_causal=False, src_key_padding_mask=None):
|
|
13
15
|
self_attn_output, attn_weights = self.self_attn(
|
|
14
|
-
src,
|
|
16
|
+
src,
|
|
17
|
+
src,
|
|
18
|
+
src,
|
|
15
19
|
attn_mask=src_mask,
|
|
16
20
|
key_padding_mask=src_key_padding_mask,
|
|
17
21
|
need_weights=True,
|
|
18
22
|
average_attn_weights=False, # preserve [B, num_heads, S, S]
|
|
19
|
-
is_causal=is_causal
|
|
23
|
+
is_causal=is_causal,
|
|
20
24
|
)
|
|
21
25
|
src = src + self.dropout1(self_attn_output)
|
|
22
26
|
src = self.norm1(src)
|
|
@@ -27,18 +31,21 @@ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
|
27
31
|
# Save attention weights to module
|
|
28
32
|
self.attn_weights = attn_weights # Save to layer
|
|
29
33
|
return src
|
|
30
|
-
|
|
34
|
+
|
|
35
|
+
|
|
31
36
|
class BaseTransformer(BaseTorchModel):
|
|
32
|
-
def __init__(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
input_dim=1,
|
|
40
|
+
model_dim=64,
|
|
41
|
+
num_heads=4,
|
|
42
|
+
num_layers=2,
|
|
43
|
+
dropout=0.2,
|
|
44
|
+
seq_len=None,
|
|
45
|
+
use_learnable_pos=False,
|
|
46
|
+
use_cls_token=True,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
42
49
|
super().__init__(**kwargs)
|
|
43
50
|
# Input FC layer to map D_input to D_model
|
|
44
51
|
self.model_dim = model_dim
|
|
@@ -52,7 +59,9 @@ class BaseTransformer(BaseTorchModel):
|
|
|
52
59
|
|
|
53
60
|
if use_learnable_pos:
|
|
54
61
|
assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
|
|
55
|
-
self.pos_embed = nn.Parameter(
|
|
62
|
+
self.pos_embed = nn.Parameter(
|
|
63
|
+
torch.randn(seq_len + (1 if use_cls_token else 0), model_dim)
|
|
64
|
+
)
|
|
56
65
|
self.pos_encoder = None
|
|
57
66
|
else:
|
|
58
67
|
self.pos_encoder = PositionalEncoding(model_dim)
|
|
@@ -62,7 +71,13 @@ class BaseTransformer(BaseTorchModel):
|
|
|
62
71
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim)) # (1, 1, D)
|
|
63
72
|
|
|
64
73
|
# Specify the transformer encoder structure
|
|
65
|
-
encoder_layer = TransformerEncoderLayerWithAttn(
|
|
74
|
+
encoder_layer = TransformerEncoderLayerWithAttn(
|
|
75
|
+
d_model=model_dim,
|
|
76
|
+
nhead=num_heads,
|
|
77
|
+
batch_first=True,
|
|
78
|
+
dim_feedforward=self.ff_dim,
|
|
79
|
+
dropout=self.dropout,
|
|
80
|
+
)
|
|
66
81
|
# Stack the transformer encoder layers
|
|
67
82
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
68
83
|
|
|
@@ -95,7 +110,7 @@ class BaseTransformer(BaseTorchModel):
|
|
|
95
110
|
x = torch.cat([cls, x], dim=1) # (B, S+1, D)
|
|
96
111
|
|
|
97
112
|
if self.pos_embed is not None:
|
|
98
|
-
x = x + self.pos_embed.unsqueeze(0)[:, :x.shape[1], :]
|
|
113
|
+
x = x + self.pos_embed.unsqueeze(0)[:, : x.shape[1], :]
|
|
99
114
|
elif self.pos_encoder is not None:
|
|
100
115
|
x = self.pos_encoder(x)
|
|
101
116
|
|
|
@@ -106,8 +121,8 @@ class BaseTransformer(BaseTorchModel):
|
|
|
106
121
|
|
|
107
122
|
encoded = self.transformer(x)
|
|
108
123
|
return encoded
|
|
109
|
-
|
|
110
|
-
def compute_attn_grad(self, reduction=
|
|
124
|
+
|
|
125
|
+
def compute_attn_grad(self, reduction="mean"):
|
|
111
126
|
"""
|
|
112
127
|
Computes attention × gradient scores across layers.
|
|
113
128
|
Returns: [B, S] tensor of importance scores
|
|
@@ -116,19 +131,19 @@ class BaseTransformer(BaseTorchModel):
|
|
|
116
131
|
for attn, grad in zip(self.attn_weights, self.attn_grads):
|
|
117
132
|
# attn: [B, H, S, S]
|
|
118
133
|
# grad: [B, S, D]
|
|
119
|
-
attn = attn.mean(dim=1)
|
|
120
|
-
grad_norm = grad.norm(dim=-1)
|
|
134
|
+
attn = attn.mean(dim=1) # [B, S, S]
|
|
135
|
+
grad_norm = grad.norm(dim=-1) # [B, S]
|
|
121
136
|
attn_grad_score = (attn * grad_norm.unsqueeze(1)).sum(dim=-1) # [B, S]
|
|
122
137
|
scores.append(attn_grad_score)
|
|
123
138
|
|
|
124
139
|
# Combine across layers
|
|
125
140
|
stacked = torch.stack(scores, dim=0) # [L, B, S]
|
|
126
141
|
if reduction == "mean":
|
|
127
|
-
return stacked.mean(dim=0)
|
|
142
|
+
return stacked.mean(dim=0) # [B, S]
|
|
128
143
|
elif reduction == "sum":
|
|
129
|
-
return stacked.sum(dim=0)
|
|
144
|
+
return stacked.sum(dim=0) # [B, S]
|
|
130
145
|
else:
|
|
131
|
-
return stacked
|
|
146
|
+
return stacked # [L, B, S]
|
|
132
147
|
|
|
133
148
|
def compute_rollout(self):
|
|
134
149
|
"""
|
|
@@ -143,9 +158,9 @@ class BaseTransformer(BaseTorchModel):
|
|
|
143
158
|
attn_heads = attn_heads + torch.eye(S, device=device).unsqueeze(0) # add residual
|
|
144
159
|
attn_heads = attn_heads / attn_heads.sum(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
145
160
|
rollout = torch.bmm(attn_heads, rollout) # [B, S, S]
|
|
146
|
-
|
|
161
|
+
|
|
147
162
|
return rollout # [B, S, S]
|
|
148
|
-
|
|
163
|
+
|
|
149
164
|
def reset_attn_buffers(self):
|
|
150
165
|
self.attn_weights = []
|
|
151
166
|
self.attn_grads = []
|
|
@@ -158,11 +173,15 @@ class BaseTransformer(BaseTorchModel):
|
|
|
158
173
|
if head_idx is not None:
|
|
159
174
|
attn = attn[:, head_idx] # [B, S, S]
|
|
160
175
|
return attn
|
|
161
|
-
|
|
162
|
-
def apply_attn_interpretations_to_adata(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
176
|
+
|
|
177
|
+
def apply_attn_interpretations_to_adata(
|
|
178
|
+
self,
|
|
179
|
+
dataloader,
|
|
180
|
+
adata,
|
|
181
|
+
obsm_key_grad="attn_grad",
|
|
182
|
+
obsm_key_rollout="attn_rollout",
|
|
183
|
+
device="cpu",
|
|
184
|
+
):
|
|
166
185
|
self.to(device)
|
|
167
186
|
self.eval()
|
|
168
187
|
grad_maps = []
|
|
@@ -193,12 +212,10 @@ class BaseTransformer(BaseTorchModel):
|
|
|
193
212
|
# add per-row normalized version
|
|
194
213
|
grad_normed = grad_concat / (np.max(grad_concat, axis=1, keepdims=True) + 1e-8)
|
|
195
214
|
adata.obsm[f"{obsm_key_grad}_normalized"] = grad_normed
|
|
196
|
-
|
|
215
|
+
|
|
216
|
+
|
|
197
217
|
class TransformerClassifier(BaseTransformer):
|
|
198
|
-
def __init__(self,
|
|
199
|
-
input_dim,
|
|
200
|
-
num_classes,
|
|
201
|
-
**kwargs):
|
|
218
|
+
def __init__(self, input_dim, num_classes, **kwargs):
|
|
202
219
|
super().__init__(input_dim, **kwargs)
|
|
203
220
|
# Classification head
|
|
204
221
|
output_size = 1 if num_classes == 2 else num_classes
|
|
@@ -215,7 +232,7 @@ class TransformerClassifier(BaseTransformer):
|
|
|
215
232
|
x = x.unsqueeze(0).unsqueeze(-1) # just in case (S,) → (1, S, 1)
|
|
216
233
|
else:
|
|
217
234
|
pass
|
|
218
|
-
encoded = self.encode(x)
|
|
235
|
+
encoded = self.encode(x) # -> (B, S, D_model)
|
|
219
236
|
if self.use_cls_token:
|
|
220
237
|
pooled = encoded[:, 0] # (B, D)
|
|
221
238
|
else:
|
|
@@ -223,14 +240,13 @@ class TransformerClassifier(BaseTransformer):
|
|
|
223
240
|
|
|
224
241
|
out = self.cls_head(pooled) # (B, C)
|
|
225
242
|
return out
|
|
226
|
-
|
|
243
|
+
|
|
244
|
+
|
|
227
245
|
class DANNTransformerClassifier(TransformerClassifier):
|
|
228
246
|
def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
|
|
229
247
|
super().__init__(input_dim, model_dim, num_classes, **kwargs)
|
|
230
248
|
self.domain_classifier = nn.Sequential(
|
|
231
|
-
nn.Linear(model_dim, 128),
|
|
232
|
-
nn.ReLU(),
|
|
233
|
-
nn.Linear(128, n_domains)
|
|
249
|
+
nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
|
|
234
250
|
)
|
|
235
251
|
|
|
236
252
|
def forward(self, x, alpha=1.0):
|
|
@@ -242,6 +258,7 @@ class DANNTransformerClassifier(TransformerClassifier):
|
|
|
242
258
|
|
|
243
259
|
return class_logits, domain_logits
|
|
244
260
|
|
|
261
|
+
|
|
245
262
|
class MaskedTransformerPretrainer(BaseTransformer):
|
|
246
263
|
def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
|
|
247
264
|
super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
|
|
@@ -254,12 +271,13 @@ class MaskedTransformerPretrainer(BaseTransformer):
|
|
|
254
271
|
"""
|
|
255
272
|
if x.dim() == 2:
|
|
256
273
|
x = x.unsqueeze(-1)
|
|
257
|
-
encoded = self.encode(x, mask=mask)
|
|
258
|
-
return self.decoder(encoded)
|
|
259
|
-
|
|
274
|
+
encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
|
|
275
|
+
return self.decoder(encoded) # -> (B, D_input)
|
|
276
|
+
|
|
277
|
+
|
|
260
278
|
class DANNTransformer(BaseTransformer):
|
|
261
|
-
"""
|
|
262
|
-
|
|
279
|
+
""" """
|
|
280
|
+
|
|
263
281
|
def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
|
|
264
282
|
super().__init__(
|
|
265
283
|
input_dim=1, # 1D scalar input per token
|
|
@@ -267,7 +285,7 @@ class DANNTransformer(BaseTransformer):
|
|
|
267
285
|
num_heads=n_heads,
|
|
268
286
|
num_layers=n_layers,
|
|
269
287
|
seq_len=seq_len,
|
|
270
|
-
use_learnable_pos=True # enables learnable pos_embed in base
|
|
288
|
+
use_learnable_pos=True, # enables learnable pos_embed in base
|
|
271
289
|
)
|
|
272
290
|
|
|
273
291
|
# Reconstruction head
|
|
@@ -275,9 +293,7 @@ class DANNTransformer(BaseTransformer):
|
|
|
275
293
|
|
|
276
294
|
# Domain classification head
|
|
277
295
|
self.domain_classifier = nn.Sequential(
|
|
278
|
-
nn.Linear(model_dim, 128),
|
|
279
|
-
nn.ReLU(),
|
|
280
|
-
nn.Linear(128, n_domains)
|
|
296
|
+
nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
|
|
281
297
|
)
|
|
282
298
|
|
|
283
299
|
def forward(self, x, alpha=1.0):
|
|
@@ -300,4 +316,3 @@ class DANNTransformer(BaseTransformer):
|
|
|
300
316
|
domain_logits = self.domain_classifier(rev) # (B, n_batches)
|
|
301
317
|
|
|
302
318
|
return recon, domain_logits
|
|
303
|
-
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
|
|
4
|
+
|
|
4
5
|
class ScaledModel(nn.Module):
|
|
5
6
|
def __init__(self, model, mean, std):
|
|
6
7
|
super().__init__()
|
|
@@ -17,4 +18,4 @@ class ScaledModel(nn.Module):
|
|
|
17
18
|
x = (x - mean[None, None, :]) / std[None, None, :]
|
|
18
19
|
else:
|
|
19
20
|
raise ValueError(f"Unsupported input shape {x.shape}")
|
|
20
|
-
return self.model(x)
|
|
21
|
+
return self.model(x)
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from .train_lightning_model import
|
|
2
|
-
from .train_sklearn_model import
|
|
1
|
+
from .train_lightning_model import run_sliding_window_lightning_training, train_lightning_model
|
|
2
|
+
from .train_sklearn_model import run_sliding_window_sklearn_training, train_sklearn_model
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from pytorch_lightning import Trainer
|
|
3
3
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
4
|
+
|
|
4
5
|
from ..data import AnnDataModule
|
|
5
6
|
from ..models import TorchClassifierWrapper
|
|
6
7
|
|
|
8
|
+
|
|
7
9
|
def train_lightning_model(
|
|
8
10
|
model,
|
|
9
11
|
datamodule,
|
|
@@ -12,7 +14,7 @@ def train_lightning_model(
|
|
|
12
14
|
monitor_metric="val_loss",
|
|
13
15
|
checkpoint_path=None,
|
|
14
16
|
evaluate_test=True,
|
|
15
|
-
devices=1
|
|
17
|
+
devices=1,
|
|
16
18
|
):
|
|
17
19
|
"""
|
|
18
20
|
Takes a PyTorch Lightning Model and a Lightning DataLoader module to define a Lightning Trainer.
|
|
@@ -39,13 +41,15 @@ def train_lightning_model(
|
|
|
39
41
|
EarlyStopping(monitor=monitor_metric, patience=patience, mode="min"),
|
|
40
42
|
]
|
|
41
43
|
if checkpoint_path:
|
|
42
|
-
callbacks.append(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
44
|
+
callbacks.append(
|
|
45
|
+
ModelCheckpoint(
|
|
46
|
+
dirpath=checkpoint_path,
|
|
47
|
+
filename="{epoch}-{val_loss:.4f}",
|
|
48
|
+
monitor=monitor_metric,
|
|
49
|
+
save_top_k=1,
|
|
50
|
+
mode="min",
|
|
51
|
+
)
|
|
52
|
+
)
|
|
49
53
|
|
|
50
54
|
# Trainer setup
|
|
51
55
|
trainer = Trainer(
|
|
@@ -54,7 +58,7 @@ def train_lightning_model(
|
|
|
54
58
|
accelerator=accelerator,
|
|
55
59
|
devices=devices,
|
|
56
60
|
log_every_n_steps=10,
|
|
57
|
-
enable_progress_bar=False
|
|
61
|
+
enable_progress_bar=False,
|
|
58
62
|
)
|
|
59
63
|
|
|
60
64
|
# Fit model with trainer
|
|
@@ -63,7 +67,7 @@ def train_lightning_model(
|
|
|
63
67
|
# Test model (if applicable)
|
|
64
68
|
if evaluate_test and hasattr(datamodule, "test_dataloader"):
|
|
65
69
|
trainer.test(model, datamodule=datamodule)
|
|
66
|
-
|
|
70
|
+
|
|
67
71
|
# Return best checkpoint path
|
|
68
72
|
best_ckpt = None
|
|
69
73
|
for cb in callbacks:
|
|
@@ -72,6 +76,7 @@ def train_lightning_model(
|
|
|
72
76
|
|
|
73
77
|
return trainer, best_ckpt
|
|
74
78
|
|
|
79
|
+
|
|
75
80
|
def run_sliding_window_lightning_training(
|
|
76
81
|
adata,
|
|
77
82
|
tensor_source,
|
|
@@ -86,13 +91,13 @@ def run_sliding_window_lightning_training(
|
|
|
86
91
|
stride,
|
|
87
92
|
max_epochs=30,
|
|
88
93
|
patience=5,
|
|
89
|
-
enforce_eval_balance: bool=False,
|
|
90
|
-
target_eval_freq: float=0.3,
|
|
91
|
-
max_eval_positive: int=None
|
|
94
|
+
enforce_eval_balance: bool = False,
|
|
95
|
+
target_eval_freq: float = 0.3,
|
|
96
|
+
max_eval_positive: int = None,
|
|
92
97
|
):
|
|
93
98
|
input_len = adata.shape[1]
|
|
94
99
|
results = {}
|
|
95
|
-
|
|
100
|
+
|
|
96
101
|
for start in range(0, input_len - window_size + 1, stride):
|
|
97
102
|
center_idx = start + window_size // 2
|
|
98
103
|
center_varname = adata.var_names[center_idx]
|
|
@@ -106,18 +111,22 @@ def run_sliding_window_lightning_training(
|
|
|
106
111
|
label_col=label_col,
|
|
107
112
|
batch_size=64,
|
|
108
113
|
window_start=start,
|
|
109
|
-
window_size=window_size
|
|
114
|
+
window_size=window_size,
|
|
110
115
|
)
|
|
111
116
|
datamodule.setup()
|
|
112
117
|
|
|
113
118
|
# Build model for this window
|
|
114
119
|
model = model_class(window_size, num_classes)
|
|
115
120
|
wrapper = TorchClassifierWrapper(
|
|
116
|
-
model,
|
|
121
|
+
model,
|
|
122
|
+
label_col=label_col,
|
|
123
|
+
num_classes=num_classes,
|
|
117
124
|
class_names=class_names,
|
|
118
125
|
class_weights=class_weights,
|
|
119
|
-
focus_class=focus_class,
|
|
120
|
-
|
|
126
|
+
focus_class=focus_class,
|
|
127
|
+
enforce_eval_balance=enforce_eval_balance,
|
|
128
|
+
target_eval_freq=target_eval_freq,
|
|
129
|
+
max_eval_positive=max_eval_positive,
|
|
121
130
|
)
|
|
122
131
|
|
|
123
132
|
# Train model
|
|
@@ -129,7 +138,7 @@ def run_sliding_window_lightning_training(
|
|
|
129
138
|
"model": wrapper,
|
|
130
139
|
"trainer": trainer,
|
|
131
140
|
"checkpoint": ckpt,
|
|
132
|
-
"metrics": trainer.callback_metrics
|
|
141
|
+
"metrics": trainer.callback_metrics,
|
|
133
142
|
}
|
|
134
|
-
|
|
143
|
+
|
|
135
144
|
return results
|
|
@@ -1,16 +1,12 @@
|
|
|
1
1
|
from ..data import AnnDataModule
|
|
2
2
|
from ..models import SklearnModelWrapper
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
datamodule,
|
|
7
|
-
evaluate_test=True,
|
|
8
|
-
evaluate_val=False
|
|
9
|
-
):
|
|
4
|
+
|
|
5
|
+
def train_sklearn_model(model_wrapper, datamodule, evaluate_test=True, evaluate_val=False):
|
|
10
6
|
"""
|
|
11
7
|
Fits a SklearnModelWrapper on the train split from datamodule.
|
|
12
8
|
Evaluates on test and/or val set.
|
|
13
|
-
|
|
9
|
+
|
|
14
10
|
Parameters:
|
|
15
11
|
model_wrapper: SklearnModelWrapper instance
|
|
16
12
|
datamodule: AnnDataModule instance (with setup() method)
|
|
@@ -39,6 +35,7 @@ def train_sklearn_model(
|
|
|
39
35
|
|
|
40
36
|
return metrics
|
|
41
37
|
|
|
38
|
+
|
|
42
39
|
def run_sliding_window_sklearn_training(
|
|
43
40
|
adata,
|
|
44
41
|
tensor_source,
|
|
@@ -58,7 +55,7 @@ def run_sliding_window_sklearn_training(
|
|
|
58
55
|
enforce_eval_balance=False,
|
|
59
56
|
target_eval_freq=0.3,
|
|
60
57
|
max_eval_positive=None,
|
|
61
|
-
**model_kwargs
|
|
58
|
+
**model_kwargs,
|
|
62
59
|
):
|
|
63
60
|
"""
|
|
64
61
|
Sliding window training for sklearn models using AnnData.
|
|
@@ -86,29 +83,26 @@ def run_sliding_window_sklearn_training(
|
|
|
86
83
|
train_frac=train_frac,
|
|
87
84
|
val_frac=val_frac,
|
|
88
85
|
test_frac=test_frac,
|
|
89
|
-
random_seed=random_seed
|
|
86
|
+
random_seed=random_seed,
|
|
90
87
|
)
|
|
91
88
|
datamodule.setup()
|
|
92
89
|
|
|
93
90
|
# Build model wrapper
|
|
94
91
|
sklearn_model = model_class(**model_kwargs)
|
|
95
92
|
wrapper = SklearnModelWrapper(
|
|
96
|
-
sklearn_model,
|
|
93
|
+
sklearn_model,
|
|
97
94
|
num_classes=num_classes,
|
|
98
95
|
label_col=label_col,
|
|
99
96
|
class_names=class_names,
|
|
100
97
|
focus_class=focus_class,
|
|
101
98
|
enforce_eval_balance=enforce_eval_balance,
|
|
102
99
|
target_eval_freq=target_eval_freq,
|
|
103
|
-
max_eval_positive=max_eval_positive
|
|
100
|
+
max_eval_positive=max_eval_positive,
|
|
104
101
|
)
|
|
105
102
|
|
|
106
103
|
# Fit and evaluate
|
|
107
104
|
metrics = train_sklearn_model(wrapper, datamodule, evaluate_test=True, evaluate_val=False)
|
|
108
105
|
|
|
109
|
-
results[center_varname] = {
|
|
110
|
-
"model": wrapper,
|
|
111
|
-
"metrics": metrics
|
|
112
|
-
}
|
|
106
|
+
results[center_varname] = {"model": wrapper, "metrics": metrics}
|
|
113
107
|
|
|
114
108
|
return results
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .device import detect_device
|
|
2
|
-
from .grl import GradReverse
|
|
2
|
+
from .grl import GradReverse
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
def detect_device():
|
|
4
5
|
device = (
|
|
5
|
-
torch.device(
|
|
6
|
-
|
|
7
|
-
torch.device(
|
|
6
|
+
torch.device("cuda")
|
|
7
|
+
if torch.cuda.is_available()
|
|
8
|
+
else torch.device("mps")
|
|
9
|
+
if torch.backends.mps.is_available()
|
|
10
|
+
else torch.device("cpu")
|
|
8
11
|
)
|
|
9
12
|
print(f"Detected device: {device}")
|
|
10
|
-
return device
|
|
13
|
+
return device
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
class GradReverse(torch.autograd.Function):
|
|
4
5
|
@staticmethod
|
|
5
6
|
def forward(ctx, x, alpha):
|
|
@@ -10,5 +11,6 @@ class GradReverse(torch.autograd.Function):
|
|
|
10
11
|
def backward(ctx, grad_output):
|
|
11
12
|
return -ctx.alpha * grad_output, None
|
|
12
13
|
|
|
14
|
+
|
|
13
15
|
def grad_reverse(x, alpha=1.0):
|
|
14
|
-
return GradReverse.apply(x, alpha)
|
|
16
|
+
return GradReverse.apply(x, alpha)
|