smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,11 @@
1
+ import numpy as np
1
2
  import torch
2
3
  import torch.nn as nn
3
- from .base import BaseTorchModel
4
- from .positional import PositionalEncoding
4
+
5
5
  from ..utils.grl import grad_reverse
6
- import numpy as np
6
+ from .base import BaseTorchModel
7
+ from .positional import PositionalEncoding
8
+
7
9
 
8
10
  class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
9
11
  def __init__(self, *args, **kwargs):
@@ -11,12 +13,14 @@ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
11
13
 
12
14
  def forward(self, src, src_mask=None, is_causal=False, src_key_padding_mask=None):
13
15
  self_attn_output, attn_weights = self.self_attn(
14
- src, src, src,
16
+ src,
17
+ src,
18
+ src,
15
19
  attn_mask=src_mask,
16
20
  key_padding_mask=src_key_padding_mask,
17
21
  need_weights=True,
18
22
  average_attn_weights=False, # preserve [B, num_heads, S, S]
19
- is_causal=is_causal
23
+ is_causal=is_causal,
20
24
  )
21
25
  src = src + self.dropout1(self_attn_output)
22
26
  src = self.norm1(src)
@@ -27,18 +31,21 @@ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
27
31
  # Save attention weights to module
28
32
  self.attn_weights = attn_weights # Save to layer
29
33
  return src
30
-
34
+
35
+
31
36
  class BaseTransformer(BaseTorchModel):
32
- def __init__(self,
33
- input_dim=1,
34
- model_dim=64,
35
- num_heads=4,
36
- num_layers=2,
37
- dropout=0.2,
38
- seq_len=None,
39
- use_learnable_pos=False,
40
- use_cls_token=True,
41
- **kwargs):
37
+ def __init__(
38
+ self,
39
+ input_dim=1,
40
+ model_dim=64,
41
+ num_heads=4,
42
+ num_layers=2,
43
+ dropout=0.2,
44
+ seq_len=None,
45
+ use_learnable_pos=False,
46
+ use_cls_token=True,
47
+ **kwargs,
48
+ ):
42
49
  super().__init__(**kwargs)
43
50
  # Input FC layer to map D_input to D_model
44
51
  self.model_dim = model_dim
@@ -52,7 +59,9 @@ class BaseTransformer(BaseTorchModel):
52
59
 
53
60
  if use_learnable_pos:
54
61
  assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
55
- self.pos_embed = nn.Parameter(torch.randn(seq_len + (1 if use_cls_token else 0), model_dim))
62
+ self.pos_embed = nn.Parameter(
63
+ torch.randn(seq_len + (1 if use_cls_token else 0), model_dim)
64
+ )
56
65
  self.pos_encoder = None
57
66
  else:
58
67
  self.pos_encoder = PositionalEncoding(model_dim)
@@ -62,7 +71,13 @@ class BaseTransformer(BaseTorchModel):
62
71
  self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim)) # (1, 1, D)
63
72
 
64
73
  # Specify the transformer encoder structure
65
- encoder_layer = TransformerEncoderLayerWithAttn(d_model=model_dim, nhead=num_heads, batch_first=True, dim_feedforward=self.ff_dim, dropout=self.dropout)
74
+ encoder_layer = TransformerEncoderLayerWithAttn(
75
+ d_model=model_dim,
76
+ nhead=num_heads,
77
+ batch_first=True,
78
+ dim_feedforward=self.ff_dim,
79
+ dropout=self.dropout,
80
+ )
66
81
  # Stack the transformer encoder layers
67
82
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
68
83
 
@@ -95,7 +110,7 @@ class BaseTransformer(BaseTorchModel):
95
110
  x = torch.cat([cls, x], dim=1) # (B, S+1, D)
96
111
 
97
112
  if self.pos_embed is not None:
98
- x = x + self.pos_embed.unsqueeze(0)[:, :x.shape[1], :]
113
+ x = x + self.pos_embed.unsqueeze(0)[:, : x.shape[1], :]
99
114
  elif self.pos_encoder is not None:
100
115
  x = self.pos_encoder(x)
101
116
 
@@ -106,8 +121,8 @@ class BaseTransformer(BaseTorchModel):
106
121
 
107
122
  encoded = self.transformer(x)
108
123
  return encoded
109
-
110
- def compute_attn_grad(self, reduction='mean'):
124
+
125
+ def compute_attn_grad(self, reduction="mean"):
111
126
  """
112
127
  Computes attention × gradient scores across layers.
113
128
  Returns: [B, S] tensor of importance scores
@@ -116,19 +131,19 @@ class BaseTransformer(BaseTorchModel):
116
131
  for attn, grad in zip(self.attn_weights, self.attn_grads):
117
132
  # attn: [B, H, S, S]
118
133
  # grad: [B, S, D]
119
- attn = attn.mean(dim=1) # [B, S, S]
120
- grad_norm = grad.norm(dim=-1) # [B, S]
134
+ attn = attn.mean(dim=1) # [B, S, S]
135
+ grad_norm = grad.norm(dim=-1) # [B, S]
121
136
  attn_grad_score = (attn * grad_norm.unsqueeze(1)).sum(dim=-1) # [B, S]
122
137
  scores.append(attn_grad_score)
123
138
 
124
139
  # Combine across layers
125
140
  stacked = torch.stack(scores, dim=0) # [L, B, S]
126
141
  if reduction == "mean":
127
- return stacked.mean(dim=0) # [B, S]
142
+ return stacked.mean(dim=0) # [B, S]
128
143
  elif reduction == "sum":
129
- return stacked.sum(dim=0) # [B, S]
144
+ return stacked.sum(dim=0) # [B, S]
130
145
  else:
131
- return stacked # [L, B, S]
146
+ return stacked # [L, B, S]
132
147
 
133
148
  def compute_rollout(self):
134
149
  """
@@ -143,9 +158,9 @@ class BaseTransformer(BaseTorchModel):
143
158
  attn_heads = attn_heads + torch.eye(S, device=device).unsqueeze(0) # add residual
144
159
  attn_heads = attn_heads / attn_heads.sum(dim=-1, keepdim=True).clamp(min=1e-6)
145
160
  rollout = torch.bmm(attn_heads, rollout) # [B, S, S]
146
-
161
+
147
162
  return rollout # [B, S, S]
148
-
163
+
149
164
  def reset_attn_buffers(self):
150
165
  self.attn_weights = []
151
166
  self.attn_grads = []
@@ -158,11 +173,15 @@ class BaseTransformer(BaseTorchModel):
158
173
  if head_idx is not None:
159
174
  attn = attn[:, head_idx] # [B, S, S]
160
175
  return attn
161
-
162
- def apply_attn_interpretations_to_adata(self, dataloader, adata,
163
- obsm_key_grad="attn_grad",
164
- obsm_key_rollout="attn_rollout",
165
- device="cpu"):
176
+
177
+ def apply_attn_interpretations_to_adata(
178
+ self,
179
+ dataloader,
180
+ adata,
181
+ obsm_key_grad="attn_grad",
182
+ obsm_key_rollout="attn_rollout",
183
+ device="cpu",
184
+ ):
166
185
  self.to(device)
167
186
  self.eval()
168
187
  grad_maps = []
@@ -193,12 +212,10 @@ class BaseTransformer(BaseTorchModel):
193
212
  # add per-row normalized version
194
213
  grad_normed = grad_concat / (np.max(grad_concat, axis=1, keepdims=True) + 1e-8)
195
214
  adata.obsm[f"{obsm_key_grad}_normalized"] = grad_normed
196
-
215
+
216
+
197
217
  class TransformerClassifier(BaseTransformer):
198
- def __init__(self,
199
- input_dim,
200
- num_classes,
201
- **kwargs):
218
+ def __init__(self, input_dim, num_classes, **kwargs):
202
219
  super().__init__(input_dim, **kwargs)
203
220
  # Classification head
204
221
  output_size = 1 if num_classes == 2 else num_classes
@@ -215,7 +232,7 @@ class TransformerClassifier(BaseTransformer):
215
232
  x = x.unsqueeze(0).unsqueeze(-1) # just in case (S,) → (1, S, 1)
216
233
  else:
217
234
  pass
218
- encoded = self.encode(x) # -> (B, S, D_model)
235
+ encoded = self.encode(x) # -> (B, S, D_model)
219
236
  if self.use_cls_token:
220
237
  pooled = encoded[:, 0] # (B, D)
221
238
  else:
@@ -223,14 +240,13 @@ class TransformerClassifier(BaseTransformer):
223
240
 
224
241
  out = self.cls_head(pooled) # (B, C)
225
242
  return out
226
-
243
+
244
+
227
245
  class DANNTransformerClassifier(TransformerClassifier):
228
246
  def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
229
247
  super().__init__(input_dim, model_dim, num_classes, **kwargs)
230
248
  self.domain_classifier = nn.Sequential(
231
- nn.Linear(model_dim, 128),
232
- nn.ReLU(),
233
- nn.Linear(128, n_domains)
249
+ nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
234
250
  )
235
251
 
236
252
  def forward(self, x, alpha=1.0):
@@ -242,6 +258,7 @@ class DANNTransformerClassifier(TransformerClassifier):
242
258
 
243
259
  return class_logits, domain_logits
244
260
 
261
+
245
262
  class MaskedTransformerPretrainer(BaseTransformer):
246
263
  def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
247
264
  super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
@@ -254,12 +271,13 @@ class MaskedTransformerPretrainer(BaseTransformer):
254
271
  """
255
272
  if x.dim() == 2:
256
273
  x = x.unsqueeze(-1)
257
- encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
258
- return self.decoder(encoded) # -> (B, D_input)
259
-
274
+ encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
275
+ return self.decoder(encoded) # -> (B, D_input)
276
+
277
+
260
278
  class DANNTransformer(BaseTransformer):
261
- """
262
- """
279
+ """ """
280
+
263
281
  def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
264
282
  super().__init__(
265
283
  input_dim=1, # 1D scalar input per token
@@ -267,7 +285,7 @@ class DANNTransformer(BaseTransformer):
267
285
  num_heads=n_heads,
268
286
  num_layers=n_layers,
269
287
  seq_len=seq_len,
270
- use_learnable_pos=True # enables learnable pos_embed in base
288
+ use_learnable_pos=True, # enables learnable pos_embed in base
271
289
  )
272
290
 
273
291
  # Reconstruction head
@@ -275,9 +293,7 @@ class DANNTransformer(BaseTransformer):
275
293
 
276
294
  # Domain classification head
277
295
  self.domain_classifier = nn.Sequential(
278
- nn.Linear(model_dim, 128),
279
- nn.ReLU(),
280
- nn.Linear(128, n_domains)
296
+ nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
281
297
  )
282
298
 
283
299
  def forward(self, x, alpha=1.0):
@@ -300,4 +316,3 @@ class DANNTransformer(BaseTransformer):
300
316
  domain_logits = self.domain_classifier(rev) # (B, n_batches)
301
317
 
302
318
  return recon, domain_logits
303
-
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
+
4
5
  class ScaledModel(nn.Module):
5
6
  def __init__(self, model, mean, std):
6
7
  super().__init__()
@@ -17,4 +18,4 @@ class ScaledModel(nn.Module):
17
18
  x = (x - mean[None, None, :]) / std[None, None, :]
18
19
  else:
19
20
  raise ValueError(f"Unsupported input shape {x.shape}")
20
- return self.model(x)
21
+ return self.model(x)
@@ -1,2 +1,2 @@
1
- from .train_lightning_model import train_lightning_model, run_sliding_window_lightning_training
2
- from .train_sklearn_model import train_sklearn_model, run_sliding_window_sklearn_training
1
+ from .train_lightning_model import run_sliding_window_lightning_training, train_lightning_model
2
+ from .train_sklearn_model import run_sliding_window_sklearn_training, train_sklearn_model
@@ -1,9 +1,11 @@
1
1
  import torch
2
2
  from pytorch_lightning import Trainer
3
3
  from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
4
+
4
5
  from ..data import AnnDataModule
5
6
  from ..models import TorchClassifierWrapper
6
7
 
8
+
7
9
  def train_lightning_model(
8
10
  model,
9
11
  datamodule,
@@ -12,7 +14,7 @@ def train_lightning_model(
12
14
  monitor_metric="val_loss",
13
15
  checkpoint_path=None,
14
16
  evaluate_test=True,
15
- devices=1
17
+ devices=1,
16
18
  ):
17
19
  """
18
20
  Takes a PyTorch Lightning Model and a Lightning DataLoader module to define a Lightning Trainer.
@@ -39,13 +41,15 @@ def train_lightning_model(
39
41
  EarlyStopping(monitor=monitor_metric, patience=patience, mode="min"),
40
42
  ]
41
43
  if checkpoint_path:
42
- callbacks.append(ModelCheckpoint(
43
- dirpath=checkpoint_path,
44
- filename="{epoch}-{val_loss:.4f}",
45
- monitor=monitor_metric,
46
- save_top_k=1,
47
- mode="min",
48
- ))
44
+ callbacks.append(
45
+ ModelCheckpoint(
46
+ dirpath=checkpoint_path,
47
+ filename="{epoch}-{val_loss:.4f}",
48
+ monitor=monitor_metric,
49
+ save_top_k=1,
50
+ mode="min",
51
+ )
52
+ )
49
53
 
50
54
  # Trainer setup
51
55
  trainer = Trainer(
@@ -54,7 +58,7 @@ def train_lightning_model(
54
58
  accelerator=accelerator,
55
59
  devices=devices,
56
60
  log_every_n_steps=10,
57
- enable_progress_bar=False
61
+ enable_progress_bar=False,
58
62
  )
59
63
 
60
64
  # Fit model with trainer
@@ -63,7 +67,7 @@ def train_lightning_model(
63
67
  # Test model (if applicable)
64
68
  if evaluate_test and hasattr(datamodule, "test_dataloader"):
65
69
  trainer.test(model, datamodule=datamodule)
66
-
70
+
67
71
  # Return best checkpoint path
68
72
  best_ckpt = None
69
73
  for cb in callbacks:
@@ -72,6 +76,7 @@ def train_lightning_model(
72
76
 
73
77
  return trainer, best_ckpt
74
78
 
79
+
75
80
  def run_sliding_window_lightning_training(
76
81
  adata,
77
82
  tensor_source,
@@ -86,13 +91,13 @@ def run_sliding_window_lightning_training(
86
91
  stride,
87
92
  max_epochs=30,
88
93
  patience=5,
89
- enforce_eval_balance: bool=False,
90
- target_eval_freq: float=0.3,
91
- max_eval_positive: int=None
94
+ enforce_eval_balance: bool = False,
95
+ target_eval_freq: float = 0.3,
96
+ max_eval_positive: int = None,
92
97
  ):
93
98
  input_len = adata.shape[1]
94
99
  results = {}
95
-
100
+
96
101
  for start in range(0, input_len - window_size + 1, stride):
97
102
  center_idx = start + window_size // 2
98
103
  center_varname = adata.var_names[center_idx]
@@ -106,18 +111,22 @@ def run_sliding_window_lightning_training(
106
111
  label_col=label_col,
107
112
  batch_size=64,
108
113
  window_start=start,
109
- window_size=window_size
114
+ window_size=window_size,
110
115
  )
111
116
  datamodule.setup()
112
117
 
113
118
  # Build model for this window
114
119
  model = model_class(window_size, num_classes)
115
120
  wrapper = TorchClassifierWrapper(
116
- model, label_col=label_col, num_classes=num_classes,
121
+ model,
122
+ label_col=label_col,
123
+ num_classes=num_classes,
117
124
  class_names=class_names,
118
125
  class_weights=class_weights,
119
- focus_class=focus_class, enforce_eval_balance=enforce_eval_balance,
120
- target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive
126
+ focus_class=focus_class,
127
+ enforce_eval_balance=enforce_eval_balance,
128
+ target_eval_freq=target_eval_freq,
129
+ max_eval_positive=max_eval_positive,
121
130
  )
122
131
 
123
132
  # Train model
@@ -129,7 +138,7 @@ def run_sliding_window_lightning_training(
129
138
  "model": wrapper,
130
139
  "trainer": trainer,
131
140
  "checkpoint": ckpt,
132
- "metrics": trainer.callback_metrics
141
+ "metrics": trainer.callback_metrics,
133
142
  }
134
-
143
+
135
144
  return results
@@ -1,16 +1,12 @@
1
1
  from ..data import AnnDataModule
2
2
  from ..models import SklearnModelWrapper
3
3
 
4
- def train_sklearn_model(
5
- model_wrapper,
6
- datamodule,
7
- evaluate_test=True,
8
- evaluate_val=False
9
- ):
4
+
5
+ def train_sklearn_model(model_wrapper, datamodule, evaluate_test=True, evaluate_val=False):
10
6
  """
11
7
  Fits a SklearnModelWrapper on the train split from datamodule.
12
8
  Evaluates on test and/or val set.
13
-
9
+
14
10
  Parameters:
15
11
  model_wrapper: SklearnModelWrapper instance
16
12
  datamodule: AnnDataModule instance (with setup() method)
@@ -39,6 +35,7 @@ def train_sklearn_model(
39
35
 
40
36
  return metrics
41
37
 
38
+
42
39
  def run_sliding_window_sklearn_training(
43
40
  adata,
44
41
  tensor_source,
@@ -58,7 +55,7 @@ def run_sliding_window_sklearn_training(
58
55
  enforce_eval_balance=False,
59
56
  target_eval_freq=0.3,
60
57
  max_eval_positive=None,
61
- **model_kwargs
58
+ **model_kwargs,
62
59
  ):
63
60
  """
64
61
  Sliding window training for sklearn models using AnnData.
@@ -86,29 +83,26 @@ def run_sliding_window_sklearn_training(
86
83
  train_frac=train_frac,
87
84
  val_frac=val_frac,
88
85
  test_frac=test_frac,
89
- random_seed=random_seed
86
+ random_seed=random_seed,
90
87
  )
91
88
  datamodule.setup()
92
89
 
93
90
  # Build model wrapper
94
91
  sklearn_model = model_class(**model_kwargs)
95
92
  wrapper = SklearnModelWrapper(
96
- sklearn_model,
93
+ sklearn_model,
97
94
  num_classes=num_classes,
98
95
  label_col=label_col,
99
96
  class_names=class_names,
100
97
  focus_class=focus_class,
101
98
  enforce_eval_balance=enforce_eval_balance,
102
99
  target_eval_freq=target_eval_freq,
103
- max_eval_positive=max_eval_positive
100
+ max_eval_positive=max_eval_positive,
104
101
  )
105
102
 
106
103
  # Fit and evaluate
107
104
  metrics = train_sklearn_model(wrapper, datamodule, evaluate_test=True, evaluate_val=False)
108
105
 
109
- results[center_varname] = {
110
- "model": wrapper,
111
- "metrics": metrics
112
- }
106
+ results[center_varname] = {"model": wrapper, "metrics": metrics}
113
107
 
114
108
  return results
@@ -1,2 +1,2 @@
1
1
  from .device import detect_device
2
- from .grl import GradReverse
2
+ from .grl import GradReverse
@@ -1,10 +1,13 @@
1
1
  import torch
2
2
 
3
+
3
4
  def detect_device():
4
5
  device = (
5
- torch.device('cuda') if torch.cuda.is_available() else
6
- torch.device('mps') if torch.backends.mps.is_available() else
7
- torch.device('cpu')
6
+ torch.device("cuda")
7
+ if torch.cuda.is_available()
8
+ else torch.device("mps")
9
+ if torch.backends.mps.is_available()
10
+ else torch.device("cpu")
8
11
  )
9
12
  print(f"Detected device: {device}")
10
- return device
13
+ return device
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
 
3
+
3
4
  class GradReverse(torch.autograd.Function):
4
5
  @staticmethod
5
6
  def forward(ctx, x, alpha):
@@ -10,5 +11,6 @@ class GradReverse(torch.autograd.Function):
10
11
  def backward(ctx, grad_output):
11
12
  return -ctx.alpha * grad_output, None
12
13
 
14
+
13
15
  def grad_reverse(x, alpha=1.0):
14
- return GradReverse.apply(x, alpha)
16
+ return GradReverse.apply(x, alpha)