grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,990 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch_geometric.data import Data, Batch
|
|
6
|
+
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import random
|
|
10
|
+
from sklearn.cluster import KMeans
|
|
11
|
+
from scipy.stats import t
|
|
12
|
+
import math
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
16
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# --- GATEncoder and ProjectionHead (from original.py) ---
|
|
20
|
+
class ProjectionHead(nn.Module):
|
|
21
|
+
def __init__(self, in_channels, hidden_channels, out_channels):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.mlp = nn.Sequential(
|
|
24
|
+
nn.Linear(in_channels, hidden_channels),
|
|
25
|
+
nn.BatchNorm1d(hidden_channels),
|
|
26
|
+
nn.ReLU(inplace=True),
|
|
27
|
+
nn.Linear(hidden_channels, out_channels),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
return self.mlp(x)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GATEncoder(nn.Module): # From original.py
|
|
35
|
+
def __init__(
|
|
36
|
+
self, in_channels, hidden_channels, out_channels, heads=1, dropout=0.1
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
|
|
40
|
+
self.bn1 = nn.BatchNorm1d(hidden_channels * heads)
|
|
41
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
42
|
+
self.conv2 = GATConv(
|
|
43
|
+
hidden_channels * heads, out_channels, heads=heads, dropout=dropout
|
|
44
|
+
)
|
|
45
|
+
self.expected_output_dim = (
|
|
46
|
+
out_channels # For compatibility with checked.py logic if needed elsewhere
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def forward(self, x, edge_index, batch):
|
|
50
|
+
x = self.conv1(x, edge_index)
|
|
51
|
+
x = self.bn1(x)
|
|
52
|
+
x = F.elu(x)
|
|
53
|
+
x = self.dropout1(x)
|
|
54
|
+
x = self.conv2(x, edge_index)
|
|
55
|
+
if batch is not None:
|
|
56
|
+
graph_representation = global_mean_pool(x, batch)
|
|
57
|
+
else:
|
|
58
|
+
graph_representation = x.mean(dim=0)
|
|
59
|
+
return x, graph_representation
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GCNEncoder(nn.Module):
|
|
63
|
+
def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.conv1 = GCNConv(in_channels, hidden_channels)
|
|
66
|
+
self.bn1 = nn.BatchNorm1d(hidden_channels)
|
|
67
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
68
|
+
self.conv2 = GCNConv(hidden_channels, out_channels)
|
|
69
|
+
self.expected_output_dim = out_channels
|
|
70
|
+
|
|
71
|
+
def forward(self, x, edge_index, batch):
|
|
72
|
+
x = self.conv1(x, edge_index)
|
|
73
|
+
x = self.bn1(x)
|
|
74
|
+
x = F.elu(x)
|
|
75
|
+
x = self.dropout1(x)
|
|
76
|
+
x = self.conv2(x, edge_index)
|
|
77
|
+
if batch is not None:
|
|
78
|
+
graph_representation = global_mean_pool(x, batch)
|
|
79
|
+
else:
|
|
80
|
+
graph_representation = x.mean(dim=0)
|
|
81
|
+
return x, graph_representation
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MoCo(nn.Module):
|
|
85
|
+
def __init__(self, base_encoder, dim=128, K=1024, m=0.999, T=0.07):
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.K = K
|
|
88
|
+
self.m = m
|
|
89
|
+
self.T = T
|
|
90
|
+
|
|
91
|
+
self.encoder_q = base_encoder
|
|
92
|
+
self.encoder_k = copy.deepcopy(base_encoder)
|
|
93
|
+
|
|
94
|
+
# Use the base encoder output dimension for the projection head.
|
|
95
|
+
# GATEncoder outputs out_channels, used as ProjectionHead input.
|
|
96
|
+
# If expected_output_dim=128 then in_channels=128.
|
|
97
|
+
# dim is the projection output dimension.
|
|
98
|
+
projector_in_channels = getattr(
|
|
99
|
+
base_encoder, "expected_output_dim", 128
|
|
100
|
+
) # Default to 128 if not found
|
|
101
|
+
|
|
102
|
+
self.projector_q = ProjectionHead(
|
|
103
|
+
in_channels=projector_in_channels,
|
|
104
|
+
hidden_channels=256, # original.py uses 256 in the projector hidden layer
|
|
105
|
+
out_channels=dim,
|
|
106
|
+
)
|
|
107
|
+
self.projector_k = copy.deepcopy(self.projector_q)
|
|
108
|
+
|
|
109
|
+
for param_q, param_k in zip(
|
|
110
|
+
self.encoder_q.parameters(), self.encoder_k.parameters()
|
|
111
|
+
):
|
|
112
|
+
param_k.data.copy_(param_q.data)
|
|
113
|
+
param_k.requires_grad = False
|
|
114
|
+
for param_q, param_k in zip(
|
|
115
|
+
self.projector_q.parameters(), self.projector_k.parameters()
|
|
116
|
+
):
|
|
117
|
+
param_k.data.copy_(param_q.data)
|
|
118
|
+
param_k.requires_grad = False
|
|
119
|
+
|
|
120
|
+
self.register_buffer("queue", torch.randn(K, dim))
|
|
121
|
+
self.queue = F.normalize(
|
|
122
|
+
self.queue, dim=1
|
|
123
|
+
) # Row-wise normalization (dim=1), matching fixed.py
|
|
124
|
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
|
125
|
+
|
|
126
|
+
# --- Loss functions (from fixed.py) ---
|
|
127
|
+
def _compute_reconstruction_loss(self, node_embeddings, edge_index, num_nodes):
|
|
128
|
+
if hasattr(self, "weighted_recon_loss") and self.weighted_recon_loss:
|
|
129
|
+
return self._compute_weighted_reconstruction_loss(
|
|
130
|
+
node_embeddings, edge_index, num_nodes
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
return self._compute_basic_reconstruction_loss(
|
|
134
|
+
node_embeddings, edge_index, num_nodes
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _compute_basic_reconstruction_loss(
|
|
138
|
+
self, node_embeddings, edge_index, num_nodes
|
|
139
|
+
):
|
|
140
|
+
# NOTE: this variant uses raw node embeddings (no normalization).
|
|
141
|
+
reconstructed_adj = torch.sigmoid(
|
|
142
|
+
torch.mm(node_embeddings, node_embeddings.t())
|
|
143
|
+
)
|
|
144
|
+
original_dense_adj = torch.zeros(
|
|
145
|
+
(num_nodes, num_nodes), device=node_embeddings.device
|
|
146
|
+
)
|
|
147
|
+
original_dense_adj[edge_index[0], edge_index[1]] = 1
|
|
148
|
+
return F.binary_cross_entropy(reconstructed_adj, original_dense_adj)
|
|
149
|
+
|
|
150
|
+
def _compute_weighted_reconstruction_loss(
|
|
151
|
+
self, node_embeddings, edge_index, num_nodes
|
|
152
|
+
):
|
|
153
|
+
"""Weighted reconstruction loss to mitigate class imbalance."""
|
|
154
|
+
|
|
155
|
+
# L2-normalize node embeddings
|
|
156
|
+
node_embeddings_normalized = F.normalize(node_embeddings, p=2, dim=1)
|
|
157
|
+
|
|
158
|
+
# Pairwise similarity (logits)
|
|
159
|
+
sim_scores = torch.mm(
|
|
160
|
+
node_embeddings_normalized, node_embeddings_normalized.t()
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Build dense adjacency labels
|
|
164
|
+
original_dense_adj = torch.zeros(
|
|
165
|
+
(num_nodes, num_nodes), device=node_embeddings.device
|
|
166
|
+
)
|
|
167
|
+
original_dense_adj[edge_index[0], edge_index[1]] = 1
|
|
168
|
+
|
|
169
|
+
# Compute a capped positive weight (log-scaled)
|
|
170
|
+
edge_count = edge_index.size(1)
|
|
171
|
+
total_pairs = num_nodes * num_nodes
|
|
172
|
+
pos_weight = torch.tensor(
|
|
173
|
+
min(10.0, math.log(total_pairs / max(edge_count, 1)) + 1),
|
|
174
|
+
device=node_embeddings.device,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
loss = F.binary_cross_entropy_with_logits(
|
|
178
|
+
sim_scores,
|
|
179
|
+
original_dense_adj,
|
|
180
|
+
pos_weight=pos_weight,
|
|
181
|
+
reduction="mean",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
return loss
|
|
185
|
+
|
|
186
|
+
def _compute_clustering_loss(
|
|
187
|
+
self,
|
|
188
|
+
embeddings,
|
|
189
|
+
num_clusters,
|
|
190
|
+
dist_type="uniform",
|
|
191
|
+
input_features=None,
|
|
192
|
+
batch=None,
|
|
193
|
+
):
|
|
194
|
+
if dist_type == "spectral":
|
|
195
|
+
k_neighbors = getattr(self, "k_neighbors", 100)
|
|
196
|
+
sigma = getattr(self, "sigma", 1.0)
|
|
197
|
+
if input_features is not None and batch is not None:
|
|
198
|
+
graph_level_input = global_mean_pool(input_features, batch)
|
|
199
|
+
return self._compute_spectral_clustering_loss(
|
|
200
|
+
graph_level_input, embeddings, k=k_neighbors, sigma=sigma
|
|
201
|
+
)
|
|
202
|
+
elif input_features is not None:
|
|
203
|
+
print(
|
|
204
|
+
"WARNING: batch is None; pooling input_features by mean as a fallback"
|
|
205
|
+
)
|
|
206
|
+
graph_level_input = input_features.mean(dim=0, keepdim=True)
|
|
207
|
+
return self._compute_spectral_clustering_loss(
|
|
208
|
+
graph_level_input, embeddings, k=k_neighbors, sigma=sigma
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
print(
|
|
212
|
+
"WARNING: input_features not provided; using embeddings as a proxy for clustering loss"
|
|
213
|
+
)
|
|
214
|
+
return self._compute_spectral_clustering_loss(
|
|
215
|
+
embeddings, embeddings, k=k_neighbors, sigma=sigma
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
batch_size = embeddings.size(0)
|
|
219
|
+
if batch_size < num_clusters:
|
|
220
|
+
print(
|
|
221
|
+
f"WARNING: batch_size ({batch_size}) < num_clusters ({num_clusters}); using simplified clustering loss"
|
|
222
|
+
)
|
|
223
|
+
return self._compute_simple_clustering_loss(
|
|
224
|
+
embeddings, min(num_clusters, batch_size)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
embeddings_np = embeddings.detach().cpu().numpy()
|
|
229
|
+
# KMeans clustering (default n_init).
|
|
230
|
+
kmeans = KMeans(n_clusters=num_clusters, random_state=0, max_iter=100).fit(
|
|
231
|
+
embeddings_np
|
|
232
|
+
)
|
|
233
|
+
pseudo_labels = kmeans.labels_
|
|
234
|
+
|
|
235
|
+
pseudo_label_dist = torch.zeros(num_clusters, device=embeddings.device)
|
|
236
|
+
if pseudo_labels is not None: # Guard against None.
|
|
237
|
+
for label in pseudo_labels:
|
|
238
|
+
pseudo_label_dist[label] += 1
|
|
239
|
+
|
|
240
|
+
eps = 1e-8
|
|
241
|
+
pseudo_label_dist = pseudo_label_dist + eps
|
|
242
|
+
pseudo_label_dist /= pseudo_label_dist.sum()
|
|
243
|
+
|
|
244
|
+
if dist_type == "t-distribution":
|
|
245
|
+
embedding_mean = torch.median(embeddings, dim=1).values
|
|
246
|
+
dof = 10
|
|
247
|
+
scale = torch.std(embedding_mean)
|
|
248
|
+
embedding_mean_np = embedding_mean.detach().cpu().numpy()
|
|
249
|
+
ideal_t_dist_pdf = t.pdf(
|
|
250
|
+
embedding_mean_np,
|
|
251
|
+
dof,
|
|
252
|
+
loc=embedding_mean_np.mean(),
|
|
253
|
+
scale=scale.item(),
|
|
254
|
+
)
|
|
255
|
+
ideal_t_dist_tensor = torch.tensor(
|
|
256
|
+
ideal_t_dist_pdf, dtype=torch.float, device=embeddings.device
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
min_val = ideal_t_dist_tensor.min().item() # Python scalar.
|
|
260
|
+
max_val = ideal_t_dist_tensor.max().item() # Python scalar.
|
|
261
|
+
if max_val == min_val: # Handle constant values.
|
|
262
|
+
max_val = min_val + eps
|
|
263
|
+
|
|
264
|
+
ideal_t_dist_tensor = torch.histc(
|
|
265
|
+
ideal_t_dist_tensor, bins=num_clusters, min=min_val, max=max_val
|
|
266
|
+
)
|
|
267
|
+
ideal_t_dist_tensor = ideal_t_dist_tensor + eps
|
|
268
|
+
ideal_t_dist_tensor /= ideal_t_dist_tensor.sum()
|
|
269
|
+
return F.kl_div(
|
|
270
|
+
pseudo_label_dist.log(), ideal_t_dist_tensor, reduction="batchmean"
|
|
271
|
+
)
|
|
272
|
+
else: # 'uniform'
|
|
273
|
+
ideal_dist = (
|
|
274
|
+
torch.ones(num_clusters, device=embeddings.device) / num_clusters
|
|
275
|
+
)
|
|
276
|
+
return F.kl_div(
|
|
277
|
+
pseudo_label_dist.log(), ideal_dist, reduction="batchmean"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
except Exception as e:
|
|
281
|
+
print(
|
|
282
|
+
f"ERROR: KMeans clustering failed: {e}; using simplified clustering loss"
|
|
283
|
+
)
|
|
284
|
+
return self._compute_simple_clustering_loss(embeddings, num_clusters)
|
|
285
|
+
|
|
286
|
+
def _compute_simple_clustering_loss(self, embeddings, num_clusters):
|
|
287
|
+
batch_size = embeddings.size(0)
|
|
288
|
+
if batch_size < 2:
|
|
289
|
+
return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
|
|
290
|
+
|
|
291
|
+
normalized_emb = F.normalize(embeddings, p=2, dim=1)
|
|
292
|
+
similarity_matrix = torch.mm(normalized_emb, normalized_emb.t())
|
|
293
|
+
|
|
294
|
+
mask = torch.eye(batch_size, device=embeddings.device).bool()
|
|
295
|
+
similarity_matrix.masked_fill_(mask, 0)
|
|
296
|
+
|
|
297
|
+
clustering_loss = similarity_matrix.abs().mean()
|
|
298
|
+
return clustering_loss
|
|
299
|
+
|
|
300
|
+
def _compute_spectral_clustering_loss(
|
|
301
|
+
self, input_features, output_embeddings, k=100, sigma=1.0
|
|
302
|
+
):
|
|
303
|
+
batch_size = output_embeddings.size(0)
|
|
304
|
+
if batch_size <= 1:
|
|
305
|
+
return torch.tensor(0.0, device=output_embeddings.device)
|
|
306
|
+
|
|
307
|
+
input_squared_dist = torch.cdist(input_features, input_features, p=2) ** 2
|
|
308
|
+
sigma_squared = 2 * (sigma**2)
|
|
309
|
+
similarity = torch.exp(-input_squared_dist / sigma_squared)
|
|
310
|
+
|
|
311
|
+
mask = torch.eye(batch_size, device=output_embeddings.device).bool()
|
|
312
|
+
similarity.masked_fill_(mask, 0)
|
|
313
|
+
|
|
314
|
+
if k < batch_size - 1:
|
|
315
|
+
_, topk_indices = torch.topk(similarity, k=min(k, batch_size - 1), dim=1)
|
|
316
|
+
adjacency = torch.zeros_like(similarity)
|
|
317
|
+
batch_indices_arange = (
|
|
318
|
+
torch.arange(batch_size, device=output_embeddings.device)
|
|
319
|
+
.unsqueeze(1)
|
|
320
|
+
.expand(-1, topk_indices.size(1))
|
|
321
|
+
)
|
|
322
|
+
adjacency[batch_indices_arange, topk_indices] = similarity[
|
|
323
|
+
batch_indices_arange, topk_indices
|
|
324
|
+
]
|
|
325
|
+
adjacency = 0.5 * (adjacency + adjacency.t())
|
|
326
|
+
else:
|
|
327
|
+
adjacency = similarity
|
|
328
|
+
|
|
329
|
+
output_squared_dist = (
|
|
330
|
+
torch.cdist(output_embeddings, output_embeddings, p=2) ** 2
|
|
331
|
+
)
|
|
332
|
+
loss = torch.sum(adjacency * output_squared_dist)
|
|
333
|
+
|
|
334
|
+
edge_weights_sum = torch.sum(adjacency)
|
|
335
|
+
if edge_weights_sum > 0:
|
|
336
|
+
loss = loss / edge_weights_sum
|
|
337
|
+
|
|
338
|
+
return loss
|
|
339
|
+
|
|
340
|
+
@torch.no_grad()
|
|
341
|
+
def _momentum_update_key_encoder(self):
|
|
342
|
+
for param_q, param_k in zip(
|
|
343
|
+
self.encoder_q.parameters(), self.encoder_k.parameters()
|
|
344
|
+
):
|
|
345
|
+
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
|
|
346
|
+
for param_q, param_k in zip(
|
|
347
|
+
self.projector_q.parameters(), self.projector_k.parameters()
|
|
348
|
+
):
|
|
349
|
+
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
|
|
350
|
+
|
|
351
|
+
@torch.no_grad()
|
|
352
|
+
def _dequeue_and_enqueue(self, keys): # From fixed.py (robust version)
|
|
353
|
+
batch_size = keys.shape[0]
|
|
354
|
+
ptr_val = int(self.queue_ptr.item()) # Scalar value from tensor.
|
|
355
|
+
|
|
356
|
+
if ptr_val + batch_size > self.K:
|
|
357
|
+
first_part = self.K - ptr_val
|
|
358
|
+
self.queue[ptr_val:, :] = keys[:first_part]
|
|
359
|
+
remaining = batch_size - first_part
|
|
360
|
+
if remaining > 0:
|
|
361
|
+
self.queue[:remaining, :] = keys[first_part:]
|
|
362
|
+
new_ptr = remaining
|
|
363
|
+
else:
|
|
364
|
+
self.queue[ptr_val : ptr_val + batch_size, :] = keys
|
|
365
|
+
new_ptr = (ptr_val + batch_size) % self.K
|
|
366
|
+
|
|
367
|
+
self.queue_ptr[0] = new_ptr
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class MoCoMultiPositive(MoCo):
|
|
371
|
+
def __init__(self, base_encoder, dim=128, K=1024, m=0.999, T=0.07):
|
|
372
|
+
super().__init__(base_encoder, dim, K, m, T)
|
|
373
|
+
|
|
374
|
+
# --- forward method (batch_k_list removed) ---
|
|
375
|
+
def forward(
|
|
376
|
+
self,
|
|
377
|
+
im_q,
|
|
378
|
+
im_k_list,
|
|
379
|
+
edge_index_q,
|
|
380
|
+
edge_index_k_list,
|
|
381
|
+
batch,
|
|
382
|
+
num_clusters=None,
|
|
383
|
+
dist_type="uniform",
|
|
384
|
+
a=1.0,
|
|
385
|
+
b=1.0,
|
|
386
|
+
c=1.0,
|
|
387
|
+
use_clustering=True,
|
|
388
|
+
):
|
|
389
|
+
"""Forward pass with multiple positive samples.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
im_q: Node features of the query graph batch.
|
|
393
|
+
im_k_list: List of node features for positive key graph batches.
|
|
394
|
+
edge_index_q: Edge index for the query batch.
|
|
395
|
+
edge_index_k_list: List of edge indices for the key batches.
|
|
396
|
+
batch: Batch vector mapping nodes to graphs.
|
|
397
|
+
num_clusters: Number of clusters used by clustering loss.
|
|
398
|
+
dist_type: Target distribution for clustering loss ('uniform', 't-distribution', 'spectral').
|
|
399
|
+
a: Weight for reconstruction loss.
|
|
400
|
+
b: Weight for contrastive loss.
|
|
401
|
+
c: Weight for clustering loss.
|
|
402
|
+
use_clustering: Enable/disable clustering loss.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
(total_loss, reconstruction_loss, contrastive_loss, clustering_loss,
|
|
406
|
+
adjusted_reconstruction, adjusted_contrastive, adjusted_clustering)
|
|
407
|
+
"""
|
|
408
|
+
# 1. Query features
|
|
409
|
+
q_node_embeddings, q = self.encoder_q(im_q, edge_index_q, batch)
|
|
410
|
+
q = self.projector_q(q)
|
|
411
|
+
q = F.normalize(q, dim=1)
|
|
412
|
+
|
|
413
|
+
# 2. Positive key features
|
|
414
|
+
k_features = []
|
|
415
|
+
with torch.no_grad():
|
|
416
|
+
self._momentum_update_key_encoder()
|
|
417
|
+
for im_k, edge_index_k in zip(im_k_list, edge_index_k_list):
|
|
418
|
+
_, k = self.encoder_k(im_k, edge_index_k, batch)
|
|
419
|
+
k = self.projector_k(k)
|
|
420
|
+
k = F.normalize(k, dim=1)
|
|
421
|
+
k_features.append(k)
|
|
422
|
+
|
|
423
|
+
# 3. Similarities to all positive samples
|
|
424
|
+
l_pos_list = []
|
|
425
|
+
for k in k_features:
|
|
426
|
+
pos_sim = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
|
|
427
|
+
l_pos_list.append(pos_sim)
|
|
428
|
+
|
|
429
|
+
# l_pos: [batch_size, num_positives]
|
|
430
|
+
# l_neg: [batch_size, K] # K is queue size
|
|
431
|
+
l_pos = torch.cat(l_pos_list, dim=1) # shape: (batch_size, num_positives)
|
|
432
|
+
l_neg = torch.einsum("nc,kc->nk", [q, self.queue.clone().detach()])
|
|
433
|
+
|
|
434
|
+
# 4. Contrastive loss
|
|
435
|
+
pos_exp = torch.exp(l_pos / self.T) # [batch_size, num_positives]
|
|
436
|
+
neg_exp = torch.exp(l_neg / self.T)
|
|
437
|
+
# Numerator: sum over positives
|
|
438
|
+
numerator = pos_exp.sum(dim=1) # [batch_size]
|
|
439
|
+
# Denominator: sum over positives + sum over negatives
|
|
440
|
+
denominator = numerator + neg_exp.sum(dim=1) # [batch_size]
|
|
441
|
+
# loss = -log(sum(exp(pos))/sum(exp(all)))
|
|
442
|
+
contrastive_loss = -torch.log(numerator / denominator).mean()
|
|
443
|
+
|
|
444
|
+
# 5. Reconstruction loss
|
|
445
|
+
reconstruction_loss = self._compute_reconstruction_loss(
|
|
446
|
+
q_node_embeddings, edge_index_q, im_q.size(0)
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# 6. Clustering loss (optional)
|
|
450
|
+
if use_clustering:
|
|
451
|
+
# Extra params for spectral clustering
|
|
452
|
+
if dist_type == "spectral":
|
|
453
|
+
clustering_loss = self._compute_clustering_loss(
|
|
454
|
+
q, num_clusters, dist_type, input_features=im_q, batch=batch
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
clustering_loss = self._compute_clustering_loss(
|
|
458
|
+
q, num_clusters, dist_type
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
clustering_loss = torch.tensor(0.0, device=contrastive_loss.device)
|
|
462
|
+
|
|
463
|
+
# 7. Loss alignment
|
|
464
|
+
eps = 1e-6
|
|
465
|
+
adjusted_contrastive = contrastive_loss / (
|
|
466
|
+
(contrastive_loss / (reconstruction_loss + eps)).detach() + eps
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if use_clustering:
|
|
470
|
+
adjusted_clustering = clustering_loss / (
|
|
471
|
+
(clustering_loss / (reconstruction_loss + eps)).detach() + eps
|
|
472
|
+
)
|
|
473
|
+
else:
|
|
474
|
+
adjusted_clustering = torch.tensor(0.0, device=contrastive_loss.device)
|
|
475
|
+
|
|
476
|
+
adjusted_reconstruction = reconstruction_loss
|
|
477
|
+
|
|
478
|
+
# 8. Total loss
|
|
479
|
+
if use_clustering:
|
|
480
|
+
total_loss = (
|
|
481
|
+
a * adjusted_reconstruction
|
|
482
|
+
+ b * adjusted_contrastive
|
|
483
|
+
+ c * adjusted_clustering
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
# Only reconstruction + contrastive
|
|
487
|
+
total_loss = a * adjusted_reconstruction + b * adjusted_contrastive
|
|
488
|
+
|
|
489
|
+
# 9. Update queue (use the last positive)
|
|
490
|
+
if k_features: # Ensure k_features is not empty.
|
|
491
|
+
self._dequeue_and_enqueue(k_features[-1])
|
|
492
|
+
|
|
493
|
+
# 10. Return losses
|
|
494
|
+
return (
|
|
495
|
+
total_loss,
|
|
496
|
+
reconstruction_loss,
|
|
497
|
+
contrastive_loss,
|
|
498
|
+
clustering_loss,
|
|
499
|
+
adjusted_reconstruction,
|
|
500
|
+
adjusted_contrastive,
|
|
501
|
+
adjusted_clustering,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
def forward_supcon(
|
|
505
|
+
self,
|
|
506
|
+
im_q,
|
|
507
|
+
im_k_list,
|
|
508
|
+
edge_index_q,
|
|
509
|
+
edge_index_k_list,
|
|
510
|
+
batch,
|
|
511
|
+
num_clusters=None,
|
|
512
|
+
dist_type="uniform",
|
|
513
|
+
a=1.0,
|
|
514
|
+
b=1.0,
|
|
515
|
+
c=1.0,
|
|
516
|
+
use_clustering=True,
|
|
517
|
+
):
|
|
518
|
+
q_node_embeddings, q = self.encoder_q(im_q, edge_index_q, batch)
|
|
519
|
+
q = self.projector_q(q)
|
|
520
|
+
q = F.normalize(q, dim=1)
|
|
521
|
+
|
|
522
|
+
k_features = []
|
|
523
|
+
with torch.no_grad():
|
|
524
|
+
self._momentum_update_key_encoder()
|
|
525
|
+
for im_k_single, edge_index_k_single in zip(im_k_list, edge_index_k_list):
|
|
526
|
+
_, k_single = self.encoder_k(im_k_single, edge_index_k_single, batch)
|
|
527
|
+
k_single = self.projector_k(k_single)
|
|
528
|
+
k_single = F.normalize(k_single, dim=1)
|
|
529
|
+
k_features.append(k_single)
|
|
530
|
+
|
|
531
|
+
contrastive_loss = torch.tensor(0.0, device=q.device)
|
|
532
|
+
if q.size(0) > 0 and k_features:
|
|
533
|
+
k_stacked = torch.stack(k_features).permute(1, 0, 2) # [B, P, D]
|
|
534
|
+
q_expanded = q.unsqueeze(1) # [B, 1, D]
|
|
535
|
+
pos_sims = (
|
|
536
|
+
torch.bmm(q_expanded, k_stacked.transpose(1, 2)).squeeze(1) / self.T
|
|
537
|
+
) # [B, P]
|
|
538
|
+
neg_sims = torch.matmul(q, self.queue.clone().detach().T) / self.T # [B, K]
|
|
539
|
+
logits_all = torch.cat([pos_sims, neg_sims], dim=1)
|
|
540
|
+
logsum = torch.logsumexp(logits_all, dim=1, keepdim=True)
|
|
541
|
+
log_probs = pos_sims - logsum
|
|
542
|
+
contrastive_loss = -log_probs.mean()
|
|
543
|
+
|
|
544
|
+
reconstruction_loss = self._compute_reconstruction_loss(
|
|
545
|
+
q_node_embeddings, edge_index_q, im_q.size(0)
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
if use_clustering:
|
|
549
|
+
clustering_loss = self._compute_clustering_loss(
|
|
550
|
+
q, num_clusters, dist_type, input_features=im_q, batch=batch
|
|
551
|
+
)
|
|
552
|
+
else:
|
|
553
|
+
clustering_loss = torch.tensor(
|
|
554
|
+
0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
eps = 1e-6
|
|
558
|
+
adjusted_contrastive = contrastive_loss / (
|
|
559
|
+
(contrastive_loss / (reconstruction_loss + eps)).detach() + eps
|
|
560
|
+
)
|
|
561
|
+
adjusted_reconstruction = reconstruction_loss
|
|
562
|
+
|
|
563
|
+
if use_clustering:
|
|
564
|
+
adjusted_clustering = clustering_loss / (
|
|
565
|
+
(clustering_loss / (reconstruction_loss + eps)).detach() + eps
|
|
566
|
+
)
|
|
567
|
+
total_loss = (
|
|
568
|
+
a * adjusted_reconstruction
|
|
569
|
+
+ b * adjusted_contrastive
|
|
570
|
+
+ c * adjusted_clustering
|
|
571
|
+
)
|
|
572
|
+
else:
|
|
573
|
+
adjusted_clustering = torch.tensor(
|
|
574
|
+
0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
|
|
575
|
+
)
|
|
576
|
+
total_loss = a * adjusted_reconstruction + b * adjusted_contrastive
|
|
577
|
+
|
|
578
|
+
if k_features:
|
|
579
|
+
k_avg = torch.mean(torch.stack(k_features), dim=0)
|
|
580
|
+
k_avg = F.normalize(k_avg, dim=1)
|
|
581
|
+
self._dequeue_and_enqueue(k_avg)
|
|
582
|
+
|
|
583
|
+
return (
|
|
584
|
+
total_loss,
|
|
585
|
+
reconstruction_loss,
|
|
586
|
+
contrastive_loss,
|
|
587
|
+
clustering_loss,
|
|
588
|
+
adjusted_reconstruction,
|
|
589
|
+
adjusted_contrastive,
|
|
590
|
+
adjusted_clustering,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
def forward_avg(
|
|
594
|
+
self,
|
|
595
|
+
im_q,
|
|
596
|
+
im_k_list,
|
|
597
|
+
edge_index_q,
|
|
598
|
+
edge_index_k_list,
|
|
599
|
+
batch,
|
|
600
|
+
num_clusters=None,
|
|
601
|
+
dist_type="uniform",
|
|
602
|
+
a=1.0,
|
|
603
|
+
b=1.0,
|
|
604
|
+
c=1.0,
|
|
605
|
+
use_clustering=True,
|
|
606
|
+
):
|
|
607
|
+
# forward_avg is similar to forward_supcon but updates the queue differently.
|
|
608
|
+
q_node_embeddings, q = self.encoder_q(im_q, edge_index_q, batch)
|
|
609
|
+
q = self.projector_q(q)
|
|
610
|
+
q = F.normalize(q, dim=1)
|
|
611
|
+
|
|
612
|
+
k_features = []
|
|
613
|
+
with torch.no_grad():
|
|
614
|
+
self._momentum_update_key_encoder()
|
|
615
|
+
for im_k_single, edge_index_k_single in zip(im_k_list, edge_index_k_list):
|
|
616
|
+
_, k_single = self.encoder_k(im_k_single, edge_index_k_single, batch)
|
|
617
|
+
k_single = self.projector_k(k_single)
|
|
618
|
+
k_single = F.normalize(k_single, dim=1)
|
|
619
|
+
k_features.append(k_single)
|
|
620
|
+
|
|
621
|
+
contrastive_loss = torch.tensor(0.0, device=q.device)
|
|
622
|
+
k_avg_for_loss = None
|
|
623
|
+
if q.size(0) > 0 and k_features:
|
|
624
|
+
k_avg_for_loss = torch.mean(torch.stack(k_features), dim=0) # [B,D]
|
|
625
|
+
k_avg_for_loss = F.normalize(k_avg_for_loss, dim=1)
|
|
626
|
+
|
|
627
|
+
pos_sims_avg = (
|
|
628
|
+
torch.einsum("nc,nc->n", q, k_avg_for_loss).unsqueeze(-1) / self.T
|
|
629
|
+
) # [B,1]
|
|
630
|
+
neg_sims = torch.matmul(q, self.queue.clone().detach().T) / self.T # [B, K]
|
|
631
|
+
|
|
632
|
+
logits_all = torch.cat([pos_sims_avg, neg_sims], dim=1) # [B, 1+K]
|
|
633
|
+
logsum = torch.logsumexp(logits_all, dim=1, keepdim=True) # [B,1]
|
|
634
|
+
log_probs = pos_sims_avg - logsum # [B,1]
|
|
635
|
+
contrastive_loss = -log_probs.mean()
|
|
636
|
+
|
|
637
|
+
reconstruction_loss = self._compute_reconstruction_loss(
|
|
638
|
+
q_node_embeddings, edge_index_q, im_q.size(0)
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
if use_clustering:
|
|
642
|
+
clustering_loss = self._compute_clustering_loss(
|
|
643
|
+
q, num_clusters, dist_type, input_features=im_q, batch=batch
|
|
644
|
+
)
|
|
645
|
+
else:
|
|
646
|
+
clustering_loss = torch.tensor(
|
|
647
|
+
0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
eps = 1e-6
|
|
651
|
+
adjusted_contrastive = contrastive_loss / (
|
|
652
|
+
(contrastive_loss / (reconstruction_loss + eps)).detach() + eps
|
|
653
|
+
)
|
|
654
|
+
adjusted_reconstruction = reconstruction_loss
|
|
655
|
+
|
|
656
|
+
if use_clustering:
|
|
657
|
+
adjusted_clustering = clustering_loss / (
|
|
658
|
+
(clustering_loss / (reconstruction_loss + eps)).detach() + eps
|
|
659
|
+
)
|
|
660
|
+
total_loss = (
|
|
661
|
+
a * adjusted_reconstruction
|
|
662
|
+
+ b * adjusted_contrastive
|
|
663
|
+
+ c * adjusted_clustering
|
|
664
|
+
)
|
|
665
|
+
else:
|
|
666
|
+
adjusted_clustering = torch.tensor(
|
|
667
|
+
0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
|
|
668
|
+
)
|
|
669
|
+
total_loss = a * adjusted_reconstruction + b * adjusted_contrastive
|
|
670
|
+
|
|
671
|
+
if k_avg_for_loss is not None: # Use the precomputed k_avg_for_loss.
|
|
672
|
+
self._dequeue_and_enqueue(k_avg_for_loss)
|
|
673
|
+
|
|
674
|
+
return (
|
|
675
|
+
total_loss,
|
|
676
|
+
reconstruction_loss,
|
|
677
|
+
contrastive_loss,
|
|
678
|
+
clustering_loss,
|
|
679
|
+
adjusted_reconstruction,
|
|
680
|
+
adjusted_contrastive,
|
|
681
|
+
adjusted_clustering,
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# TODO: Prefer JS-distance positives when available; if insufficient, sample from same-gene graphs,
|
|
685
|
+
# and only fall back to the query graph as a last resort.
|
|
686
|
+
@staticmethod
|
|
687
|
+
def generate_samples_gw(
|
|
688
|
+
original_graphs,
|
|
689
|
+
augmented_graphs,
|
|
690
|
+
gene_labels,
|
|
691
|
+
cell_labels,
|
|
692
|
+
num_positive,
|
|
693
|
+
gw_distances_df,
|
|
694
|
+
):
|
|
695
|
+
"""Generate positive samples using GW distances.
|
|
696
|
+
|
|
697
|
+
Returns a list of (query_idx, [pos_idx1, pos_idx2, ...]). The first
|
|
698
|
+
positive is always the augmented view of the same graph.
|
|
699
|
+
"""
|
|
700
|
+
# (Optional) preprocessing: build gene/cell index maps
|
|
701
|
+
# gene_to_indices = {}
|
|
702
|
+
# cell_to_indices = {}
|
|
703
|
+
|
|
704
|
+
# for i, (gene, cell) in enumerate(zip(gene_labels, cell_labels)):
|
|
705
|
+
# if gene not in gene_to_indices:
|
|
706
|
+
# gene_to_indices[gene] = []
|
|
707
|
+
# gene_to_indices[gene].append(i)
|
|
708
|
+
|
|
709
|
+
# if cell not in cell_to_indices:
|
|
710
|
+
# cell_to_indices[cell] = []
|
|
711
|
+
# cell_to_indices[cell].append(i)
|
|
712
|
+
|
|
713
|
+
positive_samples = []
|
|
714
|
+
num_graphs = len(gene_labels)
|
|
715
|
+
# Build positives for each graph.
|
|
716
|
+
for i in range(len(original_graphs)):
|
|
717
|
+
current_positives = []
|
|
718
|
+
|
|
719
|
+
# 1) Always include the augmented positive
|
|
720
|
+
current_positives.append(i) # augmented_graphs[i] index
|
|
721
|
+
|
|
722
|
+
# 2) Add positives by GW distance
|
|
723
|
+
target_cell = cell_labels[i]
|
|
724
|
+
target_gene = gene_labels[i]
|
|
725
|
+
|
|
726
|
+
# Filter GW distances for the current (cell, gene).
|
|
727
|
+
filtered_distances = gw_distances_df[
|
|
728
|
+
(gw_distances_df["target_cell"] == target_cell)
|
|
729
|
+
& (gw_distances_df["target_gene"] == target_gene)
|
|
730
|
+
]
|
|
731
|
+
|
|
732
|
+
# Take the closest (num_positive - 1) samples.
|
|
733
|
+
closest_samples = filtered_distances.nsmallest(
|
|
734
|
+
num_positive - 1, "gw_distance"
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
for _, row in closest_samples.iterrows():
|
|
738
|
+
other_cell = row["cell"]
|
|
739
|
+
other_gene = row["gene"]
|
|
740
|
+
|
|
741
|
+
# Find the matching graph index.
|
|
742
|
+
for j in range(len(original_graphs)):
|
|
743
|
+
if (cell_labels[j] == other_cell) and (
|
|
744
|
+
gene_labels[j] == other_gene
|
|
745
|
+
):
|
|
746
|
+
current_positives.append(j)
|
|
747
|
+
break
|
|
748
|
+
|
|
749
|
+
# Ensure enough positives; fall back to the query index.
|
|
750
|
+
while len(current_positives) < num_positive:
|
|
751
|
+
current_positives.append(i)
|
|
752
|
+
|
|
753
|
+
positive_samples.append((i, current_positives))
|
|
754
|
+
|
|
755
|
+
return positive_samples # [(query_idx, [pos_idx1, pos_idx2, ...]), ...]
|
|
756
|
+
|
|
757
|
+
@staticmethod
|
|
758
|
+
def generate_samples_random_window(
|
|
759
|
+
original_graphs,
|
|
760
|
+
augmented_graphs,
|
|
761
|
+
gene_labels,
|
|
762
|
+
cell_labels,
|
|
763
|
+
num_positive,
|
|
764
|
+
window_size=5,
|
|
765
|
+
):
|
|
766
|
+
"""Generate positives by sampling same-gene graphs with similar node counts.
|
|
767
|
+
|
|
768
|
+
The first positive is always the augmented view of the same graph.
|
|
769
|
+
"""
|
|
770
|
+
import random
|
|
771
|
+
|
|
772
|
+
# Precompute gene->indices and node counts.
|
|
773
|
+
gene_to_indices = {}
|
|
774
|
+
node_counts = {}
|
|
775
|
+
|
|
776
|
+
# Build index mapping.
|
|
777
|
+
for i, graph in enumerate(original_graphs):
|
|
778
|
+
gene = gene_labels[i]
|
|
779
|
+
|
|
780
|
+
# Number of nodes for this graph.
|
|
781
|
+
node_count = (
|
|
782
|
+
graph.num_real_nodes
|
|
783
|
+
if hasattr(graph, "num_real_nodes")
|
|
784
|
+
else sum(1 for _ in graph.x if _[2] == 0)
|
|
785
|
+
)
|
|
786
|
+
node_counts[i] = node_count
|
|
787
|
+
|
|
788
|
+
# Update gene->indices mapping.
|
|
789
|
+
if gene not in gene_to_indices:
|
|
790
|
+
gene_to_indices[gene] = []
|
|
791
|
+
gene_to_indices[gene].append(i)
|
|
792
|
+
|
|
793
|
+
positive_samples = []
|
|
794
|
+
|
|
795
|
+
# Build positives for each graph.
|
|
796
|
+
for i in range(len(original_graphs)):
|
|
797
|
+
current_positives = []
|
|
798
|
+
|
|
799
|
+
# 1) Always include the augmented positive
|
|
800
|
+
current_positives.append(i) # augmented_graphs[i] index
|
|
801
|
+
|
|
802
|
+
# 2) Random positives from same gene within a node-count window
|
|
803
|
+
target_gene = gene_labels[i]
|
|
804
|
+
target_node_count = node_counts[i]
|
|
805
|
+
|
|
806
|
+
# All indices for the same gene.
|
|
807
|
+
same_gene_indices = gene_to_indices[target_gene]
|
|
808
|
+
|
|
809
|
+
# Candidates within the node-count window.
|
|
810
|
+
candidates = []
|
|
811
|
+
for idx in same_gene_indices:
|
|
812
|
+
if (
|
|
813
|
+
idx != i
|
|
814
|
+
and abs(node_counts[idx] - target_node_count) <= window_size
|
|
815
|
+
):
|
|
816
|
+
candidates.append(idx)
|
|
817
|
+
|
|
818
|
+
# If too few candidates, expand the window.
|
|
819
|
+
if len(candidates) < num_positive - 1 and window_size < 20:
|
|
820
|
+
extended_candidates = []
|
|
821
|
+
for idx in same_gene_indices:
|
|
822
|
+
if (
|
|
823
|
+
idx != i
|
|
824
|
+
and abs(node_counts[idx] - target_node_count) <= window_size * 2
|
|
825
|
+
):
|
|
826
|
+
extended_candidates.append(idx)
|
|
827
|
+
candidates = extended_candidates
|
|
828
|
+
|
|
829
|
+
# Randomly sample up to (num_positive - 1) candidates.
|
|
830
|
+
if candidates:
|
|
831
|
+
sample_size = min(len(candidates), num_positive - 1)
|
|
832
|
+
sampled_indices = random.sample(candidates, sample_size)
|
|
833
|
+
current_positives.extend(sampled_indices)
|
|
834
|
+
|
|
835
|
+
# Ensure enough positives; fall back to the query index.
|
|
836
|
+
while len(current_positives) < num_positive:
|
|
837
|
+
current_positives.append(i)
|
|
838
|
+
|
|
839
|
+
positive_samples.append((i, current_positives))
|
|
840
|
+
|
|
841
|
+
return positive_samples # [(query_idx, [pos_idx1, pos_idx2, ...]), ...]
|
|
842
|
+
|
|
843
|
+
@staticmethod
|
|
844
|
+
def generate_samples_js(
|
|
845
|
+
original_graphs,
|
|
846
|
+
augmented_graphs,
|
|
847
|
+
gene_labels,
|
|
848
|
+
cell_labels,
|
|
849
|
+
num_positive,
|
|
850
|
+
js_distances_df,
|
|
851
|
+
):
|
|
852
|
+
"""Generate positive samples using Jensen-Shannon distances."""
|
|
853
|
+
print(f"Selecting positives by JS distance: num_positive={num_positive}")
|
|
854
|
+
|
|
855
|
+
# Validate js_distances_df format
|
|
856
|
+
required_columns = ["target_cell", "target_gene", "cell", "gene", "js_distance"]
|
|
857
|
+
missing_columns = [
|
|
858
|
+
col for col in required_columns if col not in js_distances_df.columns
|
|
859
|
+
]
|
|
860
|
+
|
|
861
|
+
if missing_columns:
|
|
862
|
+
raise ValueError(
|
|
863
|
+
f"js_distances_df is missing required columns: {missing_columns}. "
|
|
864
|
+
f"Expected columns: {required_columns}"
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
positive_samples = []
|
|
868
|
+
|
|
869
|
+
# Build positives for each graph.
|
|
870
|
+
for i in range(len(original_graphs)):
|
|
871
|
+
current_positives = []
|
|
872
|
+
|
|
873
|
+
# 1) Always include the augmented positive
|
|
874
|
+
current_positives.append(i) # augmented_graphs[i] index
|
|
875
|
+
|
|
876
|
+
# 2) Add positives by JS distance
|
|
877
|
+
target_cell = cell_labels[i]
|
|
878
|
+
target_gene = gene_labels[i]
|
|
879
|
+
|
|
880
|
+
# Filter JS distances for the current (cell, gene).
|
|
881
|
+
filtered_distances = js_distances_df[
|
|
882
|
+
(js_distances_df["target_cell"] == target_cell)
|
|
883
|
+
& (js_distances_df["target_gene"] == target_gene)
|
|
884
|
+
]
|
|
885
|
+
|
|
886
|
+
# Take the closest (num_positive - 1) samples.
|
|
887
|
+
closest_samples = filtered_distances.nsmallest(
|
|
888
|
+
num_positive - 1, "js_distance"
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
for _, row in closest_samples.iterrows():
|
|
892
|
+
other_cell = row["cell"]
|
|
893
|
+
other_gene = row["gene"]
|
|
894
|
+
|
|
895
|
+
# Find the matching graph index.
|
|
896
|
+
for j in range(len(original_graphs)):
|
|
897
|
+
if (cell_labels[j] == other_cell) and (
|
|
898
|
+
gene_labels[j] == other_gene
|
|
899
|
+
):
|
|
900
|
+
current_positives.append(j)
|
|
901
|
+
break
|
|
902
|
+
|
|
903
|
+
# Ensure enough positives; fall back to the query index.
|
|
904
|
+
while len(current_positives) < num_positive:
|
|
905
|
+
current_positives.append(i)
|
|
906
|
+
|
|
907
|
+
positive_samples.append((i, current_positives))
|
|
908
|
+
|
|
909
|
+
return positive_samples # [(query_idx, [pos_idx1, pos_idx2, ...]), ...]
|
|
910
|
+
|
|
911
|
+
@staticmethod
|
|
912
|
+
def prepare_multi_positive_batch(
|
|
913
|
+
original_graphs, augmented_graphs, positive_samples, batch_size
|
|
914
|
+
):
|
|
915
|
+
"""Yield (query_batch, positive_batches) from positive_samples."""
|
|
916
|
+
# Compute number of batches and last batch size
|
|
917
|
+
total_samples = len(original_graphs)
|
|
918
|
+
num_batches = total_samples // batch_size
|
|
919
|
+
last_batch_size = total_samples % batch_size
|
|
920
|
+
|
|
921
|
+
# Iterate over batches
|
|
922
|
+
for i in range(num_batches + (1 if last_batch_size >= 2 else 0)):
|
|
923
|
+
# Decide which batch type we are processing
|
|
924
|
+
if i < num_batches - 1:
|
|
925
|
+
# Regular full batch
|
|
926
|
+
start_idx = i * batch_size
|
|
927
|
+
end_idx = (i + 1) * batch_size
|
|
928
|
+
elif i == num_batches - 1:
|
|
929
|
+
# Last full batch (may absorb one sample)
|
|
930
|
+
start_idx = i * batch_size
|
|
931
|
+
if last_batch_size == 1:
|
|
932
|
+
# Absorb the last sample
|
|
933
|
+
end_idx = (i + 1) * batch_size + 1
|
|
934
|
+
else:
|
|
935
|
+
# Normal end
|
|
936
|
+
end_idx = (i + 1) * batch_size
|
|
937
|
+
else:
|
|
938
|
+
# Final partial batch (last_batch_size >= 2)
|
|
939
|
+
start_idx = num_batches * batch_size
|
|
940
|
+
end_idx = total_samples
|
|
941
|
+
|
|
942
|
+
# Batch indices
|
|
943
|
+
batch_indices = list(range(start_idx, end_idx))
|
|
944
|
+
# Iterate over each batch
|
|
945
|
+
# for i in range(0, len(original_graphs), batch_size):
|
|
946
|
+
# batch_indices = range(i, min(i + batch_size, len(original_graphs))) # (0,batch_size)
|
|
947
|
+
|
|
948
|
+
# Query batch
|
|
949
|
+
query_batch = Batch.from_data_list(
|
|
950
|
+
[original_graphs[j] for j in batch_indices]
|
|
951
|
+
) # Batch handles node feature concat and edge_index offsets.
|
|
952
|
+
# Positive batches
|
|
953
|
+
positive_batches = []
|
|
954
|
+
num_positives = len(
|
|
955
|
+
positive_samples[0][1]
|
|
956
|
+
) # All samples should have the same number of positives.
|
|
957
|
+
|
|
958
|
+
for pos_idx in range(num_positives):
|
|
959
|
+
pos_batch = Batch.from_data_list(
|
|
960
|
+
[ # (j, [pos1, pos2, pos3])
|
|
961
|
+
augmented_graphs[positive_samples[j][1][0]]
|
|
962
|
+
if pos_idx == 0 # First positive uses augmented graph
|
|
963
|
+
else original_graphs[
|
|
964
|
+
positive_samples[j][1][pos_idx]
|
|
965
|
+
] # Similar original graph
|
|
966
|
+
for j in batch_indices
|
|
967
|
+
]
|
|
968
|
+
)
|
|
969
|
+
positive_batches.append(pos_batch)
|
|
970
|
+
# # Prepare positive batches (alternative implementation)
|
|
971
|
+
# positive_batches = []
|
|
972
|
+
# max_positives = max(len(positive_samples[j][1]) for j in batch_indices)
|
|
973
|
+
# for pos_idx in range(max_positives): # iterate positives per query
|
|
974
|
+
# pos_graphs = []
|
|
975
|
+
# for j in batch_indices:
|
|
976
|
+
# query_idx, pos_indices = positive_samples[j]
|
|
977
|
+
# if pos_idx < len(pos_indices):
|
|
978
|
+
# # first positive uses augmented graph
|
|
979
|
+
# if pos_idx == 0:
|
|
980
|
+
# pos_graphs.append(augmented_graphs[pos_indices[pos_idx]])
|
|
981
|
+
# else:
|
|
982
|
+
# pos_graphs.append(original_graphs[pos_indices[pos_idx]]) # similar original graph
|
|
983
|
+
# else:
|
|
984
|
+
# # if not enough positives, fall back to self
|
|
985
|
+
# pos_graphs.append(original_graphs[j])
|
|
986
|
+
|
|
987
|
+
# pos_batch = Batch.from_data_list(pos_graphs)
|
|
988
|
+
# positive_batches.append(pos_batch)
|
|
989
|
+
|
|
990
|
+
yield query_batch, positive_batches
|