grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,990 @@
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch_geometric.data import Data, Batch
6
+ from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
7
+ import numpy as np
8
+ import pandas as pd
9
+ import random
10
+ from sklearn.cluster import KMeans
11
+ from scipy.stats import t
12
+ import math
13
+ import warnings
14
+
15
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
+
18
+
19
+ # --- GATEncoder and ProjectionHead (from original.py) ---
20
+ class ProjectionHead(nn.Module):
21
+ def __init__(self, in_channels, hidden_channels, out_channels):
22
+ super().__init__()
23
+ self.mlp = nn.Sequential(
24
+ nn.Linear(in_channels, hidden_channels),
25
+ nn.BatchNorm1d(hidden_channels),
26
+ nn.ReLU(inplace=True),
27
+ nn.Linear(hidden_channels, out_channels),
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.mlp(x)
32
+
33
+
34
+ class GATEncoder(nn.Module): # From original.py
35
+ def __init__(
36
+ self, in_channels, hidden_channels, out_channels, heads=1, dropout=0.1
37
+ ):
38
+ super().__init__()
39
+ self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
40
+ self.bn1 = nn.BatchNorm1d(hidden_channels * heads)
41
+ self.dropout1 = nn.Dropout(dropout)
42
+ self.conv2 = GATConv(
43
+ hidden_channels * heads, out_channels, heads=heads, dropout=dropout
44
+ )
45
+ self.expected_output_dim = (
46
+ out_channels # For compatibility with checked.py logic if needed elsewhere
47
+ )
48
+
49
+ def forward(self, x, edge_index, batch):
50
+ x = self.conv1(x, edge_index)
51
+ x = self.bn1(x)
52
+ x = F.elu(x)
53
+ x = self.dropout1(x)
54
+ x = self.conv2(x, edge_index)
55
+ if batch is not None:
56
+ graph_representation = global_mean_pool(x, batch)
57
+ else:
58
+ graph_representation = x.mean(dim=0)
59
+ return x, graph_representation
60
+
61
+
62
+ class GCNEncoder(nn.Module):
63
+ def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1):
64
+ super().__init__()
65
+ self.conv1 = GCNConv(in_channels, hidden_channels)
66
+ self.bn1 = nn.BatchNorm1d(hidden_channels)
67
+ self.dropout1 = nn.Dropout(dropout)
68
+ self.conv2 = GCNConv(hidden_channels, out_channels)
69
+ self.expected_output_dim = out_channels
70
+
71
+ def forward(self, x, edge_index, batch):
72
+ x = self.conv1(x, edge_index)
73
+ x = self.bn1(x)
74
+ x = F.elu(x)
75
+ x = self.dropout1(x)
76
+ x = self.conv2(x, edge_index)
77
+ if batch is not None:
78
+ graph_representation = global_mean_pool(x, batch)
79
+ else:
80
+ graph_representation = x.mean(dim=0)
81
+ return x, graph_representation
82
+
83
+
84
+ class MoCo(nn.Module):
85
+ def __init__(self, base_encoder, dim=128, K=1024, m=0.999, T=0.07):
86
+ super().__init__()
87
+ self.K = K
88
+ self.m = m
89
+ self.T = T
90
+
91
+ self.encoder_q = base_encoder
92
+ self.encoder_k = copy.deepcopy(base_encoder)
93
+
94
+ # Use the base encoder output dimension for the projection head.
95
+ # GATEncoder outputs out_channels, used as ProjectionHead input.
96
+ # If expected_output_dim=128 then in_channels=128.
97
+ # dim is the projection output dimension.
98
+ projector_in_channels = getattr(
99
+ base_encoder, "expected_output_dim", 128
100
+ ) # Default to 128 if not found
101
+
102
+ self.projector_q = ProjectionHead(
103
+ in_channels=projector_in_channels,
104
+ hidden_channels=256, # original.py uses 256 in the projector hidden layer
105
+ out_channels=dim,
106
+ )
107
+ self.projector_k = copy.deepcopy(self.projector_q)
108
+
109
+ for param_q, param_k in zip(
110
+ self.encoder_q.parameters(), self.encoder_k.parameters()
111
+ ):
112
+ param_k.data.copy_(param_q.data)
113
+ param_k.requires_grad = False
114
+ for param_q, param_k in zip(
115
+ self.projector_q.parameters(), self.projector_k.parameters()
116
+ ):
117
+ param_k.data.copy_(param_q.data)
118
+ param_k.requires_grad = False
119
+
120
+ self.register_buffer("queue", torch.randn(K, dim))
121
+ self.queue = F.normalize(
122
+ self.queue, dim=1
123
+ ) # Row-wise normalization (dim=1), matching fixed.py
124
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
125
+
126
+ # --- Loss functions (from fixed.py) ---
127
+ def _compute_reconstruction_loss(self, node_embeddings, edge_index, num_nodes):
128
+ if hasattr(self, "weighted_recon_loss") and self.weighted_recon_loss:
129
+ return self._compute_weighted_reconstruction_loss(
130
+ node_embeddings, edge_index, num_nodes
131
+ )
132
+ else:
133
+ return self._compute_basic_reconstruction_loss(
134
+ node_embeddings, edge_index, num_nodes
135
+ )
136
+
137
+ def _compute_basic_reconstruction_loss(
138
+ self, node_embeddings, edge_index, num_nodes
139
+ ):
140
+ # NOTE: this variant uses raw node embeddings (no normalization).
141
+ reconstructed_adj = torch.sigmoid(
142
+ torch.mm(node_embeddings, node_embeddings.t())
143
+ )
144
+ original_dense_adj = torch.zeros(
145
+ (num_nodes, num_nodes), device=node_embeddings.device
146
+ )
147
+ original_dense_adj[edge_index[0], edge_index[1]] = 1
148
+ return F.binary_cross_entropy(reconstructed_adj, original_dense_adj)
149
+
150
+ def _compute_weighted_reconstruction_loss(
151
+ self, node_embeddings, edge_index, num_nodes
152
+ ):
153
+ """Weighted reconstruction loss to mitigate class imbalance."""
154
+
155
+ # L2-normalize node embeddings
156
+ node_embeddings_normalized = F.normalize(node_embeddings, p=2, dim=1)
157
+
158
+ # Pairwise similarity (logits)
159
+ sim_scores = torch.mm(
160
+ node_embeddings_normalized, node_embeddings_normalized.t()
161
+ )
162
+
163
+ # Build dense adjacency labels
164
+ original_dense_adj = torch.zeros(
165
+ (num_nodes, num_nodes), device=node_embeddings.device
166
+ )
167
+ original_dense_adj[edge_index[0], edge_index[1]] = 1
168
+
169
+ # Compute a capped positive weight (log-scaled)
170
+ edge_count = edge_index.size(1)
171
+ total_pairs = num_nodes * num_nodes
172
+ pos_weight = torch.tensor(
173
+ min(10.0, math.log(total_pairs / max(edge_count, 1)) + 1),
174
+ device=node_embeddings.device,
175
+ )
176
+
177
+ loss = F.binary_cross_entropy_with_logits(
178
+ sim_scores,
179
+ original_dense_adj,
180
+ pos_weight=pos_weight,
181
+ reduction="mean",
182
+ )
183
+
184
+ return loss
185
+
186
+ def _compute_clustering_loss(
187
+ self,
188
+ embeddings,
189
+ num_clusters,
190
+ dist_type="uniform",
191
+ input_features=None,
192
+ batch=None,
193
+ ):
194
+ if dist_type == "spectral":
195
+ k_neighbors = getattr(self, "k_neighbors", 100)
196
+ sigma = getattr(self, "sigma", 1.0)
197
+ if input_features is not None and batch is not None:
198
+ graph_level_input = global_mean_pool(input_features, batch)
199
+ return self._compute_spectral_clustering_loss(
200
+ graph_level_input, embeddings, k=k_neighbors, sigma=sigma
201
+ )
202
+ elif input_features is not None:
203
+ print(
204
+ "WARNING: batch is None; pooling input_features by mean as a fallback"
205
+ )
206
+ graph_level_input = input_features.mean(dim=0, keepdim=True)
207
+ return self._compute_spectral_clustering_loss(
208
+ graph_level_input, embeddings, k=k_neighbors, sigma=sigma
209
+ )
210
+ else:
211
+ print(
212
+ "WARNING: input_features not provided; using embeddings as a proxy for clustering loss"
213
+ )
214
+ return self._compute_spectral_clustering_loss(
215
+ embeddings, embeddings, k=k_neighbors, sigma=sigma
216
+ )
217
+
218
+ batch_size = embeddings.size(0)
219
+ if batch_size < num_clusters:
220
+ print(
221
+ f"WARNING: batch_size ({batch_size}) < num_clusters ({num_clusters}); using simplified clustering loss"
222
+ )
223
+ return self._compute_simple_clustering_loss(
224
+ embeddings, min(num_clusters, batch_size)
225
+ )
226
+
227
+ try:
228
+ embeddings_np = embeddings.detach().cpu().numpy()
229
+ # KMeans clustering (default n_init).
230
+ kmeans = KMeans(n_clusters=num_clusters, random_state=0, max_iter=100).fit(
231
+ embeddings_np
232
+ )
233
+ pseudo_labels = kmeans.labels_
234
+
235
+ pseudo_label_dist = torch.zeros(num_clusters, device=embeddings.device)
236
+ if pseudo_labels is not None: # Guard against None.
237
+ for label in pseudo_labels:
238
+ pseudo_label_dist[label] += 1
239
+
240
+ eps = 1e-8
241
+ pseudo_label_dist = pseudo_label_dist + eps
242
+ pseudo_label_dist /= pseudo_label_dist.sum()
243
+
244
+ if dist_type == "t-distribution":
245
+ embedding_mean = torch.median(embeddings, dim=1).values
246
+ dof = 10
247
+ scale = torch.std(embedding_mean)
248
+ embedding_mean_np = embedding_mean.detach().cpu().numpy()
249
+ ideal_t_dist_pdf = t.pdf(
250
+ embedding_mean_np,
251
+ dof,
252
+ loc=embedding_mean_np.mean(),
253
+ scale=scale.item(),
254
+ )
255
+ ideal_t_dist_tensor = torch.tensor(
256
+ ideal_t_dist_pdf, dtype=torch.float, device=embeddings.device
257
+ )
258
+
259
+ min_val = ideal_t_dist_tensor.min().item() # Python scalar.
260
+ max_val = ideal_t_dist_tensor.max().item() # Python scalar.
261
+ if max_val == min_val: # Handle constant values.
262
+ max_val = min_val + eps
263
+
264
+ ideal_t_dist_tensor = torch.histc(
265
+ ideal_t_dist_tensor, bins=num_clusters, min=min_val, max=max_val
266
+ )
267
+ ideal_t_dist_tensor = ideal_t_dist_tensor + eps
268
+ ideal_t_dist_tensor /= ideal_t_dist_tensor.sum()
269
+ return F.kl_div(
270
+ pseudo_label_dist.log(), ideal_t_dist_tensor, reduction="batchmean"
271
+ )
272
+ else: # 'uniform'
273
+ ideal_dist = (
274
+ torch.ones(num_clusters, device=embeddings.device) / num_clusters
275
+ )
276
+ return F.kl_div(
277
+ pseudo_label_dist.log(), ideal_dist, reduction="batchmean"
278
+ )
279
+
280
+ except Exception as e:
281
+ print(
282
+ f"ERROR: KMeans clustering failed: {e}; using simplified clustering loss"
283
+ )
284
+ return self._compute_simple_clustering_loss(embeddings, num_clusters)
285
+
286
+ def _compute_simple_clustering_loss(self, embeddings, num_clusters):
287
+ batch_size = embeddings.size(0)
288
+ if batch_size < 2:
289
+ return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
290
+
291
+ normalized_emb = F.normalize(embeddings, p=2, dim=1)
292
+ similarity_matrix = torch.mm(normalized_emb, normalized_emb.t())
293
+
294
+ mask = torch.eye(batch_size, device=embeddings.device).bool()
295
+ similarity_matrix.masked_fill_(mask, 0)
296
+
297
+ clustering_loss = similarity_matrix.abs().mean()
298
+ return clustering_loss
299
+
300
+ def _compute_spectral_clustering_loss(
301
+ self, input_features, output_embeddings, k=100, sigma=1.0
302
+ ):
303
+ batch_size = output_embeddings.size(0)
304
+ if batch_size <= 1:
305
+ return torch.tensor(0.0, device=output_embeddings.device)
306
+
307
+ input_squared_dist = torch.cdist(input_features, input_features, p=2) ** 2
308
+ sigma_squared = 2 * (sigma**2)
309
+ similarity = torch.exp(-input_squared_dist / sigma_squared)
310
+
311
+ mask = torch.eye(batch_size, device=output_embeddings.device).bool()
312
+ similarity.masked_fill_(mask, 0)
313
+
314
+ if k < batch_size - 1:
315
+ _, topk_indices = torch.topk(similarity, k=min(k, batch_size - 1), dim=1)
316
+ adjacency = torch.zeros_like(similarity)
317
+ batch_indices_arange = (
318
+ torch.arange(batch_size, device=output_embeddings.device)
319
+ .unsqueeze(1)
320
+ .expand(-1, topk_indices.size(1))
321
+ )
322
+ adjacency[batch_indices_arange, topk_indices] = similarity[
323
+ batch_indices_arange, topk_indices
324
+ ]
325
+ adjacency = 0.5 * (adjacency + adjacency.t())
326
+ else:
327
+ adjacency = similarity
328
+
329
+ output_squared_dist = (
330
+ torch.cdist(output_embeddings, output_embeddings, p=2) ** 2
331
+ )
332
+ loss = torch.sum(adjacency * output_squared_dist)
333
+
334
+ edge_weights_sum = torch.sum(adjacency)
335
+ if edge_weights_sum > 0:
336
+ loss = loss / edge_weights_sum
337
+
338
+ return loss
339
+
340
+ @torch.no_grad()
341
+ def _momentum_update_key_encoder(self):
342
+ for param_q, param_k in zip(
343
+ self.encoder_q.parameters(), self.encoder_k.parameters()
344
+ ):
345
+ param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
346
+ for param_q, param_k in zip(
347
+ self.projector_q.parameters(), self.projector_k.parameters()
348
+ ):
349
+ param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
350
+
351
+ @torch.no_grad()
352
+ def _dequeue_and_enqueue(self, keys): # From fixed.py (robust version)
353
+ batch_size = keys.shape[0]
354
+ ptr_val = int(self.queue_ptr.item()) # Scalar value from tensor.
355
+
356
+ if ptr_val + batch_size > self.K:
357
+ first_part = self.K - ptr_val
358
+ self.queue[ptr_val:, :] = keys[:first_part]
359
+ remaining = batch_size - first_part
360
+ if remaining > 0:
361
+ self.queue[:remaining, :] = keys[first_part:]
362
+ new_ptr = remaining
363
+ else:
364
+ self.queue[ptr_val : ptr_val + batch_size, :] = keys
365
+ new_ptr = (ptr_val + batch_size) % self.K
366
+
367
+ self.queue_ptr[0] = new_ptr
368
+
369
+
370
+ class MoCoMultiPositive(MoCo):
371
+ def __init__(self, base_encoder, dim=128, K=1024, m=0.999, T=0.07):
372
+ super().__init__(base_encoder, dim, K, m, T)
373
+
374
+ # --- forward method (batch_k_list removed) ---
375
+ def forward(
376
+ self,
377
+ im_q,
378
+ im_k_list,
379
+ edge_index_q,
380
+ edge_index_k_list,
381
+ batch,
382
+ num_clusters=None,
383
+ dist_type="uniform",
384
+ a=1.0,
385
+ b=1.0,
386
+ c=1.0,
387
+ use_clustering=True,
388
+ ):
389
+ """Forward pass with multiple positive samples.
390
+
391
+ Args:
392
+ im_q: Node features of the query graph batch.
393
+ im_k_list: List of node features for positive key graph batches.
394
+ edge_index_q: Edge index for the query batch.
395
+ edge_index_k_list: List of edge indices for the key batches.
396
+ batch: Batch vector mapping nodes to graphs.
397
+ num_clusters: Number of clusters used by clustering loss.
398
+ dist_type: Target distribution for clustering loss ('uniform', 't-distribution', 'spectral').
399
+ a: Weight for reconstruction loss.
400
+ b: Weight for contrastive loss.
401
+ c: Weight for clustering loss.
402
+ use_clustering: Enable/disable clustering loss.
403
+
404
+ Returns:
405
+ (total_loss, reconstruction_loss, contrastive_loss, clustering_loss,
406
+ adjusted_reconstruction, adjusted_contrastive, adjusted_clustering)
407
+ """
408
+ # 1. Query features
409
+ q_node_embeddings, q = self.encoder_q(im_q, edge_index_q, batch)
410
+ q = self.projector_q(q)
411
+ q = F.normalize(q, dim=1)
412
+
413
+ # 2. Positive key features
414
+ k_features = []
415
+ with torch.no_grad():
416
+ self._momentum_update_key_encoder()
417
+ for im_k, edge_index_k in zip(im_k_list, edge_index_k_list):
418
+ _, k = self.encoder_k(im_k, edge_index_k, batch)
419
+ k = self.projector_k(k)
420
+ k = F.normalize(k, dim=1)
421
+ k_features.append(k)
422
+
423
+ # 3. Similarities to all positive samples
424
+ l_pos_list = []
425
+ for k in k_features:
426
+ pos_sim = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
427
+ l_pos_list.append(pos_sim)
428
+
429
+ # l_pos: [batch_size, num_positives]
430
+ # l_neg: [batch_size, K] # K is queue size
431
+ l_pos = torch.cat(l_pos_list, dim=1) # shape: (batch_size, num_positives)
432
+ l_neg = torch.einsum("nc,kc->nk", [q, self.queue.clone().detach()])
433
+
434
+ # 4. Contrastive loss
435
+ pos_exp = torch.exp(l_pos / self.T) # [batch_size, num_positives]
436
+ neg_exp = torch.exp(l_neg / self.T)
437
+ # Numerator: sum over positives
438
+ numerator = pos_exp.sum(dim=1) # [batch_size]
439
+ # Denominator: sum over positives + sum over negatives
440
+ denominator = numerator + neg_exp.sum(dim=1) # [batch_size]
441
+ # loss = -log(sum(exp(pos))/sum(exp(all)))
442
+ contrastive_loss = -torch.log(numerator / denominator).mean()
443
+
444
+ # 5. Reconstruction loss
445
+ reconstruction_loss = self._compute_reconstruction_loss(
446
+ q_node_embeddings, edge_index_q, im_q.size(0)
447
+ )
448
+
449
+ # 6. Clustering loss (optional)
450
+ if use_clustering:
451
+ # Extra params for spectral clustering
452
+ if dist_type == "spectral":
453
+ clustering_loss = self._compute_clustering_loss(
454
+ q, num_clusters, dist_type, input_features=im_q, batch=batch
455
+ )
456
+ else:
457
+ clustering_loss = self._compute_clustering_loss(
458
+ q, num_clusters, dist_type
459
+ )
460
+ else:
461
+ clustering_loss = torch.tensor(0.0, device=contrastive_loss.device)
462
+
463
+ # 7. Loss alignment
464
+ eps = 1e-6
465
+ adjusted_contrastive = contrastive_loss / (
466
+ (contrastive_loss / (reconstruction_loss + eps)).detach() + eps
467
+ )
468
+
469
+ if use_clustering:
470
+ adjusted_clustering = clustering_loss / (
471
+ (clustering_loss / (reconstruction_loss + eps)).detach() + eps
472
+ )
473
+ else:
474
+ adjusted_clustering = torch.tensor(0.0, device=contrastive_loss.device)
475
+
476
+ adjusted_reconstruction = reconstruction_loss
477
+
478
+ # 8. Total loss
479
+ if use_clustering:
480
+ total_loss = (
481
+ a * adjusted_reconstruction
482
+ + b * adjusted_contrastive
483
+ + c * adjusted_clustering
484
+ )
485
+ else:
486
+ # Only reconstruction + contrastive
487
+ total_loss = a * adjusted_reconstruction + b * adjusted_contrastive
488
+
489
+ # 9. Update queue (use the last positive)
490
+ if k_features: # Ensure k_features is not empty.
491
+ self._dequeue_and_enqueue(k_features[-1])
492
+
493
+ # 10. Return losses
494
+ return (
495
+ total_loss,
496
+ reconstruction_loss,
497
+ contrastive_loss,
498
+ clustering_loss,
499
+ adjusted_reconstruction,
500
+ adjusted_contrastive,
501
+ adjusted_clustering,
502
+ )
503
+
504
+ def forward_supcon(
505
+ self,
506
+ im_q,
507
+ im_k_list,
508
+ edge_index_q,
509
+ edge_index_k_list,
510
+ batch,
511
+ num_clusters=None,
512
+ dist_type="uniform",
513
+ a=1.0,
514
+ b=1.0,
515
+ c=1.0,
516
+ use_clustering=True,
517
+ ):
518
+ q_node_embeddings, q = self.encoder_q(im_q, edge_index_q, batch)
519
+ q = self.projector_q(q)
520
+ q = F.normalize(q, dim=1)
521
+
522
+ k_features = []
523
+ with torch.no_grad():
524
+ self._momentum_update_key_encoder()
525
+ for im_k_single, edge_index_k_single in zip(im_k_list, edge_index_k_list):
526
+ _, k_single = self.encoder_k(im_k_single, edge_index_k_single, batch)
527
+ k_single = self.projector_k(k_single)
528
+ k_single = F.normalize(k_single, dim=1)
529
+ k_features.append(k_single)
530
+
531
+ contrastive_loss = torch.tensor(0.0, device=q.device)
532
+ if q.size(0) > 0 and k_features:
533
+ k_stacked = torch.stack(k_features).permute(1, 0, 2) # [B, P, D]
534
+ q_expanded = q.unsqueeze(1) # [B, 1, D]
535
+ pos_sims = (
536
+ torch.bmm(q_expanded, k_stacked.transpose(1, 2)).squeeze(1) / self.T
537
+ ) # [B, P]
538
+ neg_sims = torch.matmul(q, self.queue.clone().detach().T) / self.T # [B, K]
539
+ logits_all = torch.cat([pos_sims, neg_sims], dim=1)
540
+ logsum = torch.logsumexp(logits_all, dim=1, keepdim=True)
541
+ log_probs = pos_sims - logsum
542
+ contrastive_loss = -log_probs.mean()
543
+
544
+ reconstruction_loss = self._compute_reconstruction_loss(
545
+ q_node_embeddings, edge_index_q, im_q.size(0)
546
+ )
547
+
548
+ if use_clustering:
549
+ clustering_loss = self._compute_clustering_loss(
550
+ q, num_clusters, dist_type, input_features=im_q, batch=batch
551
+ )
552
+ else:
553
+ clustering_loss = torch.tensor(
554
+ 0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
555
+ )
556
+
557
+ eps = 1e-6
558
+ adjusted_contrastive = contrastive_loss / (
559
+ (contrastive_loss / (reconstruction_loss + eps)).detach() + eps
560
+ )
561
+ adjusted_reconstruction = reconstruction_loss
562
+
563
+ if use_clustering:
564
+ adjusted_clustering = clustering_loss / (
565
+ (clustering_loss / (reconstruction_loss + eps)).detach() + eps
566
+ )
567
+ total_loss = (
568
+ a * adjusted_reconstruction
569
+ + b * adjusted_contrastive
570
+ + c * adjusted_clustering
571
+ )
572
+ else:
573
+ adjusted_clustering = torch.tensor(
574
+ 0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
575
+ )
576
+ total_loss = a * adjusted_reconstruction + b * adjusted_contrastive
577
+
578
+ if k_features:
579
+ k_avg = torch.mean(torch.stack(k_features), dim=0)
580
+ k_avg = F.normalize(k_avg, dim=1)
581
+ self._dequeue_and_enqueue(k_avg)
582
+
583
+ return (
584
+ total_loss,
585
+ reconstruction_loss,
586
+ contrastive_loss,
587
+ clustering_loss,
588
+ adjusted_reconstruction,
589
+ adjusted_contrastive,
590
+ adjusted_clustering,
591
+ )
592
+
593
+ def forward_avg(
594
+ self,
595
+ im_q,
596
+ im_k_list,
597
+ edge_index_q,
598
+ edge_index_k_list,
599
+ batch,
600
+ num_clusters=None,
601
+ dist_type="uniform",
602
+ a=1.0,
603
+ b=1.0,
604
+ c=1.0,
605
+ use_clustering=True,
606
+ ):
607
+ # forward_avg is similar to forward_supcon but updates the queue differently.
608
+ q_node_embeddings, q = self.encoder_q(im_q, edge_index_q, batch)
609
+ q = self.projector_q(q)
610
+ q = F.normalize(q, dim=1)
611
+
612
+ k_features = []
613
+ with torch.no_grad():
614
+ self._momentum_update_key_encoder()
615
+ for im_k_single, edge_index_k_single in zip(im_k_list, edge_index_k_list):
616
+ _, k_single = self.encoder_k(im_k_single, edge_index_k_single, batch)
617
+ k_single = self.projector_k(k_single)
618
+ k_single = F.normalize(k_single, dim=1)
619
+ k_features.append(k_single)
620
+
621
+ contrastive_loss = torch.tensor(0.0, device=q.device)
622
+ k_avg_for_loss = None
623
+ if q.size(0) > 0 and k_features:
624
+ k_avg_for_loss = torch.mean(torch.stack(k_features), dim=0) # [B,D]
625
+ k_avg_for_loss = F.normalize(k_avg_for_loss, dim=1)
626
+
627
+ pos_sims_avg = (
628
+ torch.einsum("nc,nc->n", q, k_avg_for_loss).unsqueeze(-1) / self.T
629
+ ) # [B,1]
630
+ neg_sims = torch.matmul(q, self.queue.clone().detach().T) / self.T # [B, K]
631
+
632
+ logits_all = torch.cat([pos_sims_avg, neg_sims], dim=1) # [B, 1+K]
633
+ logsum = torch.logsumexp(logits_all, dim=1, keepdim=True) # [B,1]
634
+ log_probs = pos_sims_avg - logsum # [B,1]
635
+ contrastive_loss = -log_probs.mean()
636
+
637
+ reconstruction_loss = self._compute_reconstruction_loss(
638
+ q_node_embeddings, edge_index_q, im_q.size(0)
639
+ )
640
+
641
+ if use_clustering:
642
+ clustering_loss = self._compute_clustering_loss(
643
+ q, num_clusters, dist_type, input_features=im_q, batch=batch
644
+ )
645
+ else:
646
+ clustering_loss = torch.tensor(
647
+ 0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
648
+ )
649
+
650
+ eps = 1e-6
651
+ adjusted_contrastive = contrastive_loss / (
652
+ (contrastive_loss / (reconstruction_loss + eps)).detach() + eps
653
+ )
654
+ adjusted_reconstruction = reconstruction_loss
655
+
656
+ if use_clustering:
657
+ adjusted_clustering = clustering_loss / (
658
+ (clustering_loss / (reconstruction_loss + eps)).detach() + eps
659
+ )
660
+ total_loss = (
661
+ a * adjusted_reconstruction
662
+ + b * adjusted_contrastive
663
+ + c * adjusted_clustering
664
+ )
665
+ else:
666
+ adjusted_clustering = torch.tensor(
667
+ 0.0, device=q.device if q.numel() > 0 else torch.device("cpu")
668
+ )
669
+ total_loss = a * adjusted_reconstruction + b * adjusted_contrastive
670
+
671
+ if k_avg_for_loss is not None: # Use the precomputed k_avg_for_loss.
672
+ self._dequeue_and_enqueue(k_avg_for_loss)
673
+
674
+ return (
675
+ total_loss,
676
+ reconstruction_loss,
677
+ contrastive_loss,
678
+ clustering_loss,
679
+ adjusted_reconstruction,
680
+ adjusted_contrastive,
681
+ adjusted_clustering,
682
+ )
683
+
684
+ # TODO: Prefer JS-distance positives when available; if insufficient, sample from same-gene graphs,
685
+ # and only fall back to the query graph as a last resort.
686
+ @staticmethod
687
+ def generate_samples_gw(
688
+ original_graphs,
689
+ augmented_graphs,
690
+ gene_labels,
691
+ cell_labels,
692
+ num_positive,
693
+ gw_distances_df,
694
+ ):
695
+ """Generate positive samples using GW distances.
696
+
697
+ Returns a list of (query_idx, [pos_idx1, pos_idx2, ...]). The first
698
+ positive is always the augmented view of the same graph.
699
+ """
700
+ # (Optional) preprocessing: build gene/cell index maps
701
+ # gene_to_indices = {}
702
+ # cell_to_indices = {}
703
+
704
+ # for i, (gene, cell) in enumerate(zip(gene_labels, cell_labels)):
705
+ # if gene not in gene_to_indices:
706
+ # gene_to_indices[gene] = []
707
+ # gene_to_indices[gene].append(i)
708
+
709
+ # if cell not in cell_to_indices:
710
+ # cell_to_indices[cell] = []
711
+ # cell_to_indices[cell].append(i)
712
+
713
+ positive_samples = []
714
+ num_graphs = len(gene_labels)
715
+ # Build positives for each graph.
716
+ for i in range(len(original_graphs)):
717
+ current_positives = []
718
+
719
+ # 1) Always include the augmented positive
720
+ current_positives.append(i) # augmented_graphs[i] index
721
+
722
+ # 2) Add positives by GW distance
723
+ target_cell = cell_labels[i]
724
+ target_gene = gene_labels[i]
725
+
726
+ # Filter GW distances for the current (cell, gene).
727
+ filtered_distances = gw_distances_df[
728
+ (gw_distances_df["target_cell"] == target_cell)
729
+ & (gw_distances_df["target_gene"] == target_gene)
730
+ ]
731
+
732
+ # Take the closest (num_positive - 1) samples.
733
+ closest_samples = filtered_distances.nsmallest(
734
+ num_positive - 1, "gw_distance"
735
+ )
736
+
737
+ for _, row in closest_samples.iterrows():
738
+ other_cell = row["cell"]
739
+ other_gene = row["gene"]
740
+
741
+ # Find the matching graph index.
742
+ for j in range(len(original_graphs)):
743
+ if (cell_labels[j] == other_cell) and (
744
+ gene_labels[j] == other_gene
745
+ ):
746
+ current_positives.append(j)
747
+ break
748
+
749
+ # Ensure enough positives; fall back to the query index.
750
+ while len(current_positives) < num_positive:
751
+ current_positives.append(i)
752
+
753
+ positive_samples.append((i, current_positives))
754
+
755
+ return positive_samples # [(query_idx, [pos_idx1, pos_idx2, ...]), ...]
756
+
757
+ @staticmethod
758
+ def generate_samples_random_window(
759
+ original_graphs,
760
+ augmented_graphs,
761
+ gene_labels,
762
+ cell_labels,
763
+ num_positive,
764
+ window_size=5,
765
+ ):
766
+ """Generate positives by sampling same-gene graphs with similar node counts.
767
+
768
+ The first positive is always the augmented view of the same graph.
769
+ """
770
+ import random
771
+
772
+ # Precompute gene->indices and node counts.
773
+ gene_to_indices = {}
774
+ node_counts = {}
775
+
776
+ # Build index mapping.
777
+ for i, graph in enumerate(original_graphs):
778
+ gene = gene_labels[i]
779
+
780
+ # Number of nodes for this graph.
781
+ node_count = (
782
+ graph.num_real_nodes
783
+ if hasattr(graph, "num_real_nodes")
784
+ else sum(1 for _ in graph.x if _[2] == 0)
785
+ )
786
+ node_counts[i] = node_count
787
+
788
+ # Update gene->indices mapping.
789
+ if gene not in gene_to_indices:
790
+ gene_to_indices[gene] = []
791
+ gene_to_indices[gene].append(i)
792
+
793
+ positive_samples = []
794
+
795
+ # Build positives for each graph.
796
+ for i in range(len(original_graphs)):
797
+ current_positives = []
798
+
799
+ # 1) Always include the augmented positive
800
+ current_positives.append(i) # augmented_graphs[i] index
801
+
802
+ # 2) Random positives from same gene within a node-count window
803
+ target_gene = gene_labels[i]
804
+ target_node_count = node_counts[i]
805
+
806
+ # All indices for the same gene.
807
+ same_gene_indices = gene_to_indices[target_gene]
808
+
809
+ # Candidates within the node-count window.
810
+ candidates = []
811
+ for idx in same_gene_indices:
812
+ if (
813
+ idx != i
814
+ and abs(node_counts[idx] - target_node_count) <= window_size
815
+ ):
816
+ candidates.append(idx)
817
+
818
+ # If too few candidates, expand the window.
819
+ if len(candidates) < num_positive - 1 and window_size < 20:
820
+ extended_candidates = []
821
+ for idx in same_gene_indices:
822
+ if (
823
+ idx != i
824
+ and abs(node_counts[idx] - target_node_count) <= window_size * 2
825
+ ):
826
+ extended_candidates.append(idx)
827
+ candidates = extended_candidates
828
+
829
+ # Randomly sample up to (num_positive - 1) candidates.
830
+ if candidates:
831
+ sample_size = min(len(candidates), num_positive - 1)
832
+ sampled_indices = random.sample(candidates, sample_size)
833
+ current_positives.extend(sampled_indices)
834
+
835
+ # Ensure enough positives; fall back to the query index.
836
+ while len(current_positives) < num_positive:
837
+ current_positives.append(i)
838
+
839
+ positive_samples.append((i, current_positives))
840
+
841
+ return positive_samples # [(query_idx, [pos_idx1, pos_idx2, ...]), ...]
842
+
843
+ @staticmethod
844
+ def generate_samples_js(
845
+ original_graphs,
846
+ augmented_graphs,
847
+ gene_labels,
848
+ cell_labels,
849
+ num_positive,
850
+ js_distances_df,
851
+ ):
852
+ """Generate positive samples using Jensen-Shannon distances."""
853
+ print(f"Selecting positives by JS distance: num_positive={num_positive}")
854
+
855
+ # Validate js_distances_df format
856
+ required_columns = ["target_cell", "target_gene", "cell", "gene", "js_distance"]
857
+ missing_columns = [
858
+ col for col in required_columns if col not in js_distances_df.columns
859
+ ]
860
+
861
+ if missing_columns:
862
+ raise ValueError(
863
+ f"js_distances_df is missing required columns: {missing_columns}. "
864
+ f"Expected columns: {required_columns}"
865
+ )
866
+
867
+ positive_samples = []
868
+
869
+ # Build positives for each graph.
870
+ for i in range(len(original_graphs)):
871
+ current_positives = []
872
+
873
+ # 1) Always include the augmented positive
874
+ current_positives.append(i) # augmented_graphs[i] index
875
+
876
+ # 2) Add positives by JS distance
877
+ target_cell = cell_labels[i]
878
+ target_gene = gene_labels[i]
879
+
880
+ # Filter JS distances for the current (cell, gene).
881
+ filtered_distances = js_distances_df[
882
+ (js_distances_df["target_cell"] == target_cell)
883
+ & (js_distances_df["target_gene"] == target_gene)
884
+ ]
885
+
886
+ # Take the closest (num_positive - 1) samples.
887
+ closest_samples = filtered_distances.nsmallest(
888
+ num_positive - 1, "js_distance"
889
+ )
890
+
891
+ for _, row in closest_samples.iterrows():
892
+ other_cell = row["cell"]
893
+ other_gene = row["gene"]
894
+
895
+ # Find the matching graph index.
896
+ for j in range(len(original_graphs)):
897
+ if (cell_labels[j] == other_cell) and (
898
+ gene_labels[j] == other_gene
899
+ ):
900
+ current_positives.append(j)
901
+ break
902
+
903
+ # Ensure enough positives; fall back to the query index.
904
+ while len(current_positives) < num_positive:
905
+ current_positives.append(i)
906
+
907
+ positive_samples.append((i, current_positives))
908
+
909
+ return positive_samples # [(query_idx, [pos_idx1, pos_idx2, ...]), ...]
910
+
911
+ @staticmethod
912
+ def prepare_multi_positive_batch(
913
+ original_graphs, augmented_graphs, positive_samples, batch_size
914
+ ):
915
+ """Yield (query_batch, positive_batches) from positive_samples."""
916
+ # Compute number of batches and last batch size
917
+ total_samples = len(original_graphs)
918
+ num_batches = total_samples // batch_size
919
+ last_batch_size = total_samples % batch_size
920
+
921
+ # Iterate over batches
922
+ for i in range(num_batches + (1 if last_batch_size >= 2 else 0)):
923
+ # Decide which batch type we are processing
924
+ if i < num_batches - 1:
925
+ # Regular full batch
926
+ start_idx = i * batch_size
927
+ end_idx = (i + 1) * batch_size
928
+ elif i == num_batches - 1:
929
+ # Last full batch (may absorb one sample)
930
+ start_idx = i * batch_size
931
+ if last_batch_size == 1:
932
+ # Absorb the last sample
933
+ end_idx = (i + 1) * batch_size + 1
934
+ else:
935
+ # Normal end
936
+ end_idx = (i + 1) * batch_size
937
+ else:
938
+ # Final partial batch (last_batch_size >= 2)
939
+ start_idx = num_batches * batch_size
940
+ end_idx = total_samples
941
+
942
+ # Batch indices
943
+ batch_indices = list(range(start_idx, end_idx))
944
+ # Iterate over each batch
945
+ # for i in range(0, len(original_graphs), batch_size):
946
+ # batch_indices = range(i, min(i + batch_size, len(original_graphs))) # (0,batch_size)
947
+
948
+ # Query batch
949
+ query_batch = Batch.from_data_list(
950
+ [original_graphs[j] for j in batch_indices]
951
+ ) # Batch handles node feature concat and edge_index offsets.
952
+ # Positive batches
953
+ positive_batches = []
954
+ num_positives = len(
955
+ positive_samples[0][1]
956
+ ) # All samples should have the same number of positives.
957
+
958
+ for pos_idx in range(num_positives):
959
+ pos_batch = Batch.from_data_list(
960
+ [ # (j, [pos1, pos2, pos3])
961
+ augmented_graphs[positive_samples[j][1][0]]
962
+ if pos_idx == 0 # First positive uses augmented graph
963
+ else original_graphs[
964
+ positive_samples[j][1][pos_idx]
965
+ ] # Similar original graph
966
+ for j in batch_indices
967
+ ]
968
+ )
969
+ positive_batches.append(pos_batch)
970
+ # # Prepare positive batches (alternative implementation)
971
+ # positive_batches = []
972
+ # max_positives = max(len(positive_samples[j][1]) for j in batch_indices)
973
+ # for pos_idx in range(max_positives): # iterate positives per query
974
+ # pos_graphs = []
975
+ # for j in batch_indices:
976
+ # query_idx, pos_indices = positive_samples[j]
977
+ # if pos_idx < len(pos_indices):
978
+ # # first positive uses augmented graph
979
+ # if pos_idx == 0:
980
+ # pos_graphs.append(augmented_graphs[pos_indices[pos_idx]])
981
+ # else:
982
+ # pos_graphs.append(original_graphs[pos_indices[pos_idx]]) # similar original graph
983
+ # else:
984
+ # # if not enough positives, fall back to self
985
+ # pos_graphs.append(original_graphs[j])
986
+
987
+ # pos_batch = Batch.from_data_list(pos_graphs)
988
+ # positive_batches.append(pos_batch)
989
+
990
+ yield query_batch, positive_batches