smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,22 +1,31 @@
1
- import torch
2
- import torch.nn as nn
3
- from .base import BaseTorchModel
4
- from .positional import PositionalEncoding
5
- from ..utils.grl import grad_reverse
1
+ from __future__ import annotations
2
+
6
3
  import numpy as np
7
4
 
5
+ from smftools.optional_imports import require
6
+
7
+ from ..utils.grl import grad_reverse
8
+ from .base import BaseTorchModel
9
+ from .positional import PositionalEncoding
10
+
11
+ torch = require("torch", extra="ml-base", purpose="Transformer models")
12
+ nn = torch.nn
13
+
14
+
8
15
  class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
9
16
  def __init__(self, *args, **kwargs):
10
17
  super().__init__(*args, **kwargs)
11
18
 
12
19
  def forward(self, src, src_mask=None, is_causal=False, src_key_padding_mask=None):
13
20
  self_attn_output, attn_weights = self.self_attn(
14
- src, src, src,
21
+ src,
22
+ src,
23
+ src,
15
24
  attn_mask=src_mask,
16
25
  key_padding_mask=src_key_padding_mask,
17
26
  need_weights=True,
18
27
  average_attn_weights=False, # preserve [B, num_heads, S, S]
19
- is_causal=is_causal
28
+ is_causal=is_causal,
20
29
  )
21
30
  src = src + self.dropout1(self_attn_output)
22
31
  src = self.norm1(src)
@@ -27,18 +36,21 @@ class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
27
36
  # Save attention weights to module
28
37
  self.attn_weights = attn_weights # Save to layer
29
38
  return src
30
-
39
+
40
+
31
41
  class BaseTransformer(BaseTorchModel):
32
- def __init__(self,
33
- input_dim=1,
34
- model_dim=64,
35
- num_heads=4,
36
- num_layers=2,
37
- dropout=0.2,
38
- seq_len=None,
39
- use_learnable_pos=False,
40
- use_cls_token=True,
41
- **kwargs):
42
+ def __init__(
43
+ self,
44
+ input_dim=1,
45
+ model_dim=64,
46
+ num_heads=4,
47
+ num_layers=2,
48
+ dropout=0.2,
49
+ seq_len=None,
50
+ use_learnable_pos=False,
51
+ use_cls_token=True,
52
+ **kwargs,
53
+ ):
42
54
  super().__init__(**kwargs)
43
55
  # Input FC layer to map D_input to D_model
44
56
  self.model_dim = model_dim
@@ -52,7 +64,9 @@ class BaseTransformer(BaseTorchModel):
52
64
 
53
65
  if use_learnable_pos:
54
66
  assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
55
- self.pos_embed = nn.Parameter(torch.randn(seq_len + (1 if use_cls_token else 0), model_dim))
67
+ self.pos_embed = nn.Parameter(
68
+ torch.randn(seq_len + (1 if use_cls_token else 0), model_dim)
69
+ )
56
70
  self.pos_encoder = None
57
71
  else:
58
72
  self.pos_encoder = PositionalEncoding(model_dim)
@@ -62,7 +76,13 @@ class BaseTransformer(BaseTorchModel):
62
76
  self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim)) # (1, 1, D)
63
77
 
64
78
  # Specify the transformer encoder structure
65
- encoder_layer = TransformerEncoderLayerWithAttn(d_model=model_dim, nhead=num_heads, batch_first=True, dim_feedforward=self.ff_dim, dropout=self.dropout)
79
+ encoder_layer = TransformerEncoderLayerWithAttn(
80
+ d_model=model_dim,
81
+ nhead=num_heads,
82
+ batch_first=True,
83
+ dim_feedforward=self.ff_dim,
84
+ dropout=self.dropout,
85
+ )
66
86
  # Stack the transformer encoder layers
67
87
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
68
88
 
@@ -95,7 +115,7 @@ class BaseTransformer(BaseTorchModel):
95
115
  x = torch.cat([cls, x], dim=1) # (B, S+1, D)
96
116
 
97
117
  if self.pos_embed is not None:
98
- x = x + self.pos_embed.unsqueeze(0)[:, :x.shape[1], :]
118
+ x = x + self.pos_embed.unsqueeze(0)[:, : x.shape[1], :]
99
119
  elif self.pos_encoder is not None:
100
120
  x = self.pos_encoder(x)
101
121
 
@@ -106,8 +126,8 @@ class BaseTransformer(BaseTorchModel):
106
126
 
107
127
  encoded = self.transformer(x)
108
128
  return encoded
109
-
110
- def compute_attn_grad(self, reduction='mean'):
129
+
130
+ def compute_attn_grad(self, reduction="mean"):
111
131
  """
112
132
  Computes attention × gradient scores across layers.
113
133
  Returns: [B, S] tensor of importance scores
@@ -116,19 +136,19 @@ class BaseTransformer(BaseTorchModel):
116
136
  for attn, grad in zip(self.attn_weights, self.attn_grads):
117
137
  # attn: [B, H, S, S]
118
138
  # grad: [B, S, D]
119
- attn = attn.mean(dim=1) # [B, S, S]
120
- grad_norm = grad.norm(dim=-1) # [B, S]
139
+ attn = attn.mean(dim=1) # [B, S, S]
140
+ grad_norm = grad.norm(dim=-1) # [B, S]
121
141
  attn_grad_score = (attn * grad_norm.unsqueeze(1)).sum(dim=-1) # [B, S]
122
142
  scores.append(attn_grad_score)
123
143
 
124
144
  # Combine across layers
125
145
  stacked = torch.stack(scores, dim=0) # [L, B, S]
126
146
  if reduction == "mean":
127
- return stacked.mean(dim=0) # [B, S]
147
+ return stacked.mean(dim=0) # [B, S]
128
148
  elif reduction == "sum":
129
- return stacked.sum(dim=0) # [B, S]
149
+ return stacked.sum(dim=0) # [B, S]
130
150
  else:
131
- return stacked # [L, B, S]
151
+ return stacked # [L, B, S]
132
152
 
133
153
  def compute_rollout(self):
134
154
  """
@@ -143,9 +163,9 @@ class BaseTransformer(BaseTorchModel):
143
163
  attn_heads = attn_heads + torch.eye(S, device=device).unsqueeze(0) # add residual
144
164
  attn_heads = attn_heads / attn_heads.sum(dim=-1, keepdim=True).clamp(min=1e-6)
145
165
  rollout = torch.bmm(attn_heads, rollout) # [B, S, S]
146
-
166
+
147
167
  return rollout # [B, S, S]
148
-
168
+
149
169
  def reset_attn_buffers(self):
150
170
  self.attn_weights = []
151
171
  self.attn_grads = []
@@ -158,11 +178,15 @@ class BaseTransformer(BaseTorchModel):
158
178
  if head_idx is not None:
159
179
  attn = attn[:, head_idx] # [B, S, S]
160
180
  return attn
161
-
162
- def apply_attn_interpretations_to_adata(self, dataloader, adata,
163
- obsm_key_grad="attn_grad",
164
- obsm_key_rollout="attn_rollout",
165
- device="cpu"):
181
+
182
+ def apply_attn_interpretations_to_adata(
183
+ self,
184
+ dataloader,
185
+ adata,
186
+ obsm_key_grad="attn_grad",
187
+ obsm_key_rollout="attn_rollout",
188
+ device="cpu",
189
+ ):
166
190
  self.to(device)
167
191
  self.eval()
168
192
  grad_maps = []
@@ -193,12 +217,10 @@ class BaseTransformer(BaseTorchModel):
193
217
  # add per-row normalized version
194
218
  grad_normed = grad_concat / (np.max(grad_concat, axis=1, keepdims=True) + 1e-8)
195
219
  adata.obsm[f"{obsm_key_grad}_normalized"] = grad_normed
196
-
220
+
221
+
197
222
  class TransformerClassifier(BaseTransformer):
198
- def __init__(self,
199
- input_dim,
200
- num_classes,
201
- **kwargs):
223
+ def __init__(self, input_dim, num_classes, **kwargs):
202
224
  super().__init__(input_dim, **kwargs)
203
225
  # Classification head
204
226
  output_size = 1 if num_classes == 2 else num_classes
@@ -215,7 +237,7 @@ class TransformerClassifier(BaseTransformer):
215
237
  x = x.unsqueeze(0).unsqueeze(-1) # just in case (S,) → (1, S, 1)
216
238
  else:
217
239
  pass
218
- encoded = self.encode(x) # -> (B, S, D_model)
240
+ encoded = self.encode(x) # -> (B, S, D_model)
219
241
  if self.use_cls_token:
220
242
  pooled = encoded[:, 0] # (B, D)
221
243
  else:
@@ -223,14 +245,13 @@ class TransformerClassifier(BaseTransformer):
223
245
 
224
246
  out = self.cls_head(pooled) # (B, C)
225
247
  return out
226
-
248
+
249
+
227
250
  class DANNTransformerClassifier(TransformerClassifier):
228
251
  def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
229
252
  super().__init__(input_dim, model_dim, num_classes, **kwargs)
230
253
  self.domain_classifier = nn.Sequential(
231
- nn.Linear(model_dim, 128),
232
- nn.ReLU(),
233
- nn.Linear(128, n_domains)
254
+ nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
234
255
  )
235
256
 
236
257
  def forward(self, x, alpha=1.0):
@@ -242,6 +263,7 @@ class DANNTransformerClassifier(TransformerClassifier):
242
263
 
243
264
  return class_logits, domain_logits
244
265
 
266
+
245
267
  class MaskedTransformerPretrainer(BaseTransformer):
246
268
  def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
247
269
  super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
@@ -254,12 +276,13 @@ class MaskedTransformerPretrainer(BaseTransformer):
254
276
  """
255
277
  if x.dim() == 2:
256
278
  x = x.unsqueeze(-1)
257
- encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
258
- return self.decoder(encoded) # -> (B, D_input)
259
-
279
+ encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
280
+ return self.decoder(encoded) # -> (B, D_input)
281
+
282
+
260
283
  class DANNTransformer(BaseTransformer):
261
- """
262
- """
284
+ """ """
285
+
263
286
  def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
264
287
  super().__init__(
265
288
  input_dim=1, # 1D scalar input per token
@@ -267,7 +290,7 @@ class DANNTransformer(BaseTransformer):
267
290
  num_heads=n_heads,
268
291
  num_layers=n_layers,
269
292
  seq_len=seq_len,
270
- use_learnable_pos=True # enables learnable pos_embed in base
293
+ use_learnable_pos=True, # enables learnable pos_embed in base
271
294
  )
272
295
 
273
296
  # Reconstruction head
@@ -275,9 +298,7 @@ class DANNTransformer(BaseTransformer):
275
298
 
276
299
  # Domain classification head
277
300
  self.domain_classifier = nn.Sequential(
278
- nn.Linear(model_dim, 128),
279
- nn.ReLU(),
280
- nn.Linear(128, n_domains)
301
+ nn.Linear(model_dim, 128), nn.ReLU(), nn.Linear(128, n_domains)
281
302
  )
282
303
 
283
304
  def forward(self, x, alpha=1.0):
@@ -300,4 +321,3 @@ class DANNTransformer(BaseTransformer):
300
321
  domain_logits = self.domain_classifier(rev) # (B, n_batches)
301
322
 
302
323
  return recon, domain_logits
303
-
@@ -1,5 +1,10 @@
1
- import torch
2
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="model wrappers")
6
+ nn = torch.nn
7
+
3
8
 
4
9
  class ScaledModel(nn.Module):
5
10
  def __init__(self, model, mean, std):
@@ -17,4 +22,4 @@ class ScaledModel(nn.Module):
17
22
  x = (x - mean[None, None, :]) / std[None, None, :]
18
23
  else:
19
24
  raise ValueError(f"Unsupported input shape {x.shape}")
20
- return self.model(x)
25
+ return self.model(x)
@@ -1,2 +1,4 @@
1
- from .train_lightning_model import train_lightning_model, run_sliding_window_lightning_training
2
- from .train_sklearn_model import train_sklearn_model, run_sliding_window_sklearn_training
1
+ from __future__ import annotations
2
+
3
+ from .train_lightning_model import run_sliding_window_lightning_training, train_lightning_model
4
+ from .train_sklearn_model import run_sliding_window_sklearn_training, train_sklearn_model
@@ -1,9 +1,21 @@
1
- import torch
2
- from pytorch_lightning import Trainer
3
- from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
4
5
  from ..data import AnnDataModule
5
6
  from ..models import TorchClassifierWrapper
6
7
 
8
+ torch = require("torch", extra="ml-base", purpose="Lightning training")
9
+ pytorch_lightning = require("pytorch_lightning", extra="ml-extended", purpose="Lightning training")
10
+ pl_callbacks = require(
11
+ "pytorch_lightning.callbacks", extra="ml-extended", purpose="Lightning training"
12
+ )
13
+
14
+ Trainer = pytorch_lightning.Trainer
15
+ EarlyStopping = pl_callbacks.EarlyStopping
16
+ ModelCheckpoint = pl_callbacks.ModelCheckpoint
17
+
18
+
7
19
  def train_lightning_model(
8
20
  model,
9
21
  datamodule,
@@ -12,7 +24,7 @@ def train_lightning_model(
12
24
  monitor_metric="val_loss",
13
25
  checkpoint_path=None,
14
26
  evaluate_test=True,
15
- devices=1
27
+ devices=1,
16
28
  ):
17
29
  """
18
30
  Takes a PyTorch Lightning Model and a Lightning DataLoader module to define a Lightning Trainer.
@@ -39,13 +51,15 @@ def train_lightning_model(
39
51
  EarlyStopping(monitor=monitor_metric, patience=patience, mode="min"),
40
52
  ]
41
53
  if checkpoint_path:
42
- callbacks.append(ModelCheckpoint(
43
- dirpath=checkpoint_path,
44
- filename="{epoch}-{val_loss:.4f}",
45
- monitor=monitor_metric,
46
- save_top_k=1,
47
- mode="min",
48
- ))
54
+ callbacks.append(
55
+ ModelCheckpoint(
56
+ dirpath=checkpoint_path,
57
+ filename="{epoch}-{val_loss:.4f}",
58
+ monitor=monitor_metric,
59
+ save_top_k=1,
60
+ mode="min",
61
+ )
62
+ )
49
63
 
50
64
  # Trainer setup
51
65
  trainer = Trainer(
@@ -54,7 +68,7 @@ def train_lightning_model(
54
68
  accelerator=accelerator,
55
69
  devices=devices,
56
70
  log_every_n_steps=10,
57
- enable_progress_bar=False
71
+ enable_progress_bar=False,
58
72
  )
59
73
 
60
74
  # Fit model with trainer
@@ -63,7 +77,7 @@ def train_lightning_model(
63
77
  # Test model (if applicable)
64
78
  if evaluate_test and hasattr(datamodule, "test_dataloader"):
65
79
  trainer.test(model, datamodule=datamodule)
66
-
80
+
67
81
  # Return best checkpoint path
68
82
  best_ckpt = None
69
83
  for cb in callbacks:
@@ -72,6 +86,7 @@ def train_lightning_model(
72
86
 
73
87
  return trainer, best_ckpt
74
88
 
89
+
75
90
  def run_sliding_window_lightning_training(
76
91
  adata,
77
92
  tensor_source,
@@ -86,13 +101,13 @@ def run_sliding_window_lightning_training(
86
101
  stride,
87
102
  max_epochs=30,
88
103
  patience=5,
89
- enforce_eval_balance: bool=False,
90
- target_eval_freq: float=0.3,
91
- max_eval_positive: int=None
104
+ enforce_eval_balance: bool = False,
105
+ target_eval_freq: float = 0.3,
106
+ max_eval_positive: int = None,
92
107
  ):
93
108
  input_len = adata.shape[1]
94
109
  results = {}
95
-
110
+
96
111
  for start in range(0, input_len - window_size + 1, stride):
97
112
  center_idx = start + window_size // 2
98
113
  center_varname = adata.var_names[center_idx]
@@ -106,18 +121,22 @@ def run_sliding_window_lightning_training(
106
121
  label_col=label_col,
107
122
  batch_size=64,
108
123
  window_start=start,
109
- window_size=window_size
124
+ window_size=window_size,
110
125
  )
111
126
  datamodule.setup()
112
127
 
113
128
  # Build model for this window
114
129
  model = model_class(window_size, num_classes)
115
130
  wrapper = TorchClassifierWrapper(
116
- model, label_col=label_col, num_classes=num_classes,
131
+ model,
132
+ label_col=label_col,
133
+ num_classes=num_classes,
117
134
  class_names=class_names,
118
135
  class_weights=class_weights,
119
- focus_class=focus_class, enforce_eval_balance=enforce_eval_balance,
120
- target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive
136
+ focus_class=focus_class,
137
+ enforce_eval_balance=enforce_eval_balance,
138
+ target_eval_freq=target_eval_freq,
139
+ max_eval_positive=max_eval_positive,
121
140
  )
122
141
 
123
142
  # Train model
@@ -129,7 +148,7 @@ def run_sliding_window_lightning_training(
129
148
  "model": wrapper,
130
149
  "trainer": trainer,
131
150
  "checkpoint": ckpt,
132
- "metrics": trainer.callback_metrics
151
+ "metrics": trainer.callback_metrics,
133
152
  }
134
-
153
+
135
154
  return results
@@ -1,16 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  from ..data import AnnDataModule
2
4
  from ..models import SklearnModelWrapper
3
5
 
4
- def train_sklearn_model(
5
- model_wrapper,
6
- datamodule,
7
- evaluate_test=True,
8
- evaluate_val=False
9
- ):
6
+
7
+ def train_sklearn_model(model_wrapper, datamodule, evaluate_test=True, evaluate_val=False):
10
8
  """
11
9
  Fits a SklearnModelWrapper on the train split from datamodule.
12
10
  Evaluates on test and/or val set.
13
-
11
+
14
12
  Parameters:
15
13
  model_wrapper: SklearnModelWrapper instance
16
14
  datamodule: AnnDataModule instance (with setup() method)
@@ -39,6 +37,7 @@ def train_sklearn_model(
39
37
 
40
38
  return metrics
41
39
 
40
+
42
41
  def run_sliding_window_sklearn_training(
43
42
  adata,
44
43
  tensor_source,
@@ -58,7 +57,7 @@ def run_sliding_window_sklearn_training(
58
57
  enforce_eval_balance=False,
59
58
  target_eval_freq=0.3,
60
59
  max_eval_positive=None,
61
- **model_kwargs
60
+ **model_kwargs,
62
61
  ):
63
62
  """
64
63
  Sliding window training for sklearn models using AnnData.
@@ -86,29 +85,26 @@ def run_sliding_window_sklearn_training(
86
85
  train_frac=train_frac,
87
86
  val_frac=val_frac,
88
87
  test_frac=test_frac,
89
- random_seed=random_seed
88
+ random_seed=random_seed,
90
89
  )
91
90
  datamodule.setup()
92
91
 
93
92
  # Build model wrapper
94
93
  sklearn_model = model_class(**model_kwargs)
95
94
  wrapper = SklearnModelWrapper(
96
- sklearn_model,
95
+ sklearn_model,
97
96
  num_classes=num_classes,
98
97
  label_col=label_col,
99
98
  class_names=class_names,
100
99
  focus_class=focus_class,
101
100
  enforce_eval_balance=enforce_eval_balance,
102
101
  target_eval_freq=target_eval_freq,
103
- max_eval_positive=max_eval_positive
102
+ max_eval_positive=max_eval_positive,
104
103
  )
105
104
 
106
105
  # Fit and evaluate
107
106
  metrics = train_sklearn_model(wrapper, datamodule, evaluate_test=True, evaluate_val=False)
108
107
 
109
- results[center_varname] = {
110
- "model": wrapper,
111
- "metrics": metrics
112
- }
108
+ results[center_varname] = {"model": wrapper, "metrics": metrics}
113
109
 
114
110
  return results
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .device import detect_device
2
- from .grl import GradReverse
4
+ from .grl import GradReverse
@@ -1,10 +1,17 @@
1
- import torch
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="device selection")
6
+
2
7
 
3
8
  def detect_device():
4
9
  device = (
5
- torch.device('cuda') if torch.cuda.is_available() else
6
- torch.device('mps') if torch.backends.mps.is_available() else
7
- torch.device('cpu')
10
+ torch.device("cuda")
11
+ if torch.cuda.is_available()
12
+ else torch.device("mps")
13
+ if torch.backends.mps.is_available()
14
+ else torch.device("cpu")
8
15
  )
9
16
  print(f"Detected device: {device}")
10
- return device
17
+ return device
@@ -1,4 +1,9 @@
1
- import torch
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="gradient reversal layers")
6
+
2
7
 
3
8
  class GradReverse(torch.autograd.Function):
4
9
  @staticmethod
@@ -10,5 +15,6 @@ class GradReverse(torch.autograd.Function):
10
15
  def backward(ctx, grad_output):
11
16
  return -ctx.alpha * grad_output, None
12
17
 
18
+
13
19
  def grad_reverse(x, alpha=1.0):
14
- return GradReverse.apply(x, alpha)
20
+ return GradReverse.apply(x, alpha)