smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. smftools/__init__.py +9 -4
  2. smftools/_version.py +1 -1
  3. smftools/cli.py +184 -0
  4. smftools/config/__init__.py +1 -0
  5. smftools/config/conversion.yaml +33 -0
  6. smftools/config/deaminase.yaml +56 -0
  7. smftools/config/default.yaml +253 -0
  8. smftools/config/direct.yaml +17 -0
  9. smftools/config/experiment_config.py +1191 -0
  10. smftools/hmm/HMM.py +1576 -0
  11. smftools/hmm/__init__.py +20 -0
  12. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  13. smftools/hmm/call_hmm_peaks.py +106 -0
  14. smftools/{tools → hmm}/display_hmm.py +3 -3
  15. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  16. smftools/{tools → hmm}/train_hmm.py +1 -1
  17. smftools/informatics/__init__.py +0 -2
  18. smftools/informatics/archived/deaminase_smf.py +132 -0
  19. smftools/informatics/fast5_to_pod5.py +4 -1
  20. smftools/informatics/helpers/__init__.py +3 -4
  21. smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
  22. smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
  23. smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
  24. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
  25. smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
  26. smftools/informatics/helpers/discover_input_files.py +100 -0
  27. smftools/informatics/helpers/extract_base_identities.py +29 -3
  28. smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
  29. smftools/informatics/helpers/find_conversion_sites.py +5 -4
  30. smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
  31. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  32. smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
  33. smftools/informatics/helpers/split_and_index_BAM.py +1 -5
  34. smftools/load_adata.py +1346 -0
  35. smftools/machine_learning/__init__.py +12 -0
  36. smftools/machine_learning/data/__init__.py +2 -0
  37. smftools/machine_learning/data/anndata_data_module.py +234 -0
  38. smftools/machine_learning/evaluation/__init__.py +2 -0
  39. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  40. smftools/machine_learning/evaluation/evaluators.py +223 -0
  41. smftools/machine_learning/inference/__init__.py +3 -0
  42. smftools/machine_learning/inference/inference_utils.py +27 -0
  43. smftools/machine_learning/inference/lightning_inference.py +68 -0
  44. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  45. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  46. smftools/machine_learning/models/base.py +295 -0
  47. smftools/machine_learning/models/cnn.py +138 -0
  48. smftools/machine_learning/models/lightning_base.py +345 -0
  49. smftools/machine_learning/models/mlp.py +26 -0
  50. smftools/{tools → machine_learning}/models/positional.py +3 -2
  51. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  52. smftools/machine_learning/models/sklearn_models.py +273 -0
  53. smftools/machine_learning/models/transformer.py +303 -0
  54. smftools/machine_learning/training/__init__.py +2 -0
  55. smftools/machine_learning/training/train_lightning_model.py +135 -0
  56. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  57. smftools/plotting/__init__.py +4 -1
  58. smftools/plotting/autocorrelation_plotting.py +611 -0
  59. smftools/plotting/general_plotting.py +566 -89
  60. smftools/plotting/hmm_plotting.py +260 -0
  61. smftools/plotting/qc_plotting.py +270 -0
  62. smftools/preprocessing/__init__.py +13 -8
  63. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  64. smftools/preprocessing/append_base_context.py +122 -0
  65. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  66. smftools/preprocessing/calculate_complexity_II.py +248 -0
  67. smftools/preprocessing/calculate_coverage.py +10 -1
  68. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  69. smftools/preprocessing/clean_NaN.py +17 -1
  70. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  71. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  72. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  73. smftools/preprocessing/invert_adata.py +12 -5
  74. smftools/preprocessing/load_sample_sheet.py +19 -4
  75. smftools/readwrite.py +849 -43
  76. smftools/tools/__init__.py +3 -32
  77. smftools/tools/calculate_umap.py +5 -5
  78. smftools/tools/general_tools.py +3 -3
  79. smftools/tools/position_stats.py +468 -106
  80. smftools/tools/read_stats.py +115 -1
  81. smftools/tools/spatial_autocorrelation.py +562 -0
  82. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
  83. smftools-0.2.1.dist-info/RECORD +161 -0
  84. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  85. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  86. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  87. smftools/informatics/load_adata.py +0 -182
  88. smftools/preprocessing/append_C_context.py +0 -82
  89. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  90. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  91. smftools/preprocessing/filter_reads_on_length.py +0 -51
  92. smftools/tools/call_hmm_peaks.py +0 -105
  93. smftools/tools/data/__init__.py +0 -2
  94. smftools/tools/data/anndata_data_module.py +0 -90
  95. smftools/tools/evaluation/__init__.py +0 -0
  96. smftools/tools/inference/__init__.py +0 -1
  97. smftools/tools/inference/lightning_inference.py +0 -41
  98. smftools/tools/models/base.py +0 -14
  99. smftools/tools/models/cnn.py +0 -34
  100. smftools/tools/models/lightning_base.py +0 -41
  101. smftools/tools/models/mlp.py +0 -17
  102. smftools/tools/models/sklearn_models.py +0 -40
  103. smftools/tools/models/transformer.py +0 -133
  104. smftools/tools/training/__init__.py +0 -1
  105. smftools/tools/training/train_lightning_model.py +0 -47
  106. smftools-0.1.7.dist-info/RECORD +0 -136
  107. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  108. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  109. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  110. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  111. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  112. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  113. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  114. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  115. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  116. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  117. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  118. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  119. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  120. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,114 @@
1
+ from ..data import AnnDataModule
2
+ from ..evaluation import PostInferenceModelEvaluator
3
+ from .lightning_inference import run_lightning_inference
4
+ from .sklearn_inference import run_sklearn_inference
5
+
6
+ def sliding_window_inference(
7
+ adata,
8
+ trained_results,
9
+ tensor_source='X',
10
+ tensor_key=None,
11
+ label_col='activity_status',
12
+ batch_size=64,
13
+ cleanup=False,
14
+ target_eval_freq=None,
15
+ max_eval_positive=None
16
+ ):
17
+ """
18
+ Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
19
+ Evaluate model performance and return a df.
20
+ Optionally remove the appended inference columns from AnnData to clean up obs namespace.
21
+ """
22
+ ## Inference using trained models
23
+ for model_name, model_dict in trained_results.items():
24
+ for window_size, window_data in model_dict.items():
25
+ for center_varname, run in window_data.items():
26
+ print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
27
+
28
+ # Extract window start from varname
29
+ center_idx = adata.var_names.get_loc(center_varname)
30
+ window_start = center_idx - window_size // 2
31
+
32
+ # Build datamodule for window
33
+ datamodule = AnnDataModule(
34
+ adata,
35
+ tensor_source=tensor_source,
36
+ tensor_key=tensor_key,
37
+ label_col=label_col,
38
+ batch_size=batch_size,
39
+ window_start=window_start,
40
+ window_size=window_size,
41
+ inference_mode=True
42
+ )
43
+ datamodule.setup()
44
+
45
+ # Extract model + detect type
46
+ model = run['model']
47
+
48
+ # Lightning models
49
+ if hasattr(run, 'trainer') or 'trainer' in run:
50
+ trainer = run['trainer']
51
+ run_lightning_inference(
52
+ adata,
53
+ model=model,
54
+ datamodule=datamodule,
55
+ trainer=trainer,
56
+ prefix=f"{model_name}_w{window_size}_c{center_varname}"
57
+ )
58
+
59
+ # Sklearn models
60
+ else:
61
+ run_sklearn_inference(
62
+ adata,
63
+ model=model,
64
+ datamodule=datamodule,
65
+ prefix=f"{model_name}_w{window_size}_c{center_varname}"
66
+ )
67
+
68
+ print("Inference complete across all models.")
69
+
70
+ ## Post-inference model evaluation
71
+ model_wrappers = {}
72
+
73
+ for model_name, model_dict in trained_results.items():
74
+ for window_size, window_data in model_dict.items():
75
+ for center_varname, run in window_data.items():
76
+ # Reconstruct the prefix string you used in inference
77
+ prefix = f"{model_name}_w{window_size}_c{center_varname}"
78
+ # Use full key for uniqueness
79
+ key = prefix
80
+ model_wrappers[key] = run['model']
81
+
82
+ # Run evaluator
83
+ evaluator = PostInferenceModelEvaluator(adata, model_wrappers, target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive)
84
+ evaluator.evaluate_all()
85
+
86
+ # Get results
87
+ df = evaluator.to_dataframe()
88
+
89
+ df[['model_name', 'window_size', 'center']] = df['model'].str.extract(r'(\w+)_w(\d+)_c(\d+)_activity_status')
90
+
91
+ # Cast window_size and center to integers for plotting
92
+ df['window_size'] = df['window_size'].astype(int)
93
+ df['center'] = df['center'].astype(int)
94
+
95
+ ## Optional cleanup:
96
+ if cleanup:
97
+ prefixes = [f"{model_name}_w{window_size}_c{center_varname}"
98
+ for model_name, model_dict in trained_results.items()
99
+ for window_size, window_data in model_dict.items()
100
+ for center_varname in window_data.keys()]
101
+
102
+ # Remove matching obs columns
103
+ for prefix in prefixes:
104
+ to_remove = [col for col in adata.obs.columns if col.startswith(prefix)]
105
+ adata.obs.drop(columns=to_remove, inplace=True)
106
+
107
+ # Remove obsm entries if any
108
+ obsm_key = f"{prefix}_pred_prob_all"
109
+ if obsm_key in adata.obsm:
110
+ del adata.obsm[obsm_key]
111
+
112
+ print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
113
+
114
+ return df
@@ -0,0 +1,295 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from ..utils.device import detect_device
5
+
6
+ class BaseTorchModel(nn.Module):
7
+ """
8
+ Minimal base class for torch models that:
9
+ - Stores device and dropout regularization
10
+ """
11
+ def __init__(self, dropout_rate=0.0):
12
+ super().__init__()
13
+ self.device = detect_device() # detects available devices
14
+ self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
15
+
16
+ def compute_saliency(
17
+ self,
18
+ x,
19
+ target_class=None,
20
+ reduction="sum",
21
+ smoothgrad=False,
22
+ smooth_samples=25,
23
+ smooth_noise=0.1,
24
+ signed=True
25
+ ):
26
+ """
27
+ Compute vanilla saliency or SmoothGrad saliency.
28
+
29
+ Arguments:
30
+ ----------
31
+ x : torch.Tensor
32
+ Input tensor [B, S, D].
33
+ target_class : int or list, optional
34
+ If None, uses model predicted class.
35
+ reduction : str
36
+ 'sum' or 'mean' across channels
37
+ smoothgrad : bool
38
+ Whether to apply SmoothGrad.
39
+ smooth_samples : int
40
+ Number of noisy samples for SmoothGrad
41
+ smooth_noise : float
42
+ Standard deviation of noise added to input
43
+ """
44
+ self.eval()
45
+ x = x.clone().detach().requires_grad_(True)
46
+
47
+ if smoothgrad:
48
+ saliency_accum = torch.zeros_like(x)
49
+ for i in range(smooth_samples):
50
+ noise = torch.normal(mean=0.0, std=smooth_noise, size=x.shape).to(x.device)
51
+ x_noisy = x + noise
52
+ x_noisy.requires_grad_(True)
53
+ x_noisy.retain_grad() # <<< fixes the issue
54
+ logits = self.forward(x_noisy)
55
+ target_class = self._resolve_target_class(logits, target_class)
56
+ if logits.shape[1] == 1:
57
+ scores = logits.squeeze(1)
58
+ else:
59
+ scores = logits[torch.arange(x.shape[0]), target_class]
60
+ scores.sum().backward()
61
+ saliency_accum += x_noisy.grad.detach()
62
+ saliency = saliency_accum / smooth_samples
63
+ else:
64
+ logits = self.forward(x)
65
+ target_class = self._resolve_target_class(logits, target_class)
66
+ if logits.shape[1] == 1:
67
+ scores = logits.squeeze(1)
68
+ else:
69
+ scores = logits[torch.arange(x.shape[0]), target_class]
70
+ scores.sum().backward()
71
+ saliency = x.grad.detach()
72
+
73
+ if not signed:
74
+ saliency = saliency.abs()
75
+
76
+ if reduction == "sum" and x.ndim == 3:
77
+ return saliency.sum(dim=-1)
78
+ elif reduction == "mean" and x.ndim == 3:
79
+ return saliency.mean(dim=-1)
80
+ else:
81
+ return saliency
82
+
83
+ def compute_gradient_x_input(self, x, target_class=None):
84
+ """
85
+ Computes gradient × input attribution.
86
+ """
87
+ x = x.clone().detach().requires_grad_(True)
88
+ logits = self.forward(x)
89
+ target_class = self._resolve_target_class(logits, target_class)
90
+ if logits.shape[1] == 1:
91
+ scores = logits.squeeze(1)
92
+ else:
93
+ scores = logits[torch.arange(x.shape[0]), target_class]
94
+ scores.sum().backward()
95
+ grads = x.grad
96
+ return grads * x
97
+
98
+ def compute_integrated_gradients(self, x, target_class=None, steps=50, baseline=None):
99
+ """
100
+ Compute Integrated Gradients for a batch of x.
101
+ If target=None, uses the predicted class.
102
+ Returns: [B, seq_len, channels] attribution tensor
103
+ """
104
+ from captum.attr import IntegratedGradients
105
+
106
+ ig = IntegratedGradients(self)
107
+ self.eval()
108
+ x = x.requires_grad_(True)
109
+
110
+ with torch.no_grad():
111
+ logits = self.forward(x)
112
+ if logits.shape[1] == 1:
113
+ target_class = 0 # only one column exists, representing class 1 logit
114
+ else:
115
+ target_class = self._resolve_target_class(logits, target_class)
116
+
117
+ if baseline is None:
118
+ baseline = torch.zeros_like(x)
119
+
120
+ attributions, delta = ig.attribute(
121
+ x,
122
+ baselines=baseline,
123
+ target=target_class,
124
+ n_steps=steps,
125
+ return_convergence_delta=True
126
+ )
127
+ return attributions, delta
128
+
129
+ def compute_deeplift(
130
+ self,
131
+ x,
132
+ baseline=None,
133
+ target_class=None,
134
+ reduction="sum",
135
+ signed=True
136
+ ):
137
+ """
138
+ Compute DeepLIFT scores using captum.
139
+
140
+ baseline:
141
+ reference input for DeepLIFT.
142
+ """
143
+ from captum.attr import DeepLift
144
+
145
+ self.eval()
146
+ deeplift = DeepLift(self)
147
+
148
+ logits = self.forward(x)
149
+ if logits.shape[1] == 1:
150
+ target_class = 0 # only one column exists, representing class 1 logit
151
+ else:
152
+ target_class = self._resolve_target_class(logits, target_class)
153
+
154
+ if baseline is None:
155
+ baseline = torch.zeros_like(x)
156
+
157
+ attr = deeplift.attribute(x, target=target_class, baselines=baseline)
158
+
159
+ if not signed:
160
+ attr = attr.abs()
161
+
162
+ if reduction == "sum" and x.ndim == 3:
163
+ return attr.sum(dim=-1)
164
+ elif reduction == "mean" and x.ndim == 3:
165
+ return attr.mean(dim=-1)
166
+ else:
167
+ return attr
168
+
169
+ def compute_occlusion(
170
+ self,
171
+ x,
172
+ target_class=None,
173
+ window_size=5,
174
+ baseline=None
175
+ ):
176
+ """
177
+ Computes per-sample occlusion attribution.
178
+ Supports 2D [B, S] or 3D [B, S, D] inputs.
179
+ Returns: [B, S] occlusion scores
180
+ """
181
+ self.eval()
182
+
183
+ x_np = x.detach().cpu().numpy()
184
+ ndim = x_np.ndim
185
+ if ndim == 2:
186
+ B, S = x_np.shape
187
+ D = 1
188
+ elif ndim == 3:
189
+ B, S, D = x_np.shape
190
+ else:
191
+ raise ValueError(f"Unsupported input shape {x_np.shape}")
192
+
193
+ # if no baseline provided, fallback to mean
194
+ if baseline is None:
195
+ baseline = np.mean(x_np, axis=0)
196
+
197
+ occlusion_scores = np.zeros((B, S))
198
+
199
+ for b in range(B):
200
+ for i in range(S):
201
+ x_occluded = x_np[b].copy()
202
+ left = max(0, i - window_size // 2)
203
+ right = min(S, i + window_size // 2)
204
+
205
+ if ndim == 2:
206
+ x_occluded[left:right] = baseline[left:right]
207
+ else:
208
+ x_occluded[left:right, :] = baseline[left:right, :]
209
+
210
+ x_tensor = torch.tensor(
211
+ x_occluded,
212
+ device=self.device,
213
+ dtype=torch.float32
214
+ ).unsqueeze(0)
215
+
216
+ logits = self.forward(x_tensor)
217
+ target_class = self._resolve_target_class(logits, target_class)
218
+
219
+ if logits.shape[1] == 1:
220
+ scores = logits.squeeze(1)
221
+ else:
222
+ scores = logits[torch.arange(x.shape[0]), target_class]
223
+
224
+ occlusion_scores[b, i] = scores.mean().item()
225
+
226
+ return occlusion_scores
227
+
228
+ def apply_attributions_to_adata(
229
+ model,
230
+ dataloader,
231
+ adata,
232
+ method="saliency", # saliency, smoothgrad, IG, deeplift, gradxinput, occlusion
233
+ adata_key="attributions",
234
+ baseline=None,
235
+ device="cpu",
236
+ target_class=None,
237
+ normalize=True,
238
+ signed=True
239
+ ):
240
+ """
241
+ Apply a chosen attribution method to a dataloader and store results in adata.
242
+ """
243
+
244
+ results = []
245
+ model.to(device)
246
+ model.eval()
247
+
248
+ for batch in dataloader:
249
+ x = batch[0].to(device)
250
+
251
+ if method == "saliency":
252
+ attr = model.compute_saliency(x, target_class=target_class, signed=signed)
253
+
254
+ elif method == "smoothgrad":
255
+ attr = model.compute_saliency(x, smoothgrad=True, target_class=target_class, signed=signed)
256
+
257
+ elif method == "IG":
258
+ attributions, delta = model.compute_integrated_gradients(
259
+ x, target_class=target_class, baseline=baseline
260
+ )
261
+ attr = attributions
262
+
263
+ elif method == "deeplift":
264
+ attr = model.compute_deeplift(x, baseline=baseline, target_class=target_class, signed=signed)
265
+
266
+ elif method == "gradxinput":
267
+ attr = model.compute_gradient_x_input(x, target_class=target_class)
268
+
269
+ elif method == "occlusion":
270
+ attr = model.compute_occlusion(
271
+ x, target_class=target_class, baseline=baseline
272
+ )
273
+
274
+ else:
275
+ raise ValueError(f"Unknown method {method}")
276
+
277
+ # ensure numpy
278
+ attr = attr.detach().cpu().numpy() if torch.is_tensor(attr) else attr
279
+ results.append(attr)
280
+
281
+ results_stacked = np.concatenate(results, axis=0)
282
+ adata.obsm[adata_key] = results_stacked
283
+
284
+ if normalize:
285
+ row_max = np.max(np.abs(results_stacked), axis=1, keepdims=True)
286
+ row_max[row_max == 0] = 1 # avoid divide by zero
287
+ results_normalized = results_stacked / row_max
288
+ adata.obsm[adata_key + "_normalized"] = results_normalized
289
+
290
+ def _resolve_target_class(self, logits, target_class):
291
+ if target_class is not None:
292
+ return target_class
293
+ if logits.shape[1] == 1:
294
+ return (logits > 0).long().squeeze(1)
295
+ return logits.argmax(dim=1)
@@ -0,0 +1,138 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+ import numpy as np
5
+
6
+ class CNNClassifier(BaseTorchModel):
7
+ def __init__(
8
+ self,
9
+ input_size,
10
+ num_classes=2,
11
+ conv_channels=[16, 32],
12
+ kernel_sizes=[3, 3],
13
+ fc_dims=[64],
14
+ use_batchnorm=False,
15
+ use_pooling=False,
16
+ dropout=0.2,
17
+ gradcam_layer_idx=-1,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.name = "CNNClassifier"
22
+
23
+ # Normalize input
24
+ if isinstance(kernel_sizes, int):
25
+ kernel_sizes = [kernel_sizes] * len(conv_channels)
26
+ assert len(conv_channels) == len(kernel_sizes)
27
+
28
+ layers = []
29
+ in_channels = 1
30
+
31
+ # Build conv layers
32
+ for out_channels, ksize in zip(conv_channels, kernel_sizes):
33
+ layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
34
+ if use_batchnorm:
35
+ layers.append(nn.BatchNorm1d(out_channels))
36
+ layers.append(nn.ReLU())
37
+ if use_pooling:
38
+ layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
39
+ if dropout > 0:
40
+ layers.append(nn.Dropout(dropout))
41
+ in_channels = out_channels
42
+
43
+ self.conv = nn.Sequential(*layers)
44
+
45
+ # Determine flattened size
46
+ with torch.no_grad():
47
+ dummy = torch.zeros(1, 1, input_size)
48
+ conv_out = self.conv(dummy)
49
+ flattened_size = conv_out.view(1, -1).shape[1]
50
+
51
+ # Build FC layers
52
+ fc_layers = []
53
+ in_dim = flattened_size
54
+ for dim in fc_dims:
55
+ fc_layers.append(nn.Linear(in_dim, dim))
56
+ fc_layers.append(nn.ReLU())
57
+ if dropout > 0:
58
+ fc_layers.append(nn.Dropout(dropout))
59
+ in_dim = dim
60
+
61
+ output_size = 1 if num_classes == 2 else num_classes
62
+ fc_layers.append(nn.Linear(in_dim, output_size))
63
+
64
+ self.fc = nn.Sequential(*fc_layers)
65
+
66
+ # Build gradcam hooks
67
+ self.gradcam_layer_idx = gradcam_layer_idx
68
+ self.gradcam_activations = None
69
+ self.gradcam_gradients = None
70
+ if not hasattr(self, "_hooks_registered"):
71
+ self._register_gradcam_hooks()
72
+ self._hooks_registered = True
73
+
74
+ def forward(self, x):
75
+ x = x.unsqueeze(1) # [B, 1, L]
76
+ x = self.conv(x)
77
+ x = x.view(x.size(0), -1)
78
+ return self.fc(x)
79
+
80
+ def _register_gradcam_hooks(self):
81
+ def forward_hook(module, input, output):
82
+ self.gradcam_activations = output.detach()
83
+
84
+ def backward_hook(module, grad_input, grad_output):
85
+ self.gradcam_gradients = grad_output[0].detach()
86
+
87
+ target_layer = list(self.conv.children())[self.gradcam_layer_idx]
88
+ target_layer.register_forward_hook(forward_hook)
89
+ target_layer.register_full_backward_hook(backward_hook)
90
+
91
+ def compute_gradcam(self, x, class_idx=None):
92
+ self.zero_grad()
93
+
94
+ x = x.detach().clone().requires_grad_().to(self.device)
95
+
96
+ was_training = self.training
97
+ self.eval() # disable dropout etc.
98
+
99
+ output = self.forward(x) # shape (B, C) or (B, 1)
100
+
101
+ if class_idx is None:
102
+ class_idx = output.argmax(dim=1)
103
+
104
+ if output.shape[1] == 1:
105
+ target = output.view(-1) # shape (B,)
106
+ else:
107
+ target = output[torch.arange(output.shape[0]), class_idx]
108
+
109
+ target.sum().backward(retain_graph=True)
110
+
111
+ # restore training mode
112
+ if was_training:
113
+ self.train()
114
+
115
+ # get activations and gradients (set these via forward hook!)
116
+ activations = self.gradcam_activations # (B, C, L)
117
+ gradients = self.gradcam_gradients # (B, C, L)
118
+
119
+ weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
120
+ cam = (weights * activations).sum(dim=1) # (B, L)
121
+
122
+ cam = torch.relu(cam)
123
+ cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
124
+
125
+ return cam
126
+
127
+ def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
128
+ self.to(device)
129
+ self.eval()
130
+ cams = []
131
+
132
+ for batch in dataloader:
133
+ x = batch[0].to(device)
134
+ cam_batch = self.compute_gradcam(x)
135
+ cams.append(cam_batch.cpu().numpy())
136
+
137
+ cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
138
+ adata.obsm[obsm_key] = cams