pyg-nightly 2.7.0.dev20250903__py3-none-any.whl → 2.7.0.dev20250905__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250903.dist-info → pyg_nightly-2.7.0.dev20250905.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250903.dist-info → pyg_nightly-2.7.0.dev20250905.dist-info}/RECORD +9 -9
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/data.py +98 -0
- torch_geometric/data/hetero_data.py +116 -0
- torch_geometric/metrics/link_pred.py +13 -2
- torch_geometric/utils/cross_entropy.py +34 -13
- {pyg_nightly-2.7.0.dev20250903.dist-info → pyg_nightly-2.7.0.dev20250905.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250903.dist-info → pyg_nightly-2.7.0.dev20250905.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250903.dist-info → pyg_nightly-2.7.0.dev20250905.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250905
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=JaI2udwlYGKK9_OGrMEQn05CBfOSEPFFsk9OtBnPN2c,2292
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
3
3
|
torch_geometric/_onnx.py,sha256=ODB_8cwFUiwBUjngXn6-K5HHb7IDul7DDXuuGX7vj_0,8178
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -33,7 +33,7 @@ torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uY
|
|
33
33
|
torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
|
34
34
|
torch_geometric/data/batch.py,sha256=8X8CN4_1rjrh48R3R2--mZUgfsO7Po9JP-H6SbrBiBA,8740
|
35
35
|
torch_geometric/data/collate.py,sha256=tOUvttXoEo-bOvJx_qMivJq2JqOsB9iDdjovtiyys4o,12644
|
36
|
-
torch_geometric/data/data.py,sha256
|
36
|
+
torch_geometric/data/data.py,sha256=-E6el1knNgSJyapV8KUk2aRRHOfvwEvjUFfe_BapLfc,47490
|
37
37
|
torch_geometric/data/database.py,sha256=K3KLefYVfsBN9HRItgFZNkbUIllfDt4ueauBFxk3Rxk,23106
|
38
38
|
torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
|
39
39
|
torch_geometric/data/dataset.py,sha256=AaJH0N9eZgvxX0ljyTH8cXutKJ0AGFAyE-H4Sw9D51w,16834
|
@@ -41,7 +41,7 @@ torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KE
|
|
41
41
|
torch_geometric/data/extract.py,sha256=DMG8_6ps4O5xKfkb7j1gUBX_jlWpFdmz6OLY2jBSEx4,2339
|
42
42
|
torch_geometric/data/feature_store.py,sha256=pl2pJL25wqzEZnNZbW8c8Ee_yH0DnE2AK8TioTWZV-g,20045
|
43
43
|
torch_geometric/data/graph_store.py,sha256=dSMCcMYlka2elfw-Rof-lG_iGQv6NHX98uPEVcgDn_g,13900
|
44
|
-
torch_geometric/data/hetero_data.py,sha256=
|
44
|
+
torch_geometric/data/hetero_data.py,sha256=i4_Wt8MFyEG1ZhsU_zg7IKqA2ZatrYyjtYs_R_Qrf3U,53353
|
45
45
|
torch_geometric/data/hypergraph_data.py,sha256=LfriiuJRx9ZrrSGj_fO5NUsh4kvyXJuRdCOqsWo__vc,8307
|
46
46
|
torch_geometric/data/in_memory_dataset.py,sha256=ilFxjF4pvBILsS4wOqocwRBc2j6toI2S_KMFF19KB1w,13413
|
47
47
|
torch_geometric/data/large_graph_indexer.py,sha256=myXTXhbRHQPxEOHNHPeNHB_pBzXCIQBr1KQt9WwBoi8,25468
|
@@ -294,7 +294,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=Z7L_rYdl6SYBQXAgtr18FVcmfMH
|
|
294
294
|
torch_geometric/loader/utils.py,sha256=3hzKzIgB52QIZu7Jdn4JeXZaegIJinIQfIUP9DrUWUQ,14903
|
295
295
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
296
296
|
torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
|
297
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
297
|
+
torch_geometric/metrics/link_pred.py,sha256=1_hE3KiRqAdZLI6QuUbjgyFC__mTyFu_RimM3bD8wRw,31678
|
298
298
|
torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
|
299
299
|
torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
|
300
300
|
torch_geometric/nn/encoding.py,sha256=3DCOCO-XFt-lMb97sHWGN-4KeGUFY5lVo9P00SzrCNk,3559
|
@@ -620,7 +620,7 @@ torch_geometric/utils/_trim_to_layer.py,sha256=cauOEzMJJK4w9BC-Pg1bHVncBYqG9XxQe
|
|
620
620
|
torch_geometric/utils/_unbatch.py,sha256=B0vjKI96PtHvSBG8F_lqvsiJE134aVjUurPZsG6UZRI,2378
|
621
621
|
torch_geometric/utils/augmentation.py,sha256=1F0YCuaklZ9ZbXxdFV0oOoemWvLd8p60WvFo2chzl7E,8600
|
622
622
|
torch_geometric/utils/convert.py,sha256=RE5n5no74Xu39-QMWFE0-1RvTgykdK33ymyjF9WcuSs,21938
|
623
|
-
torch_geometric/utils/cross_entropy.py,sha256=
|
623
|
+
torch_geometric/utils/cross_entropy.py,sha256=_6whuSCWKNzavOLf83uZbI9RFU0wGRHQict_4-8XIYs,3773
|
624
624
|
torch_geometric/utils/dropout.py,sha256=gg0rDnD4FLvBaKSoLAkZwViAQflhLefJm6_Mju5dmQs,11416
|
625
625
|
torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFeclGwMk,5232
|
626
626
|
torch_geometric/utils/functions.py,sha256=orQdS_6EpzWSmBHSok3WhxCzLy9neB-cin1aTnlXY-8,703
|
@@ -646,7 +646,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
646
646
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
647
647
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
648
648
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
649
|
-
pyg_nightly-2.7.0.
|
650
|
-
pyg_nightly-2.7.0.
|
651
|
-
pyg_nightly-2.7.0.
|
652
|
-
pyg_nightly-2.7.0.
|
649
|
+
pyg_nightly-2.7.0.dev20250905.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
650
|
+
pyg_nightly-2.7.0.dev20250905.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
651
|
+
pyg_nightly-2.7.0.dev20250905.dist-info/METADATA,sha256=7zFs8SgMgRdkwK679QLO1W08UGscHMakHb-3NXsiEok,64100
|
652
|
+
pyg_nightly-2.7.0.dev20250905.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250905'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
torch_geometric/data/data.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import warnings
|
3
|
+
from collections import defaultdict
|
3
4
|
from collections.abc import Mapping, Sequence
|
4
5
|
from dataclasses import dataclass
|
5
6
|
from itertools import chain
|
@@ -904,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
904
905
|
|
905
906
|
return data
|
906
907
|
|
908
|
+
def connected_components(self) -> List[Self]:
|
909
|
+
r"""Extracts connected components of the graph using a union-find
|
910
|
+
algorithm. The components are returned as a list of
|
911
|
+
:class:`~torch_geometric.data.Data` objects, where each object
|
912
|
+
represents a connected component of the graph.
|
913
|
+
|
914
|
+
.. code-block::
|
915
|
+
|
916
|
+
data = Data()
|
917
|
+
data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
918
|
+
data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
|
919
|
+
data.edge_index = torch.tensor(
|
920
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
921
|
+
)
|
922
|
+
|
923
|
+
components = data.connected_components()
|
924
|
+
print(len(components))
|
925
|
+
>>> 2
|
926
|
+
|
927
|
+
print(components[0].x)
|
928
|
+
>>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
|
929
|
+
|
930
|
+
Returns:
|
931
|
+
List[Data]: A list of disconnected components.
|
932
|
+
"""
|
933
|
+
# Union-Find algorithm to find connected components
|
934
|
+
self._parents: Dict[int, int] = {}
|
935
|
+
self._ranks: Dict[int, int] = {}
|
936
|
+
for edge in self.edge_index.t().tolist():
|
937
|
+
self._union(edge[0], edge[1])
|
938
|
+
|
939
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
940
|
+
for node in range(self.num_nodes):
|
941
|
+
self._find_parent(node)
|
942
|
+
|
943
|
+
# Group parents
|
944
|
+
grouped_parents = defaultdict(list)
|
945
|
+
for node, parent in self._parents.items():
|
946
|
+
grouped_parents[parent].append(node)
|
947
|
+
del self._ranks
|
948
|
+
del self._parents
|
949
|
+
|
950
|
+
# Create components based on the found parents (roots)
|
951
|
+
components: List[Self] = []
|
952
|
+
for nodes in grouped_parents.values():
|
953
|
+
# Convert the list of node IDs to a tensor
|
954
|
+
subset = torch.tensor(nodes, dtype=torch.long)
|
955
|
+
|
956
|
+
# Use the existing subgraph function
|
957
|
+
component_data = self.subgraph(subset)
|
958
|
+
components.append(component_data)
|
959
|
+
|
960
|
+
return components
|
961
|
+
|
907
962
|
###########################################################################
|
908
963
|
|
909
964
|
@classmethod
|
@@ -1150,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
1150
1205
|
|
1151
1206
|
return list(edge_attrs.values())
|
1152
1207
|
|
1208
|
+
# Connected Components Helper Functions ###################################
|
1209
|
+
|
1210
|
+
def _find_parent(self, node: int) -> int:
|
1211
|
+
r"""Finds and returns the representative parent of the given node in a
|
1212
|
+
disjoint-set (union-find) data structure. Implements path compression
|
1213
|
+
to optimize future queries.
|
1214
|
+
|
1215
|
+
Args:
|
1216
|
+
node (int): The node for which to find the representative parent.
|
1217
|
+
|
1218
|
+
Returns:
|
1219
|
+
int: The representative parent of the node.
|
1220
|
+
"""
|
1221
|
+
if node not in self._parents:
|
1222
|
+
self._parents[node] = node
|
1223
|
+
self._ranks[node] = 0
|
1224
|
+
if self._parents[node] != node:
|
1225
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
1226
|
+
return self._parents[node]
|
1227
|
+
|
1228
|
+
def _union(self, node1: int, node2: int):
|
1229
|
+
r"""Merges the sets containing node1 and node2 in the disjoint-set
|
1230
|
+
data structure.
|
1231
|
+
|
1232
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
1233
|
+
method. If they belong to different sets, updates the parent of
|
1234
|
+
root2 to be root1, effectively merging the two sets.
|
1235
|
+
|
1236
|
+
Args:
|
1237
|
+
node1 (int): The index of the first node to union.
|
1238
|
+
node2 (int): The index of the second node to union.
|
1239
|
+
"""
|
1240
|
+
root1 = self._find_parent(node1)
|
1241
|
+
root2 = self._find_parent(node2)
|
1242
|
+
if root1 != root2:
|
1243
|
+
if self._ranks[root1] < self._ranks[root2]:
|
1244
|
+
self._parents[root1] = root2
|
1245
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
1246
|
+
self._parents[root2] = root1
|
1247
|
+
else:
|
1248
|
+
self._parents[root2] = root1
|
1249
|
+
self._ranks[root1] += 1
|
1250
|
+
|
1153
1251
|
|
1154
1252
|
###############################################################################
|
1155
1253
|
|
@@ -487,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
487
487
|
|
488
488
|
return status
|
489
489
|
|
490
|
+
def connected_components(self) -> List[Self]:
|
491
|
+
r"""Extracts connected components of the heterogeneous graph using
|
492
|
+
a union-find algorithm. The components are returned as a list of
|
493
|
+
:class:`~torch_geometric.data.HeteroData` objects.
|
494
|
+
|
495
|
+
.. code-block::
|
496
|
+
|
497
|
+
data = HeteroData()
|
498
|
+
data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
499
|
+
data["blue"].x = torch.tensor([[5.0], [6.0]])
|
500
|
+
data["red", "to", "red"].edge_index = torch.tensor(
|
501
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
502
|
+
)
|
503
|
+
|
504
|
+
components = data.connected_components()
|
505
|
+
print(len(components))
|
506
|
+
>>> 4
|
507
|
+
|
508
|
+
print(components[0])
|
509
|
+
>>> HeteroData(
|
510
|
+
red={x: tensor([[1.], [2.]])},
|
511
|
+
blue={x: tensor([[]])},
|
512
|
+
red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
|
513
|
+
)
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
List[HeteroData]: A list of connected components.
|
517
|
+
"""
|
518
|
+
# Initialize union-find structures
|
519
|
+
self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
|
520
|
+
self._ranks: Dict[Tuple[str, int], int] = {}
|
521
|
+
|
522
|
+
# Union-Find algorithm to find connected components
|
523
|
+
for edge_type in self.edge_types:
|
524
|
+
src, _, dst = edge_type
|
525
|
+
edge_index = self[edge_type].edge_index
|
526
|
+
for src_node, dst_node in edge_index.t().tolist():
|
527
|
+
self._union((src, src_node), (dst, dst_node))
|
528
|
+
|
529
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
530
|
+
for node_type in self.node_types:
|
531
|
+
for node_index in range(self[node_type].num_nodes):
|
532
|
+
self._find_parent((node_type, node_index))
|
533
|
+
|
534
|
+
# Group nodes by their representative parent
|
535
|
+
components_map = defaultdict(list)
|
536
|
+
for node, parent in self._parents.items():
|
537
|
+
components_map[parent].append(node)
|
538
|
+
del self._parents
|
539
|
+
del self._ranks
|
540
|
+
|
541
|
+
components: List[Self] = []
|
542
|
+
for nodes in components_map.values():
|
543
|
+
# Prefill subset_dict with all node types to ensure all are present
|
544
|
+
subset_dict = {node_type: [] for node_type in self.node_types}
|
545
|
+
|
546
|
+
# Convert the list of (node_type, node_id) tuples to a subset_dict
|
547
|
+
for node_type, node_id in nodes:
|
548
|
+
subset_dict[node_type].append(node_id)
|
549
|
+
|
550
|
+
# Convert lists to tensors
|
551
|
+
for node_type, node_ids in subset_dict.items():
|
552
|
+
subset_dict[node_type] = torch.tensor(node_ids,
|
553
|
+
dtype=torch.long)
|
554
|
+
|
555
|
+
# Use the existing subgraph function to do all the heavy lifting
|
556
|
+
component_data = self.subgraph(subset_dict)
|
557
|
+
components.append(component_data)
|
558
|
+
|
559
|
+
return components
|
560
|
+
|
490
561
|
def debug(self):
|
491
562
|
pass # TODO
|
492
563
|
|
@@ -1148,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
1148
1219
|
|
1149
1220
|
return list(edge_attrs.values())
|
1150
1221
|
|
1222
|
+
# Connected Components Helper Functions ###################################
|
1223
|
+
|
1224
|
+
def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
|
1225
|
+
r"""Finds and returns the representative parent of the given node in a
|
1226
|
+
disjoint-set (union-find) data structure. Implements path compression
|
1227
|
+
to optimize future queries.
|
1228
|
+
|
1229
|
+
Args:
|
1230
|
+
node (tuple[str, int]): The node for which to find the parent.
|
1231
|
+
First element is the node type, second is the node index.
|
1232
|
+
|
1233
|
+
Returns:
|
1234
|
+
tuple[str, int]: The representative parent of the node.
|
1235
|
+
"""
|
1236
|
+
if node not in self._parents:
|
1237
|
+
self._parents[node] = node
|
1238
|
+
self._ranks[node] = 0
|
1239
|
+
if self._parents[node] != node:
|
1240
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
1241
|
+
return self._parents[node]
|
1242
|
+
|
1243
|
+
def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
|
1244
|
+
r"""Merges the node1 and node2 in the disjoint-set data structure.
|
1245
|
+
|
1246
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
1247
|
+
method. If they belong to different sets, updates the parent of
|
1248
|
+
root2 to be root1, effectively merging the two sets.
|
1249
|
+
|
1250
|
+
Args:
|
1251
|
+
node1 (Tuple[str, int]): The first node to union. First element is
|
1252
|
+
the node type, second is the node index.
|
1253
|
+
node2 (Tuple[str, int]): The second node to union. First element is
|
1254
|
+
the node type, second is the node index.
|
1255
|
+
"""
|
1256
|
+
root1 = self._find_parent(node1)
|
1257
|
+
root2 = self._find_parent(node2)
|
1258
|
+
if root1 != root2:
|
1259
|
+
if self._ranks[root1] < self._ranks[root2]:
|
1260
|
+
self._parents[root1] = root2
|
1261
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
1262
|
+
self._parents[root2] = root1
|
1263
|
+
else:
|
1264
|
+
self._parents[root2] = root1
|
1265
|
+
self._ranks[root1] += 1
|
1266
|
+
|
1151
1267
|
|
1152
1268
|
# Helper functions ############################################################
|
1153
1269
|
|
@@ -21,6 +21,19 @@ class LinkPredMetricData:
|
|
21
21
|
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
|
22
22
|
edge_label_weight: Optional[Tensor] = None
|
23
23
|
|
24
|
+
def __post_init__(self) -> None:
|
25
|
+
# Filter all negative weights - they should not be used as ground-truth
|
26
|
+
if self.edge_label_weight is not None:
|
27
|
+
pos_mask = self.edge_label_weight > 0
|
28
|
+
self.edge_label_weight = self.edge_label_weight[pos_mask]
|
29
|
+
if isinstance(self.edge_label_index, Tensor):
|
30
|
+
self.edge_label_index = self.edge_label_index[:, pos_mask]
|
31
|
+
else:
|
32
|
+
self.edge_label_index = (
|
33
|
+
self.edge_label_index[0][pos_mask],
|
34
|
+
self.edge_label_index[1][pos_mask],
|
35
|
+
)
|
36
|
+
|
24
37
|
@property
|
25
38
|
def pred_rel_mat(self) -> Tensor:
|
26
39
|
r"""Returns a matrix indicating the relevance of the `k`-th prediction.
|
@@ -374,8 +387,6 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
|
|
374
387
|
if self.weighted and edge_label_weight is None:
|
375
388
|
raise ValueError(f"'edge_label_weight' is a required argument for "
|
376
389
|
f"weighted '{self.__class__.__name__}' metrics")
|
377
|
-
if not self.weighted:
|
378
|
-
edge_label_weight = None
|
379
390
|
|
380
391
|
data = LinkPredMetricData( # Share metric data across metrics.
|
381
392
|
pred_index_mat=pred_index_mat,
|
@@ -18,30 +18,51 @@ class SparseCrossEntropy(torch.autograd.Function):
|
|
18
18
|
) -> Tensor:
|
19
19
|
assert inputs.dim() == 2
|
20
20
|
|
21
|
-
|
22
|
-
|
23
|
-
|
21
|
+
# Support for both positive and negative weights:
|
22
|
+
# Positive weights scale the logits *after* softmax.
|
23
|
+
# Negative weights scale the denominator *before* softmax:
|
24
|
+
pos_y = edge_label_index
|
25
|
+
neg_y = pos_weight = neg_weight = None
|
24
26
|
|
25
|
-
out = inputs[edge_label_index[0], edge_label_index[1]]
|
26
|
-
out.neg_().add_(logsumexp[edge_label_index[0]])
|
27
27
|
if edge_label_weight is not None:
|
28
|
-
|
28
|
+
pos_mask = edge_label_weight >= 0
|
29
|
+
pos_y = edge_label_index[:, pos_mask]
|
30
|
+
pos_weight = edge_label_weight[pos_mask]
|
31
|
+
|
32
|
+
if pos_y.size(1) < edge_label_index.size(1):
|
33
|
+
neg_mask = ~pos_mask
|
34
|
+
neg_y = edge_label_index[:, neg_mask]
|
35
|
+
neg_weight = edge_label_weight[neg_mask]
|
36
|
+
|
37
|
+
if neg_y is not None and neg_weight is not None:
|
38
|
+
inputs = inputs.clone()
|
39
|
+
inputs[
|
40
|
+
neg_y[0],
|
41
|
+
neg_y[1],
|
42
|
+
] += neg_weight.abs().log().clamp(min=1e-12)
|
43
|
+
|
44
|
+
logsumexp = inputs.logsumexp(dim=-1)
|
45
|
+
ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)
|
46
|
+
|
47
|
+
out = inputs[pos_y[0], pos_y[1]]
|
48
|
+
out.neg_().add_(logsumexp[pos_y[0]])
|
49
|
+
if pos_weight is not None:
|
50
|
+
out *= pos_weight
|
29
51
|
|
30
52
|
return out.sum() / inputs.size(0)
|
31
53
|
|
32
54
|
@staticmethod
|
33
55
|
@torch.autograd.function.once_differentiable
|
34
56
|
def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:
|
35
|
-
inputs,
|
36
|
-
ctx.saved_tensors)
|
57
|
+
inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors
|
37
58
|
|
38
59
|
grad_out = grad_out / inputs.size(0)
|
39
|
-
grad_out = grad_out.expand(
|
60
|
+
grad_out = grad_out.expand(pos_y.size(1))
|
40
61
|
|
41
|
-
if
|
42
|
-
grad_out = grad_out *
|
62
|
+
if pos_weight is not None:
|
63
|
+
grad_out = grad_out * pos_weight
|
43
64
|
|
44
|
-
grad_logsumexp = scatter(grad_out,
|
65
|
+
grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,
|
45
66
|
dim_size=inputs.size(0), reduce='sum')
|
46
67
|
|
47
68
|
# Gradient computation of `logsumexp`: `grad * (self - result).exp()`
|
@@ -49,7 +70,7 @@ class SparseCrossEntropy(torch.autograd.Function):
|
|
49
70
|
grad_input.exp_()
|
50
71
|
grad_input.mul_(grad_logsumexp.view(-1, 1))
|
51
72
|
|
52
|
-
grad_input[
|
73
|
+
grad_input[pos_y[0], pos_y[1]] -= grad_out
|
53
74
|
|
54
75
|
return grad_input, None, None
|
55
76
|
|
File without changes
|
{pyg_nightly-2.7.0.dev20250903.dist-info → pyg_nightly-2.7.0.dev20250905.dist-info}/licenses/LICENSE
RENAMED
File without changes
|