smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  import pandas as pd
2
2
 
3
+
3
4
  def annotate_split_column(adata, model, split_col="split"):
4
5
  """
5
6
  Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
@@ -8,7 +9,7 @@ def annotate_split_column(adata, model, split_col="split"):
8
9
  train_set = set(model.train_obs_names)
9
10
  val_set = set(model.val_obs_names)
10
11
  test_set = set(model.test_obs_names)
11
-
12
+
12
13
  # Create array for split labels
13
14
  split_labels = []
14
15
  for obs in adata.obs_names:
@@ -20,8 +21,10 @@ def annotate_split_column(adata, model, split_col="split"):
20
21
  split_labels.append("testing")
21
22
  else:
22
23
  split_labels.append("new")
23
-
24
+
24
25
  # Store in AnnData.obs
25
- adata.obs[split_col] = pd.Categorical(split_labels, categories=["training", "validation", "testing", "new"])
26
-
26
+ adata.obs[split_col] = pd.Categorical(
27
+ split_labels, categories=["training", "validation", "testing", "new"]
28
+ )
29
+
27
30
  print(f"Annotated {split_col} column with training/validation/testing/new status.")
@@ -1,17 +1,11 @@
1
- import torch
2
- import pandas as pd
3
1
  import numpy as np
4
- from pytorch_lightning import Trainer
2
+ import pandas as pd
3
+ import torch
4
+
5
5
  from .inference_utils import annotate_split_column
6
6
 
7
- def run_lightning_inference(
8
- adata,
9
- model,
10
- datamodule,
11
- trainer,
12
- prefix="model",
13
- devices=1
14
- ):
7
+
8
+ def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
15
9
  """
16
10
  Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
17
11
  """
@@ -57,7 +51,9 @@ def run_lightning_inference(
57
51
  full_prefix = f"{prefix}_{label_col}"
58
52
 
59
53
  adata.obs[f"{full_prefix}_pred"] = pred_class_idx
60
- adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
54
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
55
+ pred_class_labels, categories=class_labels
56
+ )
61
57
  adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
62
58
 
63
59
  for i, class_name in enumerate(class_labels):
@@ -65,4 +61,4 @@ def run_lightning_inference(
65
61
 
66
62
  adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
67
63
 
68
- print(f"Inference complete: stored under prefix '{full_prefix}'")
64
+ print(f"Inference complete: stored under prefix '{full_prefix}'")
@@ -1,14 +1,10 @@
1
- import pandas as pd
2
1
  import numpy as np
2
+ import pandas as pd
3
+
3
4
  from .inference_utils import annotate_split_column
4
5
 
5
6
 
6
- def run_sklearn_inference(
7
- adata,
8
- model,
9
- datamodule,
10
- prefix="model"
11
- ):
7
+ def run_sklearn_inference(adata, model, datamodule, prefix="model"):
12
8
  """
13
9
  Run inference on AnnData using SklearnModelWrapper.
14
10
  """
@@ -44,7 +40,9 @@ def run_sklearn_inference(
44
40
  full_prefix = f"{prefix}_{label_col}"
45
41
 
46
42
  adata.obs[f"{full_prefix}_pred"] = pred_class_idx
47
- adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
43
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
44
+ pred_class_labels, categories=class_labels
45
+ )
48
46
  adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
49
47
 
50
48
  for i, class_name in enumerate(class_labels):
@@ -3,16 +3,17 @@ from ..evaluation import PostInferenceModelEvaluator
3
3
  from .lightning_inference import run_lightning_inference
4
4
  from .sklearn_inference import run_sklearn_inference
5
5
 
6
+
6
7
  def sliding_window_inference(
7
- adata,
8
- trained_results,
9
- tensor_source='X',
8
+ adata,
9
+ trained_results,
10
+ tensor_source="X",
10
11
  tensor_key=None,
11
- label_col='activity_status',
12
+ label_col="activity_status",
12
13
  batch_size=64,
13
14
  cleanup=False,
14
- target_eval_freq=None,
15
- max_eval_positive=None
15
+ target_eval_freq=None,
16
+ max_eval_positive=None,
16
17
  ):
17
18
  """
18
19
  Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
@@ -24,11 +25,11 @@ def sliding_window_inference(
24
25
  for window_size, window_data in model_dict.items():
25
26
  for center_varname, run in window_data.items():
26
27
  print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
27
-
28
+
28
29
  # Extract window start from varname
29
30
  center_idx = adata.var_names.get_loc(center_varname)
30
31
  window_start = center_idx - window_size // 2
31
-
32
+
32
33
  # Build datamodule for window
33
34
  datamodule = AnnDataModule(
34
35
  adata,
@@ -38,31 +39,31 @@ def sliding_window_inference(
38
39
  batch_size=batch_size,
39
40
  window_start=window_start,
40
41
  window_size=window_size,
41
- inference_mode=True
42
+ inference_mode=True,
42
43
  )
43
44
  datamodule.setup()
44
45
 
45
46
  # Extract model + detect type
46
- model = run['model']
47
+ model = run["model"]
47
48
 
48
49
  # Lightning models
49
- if hasattr(run, 'trainer') or 'trainer' in run:
50
- trainer = run['trainer']
50
+ if hasattr(run, "trainer") or "trainer" in run:
51
+ trainer = run["trainer"]
51
52
  run_lightning_inference(
52
53
  adata,
53
54
  model=model,
54
55
  datamodule=datamodule,
55
56
  trainer=trainer,
56
- prefix=f"{model_name}_w{window_size}_c{center_varname}"
57
+ prefix=f"{model_name}_w{window_size}_c{center_varname}",
57
58
  )
58
-
59
+
59
60
  # Sklearn models
60
61
  else:
61
62
  run_sklearn_inference(
62
63
  adata,
63
64
  model=model,
64
65
  datamodule=datamodule,
65
- prefix=f"{model_name}_w{window_size}_c{center_varname}"
66
+ prefix=f"{model_name}_w{window_size}_c{center_varname}",
66
67
  )
67
68
 
68
69
  print("Inference complete across all models.")
@@ -77,27 +78,36 @@ def sliding_window_inference(
77
78
  prefix = f"{model_name}_w{window_size}_c{center_varname}"
78
79
  # Use full key for uniqueness
79
80
  key = prefix
80
- model_wrappers[key] = run['model']
81
+ model_wrappers[key] = run["model"]
81
82
 
82
83
  # Run evaluator
83
- evaluator = PostInferenceModelEvaluator(adata, model_wrappers, target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive)
84
+ evaluator = PostInferenceModelEvaluator(
85
+ adata,
86
+ model_wrappers,
87
+ target_eval_freq=target_eval_freq,
88
+ max_eval_positive=max_eval_positive,
89
+ )
84
90
  evaluator.evaluate_all()
85
91
 
86
92
  # Get results
87
93
  df = evaluator.to_dataframe()
88
94
 
89
- df[['model_name', 'window_size', 'center']] = df['model'].str.extract(r'(\w+)_w(\d+)_c(\d+)_activity_status')
95
+ df[["model_name", "window_size", "center"]] = df["model"].str.extract(
96
+ r"(\w+)_w(\d+)_c(\d+)_activity_status"
97
+ )
90
98
 
91
99
  # Cast window_size and center to integers for plotting
92
- df['window_size'] = df['window_size'].astype(int)
93
- df['center'] = df['center'].astype(int)
100
+ df["window_size"] = df["window_size"].astype(int)
101
+ df["center"] = df["center"].astype(int)
94
102
 
95
103
  ## Optional cleanup:
96
104
  if cleanup:
97
- prefixes = [f"{model_name}_w{window_size}_c{center_varname}"
98
- for model_name, model_dict in trained_results.items()
99
- for window_size, window_data in model_dict.items()
100
- for center_varname in window_data.keys()]
105
+ prefixes = [
106
+ f"{model_name}_w{window_size}_c{center_varname}"
107
+ for model_name, model_dict in trained_results.items()
108
+ for window_size, window_data in model_dict.items()
109
+ for center_varname in window_data.keys()
110
+ ]
101
111
 
102
112
  # Remove matching obs columns
103
113
  for prefix in prefixes:
@@ -111,4 +121,4 @@ def sliding_window_inference(
111
121
 
112
122
  print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
113
123
 
114
- return df
124
+ return df
@@ -1,9 +1,14 @@
1
1
  from .base import BaseTorchModel
2
- from .mlp import MLPClassifier
3
2
  from .cnn import CNNClassifier
4
- from .rnn import RNNClassifier
5
- from .transformer import BaseTransformer, TransformerClassifier, DANNTransformerClassifier, MaskedTransformerPretrainer
3
+ from .lightning_base import TorchClassifierWrapper
4
+ from .mlp import MLPClassifier
6
5
  from .positional import PositionalEncoding
6
+ from .rnn import RNNClassifier
7
+ from .sklearn_models import SklearnModelWrapper
8
+ from .transformer import (
9
+ BaseTransformer,
10
+ DANNTransformerClassifier,
11
+ MaskedTransformerPretrainer,
12
+ TransformerClassifier,
13
+ )
7
14
  from .wrappers import ScaledModel
8
- from .lightning_base import TorchClassifierWrapper
9
- from .sklearn_models import SklearnModelWrapper
@@ -1,17 +1,20 @@
1
+ import numpy as np
1
2
  import torch
2
3
  import torch.nn as nn
3
- import numpy as np
4
+
4
5
  from ..utils.device import detect_device
5
6
 
7
+
6
8
  class BaseTorchModel(nn.Module):
7
9
  """
8
10
  Minimal base class for torch models that:
9
11
  - Stores device and dropout regularization
10
12
  """
13
+
11
14
  def __init__(self, dropout_rate=0.0):
12
15
  super().__init__()
13
- self.device = detect_device() # detects available devices
14
- self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
16
+ self.device = detect_device() # detects available devices
17
+ self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
15
18
 
16
19
  def compute_saliency(
17
20
  self,
@@ -21,11 +24,11 @@ class BaseTorchModel(nn.Module):
21
24
  smoothgrad=False,
22
25
  smooth_samples=25,
23
26
  smooth_noise=0.1,
24
- signed=True
27
+ signed=True,
25
28
  ):
26
29
  """
27
30
  Compute vanilla saliency or SmoothGrad saliency.
28
-
31
+
29
32
  Arguments:
30
33
  ----------
31
34
  x : torch.Tensor
@@ -43,7 +46,7 @@ class BaseTorchModel(nn.Module):
43
46
  """
44
47
  self.eval()
45
48
  x = x.clone().detach().requires_grad_(True)
46
-
49
+
47
50
  if smoothgrad:
48
51
  saliency_accum = torch.zeros_like(x)
49
52
  for i in range(smooth_samples):
@@ -56,7 +59,7 @@ class BaseTorchModel(nn.Module):
56
59
  if logits.shape[1] == 1:
57
60
  scores = logits.squeeze(1)
58
61
  else:
59
- scores = logits[torch.arange(x.shape[0]), target_class]
62
+ scores = logits[torch.arange(x.shape[0]), target_class]
60
63
  scores.sum().backward()
61
64
  saliency_accum += x_noisy.grad.detach()
62
65
  saliency = saliency_accum / smooth_samples
@@ -69,17 +72,17 @@ class BaseTorchModel(nn.Module):
69
72
  scores = logits[torch.arange(x.shape[0]), target_class]
70
73
  scores.sum().backward()
71
74
  saliency = x.grad.detach()
72
-
75
+
73
76
  if not signed:
74
77
  saliency = saliency.abs()
75
-
78
+
76
79
  if reduction == "sum" and x.ndim == 3:
77
80
  return saliency.sum(dim=-1)
78
81
  elif reduction == "mean" and x.ndim == 3:
79
82
  return saliency.mean(dim=-1)
80
83
  else:
81
84
  return saliency
82
-
85
+
83
86
  def compute_gradient_x_input(self, x, target_class=None):
84
87
  """
85
88
  Computes gradient × input attribution.
@@ -118,22 +121,11 @@ class BaseTorchModel(nn.Module):
118
121
  baseline = torch.zeros_like(x)
119
122
 
120
123
  attributions, delta = ig.attribute(
121
- x,
122
- baselines=baseline,
123
- target=target_class,
124
- n_steps=steps,
125
- return_convergence_delta=True
124
+ x, baselines=baseline, target=target_class, n_steps=steps, return_convergence_delta=True
126
125
  )
127
126
  return attributions, delta
128
127
 
129
- def compute_deeplift(
130
- self,
131
- x,
132
- baseline=None,
133
- target_class=None,
134
- reduction="sum",
135
- signed=True
136
- ):
128
+ def compute_deeplift(self, x, baseline=None, target_class=None, reduction="sum", signed=True):
137
129
  """
138
130
  Compute DeepLIFT scores using captum.
139
131
 
@@ -158,21 +150,15 @@ class BaseTorchModel(nn.Module):
158
150
 
159
151
  if not signed:
160
152
  attr = attr.abs()
161
-
153
+
162
154
  if reduction == "sum" and x.ndim == 3:
163
155
  return attr.sum(dim=-1)
164
156
  elif reduction == "mean" and x.ndim == 3:
165
157
  return attr.mean(dim=-1)
166
158
  else:
167
159
  return attr
168
-
169
- def compute_occlusion(
170
- self,
171
- x,
172
- target_class=None,
173
- window_size=5,
174
- baseline=None
175
- ):
160
+
161
+ def compute_occlusion(self, x, target_class=None, window_size=5, baseline=None):
176
162
  """
177
163
  Computes per-sample occlusion attribution.
178
164
  Supports 2D [B, S] or 3D [B, S, D] inputs.
@@ -208,9 +194,7 @@ class BaseTorchModel(nn.Module):
208
194
  x_occluded[left:right, :] = baseline[left:right, :]
209
195
 
210
196
  x_tensor = torch.tensor(
211
- x_occluded,
212
- device=self.device,
213
- dtype=torch.float32
197
+ x_occluded, device=self.device, dtype=torch.float32
214
198
  ).unsqueeze(0)
215
199
 
216
200
  logits = self.forward(x_tensor)
@@ -235,7 +219,7 @@ class BaseTorchModel(nn.Module):
235
219
  device="cpu",
236
220
  target_class=None,
237
221
  normalize=True,
238
- signed=True
222
+ signed=True,
239
223
  ):
240
224
  """
241
225
  Apply a chosen attribution method to a dataloader and store results in adata.
@@ -252,7 +236,9 @@ class BaseTorchModel(nn.Module):
252
236
  attr = model.compute_saliency(x, target_class=target_class, signed=signed)
253
237
 
254
238
  elif method == "smoothgrad":
255
- attr = model.compute_saliency(x, smoothgrad=True, target_class=target_class, signed=signed)
239
+ attr = model.compute_saliency(
240
+ x, smoothgrad=True, target_class=target_class, signed=signed
241
+ )
256
242
 
257
243
  elif method == "IG":
258
244
  attributions, delta = model.compute_integrated_gradients(
@@ -261,15 +247,15 @@ class BaseTorchModel(nn.Module):
261
247
  attr = attributions
262
248
 
263
249
  elif method == "deeplift":
264
- attr = model.compute_deeplift(x, baseline=baseline, target_class=target_class, signed=signed)
250
+ attr = model.compute_deeplift(
251
+ x, baseline=baseline, target_class=target_class, signed=signed
252
+ )
265
253
 
266
254
  elif method == "gradxinput":
267
255
  attr = model.compute_gradient_x_input(x, target_class=target_class)
268
256
 
269
257
  elif method == "occlusion":
270
- attr = model.compute_occlusion(
271
- x, target_class=target_class, baseline=baseline
272
- )
258
+ attr = model.compute_occlusion(x, target_class=target_class, baseline=baseline)
273
259
 
274
260
  else:
275
261
  raise ValueError(f"Unknown method {method}")
@@ -292,4 +278,4 @@ class BaseTorchModel(nn.Module):
292
278
  return target_class
293
279
  if logits.shape[1] == 1:
294
280
  return (logits > 0).long().squeeze(1)
295
- return logits.argmax(dim=1)
281
+ return logits.argmax(dim=1)
@@ -1,7 +1,9 @@
1
+ import numpy as np
1
2
  import torch
2
3
  import torch.nn as nn
4
+
3
5
  from .base import BaseTorchModel
4
- import numpy as np
6
+
5
7
 
6
8
  class CNNClassifier(BaseTorchModel):
7
9
  def __init__(
@@ -15,7 +17,7 @@ class CNNClassifier(BaseTorchModel):
15
17
  use_pooling=False,
16
18
  dropout=0.2,
17
19
  gradcam_layer_idx=-1,
18
- **kwargs
20
+ **kwargs,
19
21
  ):
20
22
  super().__init__(**kwargs)
21
23
  self.name = "CNNClassifier"
@@ -30,7 +32,9 @@ class CNNClassifier(BaseTorchModel):
30
32
 
31
33
  # Build conv layers
32
34
  for out_channels, ksize in zip(conv_channels, kernel_sizes):
33
- layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
35
+ layers.append(
36
+ nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2)
37
+ )
34
38
  if use_batchnorm:
35
39
  layers.append(nn.BatchNorm1d(out_channels))
36
40
  layers.append(nn.ReLU())
@@ -76,7 +80,7 @@ class CNNClassifier(BaseTorchModel):
76
80
  x = self.conv(x)
77
81
  x = x.view(x.size(0), -1)
78
82
  return self.fc(x)
79
-
83
+
80
84
  def _register_gradcam_hooks(self):
81
85
  def forward_hook(module, input, output):
82
86
  self.gradcam_activations = output.detach()
@@ -97,15 +101,15 @@ class CNNClassifier(BaseTorchModel):
97
101
  self.eval() # disable dropout etc.
98
102
 
99
103
  output = self.forward(x) # shape (B, C) or (B, 1)
100
-
104
+
101
105
  if class_idx is None:
102
106
  class_idx = output.argmax(dim=1)
103
-
107
+
104
108
  if output.shape[1] == 1:
105
109
  target = output.view(-1) # shape (B,)
106
110
  else:
107
111
  target = output[torch.arange(output.shape[0]), class_idx]
108
-
112
+
109
113
  target.sum().backward(retain_graph=True)
110
114
 
111
115
  # restore training mode
@@ -114,16 +118,16 @@ class CNNClassifier(BaseTorchModel):
114
118
 
115
119
  # get activations and gradients (set these via forward hook!)
116
120
  activations = self.gradcam_activations # (B, C, L)
117
- gradients = self.gradcam_gradients # (B, C, L)
121
+ gradients = self.gradcam_gradients # (B, C, L)
118
122
 
119
123
  weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
120
- cam = (weights * activations).sum(dim=1) # (B, L)
124
+ cam = (weights * activations).sum(dim=1) # (B, L)
121
125
 
122
126
  cam = torch.relu(cam)
123
127
  cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
124
128
 
125
129
  return cam
126
-
130
+
127
131
  def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
128
132
  self.to(device)
129
133
  self.eval()
@@ -135,4 +139,4 @@ class CNNClassifier(BaseTorchModel):
135
139
  cams.append(cam_batch.cpu().numpy())
136
140
 
137
141
  cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
138
- adata.obsm[obsm_key] = cams
142
+ adata.obsm[obsm_key] = cams