smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,787 @@
1
+ ## Train CNN, RNN, Random Forest models on double barcoded, low contamination datasets
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, TensorDataset
6
+ from sklearn.ensemble import RandomForestClassifier
7
+ from sklearn.naive_bayes import GaussianNB
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import roc_curve, auc, precision_recall_curve, f1_score, confusion_matrix
10
+ from sklearn.utils.class_weight import compute_class_weight
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import warnings
14
+
15
+ # Device detection
16
+ device = (
17
+ torch.device('cuda') if torch.cuda.is_available() else
18
+ torch.device('mps') if torch.backends.mps.is_available() else
19
+ torch.device('cpu')
20
+ )
21
+
22
+ # ------------------------- Utilities -------------------------
23
+ def random_fill_nans(X):
24
+ nan_mask = np.isnan(X)
25
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
26
+ return X
27
+
28
+ # ------------------------- Model Definitions -------------------------
29
+ class CNNClassifier(nn.Module):
30
+ def __init__(self, input_size, num_classes):
31
+ super().__init__()
32
+ self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
33
+ self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
34
+ self.relu = nn.ReLU()
35
+ dummy_input = torch.zeros(1, 1, input_size)
36
+ dummy_output = self._forward_conv(dummy_input)
37
+ flattened_size = dummy_output.view(1, -1).shape[1]
38
+ self.fc1 = nn.Linear(flattened_size, 64)
39
+ self.fc2 = nn.Linear(64, num_classes)
40
+
41
+ def _forward_conv(self, x):
42
+ x = self.relu(self.conv1(x))
43
+ x = self.relu(self.conv2(x))
44
+ return x
45
+
46
+ def forward(self, x):
47
+ x = x.unsqueeze(1)
48
+ x = self._forward_conv(x)
49
+ x = x.view(x.size(0), -1)
50
+ x = self.relu(self.fc1(x))
51
+ return self.fc2(x)
52
+
53
+ class MLPClassifier(nn.Module):
54
+ def __init__(self, input_dim, num_classes):
55
+ super().__init__()
56
+ self.model = nn.Sequential(
57
+ nn.Linear(input_dim, 128),
58
+ nn.ReLU(),
59
+ nn.Dropout(0.2),
60
+ nn.Linear(128, 64),
61
+ nn.ReLU(),
62
+ nn.Dropout(0.2),
63
+ nn.Linear(64, num_classes)
64
+ )
65
+
66
+ def forward(self, x):
67
+ return self.model(x)
68
+
69
+ class RNNClassifier(nn.Module):
70
+ def __init__(self, input_size, hidden_dim, num_classes):
71
+ super().__init__()
72
+ # Define LSTM layer
73
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
74
+ # Define fully connected output layer
75
+ self.fc = nn.Linear(hidden_dim, num_classes)
76
+
77
+ def forward(self, x):
78
+ x = x.unsqueeze(1)
79
+ _, (h_n, _) = self.lstm(x)
80
+ return self.fc(h_n.squeeze(0))
81
+
82
+ class AttentionRNNClassifier(nn.Module):
83
+ def __init__(self, input_size, hidden_dim, num_classes):
84
+ super().__init__()
85
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
86
+ self.attn = nn.Linear(hidden_dim, 1) # Simple attention scores
87
+ self.fc = nn.Linear(hidden_dim, num_classes)
88
+
89
+ def forward(self, x):
90
+ x = x.unsqueeze(1) # shape: (batch, 1, seq_len)
91
+ lstm_out, _ = self.lstm(x) # shape: (batch, 1, hidden_dim)
92
+ attn_weights = torch.softmax(self.attn(lstm_out), dim=1) # (batch, 1, 1)
93
+ context = (attn_weights * lstm_out).sum(dim=1) # weighted sum
94
+ return self.fc(context)
95
+
96
+ class PositionalEncoding(nn.Module):
97
+ def __init__(self, d_model, max_len=5000):
98
+ super().__init__()
99
+ pe = torch.zeros(max_len, d_model)
100
+ position = torch.arange(0, max_len).unsqueeze(1).float()
101
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
102
+ pe[:, 0::2] = torch.sin(position * div_term)
103
+ pe[:, 1::2] = torch.cos(position * div_term)
104
+ self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
105
+
106
+ def forward(self, x):
107
+ x = x + self.pe[:, :x.size(1)].to(x.device)
108
+ return x
109
+
110
+ class TransformerClassifier(nn.Module):
111
+ def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2):
112
+ super().__init__()
113
+ self.input_fc = nn.Linear(input_dim, model_dim)
114
+ self.pos_encoder = PositionalEncoding(model_dim)
115
+
116
+ encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads)
117
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
118
+
119
+ self.cls_head = nn.Linear(model_dim, num_classes)
120
+
121
+ def forward(self, x):
122
+ # x: [batch_size, input_dim]
123
+ x = self.input_fc(x).unsqueeze(1) # -> [batch_size, 1, model_dim]
124
+ x = self.pos_encoder(x)
125
+ x = x.permute(1, 0, 2) # -> [seq_len=1, batch_size, model_dim]
126
+ encoded = self.transformer(x)
127
+ pooled = encoded.mean(dim=0) # -> [batch_size, model_dim]
128
+ return self.cls_head(pooled)
129
+
130
+ def train_model(model, loader, optimizer, criterion, device, ref_name="", model_name="", epochs=20, patience=5):
131
+ model.train()
132
+ best_loss = float('inf')
133
+ trigger_times = 0
134
+
135
+ for epoch in range(epochs):
136
+ running_loss = 0
137
+ for batch_X, batch_y in loader:
138
+ batch_X, batch_y = batch_X.to(device), batch_y.to(device)
139
+ optimizer.zero_grad()
140
+ loss = criterion(model(batch_X), batch_y)
141
+ loss.backward()
142
+ optimizer.step()
143
+ running_loss += loss.item()
144
+ average_loss = running_loss / len(loader)
145
+ print(f"{ref_name} {model_name} Epoch {epoch+1} Loss: {average_loss:.4f}")
146
+
147
+ if average_loss < best_loss:
148
+ best_loss = average_loss
149
+ trigger_times = 0
150
+ else:
151
+ trigger_times += 1
152
+ if trigger_times >= patience:
153
+ print(f"Early stopping {model_name} for {ref_name} at epoch {epoch+1}")
154
+ break
155
+
156
+ def evaluate_model(model, X_tensor, y_encoded, device):
157
+ model.eval()
158
+ with torch.no_grad():
159
+ outputs = model(X_tensor.to(device))
160
+ probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
161
+ preds = outputs.argmax(dim=1).cpu().numpy()
162
+ fpr, tpr, _ = roc_curve(y_encoded, probs)
163
+ precision, recall, _ = precision_recall_curve(y_encoded, probs)
164
+ f1 = f1_score(y_encoded, preds)
165
+ cm = confusion_matrix(y_encoded, preds)
166
+ roc_auc = auc(fpr, tpr)
167
+ pr_auc = auc(recall,precision)
168
+ # positive-class frequency
169
+ pos_freq = np.mean(y_encoded == 1)
170
+ pr_auc_norm = pr_auc / pos_freq
171
+ return {
172
+ 'fpr': fpr, 'tpr': tpr,
173
+ 'precision': precision, 'recall': recall,
174
+ 'f1': f1, 'auc': roc_auc, 'pr_auc': pr_auc,
175
+ 'confusion_matrix': cm, 'pos_freq': pos_freq, 'pr_auc_norm': pr_auc_norm
176
+ }, preds, probs
177
+
178
+ def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
179
+ model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
180
+ model.fit(X_tensor[train_indices].numpy(), y_tensor[train_indices].numpy())
181
+ probs = model.predict_proba(X_tensor[test_indices].cpu().numpy())[:, 1]
182
+ preds = model.predict(X_tensor[test_indices].cpu().numpy())
183
+ return model, preds, probs
184
+
185
+ # ------------------------- Main Training Loop -------------------------
186
+ def run_training_loop(adata, site_config, layer_name=None,
187
+ mlp=False, cnn=False, rnn=False, arnn=False, transformer=False, rf=False, nb=False, rr_bayes=False,
188
+ max_epochs=10, max_patience=5, n_estimators=500, training_split=0.5):
189
+ device = (
190
+ torch.device('cuda') if torch.cuda.is_available() else
191
+ torch.device('mps') if torch.backends.mps.is_available() else
192
+ torch.device('cpu'))
193
+ metrics, models, positions, tensors = {}, {}, {}, {}
194
+ adata.obs["used_for_training"] = False # Initialize column to False
195
+
196
+ for ref in adata.obs['Reference_strand'].cat.categories:
197
+ ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
198
+ if ref_subset.shape[0] == 0:
199
+ continue
200
+
201
+ # Get matrix and coordinates
202
+ if layer_name:
203
+ matrix = ref_subset.layers[layer_name].copy()
204
+ coords = ref_subset.var_names
205
+ suffix = layer_name
206
+ else:
207
+ site_mask = np.zeros(ref_subset.shape[1], dtype=bool)
208
+ if ref in site_config:
209
+ for site in site_config[ref]:
210
+ site_mask |= ref_subset.var[f'{ref}_{site}']
211
+ suffix = "_".join(site_config[ref])
212
+ else:
213
+ site_mask = np.ones(ref_subset.shape[1], dtype=bool)
214
+ suffix = "full"
215
+ site_subset = ref_subset[:, site_mask].copy()
216
+ matrix = site_subset.X
217
+ coords = site_subset.var_names
218
+
219
+ positions.setdefault(ref, {})[suffix] = coords
220
+
221
+ # Fill and encode
222
+ X = random_fill_nans(matrix)
223
+ y = ref_subset.obs["activity_status"]
224
+ y_encoded = y.map({'Active': 1, 'Silent': 0})
225
+ X_tensor = torch.tensor(X, dtype=torch.float32)
226
+ y_tensor = torch.tensor(y_encoded.values, dtype=torch.long)
227
+ tensors.setdefault(ref, {})[suffix] = X_tensor
228
+
229
+ # Setup datasets
230
+ dataset = TensorDataset(X_tensor, y_tensor)
231
+ train_size = int(training_split * len(dataset))
232
+ train_dataset, test_dataset = torch.utils.data.random_split(
233
+ dataset, [train_size, len(dataset) - train_size]
234
+ )
235
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
236
+ train_indices = adata.obs[adata.obs['Reference_strand'] == ref].index[train_dataset.indices]
237
+ adata.obs.loc[train_indices, "used_for_training"] = True
238
+
239
+ test_X = X_tensor[test_dataset.indices]
240
+ test_y = y_encoded.iloc[test_dataset.indices] if hasattr(y_encoded, 'iloc') else y_encoded[test_dataset.indices]
241
+
242
+ class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)
243
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
244
+
245
+ metrics[ref], models[ref] = {}, {}
246
+
247
+ # MLP
248
+ if mlp:
249
+ mlp_model = MLPClassifier(X.shape[1], len(np.unique(y_encoded))).to(device)
250
+ optimizer_mlp = optim.Adam(mlp_model.parameters(), lr=0.001)
251
+ criterion_mlp = nn.CrossEntropyLoss(weight=class_weights)
252
+ train_model(mlp_model, train_loader, optimizer_mlp, criterion_mlp, device, ref, 'MLP', epochs=max_epochs, patience=max_patience)
253
+ mlp_metrics, mlp_preds, mlp_probs = evaluate_model(mlp_model, test_X, test_y, device)
254
+ metrics[ref][f'mlp_{suffix}'] = mlp_metrics
255
+ models[ref][f'mlp_{suffix}'] = mlp_model
256
+
257
+ # CNN
258
+ if cnn:
259
+ cnn_model = CNNClassifier(X.shape[1], len(np.unique(y_encoded))).to(device)
260
+ optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=0.001)
261
+ criterion_cnn = nn.CrossEntropyLoss(weight=class_weights)
262
+ train_model(cnn_model, train_loader, optimizer_cnn, criterion_cnn, device, ref, 'CNN', epochs=max_epochs, patience=max_patience)
263
+ cnn_metrics, cnn_preds, cnn_probs = evaluate_model(cnn_model, test_X, test_y, device)
264
+ metrics[ref][f'cnn_{suffix}'] = cnn_metrics
265
+ models[ref][f'cnn_{suffix}'] = cnn_model
266
+
267
+ # RNN
268
+ if rnn:
269
+ rnn_model = RNNClassifier(X.shape[1], 64, len(np.unique(y_encoded))).to(device)
270
+ optimizer_rnn = optim.Adam(rnn_model.parameters(), lr=0.001)
271
+ criterion_rnn = nn.CrossEntropyLoss(weight=class_weights)
272
+ train_model(rnn_model, train_loader, optimizer_rnn, criterion_rnn, device, ref, 'RNN', epochs=max_epochs, patience=max_patience)
273
+ rnn_metrics, rnn_preds, rnn_probs = evaluate_model(rnn_model, test_X, test_y, device)
274
+ metrics[ref][f'rnn_{suffix}'] = rnn_metrics
275
+ models[ref][f'rnn_{suffix}'] = rnn_model
276
+
277
+ # Attention RNN
278
+ if arnn:
279
+ arnn_model = AttentionRNNClassifier(X.shape[1], 64, len(np.unique(y_encoded))).to(device)
280
+ optimizer_arnn = optim.Adam(arnn_model.parameters(), lr=0.001)
281
+ criterion_arnn = nn.CrossEntropyLoss(weight=class_weights)
282
+ train_model(arnn_model, train_loader, optimizer_arnn, criterion_arnn, device, ref, 'aRNN', epochs=max_epochs, patience=max_patience)
283
+ arnn_metrics, arnn_preds, arnn_probs = evaluate_model(arnn_model, test_X, test_y, device)
284
+ metrics[ref][f'arnn_{suffix}'] = arnn_metrics
285
+ models[ref][f'arnn_{suffix}'] = arnn_model
286
+
287
+ # Transformer
288
+ if transformer:
289
+ t_model = TransformerClassifier(X.shape[1], 64, len(np.unique(y_encoded))).to(device)
290
+ optimizer_t = optim.Adam(t_model.parameters(), lr=0.001)
291
+ criterion_t = nn.CrossEntropyLoss(weight=class_weights)
292
+ train_model(t_model, train_loader, optimizer_t, criterion_t, device, ref, 'Transformer', epochs=max_epochs, patience=max_patience)
293
+ t_metrics, t_preds, t_probs = evaluate_model(t_model, test_X, test_y, device)
294
+ metrics[ref][f'transformer_{suffix}'] = t_metrics
295
+ models[ref][f'transformer_{suffix}'] = t_model
296
+
297
+ # RF
298
+ if rf:
299
+ rf_model, rf_preds, rf_probs = train_rf(X_tensor, y_tensor, train_dataset.indices, test_dataset.indices, n_estimators)
300
+ fpr, tpr, _ = roc_curve(test_y, rf_probs)
301
+ precision, recall, _ = precision_recall_curve(test_y, rf_probs)
302
+ f1 = f1_score(test_y, rf_preds)
303
+ cm = confusion_matrix(test_y, rf_preds)
304
+ roc_auc = auc(fpr, tpr)
305
+ pr_auc = auc(recall, precision)
306
+ metrics[ref][f'rf_{suffix}'] = {
307
+ 'fpr': fpr, 'tpr': tpr,
308
+ 'precision': precision, 'recall': recall,
309
+ 'f1': f1, 'auc': roc_auc, 'confusion_matrix': cm, 'pr_auc': pr_auc
310
+ }
311
+ models[ref][f'rf_{suffix}'] = rf_model
312
+
313
+ # Naive Bayes
314
+ if nb:
315
+ nb_model = GaussianNB()
316
+ nb_model.fit(X_tensor[train_dataset.indices].numpy(), y_tensor[train_dataset.indices].numpy())
317
+ nb_probs = nb_model.predict_proba(test_X.numpy())[:, 1]
318
+ nb_preds = nb_model.predict(test_X.numpy())
319
+ fpr_nb, tpr_nb, _ = roc_curve(test_y, nb_probs)
320
+ prec_nb, rec_nb, _ = precision_recall_curve(test_y, nb_probs)
321
+ f1_nb = f1_score(test_y, nb_preds)
322
+ cm_nb = confusion_matrix(test_y, nb_preds)
323
+ auc_nb = auc(fpr_nb, tpr_nb)
324
+ pr_auc_nb = auc(rec_nb, prec_nb)
325
+ metrics[ref][f'nb_{suffix}'] = {
326
+ 'fpr': fpr_nb, 'tpr': tpr_nb,
327
+ 'precision': prec_nb, 'recall': rec_nb,
328
+ 'f1': f1_nb, 'auc': auc_nb, 'confusion_matrix': cm_nb, 'pr_auc': pr_auc_nb
329
+ }
330
+ models[ref][f'nb_{suffix}'] = nb_model
331
+
332
+ # Relative-Risk Bayesian
333
+ if rr_bayes:
334
+ # compute relative risks from training
335
+ X_train = X_tensor[train_dataset.indices].numpy()
336
+ y_train_arr = y_tensor[train_dataset.indices].numpy()
337
+ n_a = (y_train_arr==1).sum(); n_s = (y_train_arr==0).sum()
338
+ p_a = (X_train[y_train_arr==1]==1).sum(axis=0)/(n_a+1e-6)
339
+ p_s = (X_train[y_train_arr==0]==1).sum(axis=0)/(n_s+1e-6)
340
+ rr = (p_a+1e-6)/(p_s+1e-6)
341
+ log_rr = np.log(rr)
342
+ # score test
343
+ scores = test_X.numpy().dot(log_rr)
344
+ probs = 1/(1+np.exp(-scores))
345
+ preds = (probs>=0.5).astype(int)
346
+ fpr, tpr, _ = roc_curve(test_y, probs)
347
+ pr, rc, _ = precision_recall_curve(test_y, probs)
348
+ roc_auc = auc(fpr, tpr)
349
+ pr_auc = auc(rc, pr)
350
+ pos_freq = np.mean(test_y==1)
351
+ pr_norm = pr_auc/pos_freq if pos_freq>0 else np.nan
352
+ metrics[ref][f'rr_bayes_{suffix}'] = {
353
+ 'fpr': fpr, 'tpr': tpr,
354
+ 'precision': pr, 'recall': rc,
355
+ 'auc': roc_auc, 'pr_auc': pr_auc, 'pr_auc_norm': pr_norm
356
+ }
357
+ # save rr_bayes parameters as a pseudo-model
358
+ models[ref][f'rr_bayes_{suffix}'] = {
359
+ 'log_rr': log_rr,
360
+ 'p_active': p_a,
361
+ 'p_silent': p_s
362
+ }
363
+
364
+ return metrics, models, positions, tensors
365
+
366
+ def sliding_window_train_test(
367
+ adata,
368
+ site_config,
369
+ layer_name,
370
+ window_sizes,
371
+ step_size,
372
+ training_split=0.7,
373
+ mlp=False,
374
+ cnn=False,
375
+ rnn=False,
376
+ arnn=False,
377
+ transformer=False,
378
+ rf=False,
379
+ nb=False,
380
+ rr_bayes=False,
381
+ epochs=10,
382
+ patience=5,
383
+ batch_size=64,
384
+ n_estimators=500,
385
+ balance_rf_class_weights=True,
386
+ positive_amount=None,
387
+ positive_freq=None,
388
+ bins=None
389
+ ):
390
+ """
391
+ Slide a window along features, train/test selected models at each position,
392
+ and append ROC-AUC and AUPRC values to the window center in adata.var in a single pass.
393
+
394
+ Torch models use GPU (CUDA > MPS > CPU).
395
+
396
+ bins: dict mapping bin_name -> boolean mask over adata.obs
397
+ """
398
+ # device detection
399
+ device = (
400
+ torch.device('cuda') if torch.cuda.is_available() else
401
+ torch.device('mps') if torch.backends.mps.is_available() else
402
+ torch.device('cpu')
403
+ )
404
+
405
+ # define bins: default single bin 'all'
406
+ bin_dict = bins if bins is not None else {'all': np.ones(adata.n_obs, dtype=bool)}
407
+
408
+ for bin_name, mask_obs in bin_dict.items():
409
+ mask_array = np.asarray(mask_obs)
410
+ if mask_array.shape[0] != adata.n_obs:
411
+ raise ValueError(f"Mask for bin '{bin_name}' length {mask_array.shape[0]} != n_obs {adata.n_obs}")
412
+ adata_bin = adata[mask_array]
413
+ if adata_bin.n_obs == 0:
414
+ continue
415
+
416
+ for ref in adata_bin.obs['Reference_strand'].cat.categories:
417
+ subset = adata_bin[adata_bin.obs['Reference_strand'] == ref]
418
+ if subset.n_obs == 0:
419
+ continue
420
+
421
+ # build full feature matrix and var names
422
+ if layer_name:
423
+ suffix = layer_name
424
+ X_full = subset.layers[layer_name].copy()
425
+ var_names = subset.var_names
426
+ else:
427
+ mask_vars = np.zeros(subset.n_vars, dtype=bool)
428
+ if ref in site_config:
429
+ for site in site_config[ref]:
430
+ mask_vars |= subset.var[f'{ref}_{site}']
431
+ suffix = "_".join(site_config[ref])
432
+ else:
433
+ mask_vars[:] = True
434
+ suffix = "full"
435
+ X_full = subset[:, mask_vars].X.copy()
436
+ var_names = subset.var_names[mask_vars]
437
+
438
+ for window_size in window_sizes:
439
+ # prepare global arrays for each model/metric
440
+ n_vars_global = adata.n_vars
441
+ arrays = {}
442
+
443
+ for mname in ['mlp','cnn','rnn','arnn','transformer','rf','nb', 'rr_bayes']:
444
+ if locals()[mname]: # if model enabled
445
+ col_roc = f"{bin_name}_{mname}_{suffix}_w{window_size}_roc"
446
+ col_pr = f"{bin_name}_{mname}_{suffix}_w{window_size}_pr"
447
+ col_pr_norm = f"{bin_name}_{mname}_{suffix}_w{window_size}_pr_norm"
448
+ arrays[(mname,'roc')] = np.full(n_vars_global, np.nan)
449
+ arrays[(mname,'pr')] = np.full(n_vars_global, np.nan)
450
+ arrays[(mname,'pr_norm')] = np.full(n_vars_global, np.nan)
451
+
452
+ # fill missing and labels
453
+ X_full = random_fill_nans(X_full)
454
+ y_full = subset.obs['activity_status'].map({'Active':1,'Silent':0}).values
455
+ n_feats = X_full.shape[1]
456
+
457
+ # sliding windows
458
+ for start in range(0, n_feats - window_size + 1, step_size):
459
+ end = start + window_size
460
+ X_win = X_full[:, start:end]
461
+ center_idx = start + window_size // 2
462
+ center_var = var_names[center_idx]
463
+ var_idx_global = adata.var_names.get_loc(center_var)
464
+
465
+ # train/test split
466
+ X_train, X_pool, y_train, y_pool = train_test_split(
467
+ X_win, y_full, train_size=training_split,
468
+ stratify=y_full, random_state=42
469
+ )
470
+
471
+ # Optional test sampling with fallback
472
+ try:
473
+ if positive_amount is not None and positive_freq is not None:
474
+ pos_idx = np.where(y_pool == 1)[0]
475
+ if positive_amount > len(pos_idx):
476
+ raise ValueError("positive_amount exceeds available positives")
477
+ chosen_pos = np.random.choice(pos_idx, positive_amount, replace=False)
478
+
479
+ neg_amount = int(round(positive_amount * (1 - positive_freq) / positive_freq))
480
+ neg_idx = np.where(y_pool == 0)[0]
481
+ if neg_amount > len(neg_idx):
482
+ raise ValueError("negative_amount exceeds available negatives")
483
+ chosen_neg = np.random.choice(neg_idx, neg_amount, replace=False)
484
+
485
+ sel = np.concatenate([chosen_pos, chosen_neg])
486
+ X_test, y_test = X_pool[sel], y_pool[sel]
487
+ else:
488
+ X_test, y_test = X_pool, y_pool
489
+ except ValueError as e:
490
+ warnings.warn(
491
+ f"Falling back to full pool for window {start}:{end} ({e})"
492
+ )
493
+ X_test, y_test = X_pool, y_pool
494
+
495
+ # prepare DataLoader
496
+ train_ds = TensorDataset(
497
+ torch.tensor(X_train, dtype=torch.float32).to(device),
498
+ torch.tensor(y_train, dtype=torch.long).to(device)
499
+ )
500
+ loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
501
+ class_w = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
502
+ class_w = torch.tensor(class_w, dtype=torch.float32).to(device)
503
+
504
+ # container for this window's metrics
505
+ results = {}
506
+ # train and evaluate models
507
+ if mlp:
508
+ model = MLPClassifier(window_size,2).to(device)
509
+ train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_MLP", epochs, patience)
510
+ mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
511
+ results['mlp'] = mets
512
+ if cnn:
513
+ model = CNNClassifier(window_size,2).to(device)
514
+ train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_CNN", epochs, patience)
515
+ mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
516
+ results['cnn'] = mets
517
+ if rnn:
518
+ model = RNNClassifier(window_size,64,2).to(device)
519
+ train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_RNN", epochs, patience)
520
+ mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
521
+ results['rnn'] = mets
522
+ if arnn:
523
+ model = AttentionRNNClassifier(window_size,64,2).to(device)
524
+ train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_aRNN", epochs, patience)
525
+ mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
526
+ results['arnn'] = mets
527
+ if transformer:
528
+ model = TransformerClassifier(window_size,64,2).to(device)
529
+ train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_Transformer", epochs, patience)
530
+ mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
531
+ results['transformer'] = mets
532
+ if rf:
533
+ if balance_rf_class_weights:
534
+ rf_mod = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
535
+ else:
536
+ rf_mod = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
537
+ rf_mod.fit(X_train,y_train)
538
+ probs = rf_mod.predict_proba(X_test)[:,1]
539
+ fpr,tpr,_ = roc_curve(y_test, probs)
540
+ pr,rc,_ = precision_recall_curve(y_test, probs)
541
+ pr_auc = auc(rc,pr)
542
+ # positive-class frequency
543
+ pos_freq = np.mean(y_test == 1)
544
+ pr_auc_norm = pr_auc / pos_freq
545
+ results['rf'] = {'auc':auc(fpr,tpr),'pr_auc':pr_auc, 'pr_auc_norm': pr_auc_norm}
546
+ if nb:
547
+ nb_mod = GaussianNB()
548
+ nb_mod.fit(X_train,y_train)
549
+ probs = nb_mod.predict_proba(X_test)[:,1]
550
+ fpr,tpr,_ = roc_curve(y_test, probs)
551
+ pr,rc,_ = precision_recall_curve(y_test, probs)
552
+ pr_auc = auc(rc,pr)
553
+ # positive-class frequency
554
+ pos_freq = np.mean(y_test == 1)
555
+ pr_auc_norm = pr_auc / pos_freq
556
+ results['nb'] = {'auc':auc(fpr,tpr),'pr_auc':pr_auc, 'pr_auc_norm': pr_auc_norm}
557
+ if rr_bayes:
558
+ # Relative-risk Bayesian classifier
559
+ n_active = np.sum(y_train == 1)
560
+ n_silent = np.sum(y_train == 0)
561
+ # compute feature-wise rates
562
+ p_a = (X_train[y_train == 1] == 1).sum(axis=0) / (n_active + 1e-6)
563
+ p_s = (X_train[y_train == 0] == 1).sum(axis=0) / (n_silent + 1e-6)
564
+ rr = (p_a + 1e-6) / (p_s + 1e-6)
565
+ log_rr = np.log(rr)
566
+ # score samples
567
+ scores = X_test.dot(log_rr)
568
+ probs = 1 / (1 + np.exp(-scores))
569
+ preds = (probs >= 0.5).astype(int)
570
+ # metrics
571
+ fpr, tpr, _ = roc_curve(y_test, probs)
572
+ pr, rc, _ = precision_recall_curve(y_test, probs)
573
+ roc_auc = auc(fpr, tpr)
574
+ pr_auc = auc(rc, pr)
575
+ pos_freq = np.mean(y_test == 1)
576
+ pr_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
577
+ results['rr_bayes'] = {'auc': roc_auc, 'pr_auc': pr_auc, 'pr_auc_norm': pr_norm}
578
+
579
+ # assign metrics into arrays
580
+ for mname, mets in results.items():
581
+ arrays[(mname,'roc')][var_idx_global] = mets['auc']
582
+ arrays[(mname,'pr')][var_idx_global] = mets['pr_auc']
583
+ arrays[(mname,'pr_norm')][var_idx_global] = mets['pr_auc_norm']
584
+
585
+ # after all windows, write arrays to adata.var
586
+ for (mname,metric), arr in arrays.items():
587
+ suffix_col = metric
588
+ col = f"{bin_name}_{mname}_{suffix}_w{window_size}_{suffix_col}"
589
+ adata.var[col] = arr
590
+
591
+ print("✅ Sliding-window training/testing complete. Metrics stored at window centers in adata.var.")
592
+
593
+ # ------------------------- Apply models to input adata -------------------------
594
+ def run_inference(adata, models, site_config, layer_name=None, model_names=["cnn", "mlp", "rf", "nb", "rr_bayes"]):
595
+ """
596
+ Perform inference on the full AnnData object using pre-trained models.
597
+
598
+ Parameters:
599
+ adata (AnnData): The full AnnData object.
600
+ models (dict): Dictionary of trained models keyed by reference strand.
601
+ site_config (dict): Configuration dictionary for subsetting features by site.
602
+ layer_name (str, optional): Name of the layer to use if applicable.
603
+ model_names (list, optional): List of model names to run inference on.
604
+ Defaults to ["cnn", "mlp", "rf", "nb"].
605
+
606
+ Returns:
607
+ None. The function updates adata.obs with predicted class labels, predicted probabilities,
608
+ and active probabilities for each model.
609
+ """
610
+ import numpy as np
611
+ import pandas as pd
612
+
613
+ device = (
614
+ torch.device('cuda') if torch.cuda.is_available() else
615
+ torch.device('mps') if torch.backends.mps.is_available() else
616
+ torch.device('cpu')
617
+ )
618
+
619
+ # Loop over each reference key in the models dictionary
620
+ for ref in models.keys():
621
+ # Subset the full AnnData by the reference strand
622
+ full_subset = adata[adata.obs['Reference_strand'] == ref]
623
+ if full_subset.shape[0] == 0:
624
+ continue
625
+
626
+ # Reconstruct the same layer or site mask used during training
627
+ if layer_name:
628
+ suffix = layer_name
629
+ full_matrix = full_subset.layers[layer_name].copy()
630
+ else:
631
+ site_mask = np.zeros(full_subset.shape[1], dtype=bool)
632
+ if ref in site_config:
633
+ for site in site_config[ref]:
634
+ site_mask |= full_subset.var[f'{ref}_{site}']
635
+ suffix = "_".join(site_config[ref])
636
+ else:
637
+ site_mask = np.ones(full_subset.shape[1], dtype=bool)
638
+ suffix = "full"
639
+ full_matrix = full_subset[:, site_mask].X.copy()
640
+
641
+ # Fill any NaNs in the feature matrix
642
+ full_matrix = random_fill_nans(full_matrix)
643
+
644
+ # Convert to a torch tensor; for torch models we use the specified device
645
+ full_tensor = torch.tensor(full_matrix, dtype=torch.float32)
646
+ full_tensor_device = full_tensor.to(device)
647
+
648
+ for model_name in model_names:
649
+ model_key = f"{model_name}_{suffix}"
650
+ pred_col = f"{model_name}_activity_prediction_{suffix}"
651
+ pred_prob_col = f"{model_name}_prediction_prob_{suffix}"
652
+ active_prob_col = f"{model_name}_active_prob_{suffix}"
653
+
654
+ if model_key in models[ref]:
655
+ model = models[ref][model_key]
656
+ if model_name in ["rf", "nb"]:
657
+ # For scikit-learn based models, work on CPU using NumPy arrays
658
+ X_input = full_tensor.cpu().numpy()
659
+ preds = model.predict(X_input)
660
+ probs = model.predict_proba(X_input)
661
+ elif model_name=='rr_bayes':
662
+ # model is dict of params
663
+ log_rr = model['log_rr']
664
+ scores = full_tensor.cpu().numpy().dot(log_rr)
665
+ probs1 = 1/(1+np.exp(-scores))
666
+ preds = (probs1>=0.5).astype(int)
667
+ probs = np.vstack([1-probs1, probs1]).T
668
+ else:
669
+ # For torch models, perform inference on the specified device
670
+ model.eval()
671
+ with torch.no_grad():
672
+ logits = model(full_tensor_device)
673
+ preds = logits.argmax(dim=1).cpu().numpy()
674
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
675
+
676
+ pred_probs = probs[np.arange(len(preds)), preds]
677
+ active_probs = probs[:, 1] # class 1 is assumed to be "Active"
678
+
679
+ # Store predictions in the AnnData object
680
+ adata.obs.loc[full_subset.obs.index, pred_col] = preds
681
+ adata.obs.loc[full_subset.obs.index, pred_prob_col] = pred_probs
682
+ adata.obs.loc[full_subset.obs.index, active_prob_col] = active_probs
683
+
684
+ # Ensure that the prediction columns are of categorical type.
685
+ # Replace non-finite values with 0 before converting.
686
+ for model_name in model_names:
687
+ pred_col = f"{model_name}_activity_prediction_{suffix}"
688
+ if pred_col in adata.obs.columns:
689
+ # Convert to numeric (non-finite become NaN), fill NaNs with 0, then cast
690
+ adata.obs[pred_col] = pd.to_numeric(adata.obs[pred_col], errors='coerce').fillna(0).astype(int).astype("category")
691
+
692
+ print("✅ Inference complete: stored predicted class, predicted probability, and active probability for each model.")
693
+
694
+ # ------------------------- Evaluate model activity predictions within categorical subgroups -------------------------
695
+
696
+ def evaluate_model_by_subgroups(
697
+ adata,
698
+ model_prefix="mlp",
699
+ suffix="GpC_site_CpG_site",
700
+ groupby_cols=["Sample_Names_Full", "Enhancer_Open", "Promoter_Open"],
701
+ label_col="activity_status",
702
+ min_samples=10,
703
+ exclude_training_data=True):
704
+ import pandas as pd
705
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
706
+
707
+ results = []
708
+
709
+ if exclude_training_data:
710
+ test_subset = adata[adata.obs['used_for_training'] == False]
711
+ else:
712
+ test_subset = adata
713
+
714
+ df = test_subset.obs.copy()
715
+ df[label_col] = df[label_col].astype('category').cat.codes
716
+
717
+ pred_col = f"{model_prefix}_activity_prediction_{suffix}"
718
+ prob_col = f"{model_prefix}_activity_prob_{suffix}"
719
+
720
+ if pred_col not in df or prob_col not in df:
721
+ raise ValueError(f"Missing prediction/probability columns for {model_prefix}")
722
+
723
+ for group_vals, group_df in df.groupby(groupby_cols):
724
+ if len(group_df) < min_samples:
725
+ continue # skip small groups
726
+
727
+ y_true = group_df[label_col].values
728
+ y_pred = group_df[pred_col].astype(int).values
729
+ y_prob = group_df[prob_col].astype(float).values
730
+
731
+ if len(set(y_true)) < 2:
732
+ auc = float('nan') # undefined if only one class present
733
+ else:
734
+ auc = roc_auc_score(y_true, y_prob)
735
+
736
+ results.append({
737
+ **dict(zip(groupby_cols, group_vals)),
738
+ "model": model_prefix,
739
+ "n_samples": len(group_df),
740
+ "accuracy": accuracy_score(y_true, y_pred),
741
+ "f1": f1_score(y_true, y_pred),
742
+ "auc": auc,
743
+ })
744
+
745
+ return pd.DataFrame(results)
746
+
747
+ def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col, exclude_training_data=True):
748
+ import pandas as pd
749
+ all_metrics = []
750
+ for model_prefix in model_prefixes:
751
+ try:
752
+ df = evaluate_model_by_subgroups(adata, model_prefix=model_prefix, suffix="GpC_site_CpG_site", groupby_cols=groupby_cols, label_col=label_col, exclude_training_data=exclude_training_data)
753
+ all_metrics.append(df)
754
+ except Exception as e:
755
+ print(f"Skipping {model_prefix} due to error: {e}")
756
+
757
+ final_df = pd.concat(all_metrics, ignore_index=True)
758
+ return final_df
759
+
760
+ def prepare_melted_model_data(adata, outkey='melted_model_df', groupby=['Enhancer_Open', 'Promoter_Open'], label_col='activity_status', model_names = ['cnn', 'mlp', 'rf'], suffix='GpC_site_CpG_site', omit_training=True):
761
+ import pandas as pd
762
+ import seaborn as sns
763
+ import matplotlib.pyplot as plt
764
+ from sklearn.metrics import precision_recall_curve, roc_curve, auc
765
+ cols = groupby.append(label_col)
766
+ if omit_training:
767
+ subset = adata[adata.obs['used_for_training'] == False]
768
+ else:
769
+ subset = adata
770
+ df = subset.obs[cols].copy()
771
+ df[label_col] = df[label_col].astype('category').cat.codes
772
+
773
+ for model in model_names:
774
+ col = f"{model}_active_prob_{suffix}"
775
+ if col in subset.obs.columns:
776
+ df[f"{model}_prob"] = subset.obs[col].astype(float)
777
+
778
+ # Melt into long format
779
+ melted = df.melt(
780
+ id_vars=cols,
781
+ value_vars=[f"{m}_prob" for m in model_names if f"{m}_active_prob_{suffix}" in subset.obs.columns],
782
+ var_name='model',
783
+ value_name='prob'
784
+ )
785
+ melted['model'] = melted['model'].str.replace('_prob', '', regex=False)
786
+
787
+ adata.uns[outkey] = melted