pyg-nightly 2.7.0.dev20250903__py3-none-any.whl → 2.7.0.dev20250905__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250903
3
+ Version: 2.7.0.dev20250905
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=FdFwyBPcM-xmx9bjJNEjeOhn8oy2Hji4JY1KU0OrrH4,2292
1
+ torch_geometric/__init__.py,sha256=JaI2udwlYGKK9_OGrMEQn05CBfOSEPFFsk9OtBnPN2c,2292
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=ODB_8cwFUiwBUjngXn6-K5HHb7IDul7DDXuuGX7vj_0,8178
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -33,7 +33,7 @@ torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uY
33
33
  torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
34
34
  torch_geometric/data/batch.py,sha256=8X8CN4_1rjrh48R3R2--mZUgfsO7Po9JP-H6SbrBiBA,8740
35
35
  torch_geometric/data/collate.py,sha256=tOUvttXoEo-bOvJx_qMivJq2JqOsB9iDdjovtiyys4o,12644
36
- torch_geometric/data/data.py,sha256=F6DLTd-hJI9hiHemS04Eh-kGZTbg_16-gvRVFsbXfRM,43827
36
+ torch_geometric/data/data.py,sha256=-E6el1knNgSJyapV8KUk2aRRHOfvwEvjUFfe_BapLfc,47490
37
37
  torch_geometric/data/database.py,sha256=K3KLefYVfsBN9HRItgFZNkbUIllfDt4ueauBFxk3Rxk,23106
38
38
  torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
39
39
  torch_geometric/data/dataset.py,sha256=AaJH0N9eZgvxX0ljyTH8cXutKJ0AGFAyE-H4Sw9D51w,16834
@@ -41,7 +41,7 @@ torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KE
41
41
  torch_geometric/data/extract.py,sha256=DMG8_6ps4O5xKfkb7j1gUBX_jlWpFdmz6OLY2jBSEx4,2339
42
42
  torch_geometric/data/feature_store.py,sha256=pl2pJL25wqzEZnNZbW8c8Ee_yH0DnE2AK8TioTWZV-g,20045
43
43
  torch_geometric/data/graph_store.py,sha256=dSMCcMYlka2elfw-Rof-lG_iGQv6NHX98uPEVcgDn_g,13900
44
- torch_geometric/data/hetero_data.py,sha256=2LV8pSvv-IWkTUk8xg7VeI17YMhikg1RkeQhMwA8tkE,48583
44
+ torch_geometric/data/hetero_data.py,sha256=i4_Wt8MFyEG1ZhsU_zg7IKqA2ZatrYyjtYs_R_Qrf3U,53353
45
45
  torch_geometric/data/hypergraph_data.py,sha256=LfriiuJRx9ZrrSGj_fO5NUsh4kvyXJuRdCOqsWo__vc,8307
46
46
  torch_geometric/data/in_memory_dataset.py,sha256=ilFxjF4pvBILsS4wOqocwRBc2j6toI2S_KMFF19KB1w,13413
47
47
  torch_geometric/data/large_graph_indexer.py,sha256=myXTXhbRHQPxEOHNHPeNHB_pBzXCIQBr1KQt9WwBoi8,25468
@@ -294,7 +294,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=Z7L_rYdl6SYBQXAgtr18FVcmfMH
294
294
  torch_geometric/loader/utils.py,sha256=3hzKzIgB52QIZu7Jdn4JeXZaegIJinIQfIUP9DrUWUQ,14903
295
295
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
296
296
  torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
297
- torch_geometric/metrics/link_pred.py,sha256=mRQTSYYJgLKXFCelZHMKVOSbPED11JVhbryp7ajjxDU,31137
297
+ torch_geometric/metrics/link_pred.py,sha256=1_hE3KiRqAdZLI6QuUbjgyFC__mTyFu_RimM3bD8wRw,31678
298
298
  torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
299
299
  torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
300
300
  torch_geometric/nn/encoding.py,sha256=3DCOCO-XFt-lMb97sHWGN-4KeGUFY5lVo9P00SzrCNk,3559
@@ -620,7 +620,7 @@ torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQe
620
620
  torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
621
621
  torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
622
622
  torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
623
- torch_geometric/utils/cross_entropy.py,sha256=ZFS5bivtzv3EV9zqgKsekmuQyoZZggPSclhl_tRNHxo,3047
623
+ torch_geometric/utils/cross_entropy.py,sha256=_6whuSCWKNzavOLf83uZbI9RFU0wGRHQict_4-8XIYs,3773
624
624
  torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
625
625
  torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
626
626
  torch_geometric/utils/functions.py,sha256=orQdS_6EpzWSmBHSok3WhxCzLy9neB-cin1aTnlXY-8,703
@@ -646,7 +646,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
646
646
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
647
647
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
648
648
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
649
- pyg_nightly-2.7.0.dev20250903.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
650
- pyg_nightly-2.7.0.dev20250903.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
651
- pyg_nightly-2.7.0.dev20250903.dist-info/METADATA,sha256=vYdkEEMQ1rrRaIksgih36EJ9htFz8pD4ZkQM7Zb3UsA,64100
652
- pyg_nightly-2.7.0.dev20250903.dist-info/RECORD,,
649
+ pyg_nightly-2.7.0.dev20250905.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
650
+ pyg_nightly-2.7.0.dev20250905.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
651
+ pyg_nightly-2.7.0.dev20250905.dist-info/METADATA,sha256=7zFs8SgMgRdkwK679QLO1W08UGscHMakHb-3NXsiEok,64100
652
+ pyg_nightly-2.7.0.dev20250905.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250903'
34
+ __version__ = '2.7.0.dev20250905'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import warnings
3
+ from collections import defaultdict
3
4
  from collections.abc import Mapping, Sequence
4
5
  from dataclasses import dataclass
5
6
  from itertools import chain
@@ -904,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
904
905
 
905
906
  return data
906
907
 
908
+ def connected_components(self) -> List[Self]:
909
+ r"""Extracts connected components of the graph using a union-find
910
+ algorithm. The components are returned as a list of
911
+ :class:`~torch_geometric.data.Data` objects, where each object
912
+ represents a connected component of the graph.
913
+
914
+ .. code-block::
915
+
916
+ data = Data()
917
+ data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
918
+ data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
919
+ data.edge_index = torch.tensor(
920
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
921
+ )
922
+
923
+ components = data.connected_components()
924
+ print(len(components))
925
+ >>> 2
926
+
927
+ print(components[0].x)
928
+ >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
929
+
930
+ Returns:
931
+ List[Data]: A list of disconnected components.
932
+ """
933
+ # Union-Find algorithm to find connected components
934
+ self._parents: Dict[int, int] = {}
935
+ self._ranks: Dict[int, int] = {}
936
+ for edge in self.edge_index.t().tolist():
937
+ self._union(edge[0], edge[1])
938
+
939
+ # Rerun _find_parent to ensure all nodes are covered correctly
940
+ for node in range(self.num_nodes):
941
+ self._find_parent(node)
942
+
943
+ # Group parents
944
+ grouped_parents = defaultdict(list)
945
+ for node, parent in self._parents.items():
946
+ grouped_parents[parent].append(node)
947
+ del self._ranks
948
+ del self._parents
949
+
950
+ # Create components based on the found parents (roots)
951
+ components: List[Self] = []
952
+ for nodes in grouped_parents.values():
953
+ # Convert the list of node IDs to a tensor
954
+ subset = torch.tensor(nodes, dtype=torch.long)
955
+
956
+ # Use the existing subgraph function
957
+ component_data = self.subgraph(subset)
958
+ components.append(component_data)
959
+
960
+ return components
961
+
907
962
  ###########################################################################
908
963
 
909
964
  @classmethod
@@ -1150,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
1150
1205
 
1151
1206
  return list(edge_attrs.values())
1152
1207
 
1208
+ # Connected Components Helper Functions ###################################
1209
+
1210
+ def _find_parent(self, node: int) -> int:
1211
+ r"""Finds and returns the representative parent of the given node in a
1212
+ disjoint-set (union-find) data structure. Implements path compression
1213
+ to optimize future queries.
1214
+
1215
+ Args:
1216
+ node (int): The node for which to find the representative parent.
1217
+
1218
+ Returns:
1219
+ int: The representative parent of the node.
1220
+ """
1221
+ if node not in self._parents:
1222
+ self._parents[node] = node
1223
+ self._ranks[node] = 0
1224
+ if self._parents[node] != node:
1225
+ self._parents[node] = self._find_parent(self._parents[node])
1226
+ return self._parents[node]
1227
+
1228
+ def _union(self, node1: int, node2: int):
1229
+ r"""Merges the sets containing node1 and node2 in the disjoint-set
1230
+ data structure.
1231
+
1232
+ Finds the root parents of node1 and node2 using the _find_parent
1233
+ method. If they belong to different sets, updates the parent of
1234
+ root2 to be root1, effectively merging the two sets.
1235
+
1236
+ Args:
1237
+ node1 (int): The index of the first node to union.
1238
+ node2 (int): The index of the second node to union.
1239
+ """
1240
+ root1 = self._find_parent(node1)
1241
+ root2 = self._find_parent(node2)
1242
+ if root1 != root2:
1243
+ if self._ranks[root1] < self._ranks[root2]:
1244
+ self._parents[root1] = root2
1245
+ elif self._ranks[root1] > self._ranks[root2]:
1246
+ self._parents[root2] = root1
1247
+ else:
1248
+ self._parents[root2] = root1
1249
+ self._ranks[root1] += 1
1250
+
1153
1251
 
1154
1252
  ###############################################################################
1155
1253
 
@@ -487,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
487
487
 
488
488
  return status
489
489
 
490
+ def connected_components(self) -> List[Self]:
491
+ r"""Extracts connected components of the heterogeneous graph using
492
+ a union-find algorithm. The components are returned as a list of
493
+ :class:`~torch_geometric.data.HeteroData` objects.
494
+
495
+ .. code-block::
496
+
497
+ data = HeteroData()
498
+ data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
499
+ data["blue"].x = torch.tensor([[5.0], [6.0]])
500
+ data["red", "to", "red"].edge_index = torch.tensor(
501
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
502
+ )
503
+
504
+ components = data.connected_components()
505
+ print(len(components))
506
+ >>> 4
507
+
508
+ print(components[0])
509
+ >>> HeteroData(
510
+ red={x: tensor([[1.], [2.]])},
511
+ blue={x: tensor([[]])},
512
+ red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
513
+ )
514
+
515
+ Returns:
516
+ List[HeteroData]: A list of connected components.
517
+ """
518
+ # Initialize union-find structures
519
+ self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
520
+ self._ranks: Dict[Tuple[str, int], int] = {}
521
+
522
+ # Union-Find algorithm to find connected components
523
+ for edge_type in self.edge_types:
524
+ src, _, dst = edge_type
525
+ edge_index = self[edge_type].edge_index
526
+ for src_node, dst_node in edge_index.t().tolist():
527
+ self._union((src, src_node), (dst, dst_node))
528
+
529
+ # Rerun _find_parent to ensure all nodes are covered correctly
530
+ for node_type in self.node_types:
531
+ for node_index in range(self[node_type].num_nodes):
532
+ self._find_parent((node_type, node_index))
533
+
534
+ # Group nodes by their representative parent
535
+ components_map = defaultdict(list)
536
+ for node, parent in self._parents.items():
537
+ components_map[parent].append(node)
538
+ del self._parents
539
+ del self._ranks
540
+
541
+ components: List[Self] = []
542
+ for nodes in components_map.values():
543
+ # Prefill subset_dict with all node types to ensure all are present
544
+ subset_dict = {node_type: [] for node_type in self.node_types}
545
+
546
+ # Convert the list of (node_type, node_id) tuples to a subset_dict
547
+ for node_type, node_id in nodes:
548
+ subset_dict[node_type].append(node_id)
549
+
550
+ # Convert lists to tensors
551
+ for node_type, node_ids in subset_dict.items():
552
+ subset_dict[node_type] = torch.tensor(node_ids,
553
+ dtype=torch.long)
554
+
555
+ # Use the existing subgraph function to do all the heavy lifting
556
+ component_data = self.subgraph(subset_dict)
557
+ components.append(component_data)
558
+
559
+ return components
560
+
490
561
  def debug(self):
491
562
  pass # TODO
492
563
 
@@ -1148,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
1148
1219
 
1149
1220
  return list(edge_attrs.values())
1150
1221
 
1222
+ # Connected Components Helper Functions ###################################
1223
+
1224
+ def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
1225
+ r"""Finds and returns the representative parent of the given node in a
1226
+ disjoint-set (union-find) data structure. Implements path compression
1227
+ to optimize future queries.
1228
+
1229
+ Args:
1230
+ node (tuple[str, int]): The node for which to find the parent.
1231
+ First element is the node type, second is the node index.
1232
+
1233
+ Returns:
1234
+ tuple[str, int]: The representative parent of the node.
1235
+ """
1236
+ if node not in self._parents:
1237
+ self._parents[node] = node
1238
+ self._ranks[node] = 0
1239
+ if self._parents[node] != node:
1240
+ self._parents[node] = self._find_parent(self._parents[node])
1241
+ return self._parents[node]
1242
+
1243
+ def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
1244
+ r"""Merges the node1 and node2 in the disjoint-set data structure.
1245
+
1246
+ Finds the root parents of node1 and node2 using the _find_parent
1247
+ method. If they belong to different sets, updates the parent of
1248
+ root2 to be root1, effectively merging the two sets.
1249
+
1250
+ Args:
1251
+ node1 (Tuple[str, int]): The first node to union. First element is
1252
+ the node type, second is the node index.
1253
+ node2 (Tuple[str, int]): The second node to union. First element is
1254
+ the node type, second is the node index.
1255
+ """
1256
+ root1 = self._find_parent(node1)
1257
+ root2 = self._find_parent(node2)
1258
+ if root1 != root2:
1259
+ if self._ranks[root1] < self._ranks[root2]:
1260
+ self._parents[root1] = root2
1261
+ elif self._ranks[root1] > self._ranks[root2]:
1262
+ self._parents[root2] = root1
1263
+ else:
1264
+ self._parents[root2] = root1
1265
+ self._ranks[root1] += 1
1266
+
1151
1267
 
1152
1268
  # Helper functions ############################################################
1153
1269
 
@@ -21,6 +21,19 @@ class LinkPredMetricData:
21
21
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
22
22
  edge_label_weight: Optional[Tensor] = None
23
23
 
24
+ def __post_init__(self) -> None:
25
+ # Filter all negative weights - they should not be used as ground-truth
26
+ if self.edge_label_weight is not None:
27
+ pos_mask = self.edge_label_weight > 0
28
+ self.edge_label_weight = self.edge_label_weight[pos_mask]
29
+ if isinstance(self.edge_label_index, Tensor):
30
+ self.edge_label_index = self.edge_label_index[:, pos_mask]
31
+ else:
32
+ self.edge_label_index = (
33
+ self.edge_label_index[0][pos_mask],
34
+ self.edge_label_index[1][pos_mask],
35
+ )
36
+
24
37
  @property
25
38
  def pred_rel_mat(self) -> Tensor:
26
39
  r"""Returns a matrix indicating the relevance of the `k`-th prediction.
@@ -374,8 +387,6 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
374
387
  if self.weighted and edge_label_weight is None:
375
388
  raise ValueError(f"'edge_label_weight' is a required argument for "
376
389
  f"weighted '{self.__class__.__name__}' metrics")
377
- if not self.weighted:
378
- edge_label_weight = None
379
390
 
380
391
  data = LinkPredMetricData( # Share metric data across metrics.
381
392
  pred_index_mat=pred_index_mat,
@@ -18,30 +18,51 @@ class SparseCrossEntropy(torch.autograd.Function):
18
18
  ) -> Tensor:
19
19
  assert inputs.dim() == 2
20
20
 
21
- logsumexp = inputs.logsumexp(dim=-1)
22
- ctx.save_for_backward(inputs, edge_label_index, edge_label_weight,
23
- logsumexp)
21
+ # Support for both positive and negative weights:
22
+ # Positive weights scale the logits *after* softmax.
23
+ # Negative weights scale the denominator *before* softmax:
24
+ pos_y = edge_label_index
25
+ neg_y = pos_weight = neg_weight = None
24
26
 
25
- out = inputs[edge_label_index[0], edge_label_index[1]]
26
- out.neg_().add_(logsumexp[edge_label_index[0]])
27
27
  if edge_label_weight is not None:
28
- out *= edge_label_weight
28
+ pos_mask = edge_label_weight >= 0
29
+ pos_y = edge_label_index[:, pos_mask]
30
+ pos_weight = edge_label_weight[pos_mask]
31
+
32
+ if pos_y.size(1) < edge_label_index.size(1):
33
+ neg_mask = ~pos_mask
34
+ neg_y = edge_label_index[:, neg_mask]
35
+ neg_weight = edge_label_weight[neg_mask]
36
+
37
+ if neg_y is not None and neg_weight is not None:
38
+ inputs = inputs.clone()
39
+ inputs[
40
+ neg_y[0],
41
+ neg_y[1],
42
+ ] += neg_weight.abs().log().clamp(min=1e-12)
43
+
44
+ logsumexp = inputs.logsumexp(dim=-1)
45
+ ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)
46
+
47
+ out = inputs[pos_y[0], pos_y[1]]
48
+ out.neg_().add_(logsumexp[pos_y[0]])
49
+ if pos_weight is not None:
50
+ out *= pos_weight
29
51
 
30
52
  return out.sum() / inputs.size(0)
31
53
 
32
54
  @staticmethod
33
55
  @torch.autograd.function.once_differentiable
34
56
  def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:
35
- inputs, edge_label_index, edge_label_weight, logsumexp = (
36
- ctx.saved_tensors)
57
+ inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors
37
58
 
38
59
  grad_out = grad_out / inputs.size(0)
39
- grad_out = grad_out.expand(edge_label_index.size(1))
60
+ grad_out = grad_out.expand(pos_y.size(1))
40
61
 
41
- if edge_label_weight is not None:
42
- grad_out = grad_out * edge_label_weight
62
+ if pos_weight is not None:
63
+ grad_out = grad_out * pos_weight
43
64
 
44
- grad_logsumexp = scatter(grad_out, edge_label_index[0], dim=0,
65
+ grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,
45
66
  dim_size=inputs.size(0), reduce='sum')
46
67
 
47
68
  # Gradient computation of `logsumexp`: `grad * (self - result).exp()`
@@ -49,7 +70,7 @@ class SparseCrossEntropy(torch.autograd.Function):
49
70
  grad_input.exp_()
50
71
  grad_input.mul_(grad_logsumexp.view(-1, 1))
51
72
 
52
- grad_input[edge_label_index[0], edge_label_index[1]] -= grad_out
73
+ grad_input[pos_y[0], pos_y[1]] -= grad_out
53
74
 
54
75
  return grad_input, None, None
55
76