smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,787 @@
|
|
|
1
|
+
## Train CNN, RNN, Random Forest models on double barcoded, low contamination datasets
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
6
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
7
|
+
from sklearn.naive_bayes import GaussianNB
|
|
8
|
+
from sklearn.model_selection import train_test_split
|
|
9
|
+
from sklearn.metrics import roc_curve, auc, precision_recall_curve, f1_score, confusion_matrix
|
|
10
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
# Device detection
|
|
16
|
+
device = (
|
|
17
|
+
torch.device('cuda') if torch.cuda.is_available() else
|
|
18
|
+
torch.device('mps') if torch.backends.mps.is_available() else
|
|
19
|
+
torch.device('cpu')
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# ------------------------- Utilities -------------------------
|
|
23
|
+
def random_fill_nans(X):
|
|
24
|
+
nan_mask = np.isnan(X)
|
|
25
|
+
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
26
|
+
return X
|
|
27
|
+
|
|
28
|
+
# ------------------------- Model Definitions -------------------------
|
|
29
|
+
class CNNClassifier(nn.Module):
|
|
30
|
+
def __init__(self, input_size, num_classes):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
|
|
33
|
+
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
|
|
34
|
+
self.relu = nn.ReLU()
|
|
35
|
+
dummy_input = torch.zeros(1, 1, input_size)
|
|
36
|
+
dummy_output = self._forward_conv(dummy_input)
|
|
37
|
+
flattened_size = dummy_output.view(1, -1).shape[1]
|
|
38
|
+
self.fc1 = nn.Linear(flattened_size, 64)
|
|
39
|
+
self.fc2 = nn.Linear(64, num_classes)
|
|
40
|
+
|
|
41
|
+
def _forward_conv(self, x):
|
|
42
|
+
x = self.relu(self.conv1(x))
|
|
43
|
+
x = self.relu(self.conv2(x))
|
|
44
|
+
return x
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
x = x.unsqueeze(1)
|
|
48
|
+
x = self._forward_conv(x)
|
|
49
|
+
x = x.view(x.size(0), -1)
|
|
50
|
+
x = self.relu(self.fc1(x))
|
|
51
|
+
return self.fc2(x)
|
|
52
|
+
|
|
53
|
+
class MLPClassifier(nn.Module):
|
|
54
|
+
def __init__(self, input_dim, num_classes):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.model = nn.Sequential(
|
|
57
|
+
nn.Linear(input_dim, 128),
|
|
58
|
+
nn.ReLU(),
|
|
59
|
+
nn.Dropout(0.2),
|
|
60
|
+
nn.Linear(128, 64),
|
|
61
|
+
nn.ReLU(),
|
|
62
|
+
nn.Dropout(0.2),
|
|
63
|
+
nn.Linear(64, num_classes)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def forward(self, x):
|
|
67
|
+
return self.model(x)
|
|
68
|
+
|
|
69
|
+
class RNNClassifier(nn.Module):
|
|
70
|
+
def __init__(self, input_size, hidden_dim, num_classes):
|
|
71
|
+
super().__init__()
|
|
72
|
+
# Define LSTM layer
|
|
73
|
+
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
74
|
+
# Define fully connected output layer
|
|
75
|
+
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
76
|
+
|
|
77
|
+
def forward(self, x):
|
|
78
|
+
x = x.unsqueeze(1)
|
|
79
|
+
_, (h_n, _) = self.lstm(x)
|
|
80
|
+
return self.fc(h_n.squeeze(0))
|
|
81
|
+
|
|
82
|
+
class AttentionRNNClassifier(nn.Module):
|
|
83
|
+
def __init__(self, input_size, hidden_dim, num_classes):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
86
|
+
self.attn = nn.Linear(hidden_dim, 1) # Simple attention scores
|
|
87
|
+
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
88
|
+
|
|
89
|
+
def forward(self, x):
|
|
90
|
+
x = x.unsqueeze(1) # shape: (batch, 1, seq_len)
|
|
91
|
+
lstm_out, _ = self.lstm(x) # shape: (batch, 1, hidden_dim)
|
|
92
|
+
attn_weights = torch.softmax(self.attn(lstm_out), dim=1) # (batch, 1, 1)
|
|
93
|
+
context = (attn_weights * lstm_out).sum(dim=1) # weighted sum
|
|
94
|
+
return self.fc(context)
|
|
95
|
+
|
|
96
|
+
class PositionalEncoding(nn.Module):
|
|
97
|
+
def __init__(self, d_model, max_len=5000):
|
|
98
|
+
super().__init__()
|
|
99
|
+
pe = torch.zeros(max_len, d_model)
|
|
100
|
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
101
|
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
|
102
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
103
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
104
|
+
self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
|
105
|
+
|
|
106
|
+
def forward(self, x):
|
|
107
|
+
x = x + self.pe[:, :x.size(1)].to(x.device)
|
|
108
|
+
return x
|
|
109
|
+
|
|
110
|
+
class TransformerClassifier(nn.Module):
|
|
111
|
+
def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.input_fc = nn.Linear(input_dim, model_dim)
|
|
114
|
+
self.pos_encoder = PositionalEncoding(model_dim)
|
|
115
|
+
|
|
116
|
+
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads)
|
|
117
|
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
118
|
+
|
|
119
|
+
self.cls_head = nn.Linear(model_dim, num_classes)
|
|
120
|
+
|
|
121
|
+
def forward(self, x):
|
|
122
|
+
# x: [batch_size, input_dim]
|
|
123
|
+
x = self.input_fc(x).unsqueeze(1) # -> [batch_size, 1, model_dim]
|
|
124
|
+
x = self.pos_encoder(x)
|
|
125
|
+
x = x.permute(1, 0, 2) # -> [seq_len=1, batch_size, model_dim]
|
|
126
|
+
encoded = self.transformer(x)
|
|
127
|
+
pooled = encoded.mean(dim=0) # -> [batch_size, model_dim]
|
|
128
|
+
return self.cls_head(pooled)
|
|
129
|
+
|
|
130
|
+
def train_model(model, loader, optimizer, criterion, device, ref_name="", model_name="", epochs=20, patience=5):
|
|
131
|
+
model.train()
|
|
132
|
+
best_loss = float('inf')
|
|
133
|
+
trigger_times = 0
|
|
134
|
+
|
|
135
|
+
for epoch in range(epochs):
|
|
136
|
+
running_loss = 0
|
|
137
|
+
for batch_X, batch_y in loader:
|
|
138
|
+
batch_X, batch_y = batch_X.to(device), batch_y.to(device)
|
|
139
|
+
optimizer.zero_grad()
|
|
140
|
+
loss = criterion(model(batch_X), batch_y)
|
|
141
|
+
loss.backward()
|
|
142
|
+
optimizer.step()
|
|
143
|
+
running_loss += loss.item()
|
|
144
|
+
average_loss = running_loss / len(loader)
|
|
145
|
+
print(f"{ref_name} {model_name} Epoch {epoch+1} Loss: {average_loss:.4f}")
|
|
146
|
+
|
|
147
|
+
if average_loss < best_loss:
|
|
148
|
+
best_loss = average_loss
|
|
149
|
+
trigger_times = 0
|
|
150
|
+
else:
|
|
151
|
+
trigger_times += 1
|
|
152
|
+
if trigger_times >= patience:
|
|
153
|
+
print(f"Early stopping {model_name} for {ref_name} at epoch {epoch+1}")
|
|
154
|
+
break
|
|
155
|
+
|
|
156
|
+
def evaluate_model(model, X_tensor, y_encoded, device):
|
|
157
|
+
model.eval()
|
|
158
|
+
with torch.no_grad():
|
|
159
|
+
outputs = model(X_tensor.to(device))
|
|
160
|
+
probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
|
|
161
|
+
preds = outputs.argmax(dim=1).cpu().numpy()
|
|
162
|
+
fpr, tpr, _ = roc_curve(y_encoded, probs)
|
|
163
|
+
precision, recall, _ = precision_recall_curve(y_encoded, probs)
|
|
164
|
+
f1 = f1_score(y_encoded, preds)
|
|
165
|
+
cm = confusion_matrix(y_encoded, preds)
|
|
166
|
+
roc_auc = auc(fpr, tpr)
|
|
167
|
+
pr_auc = auc(recall,precision)
|
|
168
|
+
# positive-class frequency
|
|
169
|
+
pos_freq = np.mean(y_encoded == 1)
|
|
170
|
+
pr_auc_norm = pr_auc / pos_freq
|
|
171
|
+
return {
|
|
172
|
+
'fpr': fpr, 'tpr': tpr,
|
|
173
|
+
'precision': precision, 'recall': recall,
|
|
174
|
+
'f1': f1, 'auc': roc_auc, 'pr_auc': pr_auc,
|
|
175
|
+
'confusion_matrix': cm, 'pos_freq': pos_freq, 'pr_auc_norm': pr_auc_norm
|
|
176
|
+
}, preds, probs
|
|
177
|
+
|
|
178
|
+
def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
|
|
179
|
+
model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
|
|
180
|
+
model.fit(X_tensor[train_indices].numpy(), y_tensor[train_indices].numpy())
|
|
181
|
+
probs = model.predict_proba(X_tensor[test_indices].cpu().numpy())[:, 1]
|
|
182
|
+
preds = model.predict(X_tensor[test_indices].cpu().numpy())
|
|
183
|
+
return model, preds, probs
|
|
184
|
+
|
|
185
|
+
# ------------------------- Main Training Loop -------------------------
|
|
186
|
+
def run_training_loop(adata, site_config, layer_name=None,
|
|
187
|
+
mlp=False, cnn=False, rnn=False, arnn=False, transformer=False, rf=False, nb=False, rr_bayes=False,
|
|
188
|
+
max_epochs=10, max_patience=5, n_estimators=500, training_split=0.5):
|
|
189
|
+
device = (
|
|
190
|
+
torch.device('cuda') if torch.cuda.is_available() else
|
|
191
|
+
torch.device('mps') if torch.backends.mps.is_available() else
|
|
192
|
+
torch.device('cpu'))
|
|
193
|
+
metrics, models, positions, tensors = {}, {}, {}, {}
|
|
194
|
+
adata.obs["used_for_training"] = False # Initialize column to False
|
|
195
|
+
|
|
196
|
+
for ref in adata.obs['Reference_strand'].cat.categories:
|
|
197
|
+
ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
|
|
198
|
+
if ref_subset.shape[0] == 0:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
# Get matrix and coordinates
|
|
202
|
+
if layer_name:
|
|
203
|
+
matrix = ref_subset.layers[layer_name].copy()
|
|
204
|
+
coords = ref_subset.var_names
|
|
205
|
+
suffix = layer_name
|
|
206
|
+
else:
|
|
207
|
+
site_mask = np.zeros(ref_subset.shape[1], dtype=bool)
|
|
208
|
+
if ref in site_config:
|
|
209
|
+
for site in site_config[ref]:
|
|
210
|
+
site_mask |= ref_subset.var[f'{ref}_{site}']
|
|
211
|
+
suffix = "_".join(site_config[ref])
|
|
212
|
+
else:
|
|
213
|
+
site_mask = np.ones(ref_subset.shape[1], dtype=bool)
|
|
214
|
+
suffix = "full"
|
|
215
|
+
site_subset = ref_subset[:, site_mask].copy()
|
|
216
|
+
matrix = site_subset.X
|
|
217
|
+
coords = site_subset.var_names
|
|
218
|
+
|
|
219
|
+
positions.setdefault(ref, {})[suffix] = coords
|
|
220
|
+
|
|
221
|
+
# Fill and encode
|
|
222
|
+
X = random_fill_nans(matrix)
|
|
223
|
+
y = ref_subset.obs["activity_status"]
|
|
224
|
+
y_encoded = y.map({'Active': 1, 'Silent': 0})
|
|
225
|
+
X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
226
|
+
y_tensor = torch.tensor(y_encoded.values, dtype=torch.long)
|
|
227
|
+
tensors.setdefault(ref, {})[suffix] = X_tensor
|
|
228
|
+
|
|
229
|
+
# Setup datasets
|
|
230
|
+
dataset = TensorDataset(X_tensor, y_tensor)
|
|
231
|
+
train_size = int(training_split * len(dataset))
|
|
232
|
+
train_dataset, test_dataset = torch.utils.data.random_split(
|
|
233
|
+
dataset, [train_size, len(dataset) - train_size]
|
|
234
|
+
)
|
|
235
|
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
236
|
+
train_indices = adata.obs[adata.obs['Reference_strand'] == ref].index[train_dataset.indices]
|
|
237
|
+
adata.obs.loc[train_indices, "used_for_training"] = True
|
|
238
|
+
|
|
239
|
+
test_X = X_tensor[test_dataset.indices]
|
|
240
|
+
test_y = y_encoded.iloc[test_dataset.indices] if hasattr(y_encoded, 'iloc') else y_encoded[test_dataset.indices]
|
|
241
|
+
|
|
242
|
+
class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)
|
|
243
|
+
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
|
|
244
|
+
|
|
245
|
+
metrics[ref], models[ref] = {}, {}
|
|
246
|
+
|
|
247
|
+
# MLP
|
|
248
|
+
if mlp:
|
|
249
|
+
mlp_model = MLPClassifier(X.shape[1], len(np.unique(y_encoded))).to(device)
|
|
250
|
+
optimizer_mlp = optim.Adam(mlp_model.parameters(), lr=0.001)
|
|
251
|
+
criterion_mlp = nn.CrossEntropyLoss(weight=class_weights)
|
|
252
|
+
train_model(mlp_model, train_loader, optimizer_mlp, criterion_mlp, device, ref, 'MLP', epochs=max_epochs, patience=max_patience)
|
|
253
|
+
mlp_metrics, mlp_preds, mlp_probs = evaluate_model(mlp_model, test_X, test_y, device)
|
|
254
|
+
metrics[ref][f'mlp_{suffix}'] = mlp_metrics
|
|
255
|
+
models[ref][f'mlp_{suffix}'] = mlp_model
|
|
256
|
+
|
|
257
|
+
# CNN
|
|
258
|
+
if cnn:
|
|
259
|
+
cnn_model = CNNClassifier(X.shape[1], len(np.unique(y_encoded))).to(device)
|
|
260
|
+
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=0.001)
|
|
261
|
+
criterion_cnn = nn.CrossEntropyLoss(weight=class_weights)
|
|
262
|
+
train_model(cnn_model, train_loader, optimizer_cnn, criterion_cnn, device, ref, 'CNN', epochs=max_epochs, patience=max_patience)
|
|
263
|
+
cnn_metrics, cnn_preds, cnn_probs = evaluate_model(cnn_model, test_X, test_y, device)
|
|
264
|
+
metrics[ref][f'cnn_{suffix}'] = cnn_metrics
|
|
265
|
+
models[ref][f'cnn_{suffix}'] = cnn_model
|
|
266
|
+
|
|
267
|
+
# RNN
|
|
268
|
+
if rnn:
|
|
269
|
+
rnn_model = RNNClassifier(X.shape[1], 64, len(np.unique(y_encoded))).to(device)
|
|
270
|
+
optimizer_rnn = optim.Adam(rnn_model.parameters(), lr=0.001)
|
|
271
|
+
criterion_rnn = nn.CrossEntropyLoss(weight=class_weights)
|
|
272
|
+
train_model(rnn_model, train_loader, optimizer_rnn, criterion_rnn, device, ref, 'RNN', epochs=max_epochs, patience=max_patience)
|
|
273
|
+
rnn_metrics, rnn_preds, rnn_probs = evaluate_model(rnn_model, test_X, test_y, device)
|
|
274
|
+
metrics[ref][f'rnn_{suffix}'] = rnn_metrics
|
|
275
|
+
models[ref][f'rnn_{suffix}'] = rnn_model
|
|
276
|
+
|
|
277
|
+
# Attention RNN
|
|
278
|
+
if arnn:
|
|
279
|
+
arnn_model = AttentionRNNClassifier(X.shape[1], 64, len(np.unique(y_encoded))).to(device)
|
|
280
|
+
optimizer_arnn = optim.Adam(arnn_model.parameters(), lr=0.001)
|
|
281
|
+
criterion_arnn = nn.CrossEntropyLoss(weight=class_weights)
|
|
282
|
+
train_model(arnn_model, train_loader, optimizer_arnn, criterion_arnn, device, ref, 'aRNN', epochs=max_epochs, patience=max_patience)
|
|
283
|
+
arnn_metrics, arnn_preds, arnn_probs = evaluate_model(arnn_model, test_X, test_y, device)
|
|
284
|
+
metrics[ref][f'arnn_{suffix}'] = arnn_metrics
|
|
285
|
+
models[ref][f'arnn_{suffix}'] = arnn_model
|
|
286
|
+
|
|
287
|
+
# Transformer
|
|
288
|
+
if transformer:
|
|
289
|
+
t_model = TransformerClassifier(X.shape[1], 64, len(np.unique(y_encoded))).to(device)
|
|
290
|
+
optimizer_t = optim.Adam(t_model.parameters(), lr=0.001)
|
|
291
|
+
criterion_t = nn.CrossEntropyLoss(weight=class_weights)
|
|
292
|
+
train_model(t_model, train_loader, optimizer_t, criterion_t, device, ref, 'Transformer', epochs=max_epochs, patience=max_patience)
|
|
293
|
+
t_metrics, t_preds, t_probs = evaluate_model(t_model, test_X, test_y, device)
|
|
294
|
+
metrics[ref][f'transformer_{suffix}'] = t_metrics
|
|
295
|
+
models[ref][f'transformer_{suffix}'] = t_model
|
|
296
|
+
|
|
297
|
+
# RF
|
|
298
|
+
if rf:
|
|
299
|
+
rf_model, rf_preds, rf_probs = train_rf(X_tensor, y_tensor, train_dataset.indices, test_dataset.indices, n_estimators)
|
|
300
|
+
fpr, tpr, _ = roc_curve(test_y, rf_probs)
|
|
301
|
+
precision, recall, _ = precision_recall_curve(test_y, rf_probs)
|
|
302
|
+
f1 = f1_score(test_y, rf_preds)
|
|
303
|
+
cm = confusion_matrix(test_y, rf_preds)
|
|
304
|
+
roc_auc = auc(fpr, tpr)
|
|
305
|
+
pr_auc = auc(recall, precision)
|
|
306
|
+
metrics[ref][f'rf_{suffix}'] = {
|
|
307
|
+
'fpr': fpr, 'tpr': tpr,
|
|
308
|
+
'precision': precision, 'recall': recall,
|
|
309
|
+
'f1': f1, 'auc': roc_auc, 'confusion_matrix': cm, 'pr_auc': pr_auc
|
|
310
|
+
}
|
|
311
|
+
models[ref][f'rf_{suffix}'] = rf_model
|
|
312
|
+
|
|
313
|
+
# Naive Bayes
|
|
314
|
+
if nb:
|
|
315
|
+
nb_model = GaussianNB()
|
|
316
|
+
nb_model.fit(X_tensor[train_dataset.indices].numpy(), y_tensor[train_dataset.indices].numpy())
|
|
317
|
+
nb_probs = nb_model.predict_proba(test_X.numpy())[:, 1]
|
|
318
|
+
nb_preds = nb_model.predict(test_X.numpy())
|
|
319
|
+
fpr_nb, tpr_nb, _ = roc_curve(test_y, nb_probs)
|
|
320
|
+
prec_nb, rec_nb, _ = precision_recall_curve(test_y, nb_probs)
|
|
321
|
+
f1_nb = f1_score(test_y, nb_preds)
|
|
322
|
+
cm_nb = confusion_matrix(test_y, nb_preds)
|
|
323
|
+
auc_nb = auc(fpr_nb, tpr_nb)
|
|
324
|
+
pr_auc_nb = auc(rec_nb, prec_nb)
|
|
325
|
+
metrics[ref][f'nb_{suffix}'] = {
|
|
326
|
+
'fpr': fpr_nb, 'tpr': tpr_nb,
|
|
327
|
+
'precision': prec_nb, 'recall': rec_nb,
|
|
328
|
+
'f1': f1_nb, 'auc': auc_nb, 'confusion_matrix': cm_nb, 'pr_auc': pr_auc_nb
|
|
329
|
+
}
|
|
330
|
+
models[ref][f'nb_{suffix}'] = nb_model
|
|
331
|
+
|
|
332
|
+
# Relative-Risk Bayesian
|
|
333
|
+
if rr_bayes:
|
|
334
|
+
# compute relative risks from training
|
|
335
|
+
X_train = X_tensor[train_dataset.indices].numpy()
|
|
336
|
+
y_train_arr = y_tensor[train_dataset.indices].numpy()
|
|
337
|
+
n_a = (y_train_arr==1).sum(); n_s = (y_train_arr==0).sum()
|
|
338
|
+
p_a = (X_train[y_train_arr==1]==1).sum(axis=0)/(n_a+1e-6)
|
|
339
|
+
p_s = (X_train[y_train_arr==0]==1).sum(axis=0)/(n_s+1e-6)
|
|
340
|
+
rr = (p_a+1e-6)/(p_s+1e-6)
|
|
341
|
+
log_rr = np.log(rr)
|
|
342
|
+
# score test
|
|
343
|
+
scores = test_X.numpy().dot(log_rr)
|
|
344
|
+
probs = 1/(1+np.exp(-scores))
|
|
345
|
+
preds = (probs>=0.5).astype(int)
|
|
346
|
+
fpr, tpr, _ = roc_curve(test_y, probs)
|
|
347
|
+
pr, rc, _ = precision_recall_curve(test_y, probs)
|
|
348
|
+
roc_auc = auc(fpr, tpr)
|
|
349
|
+
pr_auc = auc(rc, pr)
|
|
350
|
+
pos_freq = np.mean(test_y==1)
|
|
351
|
+
pr_norm = pr_auc/pos_freq if pos_freq>0 else np.nan
|
|
352
|
+
metrics[ref][f'rr_bayes_{suffix}'] = {
|
|
353
|
+
'fpr': fpr, 'tpr': tpr,
|
|
354
|
+
'precision': pr, 'recall': rc,
|
|
355
|
+
'auc': roc_auc, 'pr_auc': pr_auc, 'pr_auc_norm': pr_norm
|
|
356
|
+
}
|
|
357
|
+
# save rr_bayes parameters as a pseudo-model
|
|
358
|
+
models[ref][f'rr_bayes_{suffix}'] = {
|
|
359
|
+
'log_rr': log_rr,
|
|
360
|
+
'p_active': p_a,
|
|
361
|
+
'p_silent': p_s
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
return metrics, models, positions, tensors
|
|
365
|
+
|
|
366
|
+
def sliding_window_train_test(
|
|
367
|
+
adata,
|
|
368
|
+
site_config,
|
|
369
|
+
layer_name,
|
|
370
|
+
window_sizes,
|
|
371
|
+
step_size,
|
|
372
|
+
training_split=0.7,
|
|
373
|
+
mlp=False,
|
|
374
|
+
cnn=False,
|
|
375
|
+
rnn=False,
|
|
376
|
+
arnn=False,
|
|
377
|
+
transformer=False,
|
|
378
|
+
rf=False,
|
|
379
|
+
nb=False,
|
|
380
|
+
rr_bayes=False,
|
|
381
|
+
epochs=10,
|
|
382
|
+
patience=5,
|
|
383
|
+
batch_size=64,
|
|
384
|
+
n_estimators=500,
|
|
385
|
+
balance_rf_class_weights=True,
|
|
386
|
+
positive_amount=None,
|
|
387
|
+
positive_freq=None,
|
|
388
|
+
bins=None
|
|
389
|
+
):
|
|
390
|
+
"""
|
|
391
|
+
Slide a window along features, train/test selected models at each position,
|
|
392
|
+
and append ROC-AUC and AUPRC values to the window center in adata.var in a single pass.
|
|
393
|
+
|
|
394
|
+
Torch models use GPU (CUDA > MPS > CPU).
|
|
395
|
+
|
|
396
|
+
bins: dict mapping bin_name -> boolean mask over adata.obs
|
|
397
|
+
"""
|
|
398
|
+
# device detection
|
|
399
|
+
device = (
|
|
400
|
+
torch.device('cuda') if torch.cuda.is_available() else
|
|
401
|
+
torch.device('mps') if torch.backends.mps.is_available() else
|
|
402
|
+
torch.device('cpu')
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# define bins: default single bin 'all'
|
|
406
|
+
bin_dict = bins if bins is not None else {'all': np.ones(adata.n_obs, dtype=bool)}
|
|
407
|
+
|
|
408
|
+
for bin_name, mask_obs in bin_dict.items():
|
|
409
|
+
mask_array = np.asarray(mask_obs)
|
|
410
|
+
if mask_array.shape[0] != adata.n_obs:
|
|
411
|
+
raise ValueError(f"Mask for bin '{bin_name}' length {mask_array.shape[0]} != n_obs {adata.n_obs}")
|
|
412
|
+
adata_bin = adata[mask_array]
|
|
413
|
+
if adata_bin.n_obs == 0:
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
for ref in adata_bin.obs['Reference_strand'].cat.categories:
|
|
417
|
+
subset = adata_bin[adata_bin.obs['Reference_strand'] == ref]
|
|
418
|
+
if subset.n_obs == 0:
|
|
419
|
+
continue
|
|
420
|
+
|
|
421
|
+
# build full feature matrix and var names
|
|
422
|
+
if layer_name:
|
|
423
|
+
suffix = layer_name
|
|
424
|
+
X_full = subset.layers[layer_name].copy()
|
|
425
|
+
var_names = subset.var_names
|
|
426
|
+
else:
|
|
427
|
+
mask_vars = np.zeros(subset.n_vars, dtype=bool)
|
|
428
|
+
if ref in site_config:
|
|
429
|
+
for site in site_config[ref]:
|
|
430
|
+
mask_vars |= subset.var[f'{ref}_{site}']
|
|
431
|
+
suffix = "_".join(site_config[ref])
|
|
432
|
+
else:
|
|
433
|
+
mask_vars[:] = True
|
|
434
|
+
suffix = "full"
|
|
435
|
+
X_full = subset[:, mask_vars].X.copy()
|
|
436
|
+
var_names = subset.var_names[mask_vars]
|
|
437
|
+
|
|
438
|
+
for window_size in window_sizes:
|
|
439
|
+
# prepare global arrays for each model/metric
|
|
440
|
+
n_vars_global = adata.n_vars
|
|
441
|
+
arrays = {}
|
|
442
|
+
|
|
443
|
+
for mname in ['mlp','cnn','rnn','arnn','transformer','rf','nb', 'rr_bayes']:
|
|
444
|
+
if locals()[mname]: # if model enabled
|
|
445
|
+
col_roc = f"{bin_name}_{mname}_{suffix}_w{window_size}_roc"
|
|
446
|
+
col_pr = f"{bin_name}_{mname}_{suffix}_w{window_size}_pr"
|
|
447
|
+
col_pr_norm = f"{bin_name}_{mname}_{suffix}_w{window_size}_pr_norm"
|
|
448
|
+
arrays[(mname,'roc')] = np.full(n_vars_global, np.nan)
|
|
449
|
+
arrays[(mname,'pr')] = np.full(n_vars_global, np.nan)
|
|
450
|
+
arrays[(mname,'pr_norm')] = np.full(n_vars_global, np.nan)
|
|
451
|
+
|
|
452
|
+
# fill missing and labels
|
|
453
|
+
X_full = random_fill_nans(X_full)
|
|
454
|
+
y_full = subset.obs['activity_status'].map({'Active':1,'Silent':0}).values
|
|
455
|
+
n_feats = X_full.shape[1]
|
|
456
|
+
|
|
457
|
+
# sliding windows
|
|
458
|
+
for start in range(0, n_feats - window_size + 1, step_size):
|
|
459
|
+
end = start + window_size
|
|
460
|
+
X_win = X_full[:, start:end]
|
|
461
|
+
center_idx = start + window_size // 2
|
|
462
|
+
center_var = var_names[center_idx]
|
|
463
|
+
var_idx_global = adata.var_names.get_loc(center_var)
|
|
464
|
+
|
|
465
|
+
# train/test split
|
|
466
|
+
X_train, X_pool, y_train, y_pool = train_test_split(
|
|
467
|
+
X_win, y_full, train_size=training_split,
|
|
468
|
+
stratify=y_full, random_state=42
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Optional test sampling with fallback
|
|
472
|
+
try:
|
|
473
|
+
if positive_amount is not None and positive_freq is not None:
|
|
474
|
+
pos_idx = np.where(y_pool == 1)[0]
|
|
475
|
+
if positive_amount > len(pos_idx):
|
|
476
|
+
raise ValueError("positive_amount exceeds available positives")
|
|
477
|
+
chosen_pos = np.random.choice(pos_idx, positive_amount, replace=False)
|
|
478
|
+
|
|
479
|
+
neg_amount = int(round(positive_amount * (1 - positive_freq) / positive_freq))
|
|
480
|
+
neg_idx = np.where(y_pool == 0)[0]
|
|
481
|
+
if neg_amount > len(neg_idx):
|
|
482
|
+
raise ValueError("negative_amount exceeds available negatives")
|
|
483
|
+
chosen_neg = np.random.choice(neg_idx, neg_amount, replace=False)
|
|
484
|
+
|
|
485
|
+
sel = np.concatenate([chosen_pos, chosen_neg])
|
|
486
|
+
X_test, y_test = X_pool[sel], y_pool[sel]
|
|
487
|
+
else:
|
|
488
|
+
X_test, y_test = X_pool, y_pool
|
|
489
|
+
except ValueError as e:
|
|
490
|
+
warnings.warn(
|
|
491
|
+
f"Falling back to full pool for window {start}:{end} ({e})"
|
|
492
|
+
)
|
|
493
|
+
X_test, y_test = X_pool, y_pool
|
|
494
|
+
|
|
495
|
+
# prepare DataLoader
|
|
496
|
+
train_ds = TensorDataset(
|
|
497
|
+
torch.tensor(X_train, dtype=torch.float32).to(device),
|
|
498
|
+
torch.tensor(y_train, dtype=torch.long).to(device)
|
|
499
|
+
)
|
|
500
|
+
loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
|
501
|
+
class_w = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
|
|
502
|
+
class_w = torch.tensor(class_w, dtype=torch.float32).to(device)
|
|
503
|
+
|
|
504
|
+
# container for this window's metrics
|
|
505
|
+
results = {}
|
|
506
|
+
# train and evaluate models
|
|
507
|
+
if mlp:
|
|
508
|
+
model = MLPClassifier(window_size,2).to(device)
|
|
509
|
+
train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_MLP", epochs, patience)
|
|
510
|
+
mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
|
|
511
|
+
results['mlp'] = mets
|
|
512
|
+
if cnn:
|
|
513
|
+
model = CNNClassifier(window_size,2).to(device)
|
|
514
|
+
train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_CNN", epochs, patience)
|
|
515
|
+
mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
|
|
516
|
+
results['cnn'] = mets
|
|
517
|
+
if rnn:
|
|
518
|
+
model = RNNClassifier(window_size,64,2).to(device)
|
|
519
|
+
train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_RNN", epochs, patience)
|
|
520
|
+
mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
|
|
521
|
+
results['rnn'] = mets
|
|
522
|
+
if arnn:
|
|
523
|
+
model = AttentionRNNClassifier(window_size,64,2).to(device)
|
|
524
|
+
train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_aRNN", epochs, patience)
|
|
525
|
+
mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
|
|
526
|
+
results['arnn'] = mets
|
|
527
|
+
if transformer:
|
|
528
|
+
model = TransformerClassifier(window_size,64,2).to(device)
|
|
529
|
+
train_model(model, loader, torch.optim.Adam(model.parameters()), nn.CrossEntropyLoss(weight=class_w), device, ref, f"{bin_name}_Transformer", epochs, patience)
|
|
530
|
+
mets,_,_ = evaluate_model(model, torch.tensor(X_test,dtype=torch.float32).to(device), y_test, device)
|
|
531
|
+
results['transformer'] = mets
|
|
532
|
+
if rf:
|
|
533
|
+
if balance_rf_class_weights:
|
|
534
|
+
rf_mod = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
|
|
535
|
+
else:
|
|
536
|
+
rf_mod = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
|
|
537
|
+
rf_mod.fit(X_train,y_train)
|
|
538
|
+
probs = rf_mod.predict_proba(X_test)[:,1]
|
|
539
|
+
fpr,tpr,_ = roc_curve(y_test, probs)
|
|
540
|
+
pr,rc,_ = precision_recall_curve(y_test, probs)
|
|
541
|
+
pr_auc = auc(rc,pr)
|
|
542
|
+
# positive-class frequency
|
|
543
|
+
pos_freq = np.mean(y_test == 1)
|
|
544
|
+
pr_auc_norm = pr_auc / pos_freq
|
|
545
|
+
results['rf'] = {'auc':auc(fpr,tpr),'pr_auc':pr_auc, 'pr_auc_norm': pr_auc_norm}
|
|
546
|
+
if nb:
|
|
547
|
+
nb_mod = GaussianNB()
|
|
548
|
+
nb_mod.fit(X_train,y_train)
|
|
549
|
+
probs = nb_mod.predict_proba(X_test)[:,1]
|
|
550
|
+
fpr,tpr,_ = roc_curve(y_test, probs)
|
|
551
|
+
pr,rc,_ = precision_recall_curve(y_test, probs)
|
|
552
|
+
pr_auc = auc(rc,pr)
|
|
553
|
+
# positive-class frequency
|
|
554
|
+
pos_freq = np.mean(y_test == 1)
|
|
555
|
+
pr_auc_norm = pr_auc / pos_freq
|
|
556
|
+
results['nb'] = {'auc':auc(fpr,tpr),'pr_auc':pr_auc, 'pr_auc_norm': pr_auc_norm}
|
|
557
|
+
if rr_bayes:
|
|
558
|
+
# Relative-risk Bayesian classifier
|
|
559
|
+
n_active = np.sum(y_train == 1)
|
|
560
|
+
n_silent = np.sum(y_train == 0)
|
|
561
|
+
# compute feature-wise rates
|
|
562
|
+
p_a = (X_train[y_train == 1] == 1).sum(axis=0) / (n_active + 1e-6)
|
|
563
|
+
p_s = (X_train[y_train == 0] == 1).sum(axis=0) / (n_silent + 1e-6)
|
|
564
|
+
rr = (p_a + 1e-6) / (p_s + 1e-6)
|
|
565
|
+
log_rr = np.log(rr)
|
|
566
|
+
# score samples
|
|
567
|
+
scores = X_test.dot(log_rr)
|
|
568
|
+
probs = 1 / (1 + np.exp(-scores))
|
|
569
|
+
preds = (probs >= 0.5).astype(int)
|
|
570
|
+
# metrics
|
|
571
|
+
fpr, tpr, _ = roc_curve(y_test, probs)
|
|
572
|
+
pr, rc, _ = precision_recall_curve(y_test, probs)
|
|
573
|
+
roc_auc = auc(fpr, tpr)
|
|
574
|
+
pr_auc = auc(rc, pr)
|
|
575
|
+
pos_freq = np.mean(y_test == 1)
|
|
576
|
+
pr_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
577
|
+
results['rr_bayes'] = {'auc': roc_auc, 'pr_auc': pr_auc, 'pr_auc_norm': pr_norm}
|
|
578
|
+
|
|
579
|
+
# assign metrics into arrays
|
|
580
|
+
for mname, mets in results.items():
|
|
581
|
+
arrays[(mname,'roc')][var_idx_global] = mets['auc']
|
|
582
|
+
arrays[(mname,'pr')][var_idx_global] = mets['pr_auc']
|
|
583
|
+
arrays[(mname,'pr_norm')][var_idx_global] = mets['pr_auc_norm']
|
|
584
|
+
|
|
585
|
+
# after all windows, write arrays to adata.var
|
|
586
|
+
for (mname,metric), arr in arrays.items():
|
|
587
|
+
suffix_col = metric
|
|
588
|
+
col = f"{bin_name}_{mname}_{suffix}_w{window_size}_{suffix_col}"
|
|
589
|
+
adata.var[col] = arr
|
|
590
|
+
|
|
591
|
+
print("✅ Sliding-window training/testing complete. Metrics stored at window centers in adata.var.")
|
|
592
|
+
|
|
593
|
+
# ------------------------- Apply models to input adata -------------------------
|
|
594
|
+
def run_inference(adata, models, site_config, layer_name=None, model_names=["cnn", "mlp", "rf", "nb", "rr_bayes"]):
|
|
595
|
+
"""
|
|
596
|
+
Perform inference on the full AnnData object using pre-trained models.
|
|
597
|
+
|
|
598
|
+
Parameters:
|
|
599
|
+
adata (AnnData): The full AnnData object.
|
|
600
|
+
models (dict): Dictionary of trained models keyed by reference strand.
|
|
601
|
+
site_config (dict): Configuration dictionary for subsetting features by site.
|
|
602
|
+
layer_name (str, optional): Name of the layer to use if applicable.
|
|
603
|
+
model_names (list, optional): List of model names to run inference on.
|
|
604
|
+
Defaults to ["cnn", "mlp", "rf", "nb"].
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
None. The function updates adata.obs with predicted class labels, predicted probabilities,
|
|
608
|
+
and active probabilities for each model.
|
|
609
|
+
"""
|
|
610
|
+
import numpy as np
|
|
611
|
+
import pandas as pd
|
|
612
|
+
|
|
613
|
+
device = (
|
|
614
|
+
torch.device('cuda') if torch.cuda.is_available() else
|
|
615
|
+
torch.device('mps') if torch.backends.mps.is_available() else
|
|
616
|
+
torch.device('cpu')
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Loop over each reference key in the models dictionary
|
|
620
|
+
for ref in models.keys():
|
|
621
|
+
# Subset the full AnnData by the reference strand
|
|
622
|
+
full_subset = adata[adata.obs['Reference_strand'] == ref]
|
|
623
|
+
if full_subset.shape[0] == 0:
|
|
624
|
+
continue
|
|
625
|
+
|
|
626
|
+
# Reconstruct the same layer or site mask used during training
|
|
627
|
+
if layer_name:
|
|
628
|
+
suffix = layer_name
|
|
629
|
+
full_matrix = full_subset.layers[layer_name].copy()
|
|
630
|
+
else:
|
|
631
|
+
site_mask = np.zeros(full_subset.shape[1], dtype=bool)
|
|
632
|
+
if ref in site_config:
|
|
633
|
+
for site in site_config[ref]:
|
|
634
|
+
site_mask |= full_subset.var[f'{ref}_{site}']
|
|
635
|
+
suffix = "_".join(site_config[ref])
|
|
636
|
+
else:
|
|
637
|
+
site_mask = np.ones(full_subset.shape[1], dtype=bool)
|
|
638
|
+
suffix = "full"
|
|
639
|
+
full_matrix = full_subset[:, site_mask].X.copy()
|
|
640
|
+
|
|
641
|
+
# Fill any NaNs in the feature matrix
|
|
642
|
+
full_matrix = random_fill_nans(full_matrix)
|
|
643
|
+
|
|
644
|
+
# Convert to a torch tensor; for torch models we use the specified device
|
|
645
|
+
full_tensor = torch.tensor(full_matrix, dtype=torch.float32)
|
|
646
|
+
full_tensor_device = full_tensor.to(device)
|
|
647
|
+
|
|
648
|
+
for model_name in model_names:
|
|
649
|
+
model_key = f"{model_name}_{suffix}"
|
|
650
|
+
pred_col = f"{model_name}_activity_prediction_{suffix}"
|
|
651
|
+
pred_prob_col = f"{model_name}_prediction_prob_{suffix}"
|
|
652
|
+
active_prob_col = f"{model_name}_active_prob_{suffix}"
|
|
653
|
+
|
|
654
|
+
if model_key in models[ref]:
|
|
655
|
+
model = models[ref][model_key]
|
|
656
|
+
if model_name in ["rf", "nb"]:
|
|
657
|
+
# For scikit-learn based models, work on CPU using NumPy arrays
|
|
658
|
+
X_input = full_tensor.cpu().numpy()
|
|
659
|
+
preds = model.predict(X_input)
|
|
660
|
+
probs = model.predict_proba(X_input)
|
|
661
|
+
elif model_name=='rr_bayes':
|
|
662
|
+
# model is dict of params
|
|
663
|
+
log_rr = model['log_rr']
|
|
664
|
+
scores = full_tensor.cpu().numpy().dot(log_rr)
|
|
665
|
+
probs1 = 1/(1+np.exp(-scores))
|
|
666
|
+
preds = (probs1>=0.5).astype(int)
|
|
667
|
+
probs = np.vstack([1-probs1, probs1]).T
|
|
668
|
+
else:
|
|
669
|
+
# For torch models, perform inference on the specified device
|
|
670
|
+
model.eval()
|
|
671
|
+
with torch.no_grad():
|
|
672
|
+
logits = model(full_tensor_device)
|
|
673
|
+
preds = logits.argmax(dim=1).cpu().numpy()
|
|
674
|
+
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
|
675
|
+
|
|
676
|
+
pred_probs = probs[np.arange(len(preds)), preds]
|
|
677
|
+
active_probs = probs[:, 1] # class 1 is assumed to be "Active"
|
|
678
|
+
|
|
679
|
+
# Store predictions in the AnnData object
|
|
680
|
+
adata.obs.loc[full_subset.obs.index, pred_col] = preds
|
|
681
|
+
adata.obs.loc[full_subset.obs.index, pred_prob_col] = pred_probs
|
|
682
|
+
adata.obs.loc[full_subset.obs.index, active_prob_col] = active_probs
|
|
683
|
+
|
|
684
|
+
# Ensure that the prediction columns are of categorical type.
|
|
685
|
+
# Replace non-finite values with 0 before converting.
|
|
686
|
+
for model_name in model_names:
|
|
687
|
+
pred_col = f"{model_name}_activity_prediction_{suffix}"
|
|
688
|
+
if pred_col in adata.obs.columns:
|
|
689
|
+
# Convert to numeric (non-finite become NaN), fill NaNs with 0, then cast
|
|
690
|
+
adata.obs[pred_col] = pd.to_numeric(adata.obs[pred_col], errors='coerce').fillna(0).astype(int).astype("category")
|
|
691
|
+
|
|
692
|
+
print("✅ Inference complete: stored predicted class, predicted probability, and active probability for each model.")
|
|
693
|
+
|
|
694
|
+
# ------------------------- Evaluate model activity predictions within categorical subgroups -------------------------
|
|
695
|
+
|
|
696
|
+
def evaluate_model_by_subgroups(
|
|
697
|
+
adata,
|
|
698
|
+
model_prefix="mlp",
|
|
699
|
+
suffix="GpC_site_CpG_site",
|
|
700
|
+
groupby_cols=["Sample_Names_Full", "Enhancer_Open", "Promoter_Open"],
|
|
701
|
+
label_col="activity_status",
|
|
702
|
+
min_samples=10,
|
|
703
|
+
exclude_training_data=True):
|
|
704
|
+
import pandas as pd
|
|
705
|
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
|
706
|
+
|
|
707
|
+
results = []
|
|
708
|
+
|
|
709
|
+
if exclude_training_data:
|
|
710
|
+
test_subset = adata[adata.obs['used_for_training'] == False]
|
|
711
|
+
else:
|
|
712
|
+
test_subset = adata
|
|
713
|
+
|
|
714
|
+
df = test_subset.obs.copy()
|
|
715
|
+
df[label_col] = df[label_col].astype('category').cat.codes
|
|
716
|
+
|
|
717
|
+
pred_col = f"{model_prefix}_activity_prediction_{suffix}"
|
|
718
|
+
prob_col = f"{model_prefix}_activity_prob_{suffix}"
|
|
719
|
+
|
|
720
|
+
if pred_col not in df or prob_col not in df:
|
|
721
|
+
raise ValueError(f"Missing prediction/probability columns for {model_prefix}")
|
|
722
|
+
|
|
723
|
+
for group_vals, group_df in df.groupby(groupby_cols):
|
|
724
|
+
if len(group_df) < min_samples:
|
|
725
|
+
continue # skip small groups
|
|
726
|
+
|
|
727
|
+
y_true = group_df[label_col].values
|
|
728
|
+
y_pred = group_df[pred_col].astype(int).values
|
|
729
|
+
y_prob = group_df[prob_col].astype(float).values
|
|
730
|
+
|
|
731
|
+
if len(set(y_true)) < 2:
|
|
732
|
+
auc = float('nan') # undefined if only one class present
|
|
733
|
+
else:
|
|
734
|
+
auc = roc_auc_score(y_true, y_prob)
|
|
735
|
+
|
|
736
|
+
results.append({
|
|
737
|
+
**dict(zip(groupby_cols, group_vals)),
|
|
738
|
+
"model": model_prefix,
|
|
739
|
+
"n_samples": len(group_df),
|
|
740
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
|
741
|
+
"f1": f1_score(y_true, y_pred),
|
|
742
|
+
"auc": auc,
|
|
743
|
+
})
|
|
744
|
+
|
|
745
|
+
return pd.DataFrame(results)
|
|
746
|
+
|
|
747
|
+
def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col, exclude_training_data=True):
|
|
748
|
+
import pandas as pd
|
|
749
|
+
all_metrics = []
|
|
750
|
+
for model_prefix in model_prefixes:
|
|
751
|
+
try:
|
|
752
|
+
df = evaluate_model_by_subgroups(adata, model_prefix=model_prefix, suffix="GpC_site_CpG_site", groupby_cols=groupby_cols, label_col=label_col, exclude_training_data=exclude_training_data)
|
|
753
|
+
all_metrics.append(df)
|
|
754
|
+
except Exception as e:
|
|
755
|
+
print(f"Skipping {model_prefix} due to error: {e}")
|
|
756
|
+
|
|
757
|
+
final_df = pd.concat(all_metrics, ignore_index=True)
|
|
758
|
+
return final_df
|
|
759
|
+
|
|
760
|
+
def prepare_melted_model_data(adata, outkey='melted_model_df', groupby=['Enhancer_Open', 'Promoter_Open'], label_col='activity_status', model_names = ['cnn', 'mlp', 'rf'], suffix='GpC_site_CpG_site', omit_training=True):
|
|
761
|
+
import pandas as pd
|
|
762
|
+
import seaborn as sns
|
|
763
|
+
import matplotlib.pyplot as plt
|
|
764
|
+
from sklearn.metrics import precision_recall_curve, roc_curve, auc
|
|
765
|
+
cols = groupby.append(label_col)
|
|
766
|
+
if omit_training:
|
|
767
|
+
subset = adata[adata.obs['used_for_training'] == False]
|
|
768
|
+
else:
|
|
769
|
+
subset = adata
|
|
770
|
+
df = subset.obs[cols].copy()
|
|
771
|
+
df[label_col] = df[label_col].astype('category').cat.codes
|
|
772
|
+
|
|
773
|
+
for model in model_names:
|
|
774
|
+
col = f"{model}_active_prob_{suffix}"
|
|
775
|
+
if col in subset.obs.columns:
|
|
776
|
+
df[f"{model}_prob"] = subset.obs[col].astype(float)
|
|
777
|
+
|
|
778
|
+
# Melt into long format
|
|
779
|
+
melted = df.melt(
|
|
780
|
+
id_vars=cols,
|
|
781
|
+
value_vars=[f"{m}_prob" for m in model_names if f"{m}_active_prob_{suffix}" in subset.obs.columns],
|
|
782
|
+
var_name='model',
|
|
783
|
+
value_name='prob'
|
|
784
|
+
)
|
|
785
|
+
melted['model'] = melted['model'].str.replace('_prob', '', regex=False)
|
|
786
|
+
|
|
787
|
+
adata.uns[outkey] = melted
|