smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,295 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from ..utils.device import detect_device
5
+
6
+ class BaseTorchModel(nn.Module):
7
+ """
8
+ Minimal base class for torch models that:
9
+ - Stores device and dropout regularization
10
+ """
11
+ def __init__(self, dropout_rate=0.0):
12
+ super().__init__()
13
+ self.device = detect_device() # detects available devices
14
+ self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
15
+
16
+ def compute_saliency(
17
+ self,
18
+ x,
19
+ target_class=None,
20
+ reduction="sum",
21
+ smoothgrad=False,
22
+ smooth_samples=25,
23
+ smooth_noise=0.1,
24
+ signed=True
25
+ ):
26
+ """
27
+ Compute vanilla saliency or SmoothGrad saliency.
28
+
29
+ Arguments:
30
+ ----------
31
+ x : torch.Tensor
32
+ Input tensor [B, S, D].
33
+ target_class : int or list, optional
34
+ If None, uses model predicted class.
35
+ reduction : str
36
+ 'sum' or 'mean' across channels
37
+ smoothgrad : bool
38
+ Whether to apply SmoothGrad.
39
+ smooth_samples : int
40
+ Number of noisy samples for SmoothGrad
41
+ smooth_noise : float
42
+ Standard deviation of noise added to input
43
+ """
44
+ self.eval()
45
+ x = x.clone().detach().requires_grad_(True)
46
+
47
+ if smoothgrad:
48
+ saliency_accum = torch.zeros_like(x)
49
+ for i in range(smooth_samples):
50
+ noise = torch.normal(mean=0.0, std=smooth_noise, size=x.shape).to(x.device)
51
+ x_noisy = x + noise
52
+ x_noisy.requires_grad_(True)
53
+ x_noisy.retain_grad() # <<< fixes the issue
54
+ logits = self.forward(x_noisy)
55
+ target_class = self._resolve_target_class(logits, target_class)
56
+ if logits.shape[1] == 1:
57
+ scores = logits.squeeze(1)
58
+ else:
59
+ scores = logits[torch.arange(x.shape[0]), target_class]
60
+ scores.sum().backward()
61
+ saliency_accum += x_noisy.grad.detach()
62
+ saliency = saliency_accum / smooth_samples
63
+ else:
64
+ logits = self.forward(x)
65
+ target_class = self._resolve_target_class(logits, target_class)
66
+ if logits.shape[1] == 1:
67
+ scores = logits.squeeze(1)
68
+ else:
69
+ scores = logits[torch.arange(x.shape[0]), target_class]
70
+ scores.sum().backward()
71
+ saliency = x.grad.detach()
72
+
73
+ if not signed:
74
+ saliency = saliency.abs()
75
+
76
+ if reduction == "sum" and x.ndim == 3:
77
+ return saliency.sum(dim=-1)
78
+ elif reduction == "mean" and x.ndim == 3:
79
+ return saliency.mean(dim=-1)
80
+ else:
81
+ return saliency
82
+
83
+ def compute_gradient_x_input(self, x, target_class=None):
84
+ """
85
+ Computes gradient × input attribution.
86
+ """
87
+ x = x.clone().detach().requires_grad_(True)
88
+ logits = self.forward(x)
89
+ target_class = self._resolve_target_class(logits, target_class)
90
+ if logits.shape[1] == 1:
91
+ scores = logits.squeeze(1)
92
+ else:
93
+ scores = logits[torch.arange(x.shape[0]), target_class]
94
+ scores.sum().backward()
95
+ grads = x.grad
96
+ return grads * x
97
+
98
+ def compute_integrated_gradients(self, x, target_class=None, steps=50, baseline=None):
99
+ """
100
+ Compute Integrated Gradients for a batch of x.
101
+ If target=None, uses the predicted class.
102
+ Returns: [B, seq_len, channels] attribution tensor
103
+ """
104
+ from captum.attr import IntegratedGradients
105
+
106
+ ig = IntegratedGradients(self)
107
+ self.eval()
108
+ x = x.requires_grad_(True)
109
+
110
+ with torch.no_grad():
111
+ logits = self.forward(x)
112
+ if logits.shape[1] == 1:
113
+ target_class = 0 # only one column exists, representing class 1 logit
114
+ else:
115
+ target_class = self._resolve_target_class(logits, target_class)
116
+
117
+ if baseline is None:
118
+ baseline = torch.zeros_like(x)
119
+
120
+ attributions, delta = ig.attribute(
121
+ x,
122
+ baselines=baseline,
123
+ target=target_class,
124
+ n_steps=steps,
125
+ return_convergence_delta=True
126
+ )
127
+ return attributions, delta
128
+
129
+ def compute_deeplift(
130
+ self,
131
+ x,
132
+ baseline=None,
133
+ target_class=None,
134
+ reduction="sum",
135
+ signed=True
136
+ ):
137
+ """
138
+ Compute DeepLIFT scores using captum.
139
+
140
+ baseline:
141
+ reference input for DeepLIFT.
142
+ """
143
+ from captum.attr import DeepLift
144
+
145
+ self.eval()
146
+ deeplift = DeepLift(self)
147
+
148
+ logits = self.forward(x)
149
+ if logits.shape[1] == 1:
150
+ target_class = 0 # only one column exists, representing class 1 logit
151
+ else:
152
+ target_class = self._resolve_target_class(logits, target_class)
153
+
154
+ if baseline is None:
155
+ baseline = torch.zeros_like(x)
156
+
157
+ attr = deeplift.attribute(x, target=target_class, baselines=baseline)
158
+
159
+ if not signed:
160
+ attr = attr.abs()
161
+
162
+ if reduction == "sum" and x.ndim == 3:
163
+ return attr.sum(dim=-1)
164
+ elif reduction == "mean" and x.ndim == 3:
165
+ return attr.mean(dim=-1)
166
+ else:
167
+ return attr
168
+
169
+ def compute_occlusion(
170
+ self,
171
+ x,
172
+ target_class=None,
173
+ window_size=5,
174
+ baseline=None
175
+ ):
176
+ """
177
+ Computes per-sample occlusion attribution.
178
+ Supports 2D [B, S] or 3D [B, S, D] inputs.
179
+ Returns: [B, S] occlusion scores
180
+ """
181
+ self.eval()
182
+
183
+ x_np = x.detach().cpu().numpy()
184
+ ndim = x_np.ndim
185
+ if ndim == 2:
186
+ B, S = x_np.shape
187
+ D = 1
188
+ elif ndim == 3:
189
+ B, S, D = x_np.shape
190
+ else:
191
+ raise ValueError(f"Unsupported input shape {x_np.shape}")
192
+
193
+ # if no baseline provided, fallback to mean
194
+ if baseline is None:
195
+ baseline = np.mean(x_np, axis=0)
196
+
197
+ occlusion_scores = np.zeros((B, S))
198
+
199
+ for b in range(B):
200
+ for i in range(S):
201
+ x_occluded = x_np[b].copy()
202
+ left = max(0, i - window_size // 2)
203
+ right = min(S, i + window_size // 2)
204
+
205
+ if ndim == 2:
206
+ x_occluded[left:right] = baseline[left:right]
207
+ else:
208
+ x_occluded[left:right, :] = baseline[left:right, :]
209
+
210
+ x_tensor = torch.tensor(
211
+ x_occluded,
212
+ device=self.device,
213
+ dtype=torch.float32
214
+ ).unsqueeze(0)
215
+
216
+ logits = self.forward(x_tensor)
217
+ target_class = self._resolve_target_class(logits, target_class)
218
+
219
+ if logits.shape[1] == 1:
220
+ scores = logits.squeeze(1)
221
+ else:
222
+ scores = logits[torch.arange(x.shape[0]), target_class]
223
+
224
+ occlusion_scores[b, i] = scores.mean().item()
225
+
226
+ return occlusion_scores
227
+
228
+ def apply_attributions_to_adata(
229
+ model,
230
+ dataloader,
231
+ adata,
232
+ method="saliency", # saliency, smoothgrad, IG, deeplift, gradxinput, occlusion
233
+ adata_key="attributions",
234
+ baseline=None,
235
+ device="cpu",
236
+ target_class=None,
237
+ normalize=True,
238
+ signed=True
239
+ ):
240
+ """
241
+ Apply a chosen attribution method to a dataloader and store results in adata.
242
+ """
243
+
244
+ results = []
245
+ model.to(device)
246
+ model.eval()
247
+
248
+ for batch in dataloader:
249
+ x = batch[0].to(device)
250
+
251
+ if method == "saliency":
252
+ attr = model.compute_saliency(x, target_class=target_class, signed=signed)
253
+
254
+ elif method == "smoothgrad":
255
+ attr = model.compute_saliency(x, smoothgrad=True, target_class=target_class, signed=signed)
256
+
257
+ elif method == "IG":
258
+ attributions, delta = model.compute_integrated_gradients(
259
+ x, target_class=target_class, baseline=baseline
260
+ )
261
+ attr = attributions
262
+
263
+ elif method == "deeplift":
264
+ attr = model.compute_deeplift(x, baseline=baseline, target_class=target_class, signed=signed)
265
+
266
+ elif method == "gradxinput":
267
+ attr = model.compute_gradient_x_input(x, target_class=target_class)
268
+
269
+ elif method == "occlusion":
270
+ attr = model.compute_occlusion(
271
+ x, target_class=target_class, baseline=baseline
272
+ )
273
+
274
+ else:
275
+ raise ValueError(f"Unknown method {method}")
276
+
277
+ # ensure numpy
278
+ attr = attr.detach().cpu().numpy() if torch.is_tensor(attr) else attr
279
+ results.append(attr)
280
+
281
+ results_stacked = np.concatenate(results, axis=0)
282
+ adata.obsm[adata_key] = results_stacked
283
+
284
+ if normalize:
285
+ row_max = np.max(np.abs(results_stacked), axis=1, keepdims=True)
286
+ row_max[row_max == 0] = 1 # avoid divide by zero
287
+ results_normalized = results_stacked / row_max
288
+ adata.obsm[adata_key + "_normalized"] = results_normalized
289
+
290
+ def _resolve_target_class(self, logits, target_class):
291
+ if target_class is not None:
292
+ return target_class
293
+ if logits.shape[1] == 1:
294
+ return (logits > 0).long().squeeze(1)
295
+ return logits.argmax(dim=1)
@@ -0,0 +1,138 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+ import numpy as np
5
+
6
+ class CNNClassifier(BaseTorchModel):
7
+ def __init__(
8
+ self,
9
+ input_size,
10
+ num_classes=2,
11
+ conv_channels=[16, 32],
12
+ kernel_sizes=[3, 3],
13
+ fc_dims=[64],
14
+ use_batchnorm=False,
15
+ use_pooling=False,
16
+ dropout=0.2,
17
+ gradcam_layer_idx=-1,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.name = "CNNClassifier"
22
+
23
+ # Normalize input
24
+ if isinstance(kernel_sizes, int):
25
+ kernel_sizes = [kernel_sizes] * len(conv_channels)
26
+ assert len(conv_channels) == len(kernel_sizes)
27
+
28
+ layers = []
29
+ in_channels = 1
30
+
31
+ # Build conv layers
32
+ for out_channels, ksize in zip(conv_channels, kernel_sizes):
33
+ layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
34
+ if use_batchnorm:
35
+ layers.append(nn.BatchNorm1d(out_channels))
36
+ layers.append(nn.ReLU())
37
+ if use_pooling:
38
+ layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
39
+ if dropout > 0:
40
+ layers.append(nn.Dropout(dropout))
41
+ in_channels = out_channels
42
+
43
+ self.conv = nn.Sequential(*layers)
44
+
45
+ # Determine flattened size
46
+ with torch.no_grad():
47
+ dummy = torch.zeros(1, 1, input_size)
48
+ conv_out = self.conv(dummy)
49
+ flattened_size = conv_out.view(1, -1).shape[1]
50
+
51
+ # Build FC layers
52
+ fc_layers = []
53
+ in_dim = flattened_size
54
+ for dim in fc_dims:
55
+ fc_layers.append(nn.Linear(in_dim, dim))
56
+ fc_layers.append(nn.ReLU())
57
+ if dropout > 0:
58
+ fc_layers.append(nn.Dropout(dropout))
59
+ in_dim = dim
60
+
61
+ output_size = 1 if num_classes == 2 else num_classes
62
+ fc_layers.append(nn.Linear(in_dim, output_size))
63
+
64
+ self.fc = nn.Sequential(*fc_layers)
65
+
66
+ # Build gradcam hooks
67
+ self.gradcam_layer_idx = gradcam_layer_idx
68
+ self.gradcam_activations = None
69
+ self.gradcam_gradients = None
70
+ if not hasattr(self, "_hooks_registered"):
71
+ self._register_gradcam_hooks()
72
+ self._hooks_registered = True
73
+
74
+ def forward(self, x):
75
+ x = x.unsqueeze(1) # [B, 1, L]
76
+ x = self.conv(x)
77
+ x = x.view(x.size(0), -1)
78
+ return self.fc(x)
79
+
80
+ def _register_gradcam_hooks(self):
81
+ def forward_hook(module, input, output):
82
+ self.gradcam_activations = output.detach()
83
+
84
+ def backward_hook(module, grad_input, grad_output):
85
+ self.gradcam_gradients = grad_output[0].detach()
86
+
87
+ target_layer = list(self.conv.children())[self.gradcam_layer_idx]
88
+ target_layer.register_forward_hook(forward_hook)
89
+ target_layer.register_full_backward_hook(backward_hook)
90
+
91
+ def compute_gradcam(self, x, class_idx=None):
92
+ self.zero_grad()
93
+
94
+ x = x.detach().clone().requires_grad_().to(self.device)
95
+
96
+ was_training = self.training
97
+ self.eval() # disable dropout etc.
98
+
99
+ output = self.forward(x) # shape (B, C) or (B, 1)
100
+
101
+ if class_idx is None:
102
+ class_idx = output.argmax(dim=1)
103
+
104
+ if output.shape[1] == 1:
105
+ target = output.view(-1) # shape (B,)
106
+ else:
107
+ target = output[torch.arange(output.shape[0]), class_idx]
108
+
109
+ target.sum().backward(retain_graph=True)
110
+
111
+ # restore training mode
112
+ if was_training:
113
+ self.train()
114
+
115
+ # get activations and gradients (set these via forward hook!)
116
+ activations = self.gradcam_activations # (B, C, L)
117
+ gradients = self.gradcam_gradients # (B, C, L)
118
+
119
+ weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
120
+ cam = (weights * activations).sum(dim=1) # (B, L)
121
+
122
+ cam = torch.relu(cam)
123
+ cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
124
+
125
+ return cam
126
+
127
+ def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
128
+ self.to(device)
129
+ self.eval()
130
+ cams = []
131
+
132
+ for batch in dataloader:
133
+ x = batch[0].to(device)
134
+ cam_batch = self.compute_gradcam(x)
135
+ cams.append(cam_batch.cpu().numpy())
136
+
137
+ cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
138
+ adata.obsm[obsm_key] = cams