smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,303 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+ from .positional import PositionalEncoding
5
+ from ..utils.grl import grad_reverse
6
+ import numpy as np
7
+
8
+ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ def forward(self, src, src_mask=None, is_causal=False, src_key_padding_mask=None):
13
+ self_attn_output, attn_weights = self.self_attn(
14
+ src, src, src,
15
+ attn_mask=src_mask,
16
+ key_padding_mask=src_key_padding_mask,
17
+ need_weights=True,
18
+ average_attn_weights=False, # preserve [B, num_heads, S, S]
19
+ is_causal=is_causal
20
+ )
21
+ src = src + self.dropout1(self_attn_output)
22
+ src = self.norm1(src)
23
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
24
+ src = src + self.dropout2(src2)
25
+ src = self.norm2(src)
26
+
27
+ # Save attention weights to module
28
+ self.attn_weights = attn_weights # Save to layer
29
+ return src
30
+
31
+ class BaseTransformer(BaseTorchModel):
32
+ def __init__(self,
33
+ input_dim=1,
34
+ model_dim=64,
35
+ num_heads=4,
36
+ num_layers=2,
37
+ dropout=0.2,
38
+ seq_len=None,
39
+ use_learnable_pos=False,
40
+ use_cls_token=True,
41
+ **kwargs):
42
+ super().__init__(**kwargs)
43
+ # Input FC layer to map D_input to D_model
44
+ self.model_dim = model_dim
45
+ self.input_fc = nn.Linear(input_dim, model_dim)
46
+ self.ff_dim = model_dim * 4
47
+ self.dropout = dropout
48
+ self.use_cls_token = use_cls_token
49
+
50
+ self.attn_weights = []
51
+ self.attn_grads = []
52
+
53
+ if use_learnable_pos:
54
+ assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
55
+ self.pos_embed = nn.Parameter(torch.randn(seq_len + (1 if use_cls_token else 0), model_dim))
56
+ self.pos_encoder = None
57
+ else:
58
+ self.pos_encoder = PositionalEncoding(model_dim)
59
+ self.pos_embed = None
60
+
61
+ if self.use_cls_token:
62
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim)) # (1, 1, D)
63
+
64
+ # Specify the transformer encoder structure
65
+ encoder_layer = TransformerEncoderLayerWithAttn(d_model=model_dim, nhead=num_heads, batch_first=True, dim_feedforward=self.ff_dim, dropout=self.dropout)
66
+ # Stack the transformer encoder layers
67
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
68
+
69
+ # Register hooks
70
+ for layer in self.transformer.layers:
71
+ layer.self_attn.register_forward_hook(self._save_attn_weights)
72
+ layer.self_attn.register_full_backward_hook(self._save_attn_grads)
73
+
74
+ def _save_attn_weights(self, module, input, output):
75
+ self.attn_weights.append(output[1].detach())
76
+
77
+ def _save_attn_grads(self, module, grad_input, grad_output):
78
+ self.attn_grads.append(grad_output[0].detach())
79
+
80
+ def encode(self, x, mask=None):
81
+ if x.dim() == 2: # (B, S)
82
+ x = x.unsqueeze(-1)
83
+ elif x.dim() == 1: # (S,)
84
+ x = x.unsqueeze(0).unsqueeze(-1)
85
+ elif x.dim() == 3:
86
+ pass
87
+ else:
88
+ raise ValueError(f"Unexpected input shape: {x.shape}")
89
+
90
+ x = self.input_fc(x) # (B, S, D)
91
+
92
+ B, S, D = x.shape
93
+ if self.use_cls_token:
94
+ cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
95
+ x = torch.cat([cls, x], dim=1) # (B, S+1, D)
96
+
97
+ if self.pos_embed is not None:
98
+ x = x + self.pos_embed.unsqueeze(0)[:, :x.shape[1], :]
99
+ elif self.pos_encoder is not None:
100
+ x = self.pos_encoder(x)
101
+
102
+ if mask is not None:
103
+ pad = torch.ones(B, 1, device=mask.device) if self.use_cls_token else 0
104
+ mask = torch.cat([pad, mask], dim=1) if self.use_cls_token else mask
105
+ x = x * mask.unsqueeze(-1)
106
+
107
+ encoded = self.transformer(x)
108
+ return encoded
109
+
110
+ def compute_attn_grad(self, reduction='mean'):
111
+ """
112
+ Computes attention × gradient scores across layers.
113
+ Returns: [B, S] tensor of importance scores
114
+ """
115
+ scores = []
116
+ for attn, grad in zip(self.attn_weights, self.attn_grads):
117
+ # attn: [B, H, S, S]
118
+ # grad: [B, S, D]
119
+ attn = attn.mean(dim=1) # [B, S, S]
120
+ grad_norm = grad.norm(dim=-1) # [B, S]
121
+ attn_grad_score = (attn * grad_norm.unsqueeze(1)).sum(dim=-1) # [B, S]
122
+ scores.append(attn_grad_score)
123
+
124
+ # Combine across layers
125
+ stacked = torch.stack(scores, dim=0) # [L, B, S]
126
+ if reduction == "mean":
127
+ return stacked.mean(dim=0) # [B, S]
128
+ elif reduction == "sum":
129
+ return stacked.sum(dim=0) # [B, S]
130
+ else:
131
+ return stacked # [L, B, S]
132
+
133
+ def compute_rollout(self):
134
+ """
135
+ Computes attention rollout: [B, S, S] final attention influence map
136
+ """
137
+ device = self.attn_weights[0].device
138
+ B, S = self.attn_weights[0].shape[0], self.attn_weights[0].shape[-1]
139
+ rollout = torch.eye(S, device=device).unsqueeze(0).repeat(B, 1, 1) # [B, S, S]
140
+
141
+ for attn in self.attn_weights:
142
+ attn_heads = attn.mean(dim=1) # [B, S, S]
143
+ attn_heads = attn_heads + torch.eye(S, device=device).unsqueeze(0) # add residual
144
+ attn_heads = attn_heads / attn_heads.sum(dim=-1, keepdim=True).clamp(min=1e-6)
145
+ rollout = torch.bmm(attn_heads, rollout) # [B, S, S]
146
+
147
+ return rollout # [B, S, S]
148
+
149
+ def reset_attn_buffers(self):
150
+ self.attn_weights = []
151
+ self.attn_grads = []
152
+
153
+ def get_attn_layer(self, layer_idx=0, head_idx=None):
154
+ """
155
+ Returns attention map from a specific layer (and optionally head).
156
+ """
157
+ attn = self.attn_weights[layer_idx] # [B, H, S, S]
158
+ if head_idx is not None:
159
+ attn = attn[:, head_idx] # [B, S, S]
160
+ return attn
161
+
162
+ def apply_attn_interpretations_to_adata(self, dataloader, adata,
163
+ obsm_key_grad="attn_grad",
164
+ obsm_key_rollout="attn_rollout",
165
+ device="cpu"):
166
+ self.to(device)
167
+ self.eval()
168
+ grad_maps = []
169
+ rollout_maps = []
170
+
171
+ for batch in dataloader:
172
+ x = batch[0].to(device)
173
+ x.requires_grad_()
174
+
175
+ self.reset_attn_buffers()
176
+ logits = self(x)
177
+
178
+ if logits.shape[1] == 1:
179
+ target_score = logits.squeeze()
180
+ else:
181
+ target_score = logits.max(dim=1).values
182
+
183
+ target_score.sum().backward()
184
+
185
+ grad = self.compute_attn_grad() # [B, S+1]
186
+ if self.use_cls_token:
187
+ grad = grad[:, 1:] # ignore CLS token
188
+ grad_maps.append(grad.detach().cpu().numpy())
189
+
190
+ grad_concat = np.concatenate(grad_maps, axis=0)
191
+ adata.obsm[obsm_key_grad] = grad_concat
192
+
193
+ # add per-row normalized version
194
+ grad_normed = grad_concat / (np.max(grad_concat, axis=1, keepdims=True) + 1e-8)
195
+ adata.obsm[f"{obsm_key_grad}_normalized"] = grad_normed
196
+
197
+ class TransformerClassifier(BaseTransformer):
198
+ def __init__(self,
199
+ input_dim,
200
+ num_classes,
201
+ **kwargs):
202
+ super().__init__(input_dim, **kwargs)
203
+ # Classification head
204
+ output_size = 1 if num_classes == 2 else num_classes
205
+ self.cls_head = nn.Linear(self.model_dim, output_size)
206
+
207
+ def forward(self, x):
208
+ """
209
+ x: (batch, seq_len, input_dim)
210
+ """
211
+ self.reset_attn_buffers()
212
+ if x.dim() == 2: # shape (B, S)
213
+ x = x.unsqueeze(-1) # → (B, S, 1)
214
+ elif x.dim() == 1:
215
+ x = x.unsqueeze(0).unsqueeze(-1) # just in case (S,) → (1, S, 1)
216
+ else:
217
+ pass
218
+ encoded = self.encode(x) # -> (B, S, D_model)
219
+ if self.use_cls_token:
220
+ pooled = encoded[:, 0] # (B, D)
221
+ else:
222
+ pooled = encoded.mean(dim=1) # (B, D) out = self.cls_head(pooled) # -> (B, C)
223
+
224
+ out = self.cls_head(pooled) # (B, C)
225
+ return out
226
+
227
+ class DANNTransformerClassifier(TransformerClassifier):
228
+ def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
229
+ super().__init__(input_dim, model_dim, num_classes, **kwargs)
230
+ self.domain_classifier = nn.Sequential(
231
+ nn.Linear(model_dim, 128),
232
+ nn.ReLU(),
233
+ nn.Linear(128, n_domains)
234
+ )
235
+
236
+ def forward(self, x, alpha=1.0):
237
+ encoded = self.encode(x) # (B, S, D_model)
238
+ pooled = encoded.mean(dim=1) # (B, D_model)
239
+
240
+ class_logits = self.cls_head(pooled)
241
+ domain_logits = self.domain_classifier(grad_reverse(pooled, alpha))
242
+
243
+ return class_logits, domain_logits
244
+
245
+ class MaskedTransformerPretrainer(BaseTransformer):
246
+ def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
247
+ super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
248
+ self.decoder = nn.Linear(model_dim, input_dim)
249
+
250
+ def forward(self, x, mask):
251
+ """
252
+ x: (batch, seq_len, input_dim)
253
+ mask: (batch, seq_len) optional
254
+ """
255
+ if x.dim() == 2:
256
+ x = x.unsqueeze(-1)
257
+ encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
258
+ return self.decoder(encoded) # -> (B, D_input)
259
+
260
+ class DANNTransformer(BaseTransformer):
261
+ """
262
+ """
263
+ def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
264
+ super().__init__(
265
+ input_dim=1, # 1D scalar input per token
266
+ model_dim=model_dim,
267
+ num_heads=n_heads,
268
+ num_layers=n_layers,
269
+ seq_len=seq_len,
270
+ use_learnable_pos=True # enables learnable pos_embed in base
271
+ )
272
+
273
+ # Reconstruction head
274
+ self.recon_head = nn.Linear(model_dim, 1)
275
+
276
+ # Domain classification head
277
+ self.domain_classifier = nn.Sequential(
278
+ nn.Linear(model_dim, 128),
279
+ nn.ReLU(),
280
+ nn.Linear(128, n_domains)
281
+ )
282
+
283
+ def forward(self, x, alpha=1.0):
284
+ """
285
+ x: Tensor of shape (B, S) or (B, S, 1)
286
+ alpha: GRL coefficient (float)
287
+ """
288
+ if x.dim() == 2:
289
+ x = x.unsqueeze(-1) # (B, S, 1)
290
+
291
+ # Encode sequence
292
+ h = self.encode(x) # (B, S, D_model)
293
+
294
+ # Head 1: Reconstruction
295
+ recon = self.recon_head(h).squeeze(-1) # (B, S)
296
+
297
+ # Head 2: Domain classification via GRL
298
+ pooled = h.mean(dim=1) # (B, D_model)
299
+ rev = grad_reverse(pooled, alpha)
300
+ domain_logits = self.domain_classifier(rev) # (B, n_batches)
301
+
302
+ return recon, domain_logits
303
+
@@ -0,0 +1,2 @@
1
+ from .train_lightning_model import train_lightning_model, run_sliding_window_lightning_training
2
+ from .train_sklearn_model import train_sklearn_model, run_sliding_window_sklearn_training
@@ -0,0 +1,135 @@
1
+ import torch
2
+ from pytorch_lightning import Trainer
3
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
4
+ from ..data import AnnDataModule
5
+ from ..models import TorchClassifierWrapper
6
+
7
+ def train_lightning_model(
8
+ model,
9
+ datamodule,
10
+ max_epochs=30,
11
+ patience=5,
12
+ monitor_metric="val_loss",
13
+ checkpoint_path=None,
14
+ evaluate_test=True,
15
+ devices=1
16
+ ):
17
+ """
18
+ Takes a PyTorch Lightning Model and a Lightning DataLoader module to define a Lightning Trainer.
19
+ - The Lightning trainer fits the model to the training split of the datamodule.
20
+ - The Lightning trainer uses the validation split of the datamodule for monitoring training loss.
21
+ - Option of evaluating the trained model on a test set when evaluate_test is True.
22
+ - When using cuda, devices parameter can be: 1, [0,1], "all", "auto". Depending on what devices you want to use.
23
+ """
24
+ # Device logic
25
+ if torch.cuda.is_available():
26
+ accelerator = "gpu"
27
+ elif torch.backends.mps.is_available():
28
+ accelerator = "mps"
29
+ devices = 1
30
+ else:
31
+ accelerator = "cpu"
32
+ devices = 1
33
+
34
+ # adds the train/val/test indices from the datamodule to the model class.
35
+ model.set_training_indices(datamodule)
36
+
37
+ # Callbacks
38
+ callbacks = [
39
+ EarlyStopping(monitor=monitor_metric, patience=patience, mode="min"),
40
+ ]
41
+ if checkpoint_path:
42
+ callbacks.append(ModelCheckpoint(
43
+ dirpath=checkpoint_path,
44
+ filename="{epoch}-{val_loss:.4f}",
45
+ monitor=monitor_metric,
46
+ save_top_k=1,
47
+ mode="min",
48
+ ))
49
+
50
+ # Trainer setup
51
+ trainer = Trainer(
52
+ max_epochs=max_epochs,
53
+ callbacks=callbacks,
54
+ accelerator=accelerator,
55
+ devices=devices,
56
+ log_every_n_steps=10,
57
+ enable_progress_bar=False
58
+ )
59
+
60
+ # Fit model with trainer
61
+ trainer.fit(model, datamodule=datamodule)
62
+
63
+ # Test model (if applicable)
64
+ if evaluate_test and hasattr(datamodule, "test_dataloader"):
65
+ trainer.test(model, datamodule=datamodule)
66
+
67
+ # Return best checkpoint path
68
+ best_ckpt = None
69
+ for cb in callbacks:
70
+ if isinstance(cb, ModelCheckpoint):
71
+ best_ckpt = cb.best_model_path
72
+
73
+ return trainer, best_ckpt
74
+
75
+ def run_sliding_window_lightning_training(
76
+ adata,
77
+ tensor_source,
78
+ tensor_key,
79
+ label_col,
80
+ model_class,
81
+ num_classes,
82
+ class_names,
83
+ class_weights,
84
+ focus_class,
85
+ window_size,
86
+ stride,
87
+ max_epochs=30,
88
+ patience=5,
89
+ enforce_eval_balance: bool=False,
90
+ target_eval_freq: float=0.3,
91
+ max_eval_positive: int=None
92
+ ):
93
+ input_len = adata.shape[1]
94
+ results = {}
95
+
96
+ for start in range(0, input_len - window_size + 1, stride):
97
+ center_idx = start + window_size // 2
98
+ center_varname = adata.var_names[center_idx]
99
+ print(f"\nTraining window around {center_varname}")
100
+
101
+ # Build datamodule for this window
102
+ datamodule = AnnDataModule(
103
+ adata,
104
+ tensor_source=tensor_source,
105
+ tensor_key=tensor_key,
106
+ label_col=label_col,
107
+ batch_size=64,
108
+ window_start=start,
109
+ window_size=window_size
110
+ )
111
+ datamodule.setup()
112
+
113
+ # Build model for this window
114
+ model = model_class(window_size, num_classes)
115
+ wrapper = TorchClassifierWrapper(
116
+ model, label_col=label_col, num_classes=num_classes,
117
+ class_names=class_names,
118
+ class_weights=class_weights,
119
+ focus_class=focus_class, enforce_eval_balance=enforce_eval_balance,
120
+ target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive
121
+ )
122
+
123
+ # Train model
124
+ trainer, ckpt = train_lightning_model(
125
+ wrapper, datamodule, max_epochs=max_epochs, patience=patience
126
+ )
127
+
128
+ results[center_varname] = {
129
+ "model": wrapper,
130
+ "trainer": trainer,
131
+ "checkpoint": ckpt,
132
+ "metrics": trainer.callback_metrics
133
+ }
134
+
135
+ return results
@@ -0,0 +1,114 @@
1
+ from ..data import AnnDataModule
2
+ from ..models import SklearnModelWrapper
3
+
4
+ def train_sklearn_model(
5
+ model_wrapper,
6
+ datamodule,
7
+ evaluate_test=True,
8
+ evaluate_val=False
9
+ ):
10
+ """
11
+ Fits a SklearnModelWrapper on the train split from datamodule.
12
+ Evaluates on test and/or val set.
13
+
14
+ Parameters:
15
+ model_wrapper: SklearnModelWrapper instance
16
+ datamodule: AnnDataModule instance (with setup() method)
17
+ evaluate_test: whether to evaluate on test split
18
+ evaluate_val: whether to evaluate on validation split
19
+
20
+ Returns:
21
+ metrics: dictionary containing evaluation metrics
22
+ """
23
+ # Fit model
24
+ model_wrapper.fit_from_datamodule(datamodule)
25
+
26
+ # Evaluate
27
+ metrics = {}
28
+
29
+ if evaluate_val:
30
+ val_metrics = model_wrapper.evaluate_from_datamodule(datamodule, split="val")
31
+ metrics.update({f"{k}": v for k, v in val_metrics.items()})
32
+
33
+ if evaluate_test:
34
+ test_metrics = model_wrapper.evaluate_from_datamodule(datamodule, split="test")
35
+ metrics.update({f"{k}": v for k, v in test_metrics.items()})
36
+
37
+ # Plot evaluations
38
+ model_wrapper.plot_roc_pr_curves()
39
+
40
+ return metrics
41
+
42
+ def run_sliding_window_sklearn_training(
43
+ adata,
44
+ tensor_source,
45
+ tensor_key,
46
+ label_col,
47
+ model_class,
48
+ num_classes,
49
+ class_names,
50
+ focus_class,
51
+ window_size,
52
+ stride,
53
+ batch_size=64,
54
+ train_frac=0.6,
55
+ val_frac=0.1,
56
+ test_frac=0.3,
57
+ random_seed=42,
58
+ enforce_eval_balance=False,
59
+ target_eval_freq=0.3,
60
+ max_eval_positive=None,
61
+ **model_kwargs
62
+ ):
63
+ """
64
+ Sliding window training for sklearn models using AnnData.
65
+
66
+ Returns dict keyed by window center.
67
+ """
68
+
69
+ input_len = adata.shape[1]
70
+ results = {}
71
+
72
+ for start in range(0, input_len - window_size + 1, stride):
73
+ center_idx = start + window_size // 2
74
+ center_varname = adata.var_names[center_idx]
75
+ print(f"\nTraining window around {center_varname}")
76
+
77
+ # Build datamodule for this window
78
+ datamodule = AnnDataModule(
79
+ adata,
80
+ tensor_source=tensor_source,
81
+ tensor_key=tensor_key,
82
+ label_col=label_col,
83
+ batch_size=batch_size,
84
+ window_start=start,
85
+ window_size=window_size,
86
+ train_frac=train_frac,
87
+ val_frac=val_frac,
88
+ test_frac=test_frac,
89
+ random_seed=random_seed
90
+ )
91
+ datamodule.setup()
92
+
93
+ # Build model wrapper
94
+ sklearn_model = model_class(**model_kwargs)
95
+ wrapper = SklearnModelWrapper(
96
+ sklearn_model,
97
+ num_classes=num_classes,
98
+ label_col=label_col,
99
+ class_names=class_names,
100
+ focus_class=focus_class,
101
+ enforce_eval_balance=enforce_eval_balance,
102
+ target_eval_freq=target_eval_freq,
103
+ max_eval_positive=max_eval_positive
104
+ )
105
+
106
+ # Fit and evaluate
107
+ metrics = train_sklearn_model(wrapper, datamodule, evaluate_test=True, evaluate_val=False)
108
+
109
+ results[center_varname] = {
110
+ "model": wrapper,
111
+ "metrics": metrics
112
+ }
113
+
114
+ return results
@@ -1,6 +1,9 @@
1
+ from .autocorrelation_plotting import *
2
+ from .hmm_plotting import *
1
3
  from .position_stats import plot_bar_relative_risk, plot_volcano_relative_risk, plot_positionwise_matrix, plot_positionwise_matrix_grid
2
- from .general_plotting import combined_hmm_raw_clustermap
4
+ from .general_plotting import combined_hmm_raw_clustermap, combined_raw_clustermap, plot_hmm_layers_rolling_by_sample_ref
3
5
  from .classifiers import plot_model_performance, plot_feature_importances_or_saliency, plot_model_curves_from_adata, plot_model_curves_from_adata_with_frequency_grid
6
+ from .qc_plotting import *
4
7
 
5
8
  __all__ = [
6
9
  "combined_hmm_raw_clustermap",