smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,16 @@
1
- import torch
2
- import pytorch_lightning as pl
3
1
  import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pytorch_lightning as pl
4
+ import torch
4
5
  from sklearn.metrics import (
5
- roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
6
+ auc,
7
+ confusion_matrix,
8
+ f1_score,
9
+ precision_recall_curve,
10
+ roc_auc_score,
11
+ roc_curve,
6
12
  )
7
- import numpy as np
13
+
8
14
 
9
15
  class TorchClassifierWrapper(pl.LightningModule):
10
16
  """
@@ -16,25 +22,26 @@ class TorchClassifierWrapper(pl.LightningModule):
16
22
  - Can pass the index of the class label to use as the focus class when calculating precision/recall.
17
23
  - Contains a prediction step to run inference with.
18
24
  """
25
+
19
26
  def __init__(
20
27
  self,
21
28
  model: torch.nn.Module,
22
29
  label_col: str,
23
30
  num_classes: int,
24
- class_names: list=None,
31
+ class_names: list = None,
25
32
  optimizer_cls=torch.optim.AdamW,
26
33
  optimizer_kwargs=None,
27
34
  criterion_kwargs=None,
28
35
  lr: float = 1e-3,
29
36
  focus_class: int = 1, # used for binary or multiclass precision-recall
30
37
  class_weights=None,
31
- enforce_eval_balance: bool=False,
32
- target_eval_freq: float=0.3,
33
- max_eval_positive: int=None
38
+ enforce_eval_balance: bool = False,
39
+ target_eval_freq: float = 0.3,
40
+ max_eval_positive: int = None,
34
41
  ):
35
42
  super().__init__()
36
43
  self.model = model
37
- self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
44
+ self.save_hyperparameters(ignore=["model"]) # logs all except actual model instance
38
45
  self.optimizer_cls = optimizer_cls
39
46
  self.optimizer_kwargs = optimizer_kwargs or {"weight_decay": 1e-4}
40
47
  self.criterion = None
@@ -57,14 +64,17 @@ class TorchClassifierWrapper(pl.LightningModule):
57
64
  if torch.is_tensor(class_weights[self.focus_class]):
58
65
  self.criterion_kwargs["pos_weight"] = class_weights[self.focus_class]
59
66
  else:
60
- self.criterion_kwargs["pos_weight"] = torch.tensor(class_weights[self.focus_class], dtype=torch.float32, device=self.device)
67
+ self.criterion_kwargs["pos_weight"] = torch.tensor(
68
+ class_weights[self.focus_class], dtype=torch.float32, device=self.device
69
+ )
61
70
  else:
62
71
  # CrossEntropyLoss expects weight tensor of size C
63
72
  if torch.is_tensor(class_weights):
64
73
  self.criterion_kwargs["weight"] = class_weights
65
74
  else:
66
- self.criterion_kwargs["weight"] = torch.tensor(class_weights, dtype=torch.float32)
67
-
75
+ self.criterion_kwargs["weight"] = torch.tensor(
76
+ class_weights, dtype=torch.float32
77
+ )
68
78
 
69
79
  self._val_outputs = []
70
80
  self._test_outputs = []
@@ -78,12 +88,20 @@ class TorchClassifierWrapper(pl.LightningModule):
78
88
 
79
89
  def _init_criterion(self):
80
90
  if self.num_classes == 2:
81
- if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["pos_weight"]):
82
- self.criterion_kwargs["pos_weight"] = torch.tensor(self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device)
91
+ if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(
92
+ self.criterion_kwargs["pos_weight"]
93
+ ):
94
+ self.criterion_kwargs["pos_weight"] = torch.tensor(
95
+ self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device
96
+ )
83
97
  self.criterion = torch.nn.BCEWithLogitsLoss(**self.criterion_kwargs)
84
98
  else:
85
- if "weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["weight"]):
86
- self.criterion_kwargs["weight"] = torch.tensor(self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device)
99
+ if "weight" in self.criterion_kwargs and not torch.is_tensor(
100
+ self.criterion_kwargs["weight"]
101
+ ):
102
+ self.criterion_kwargs["weight"] = torch.tensor(
103
+ self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device
104
+ )
87
105
  self.criterion = torch.nn.CrossEntropyLoss(**self.criterion_kwargs)
88
106
 
89
107
  def _resolve_focus_class(self, focus_class):
@@ -93,11 +111,13 @@ class TorchClassifierWrapper(pl.LightningModule):
93
111
  if self.class_names is None:
94
112
  raise ValueError("class_names must be provided if focus_class is a string.")
95
113
  if focus_class not in self.class_names:
96
- raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
114
+ raise ValueError(
115
+ f"focus_class '{focus_class}' not found in class_names {self.class_names}."
116
+ )
97
117
  return self.class_names.index(focus_class)
98
118
  else:
99
119
  raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
100
-
120
+
101
121
  def set_training_indices(self, datamodule):
102
122
  """
103
123
  Store obs_names for train/val/test subsets used during training.
@@ -140,7 +160,7 @@ class TorchClassifierWrapper(pl.LightningModule):
140
160
  self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=False)
141
161
  self._val_outputs.append((logits.detach(), y.detach()))
142
162
  return loss
143
-
163
+
144
164
  def test_step(self, batch, batch_idx):
145
165
  """
146
166
  Test step for a batch through the Lightning Trainer.
@@ -189,7 +209,7 @@ class TorchClassifierWrapper(pl.LightningModule):
189
209
  return self.criterion(logits.view(-1, 1), y)
190
210
  else:
191
211
  return self.criterion(logits, y)
192
-
212
+
193
213
  def _get_probs(self, logits):
194
214
  """
195
215
  A helper function for getting class probabilities for binary vs multiclass classifications.
@@ -207,8 +227,10 @@ class TorchClassifierWrapper(pl.LightningModule):
207
227
  return (torch.sigmoid(logits.view(-1)) >= 0.5).long()
208
228
  else:
209
229
  return logits.argmax(dim=1)
210
-
211
- def _subsample_for_fixed_positive_frequency(self, y_true, probs, target_freq=0.3, max_positive=None):
230
+
231
+ def _subsample_for_fixed_positive_frequency(
232
+ self, y_true, probs, target_freq=0.3, max_positive=None
233
+ ):
212
234
  pos_idx = np.where(y_true == self.focus_class)[0]
213
235
  neg_idx = np.where(y_true != self.focus_class)[0]
214
236
 
@@ -216,16 +238,20 @@ class TorchClassifierWrapper(pl.LightningModule):
216
238
  max_positives_possible = len(pos_idx)
217
239
 
218
240
  # maximum achievable positive class frequency
219
- max_possible_freq = max_positives_possible / (max_positives_possible + max_negatives_possible)
241
+ max_possible_freq = max_positives_possible / (
242
+ max_positives_possible + max_negatives_possible
243
+ )
220
244
 
221
245
  if target_freq > max_possible_freq:
222
246
  target_freq = max_possible_freq # clip if you ask for impossible freq
223
247
 
224
248
  # now calculate positive count
225
- num_pos_target = min(int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible)
249
+ num_pos_target = min(
250
+ int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible
251
+ )
226
252
  num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
227
253
  num_neg_target = min(num_neg_target, max_negatives_possible)
228
-
254
+
229
255
  pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
230
256
  neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
231
257
 
@@ -235,7 +261,7 @@ class TorchClassifierWrapper(pl.LightningModule):
235
261
  actual_freq = len(pos_sampled) / len(sampled_idx)
236
262
 
237
263
  return sampled_idx
238
-
264
+
239
265
  def _log_classification_metrics(self, logits, targets, prefix="val"):
240
266
  """
241
267
  A helper function for logging validation and testing split model evaluations.
@@ -252,9 +278,12 @@ class TorchClassifierWrapper(pl.LightningModule):
252
278
  num_pos = binary_focus.sum()
253
279
 
254
280
  # Subsample if you want to enforce a fixed proportion of the positive class
255
- if prefix == 'test' and self.enforce_eval_balance:
281
+ if prefix == "test" and self.enforce_eval_balance:
256
282
  sampled_idx = self._subsample_for_fixed_positive_frequency(
257
- y_true, probs, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
283
+ y_true,
284
+ probs,
285
+ target_freq=self.target_eval_freq,
286
+ max_positive=self.max_eval_positive,
258
287
  )
259
288
  y_true = y_true[sampled_idx]
260
289
  probs = probs[sampled_idx]
@@ -289,7 +318,7 @@ class TorchClassifierWrapper(pl.LightningModule):
289
318
  cm = confusion_matrix(y_true, preds)
290
319
 
291
320
  # Save attributes for later plotting
292
- if prefix == 'test':
321
+ if prefix == "test":
293
322
  self.test_roc_curve = (fpr, tpr)
294
323
  self.test_pr_curve = (rc, pr)
295
324
  self.test_roc_auc = roc_auc
@@ -298,19 +327,21 @@ class TorchClassifierWrapper(pl.LightningModule):
298
327
  self.test_num_pos = num_pos
299
328
  self.test_acc = acc
300
329
  self.test_f1 = f1
301
- elif prefix == 'val':
330
+ elif prefix == "val":
302
331
  pass
303
332
 
304
333
  # Logging
305
- self.log_dict({
306
- f"{prefix}_acc": acc,
307
- f"{prefix}_f1": f1,
308
- f"{prefix}_auc": roc_auc,
309
- f"{prefix}_pr_auc": pr_auc,
310
- f"{prefix}_pr_auc_norm": pr_auc_norm,
311
- f"{prefix}_pos_freq": pos_freq,
312
- f"{prefix}_num_pos": num_pos
313
- })
334
+ self.log_dict(
335
+ {
336
+ f"{prefix}_acc": acc,
337
+ f"{prefix}_f1": f1,
338
+ f"{prefix}_auc": roc_auc,
339
+ f"{prefix}_pr_auc": pr_auc,
340
+ f"{prefix}_pr_auc_norm": pr_auc_norm,
341
+ f"{prefix}_pos_freq": pos_freq,
342
+ f"{prefix}_num_pos": num_pos,
343
+ }
344
+ )
314
345
  setattr(self, f"{prefix}_confusion_matrix", cm)
315
346
 
316
347
  def _plot_roc_pr_curves(self, logits, targets):
@@ -334,7 +365,7 @@ class TorchClassifierWrapper(pl.LightningModule):
334
365
  pos_freq = self.test_pos_freq
335
366
  plt.subplot(1, 2, 2)
336
367
  plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
337
- plt.axhline(pos_freq, linestyle='--', color="gray")
368
+ plt.axhline(pos_freq, linestyle="--", color="gray")
338
369
  plt.xlabel("Recall")
339
370
  plt.ylabel("Precision")
340
371
  plt.ylim(0, 1.05)
@@ -1,9 +1,18 @@
1
- import torch
2
1
  import torch.nn as nn
2
+
3
3
  from .base import BaseTorchModel
4
-
4
+
5
+
5
6
  class MLPClassifier(BaseTorchModel):
6
- def __init__(self, input_dim, num_classes=2, hidden_dims=[64, 64], dropout=0.2, use_batchnorm=True, **kwargs):
7
+ def __init__(
8
+ self,
9
+ input_dim,
10
+ num_classes=2,
11
+ hidden_dims=[64, 64],
12
+ dropout=0.2,
13
+ use_batchnorm=True,
14
+ **kwargs,
15
+ ):
7
16
  super().__init__(**kwargs)
8
17
  layers = []
9
18
  in_dim = input_dim
@@ -23,4 +32,4 @@ class MLPClassifier(BaseTorchModel):
23
32
  self.model = nn.Sequential(*layers)
24
33
 
25
34
  def forward(self, x):
26
- return self.model(x)
35
+ return self.model(x)
@@ -2,6 +2,7 @@ import numpy as np
2
2
  import torch
3
3
  import torch.nn as nn
4
4
 
5
+
5
6
  class PositionalEncoding(nn.Module):
6
7
  def __init__(self, d_model, max_len=5000):
7
8
  super().__init__()
@@ -14,5 +15,5 @@ class PositionalEncoding(nn.Module):
14
15
  self.register_buffer("pe", pe)
15
16
 
16
17
  def forward(self, x):
17
- x = x + self.pe[:, :x.size(1)]
18
- return x
18
+ x = x + self.pe[:, : x.size(1)]
19
+ return x
@@ -1,7 +1,8 @@
1
- import torch
2
1
  import torch.nn as nn
2
+
3
3
  from .base import BaseTorchModel
4
4
 
5
+
5
6
  class RNNClassifier(BaseTorchModel):
6
7
  def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
7
8
  super().__init__(**kwargs)
@@ -14,4 +15,4 @@ class RNNClassifier(BaseTorchModel):
14
15
  def forward(self, x):
15
16
  x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
16
17
  _, (h_n, _) = self.lstm(x) # h_n: [1, B, H]
17
- return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
18
+ return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
@@ -1,23 +1,30 @@
1
- import numpy as np
2
1
  import matplotlib.pyplot as plt
2
+ import numpy as np
3
3
  from sklearn.metrics import (
4
- roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
4
+ auc,
5
+ confusion_matrix,
6
+ f1_score,
7
+ precision_recall_curve,
8
+ roc_auc_score,
9
+ roc_curve,
5
10
  )
6
11
 
12
+
7
13
  class SklearnModelWrapper:
8
14
  """
9
15
  Unified sklearn wrapper matching TorchClassifierWrapper interface.
10
16
  """
17
+
11
18
  def __init__(
12
- self,
13
- model,
19
+ self,
20
+ model,
14
21
  label_col: str,
15
- num_classes: int,
16
- class_names=None,
17
- focus_class: int=1,
18
- enforce_eval_balance: bool=False,
19
- target_eval_freq: float=0.3,
20
- max_eval_positive=None
22
+ num_classes: int,
23
+ class_names=None,
24
+ focus_class: int = 1,
25
+ enforce_eval_balance: bool = False,
26
+ target_eval_freq: float = 0.3,
27
+ max_eval_positive=None,
21
28
  ):
22
29
  self.model = model
23
30
  self.label_col = label_col
@@ -37,7 +44,9 @@ class SklearnModelWrapper:
37
44
  if self.class_names is None:
38
45
  raise ValueError("class_names must be provided if focus_class is a string.")
39
46
  if focus_class not in self.class_names:
40
- raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
47
+ raise ValueError(
48
+ f"focus_class '{focus_class}' not found in class_names {self.class_names}."
49
+ )
41
50
  return self.class_names.index(focus_class)
42
51
  else:
43
52
  raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
@@ -130,7 +139,7 @@ class SklearnModelWrapper:
130
139
  f"{prefix}_pr_auc": pr_auc,
131
140
  f"{prefix}_pr_auc_norm": pr_auc_norm,
132
141
  f"{prefix}_pos_freq": pos_freq,
133
- f"{prefix}_num_pos": num_pos
142
+ f"{prefix}_num_pos": num_pos,
134
143
  }
135
144
 
136
145
  return self.metrics
@@ -166,7 +175,10 @@ class SklearnModelWrapper:
166
175
 
167
176
  def fit_from_datamodule(self, datamodule):
168
177
  datamodule.setup()
169
- X_tensor, y_tensor = datamodule.train_set.dataset.X_tensor, datamodule.train_set.dataset.y_tensor
178
+ X_tensor, y_tensor = (
179
+ datamodule.train_set.dataset.X_tensor,
180
+ datamodule.train_set.dataset.y_tensor,
181
+ )
170
182
  indices = datamodule.train_set.indices
171
183
  X_train = X_tensor[indices].numpy()
172
184
  y_train = y_tensor[indices].numpy()
@@ -190,11 +202,11 @@ class SklearnModelWrapper:
190
202
  y_eval = y_tensor[indices].numpy()
191
203
 
192
204
  return self.evaluate(X_eval, y_eval, prefix=split)
193
-
205
+
194
206
  def compute_shap(self, X, background=None, nsamples=100, target_class=None):
195
207
  """
196
208
  Compute SHAP values on input X, optionally for a specified target class.
197
-
209
+
198
210
  Parameters
199
211
  ----------
200
212
  X : array-like
@@ -225,7 +237,7 @@ class SklearnModelWrapper:
225
237
  shap_values = explainer.shap_values(X)
226
238
  else:
227
239
  shap_values = explainer.shap_values(X, nsamples=nsamples)
228
-
240
+
229
241
  if isinstance(shap_values, np.ndarray):
230
242
  if shap_values.ndim == 3:
231
243
  if isinstance(target_class, int):
@@ -234,10 +246,7 @@ class SklearnModelWrapper:
234
246
  # target_class is per-sample
235
247
  if np.any(target_class >= shap_values.shape[2]):
236
248
  raise ValueError(f"target_class values exceed {shap_values.shape[2]}")
237
- selected = np.array([
238
- shap_values[i, :, c]
239
- for i, c in enumerate(target_class)
240
- ])
249
+ selected = np.array([shap_values[i, :, c] for i, c in enumerate(target_class)])
241
250
  return selected
242
251
  else:
243
252
  # fallback to class 0
@@ -246,7 +255,15 @@ class SklearnModelWrapper:
246
255
  # 2D shape (samples, features), no class dimension
247
256
  return shap_values
248
257
 
249
- def apply_shap_to_adata(self, dataloader, adata, background=None, adata_key="shap_values", target_class=None, normalize=True):
258
+ def apply_shap_to_adata(
259
+ self,
260
+ dataloader,
261
+ adata,
262
+ background=None,
263
+ adata_key="shap_values",
264
+ target_class=None,
265
+ normalize=True,
266
+ ):
250
267
  """
251
268
  Compute SHAP from a DataLoader and store in AnnData if provided.
252
269
  """
@@ -270,4 +287,4 @@ class SklearnModelWrapper:
270
287
  row_max[row_max == 0] = 1 # avoid divide by zero
271
288
  normalized = arr / row_max
272
289
 
273
- adata.obsm[f"{adata_key}_normalized"] = normalized
290
+ adata.obsm[f"{adata_key}_normalized"] = normalized