pyg-nightly 2.7.0.dev20250902__py3-none-any.whl → 2.7.0.dev20250904__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250902.dist-info → pyg_nightly-2.7.0.dev20250904.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250902.dist-info → pyg_nightly-2.7.0.dev20250904.dist-info}/RECORD +16 -16
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/data.py +98 -0
- torch_geometric/data/feature_store.py +3 -3
- torch_geometric/data/graph_store.py +1 -1
- torch_geometric/data/hetero_data.py +116 -0
- torch_geometric/datasets/airfrans.py +2 -2
- torch_geometric/distributed/partition.py +1 -1
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/nn/conv/meshcnn_conv.py +5 -5
- torch_geometric/nn/models/polynormer.py +1 -1
- torch_geometric/nn/models/rev_gnn.py +2 -2
- torch_geometric/utils/influence.py +3 -3
- {pyg_nightly-2.7.0.dev20250902.dist-info → pyg_nightly-2.7.0.dev20250904.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250902.dist-info → pyg_nightly-2.7.0.dev20250904.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250902.dist-info → pyg_nightly-2.7.0.dev20250904.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250904
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=kAQ6fnK2P5phSQX95Xx4qdMo6r6k-i2p8cDLRERlxuw,2292
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
3
3
|
torch_geometric/_onnx.py,sha256=ODB_8cwFUiwBUjngXn6-K5HHb7IDul7DDXuuGX7vj_0,8178
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -33,15 +33,15 @@ torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uY
|
|
33
33
|
torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
|
34
34
|
torch_geometric/data/batch.py,sha256=8X8CN4_1rjrh48R3R2--mZUgfsO7Po9JP-H6SbrBiBA,8740
|
35
35
|
torch_geometric/data/collate.py,sha256=tOUvttXoEo-bOvJx_qMivJq2JqOsB9iDdjovtiyys4o,12644
|
36
|
-
torch_geometric/data/data.py,sha256
|
36
|
+
torch_geometric/data/data.py,sha256=-E6el1knNgSJyapV8KUk2aRRHOfvwEvjUFfe_BapLfc,47490
|
37
37
|
torch_geometric/data/database.py,sha256=K3KLefYVfsBN9HRItgFZNkbUIllfDt4ueauBFxk3Rxk,23106
|
38
38
|
torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
|
39
39
|
torch_geometric/data/dataset.py,sha256=AaJH0N9eZgvxX0ljyTH8cXutKJ0AGFAyE-H4Sw9D51w,16834
|
40
40
|
torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KEORc,1889
|
41
41
|
torch_geometric/data/extract.py,sha256=DMG8_6ps4O5xKfkb7j1gUBX_jlWpFdmz6OLY2jBSEx4,2339
|
42
|
-
torch_geometric/data/feature_store.py,sha256=
|
43
|
-
torch_geometric/data/graph_store.py,sha256=
|
44
|
-
torch_geometric/data/hetero_data.py,sha256=
|
42
|
+
torch_geometric/data/feature_store.py,sha256=pl2pJL25wqzEZnNZbW8c8Ee_yH0DnE2AK8TioTWZV-g,20045
|
43
|
+
torch_geometric/data/graph_store.py,sha256=dSMCcMYlka2elfw-Rof-lG_iGQv6NHX98uPEVcgDn_g,13900
|
44
|
+
torch_geometric/data/hetero_data.py,sha256=i4_Wt8MFyEG1ZhsU_zg7IKqA2ZatrYyjtYs_R_Qrf3U,53353
|
45
45
|
torch_geometric/data/hypergraph_data.py,sha256=LfriiuJRx9ZrrSGj_fO5NUsh4kvyXJuRdCOqsWo__vc,8307
|
46
46
|
torch_geometric/data/in_memory_dataset.py,sha256=ilFxjF4pvBILsS4wOqocwRBc2j6toI2S_KMFF19KB1w,13413
|
47
47
|
torch_geometric/data/large_graph_indexer.py,sha256=myXTXhbRHQPxEOHNHPeNHB_pBzXCIQBr1KQt9WwBoi8,25468
|
@@ -57,7 +57,7 @@ torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxB
|
|
57
57
|
torch_geometric/data/lightning/datamodule.py,sha256=jDv9ibLQV_FyPeq8ncq77oOU8qy-STCf-aaYd0R8JE8,29545
|
58
58
|
torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
|
59
59
|
torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
|
60
|
-
torch_geometric/datasets/airfrans.py,sha256=
|
60
|
+
torch_geometric/datasets/airfrans.py,sha256=Pc9C7IuEKkKzko_RmFPQ5gzOAGJ3132DoZZ4HaePBT8,5440
|
61
61
|
torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
|
62
62
|
torch_geometric/datasets/amazon.py,sha256=zLiAgrd_44LAFb8drsrIphRJNyuWa6TMjZgmoWdf97Y,3005
|
63
63
|
torch_geometric/datasets/amazon_book.py,sha256=I-8kRsKgk9M60D4icYDELajlsRwksMLDaHMyn6sBC1Y,3214
|
@@ -191,7 +191,7 @@ torch_geometric/distributed/dist_neighbor_sampler.py,sha256=YrL-NMFOJwHJpF189o4k
|
|
191
191
|
torch_geometric/distributed/event_loop.py,sha256=wr3iwMYEWOGkBlvC5huD2k5YxisaGE9w1Z-8RcQiIQk,3309
|
192
192
|
torch_geometric/distributed/local_feature_store.py,sha256=CLW37RN0ouDufEs2tY9d2nLLvpxubiT6zgW3LIHAU8k,19058
|
193
193
|
torch_geometric/distributed/local_graph_store.py,sha256=wNHXSS824Kk2HynbtWFXx-W4pl97UUBv6qFHAVqO3W4,8445
|
194
|
-
torch_geometric/distributed/partition.py,sha256=
|
194
|
+
torch_geometric/distributed/partition.py,sha256=VYw_3CdpRKXr1O4C80JRSMm8Od6xrS3t6H2bfmfJlGE,14733
|
195
195
|
torch_geometric/distributed/rpc.py,sha256=j4TZQkk7NB2CIovRrasyvL9l9L4J6_YOq43gpzFMxow,5713
|
196
196
|
torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jdSIUjBeM,6567
|
197
197
|
torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
|
@@ -271,7 +271,7 @@ torch_geometric/io/txt_array.py,sha256=LDeX2qtlNKW-kVe-wpnskMwAdXQp1jVCGQnrJce7S
|
|
271
271
|
torch_geometric/loader/__init__.py,sha256=DJrdCD1A5PuBYRSgxFbZU9GTBStNuKngqkUV1oEfQQ4,1971
|
272
272
|
torch_geometric/loader/base.py,sha256=ataIwNEYL0px3CN3LJEgXIVTRylDHB6-yBFXXuX2JN0,1615
|
273
273
|
torch_geometric/loader/cache.py,sha256=S65heO3YTyUPbttqizCNtKPHIoAw5iHRpbvw6KlXmok,2106
|
274
|
-
torch_geometric/loader/cluster.py,sha256=
|
274
|
+
torch_geometric/loader/cluster.py,sha256=CbZUy739vzMqOKgof2N73uc-Br4Daw56G3XMzptLUT8,13469
|
275
275
|
torch_geometric/loader/data_list_loader.py,sha256=uLNqeMTkza8EEBjzqZWvgQS5kv5qWa9dyyxt6lIlcUA,1459
|
276
276
|
torch_geometric/loader/dataloader.py,sha256=XzboK_Ygnzvaj2UQ1Q0az-6fdlKsUKlsbjo07sbErrQ,3527
|
277
277
|
torch_geometric/loader/dense_data_loader.py,sha256=GDb_Vu2XyNL5iYzw2zoh1YiurZRr6d7mnT6HF2GWKxM,1685
|
@@ -377,7 +377,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
|
|
377
377
|
torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
|
378
378
|
torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
|
379
379
|
torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
|
380
|
-
torch_geometric/nn/conv/meshcnn_conv.py,sha256=
|
380
|
+
torch_geometric/nn/conv/meshcnn_conv.py,sha256=qt9oAlj6krDU2DBkgr6s_dPw1_vtxfish4iW74JZ70g,21951
|
381
381
|
torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
|
382
382
|
torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
|
383
383
|
torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
|
@@ -463,11 +463,11 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
|
|
463
463
|
torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
|
464
464
|
torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
|
465
465
|
torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
|
466
|
-
torch_geometric/nn/models/polynormer.py,sha256=
|
466
|
+
torch_geometric/nn/models/polynormer.py,sha256=JgUngkF18sgepAAJTO7js9RISmYLWiO04-JEeV4J__8,7641
|
467
467
|
torch_geometric/nn/models/protein_mpnn.py,sha256=SwTgafSbI2KJ-yqzn0trZtVWLmfo0_kPEaWSNJUCt70,12266
|
468
468
|
torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
|
469
469
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
470
|
-
torch_geometric/nn/models/rev_gnn.py,sha256=
|
470
|
+
torch_geometric/nn/models/rev_gnn.py,sha256=bkKBAd_vXZ3UDMTJgdVObqleYHOTVsVcisftK7XoDlo,11797
|
471
471
|
torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O3hDOZM,16656
|
472
472
|
torch_geometric/nn/models/sgformer.py,sha256=3NDzkEVRtM1QmeJsXSq7FBhGGchyUvyX1SDPKYg9F70,6875
|
473
473
|
torch_geometric/nn/models/signed_gcn.py,sha256=HEKaXZIWoDnsBRxIytviTpwsjQIFKl44c9glNUpwhlM,9841
|
@@ -626,7 +626,7 @@ torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFec
|
|
626
626
|
torch_geometric/utils/functions.py,sha256=orQdS_6EpzWSmBHSok3WhxCzLy9neB-cin1aTnlXY-8,703
|
627
627
|
torch_geometric/utils/geodesic.py,sha256=e_XCn7dxqeYJBL-sAc2DfxF3kp_ZUIP0vwqsx1yshmU,4777
|
628
628
|
torch_geometric/utils/hetero.py,sha256=ok4uAAOyMiaeEPmvyS4DNoDwdKnLS2gmgs5WVVklxOo,5539
|
629
|
-
torch_geometric/utils/influence.py,sha256=
|
629
|
+
torch_geometric/utils/influence.py,sha256=wZqt4iR5s0t5LyB--srDthJnR8h9HQN3IqdVtjzD3Cc,10351
|
630
630
|
torch_geometric/utils/isolated.py,sha256=nUxCfMY3q9IIFjelr4eyAJH4sYG9W3lGdpWidnp3dm4,3588
|
631
631
|
torch_geometric/utils/laplacian.py,sha256=ludDil4yS1A27PEuYOjZtCtE3o-t0lnucJKfiqENhvM,3695
|
632
632
|
torch_geometric/utils/loop.py,sha256=MUWUS7a5GxuxLKlCtRq95U1hc3MndybAhqKD5IAe2RY,23051
|
@@ -646,7 +646,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
646
646
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
647
647
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
648
648
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
649
|
-
pyg_nightly-2.7.0.
|
650
|
-
pyg_nightly-2.7.0.
|
651
|
-
pyg_nightly-2.7.0.
|
652
|
-
pyg_nightly-2.7.0.
|
649
|
+
pyg_nightly-2.7.0.dev20250904.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
650
|
+
pyg_nightly-2.7.0.dev20250904.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
651
|
+
pyg_nightly-2.7.0.dev20250904.dist-info/METADATA,sha256=-vA9MGqxs1s0QEIIV_Vj3gZop1fABzvtq8fMNiso7jY,64100
|
652
|
+
pyg_nightly-2.7.0.dev20250904.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250904'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
torch_geometric/data/data.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import warnings
|
3
|
+
from collections import defaultdict
|
3
4
|
from collections.abc import Mapping, Sequence
|
4
5
|
from dataclasses import dataclass
|
5
6
|
from itertools import chain
|
@@ -904,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
904
905
|
|
905
906
|
return data
|
906
907
|
|
908
|
+
def connected_components(self) -> List[Self]:
|
909
|
+
r"""Extracts connected components of the graph using a union-find
|
910
|
+
algorithm. The components are returned as a list of
|
911
|
+
:class:`~torch_geometric.data.Data` objects, where each object
|
912
|
+
represents a connected component of the graph.
|
913
|
+
|
914
|
+
.. code-block::
|
915
|
+
|
916
|
+
data = Data()
|
917
|
+
data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
918
|
+
data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
|
919
|
+
data.edge_index = torch.tensor(
|
920
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
921
|
+
)
|
922
|
+
|
923
|
+
components = data.connected_components()
|
924
|
+
print(len(components))
|
925
|
+
>>> 2
|
926
|
+
|
927
|
+
print(components[0].x)
|
928
|
+
>>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
|
929
|
+
|
930
|
+
Returns:
|
931
|
+
List[Data]: A list of disconnected components.
|
932
|
+
"""
|
933
|
+
# Union-Find algorithm to find connected components
|
934
|
+
self._parents: Dict[int, int] = {}
|
935
|
+
self._ranks: Dict[int, int] = {}
|
936
|
+
for edge in self.edge_index.t().tolist():
|
937
|
+
self._union(edge[0], edge[1])
|
938
|
+
|
939
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
940
|
+
for node in range(self.num_nodes):
|
941
|
+
self._find_parent(node)
|
942
|
+
|
943
|
+
# Group parents
|
944
|
+
grouped_parents = defaultdict(list)
|
945
|
+
for node, parent in self._parents.items():
|
946
|
+
grouped_parents[parent].append(node)
|
947
|
+
del self._ranks
|
948
|
+
del self._parents
|
949
|
+
|
950
|
+
# Create components based on the found parents (roots)
|
951
|
+
components: List[Self] = []
|
952
|
+
for nodes in grouped_parents.values():
|
953
|
+
# Convert the list of node IDs to a tensor
|
954
|
+
subset = torch.tensor(nodes, dtype=torch.long)
|
955
|
+
|
956
|
+
# Use the existing subgraph function
|
957
|
+
component_data = self.subgraph(subset)
|
958
|
+
components.append(component_data)
|
959
|
+
|
960
|
+
return components
|
961
|
+
|
907
962
|
###########################################################################
|
908
963
|
|
909
964
|
@classmethod
|
@@ -1150,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
1150
1205
|
|
1151
1206
|
return list(edge_attrs.values())
|
1152
1207
|
|
1208
|
+
# Connected Components Helper Functions ###################################
|
1209
|
+
|
1210
|
+
def _find_parent(self, node: int) -> int:
|
1211
|
+
r"""Finds and returns the representative parent of the given node in a
|
1212
|
+
disjoint-set (union-find) data structure. Implements path compression
|
1213
|
+
to optimize future queries.
|
1214
|
+
|
1215
|
+
Args:
|
1216
|
+
node (int): The node for which to find the representative parent.
|
1217
|
+
|
1218
|
+
Returns:
|
1219
|
+
int: The representative parent of the node.
|
1220
|
+
"""
|
1221
|
+
if node not in self._parents:
|
1222
|
+
self._parents[node] = node
|
1223
|
+
self._ranks[node] = 0
|
1224
|
+
if self._parents[node] != node:
|
1225
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
1226
|
+
return self._parents[node]
|
1227
|
+
|
1228
|
+
def _union(self, node1: int, node2: int):
|
1229
|
+
r"""Merges the sets containing node1 and node2 in the disjoint-set
|
1230
|
+
data structure.
|
1231
|
+
|
1232
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
1233
|
+
method. If they belong to different sets, updates the parent of
|
1234
|
+
root2 to be root1, effectively merging the two sets.
|
1235
|
+
|
1236
|
+
Args:
|
1237
|
+
node1 (int): The index of the first node to union.
|
1238
|
+
node2 (int): The index of the second node to union.
|
1239
|
+
"""
|
1240
|
+
root1 = self._find_parent(node1)
|
1241
|
+
root2 = self._find_parent(node2)
|
1242
|
+
if root1 != root2:
|
1243
|
+
if self._ranks[root1] < self._ranks[root2]:
|
1244
|
+
self._parents[root1] = root2
|
1245
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
1246
|
+
self._parents[root2] = root1
|
1247
|
+
else:
|
1248
|
+
self._parents[root2] = root1
|
1249
|
+
self._ranks[root1] += 1
|
1250
|
+
|
1153
1251
|
|
1154
1252
|
###############################################################################
|
1155
1253
|
|
@@ -11,7 +11,7 @@ This particular feature store abstraction makes a few key assumptions:
|
|
11
11
|
* A feature can be uniquely identified from any associated attributes specified
|
12
12
|
in `TensorAttr`.
|
13
13
|
|
14
|
-
It is the job of a feature store
|
14
|
+
It is the job of a feature store implementer class to handle these assumptions
|
15
15
|
properly. For example, a simple in-memory feature store implementation may
|
16
16
|
concatenate all metadata values with a feature index and use this as a unique
|
17
17
|
index in a KV store. More complicated implementations may choose to partition
|
@@ -352,7 +352,7 @@ class FeatureStore(ABC):
|
|
352
352
|
|
353
353
|
.. note::
|
354
354
|
The default implementation simply iterates over all calls to
|
355
|
-
:meth:`get_tensor`.
|
355
|
+
:meth:`get_tensor`. Implementer classes that can provide
|
356
356
|
additional, more performant functionality are recommended to
|
357
357
|
to override this method.
|
358
358
|
|
@@ -412,7 +412,7 @@ class FeatureStore(ABC):
|
|
412
412
|
value. Returns whether the update was successful.
|
413
413
|
|
414
414
|
.. note::
|
415
|
-
|
415
|
+
Implementer classes can choose to define more efficient update
|
416
416
|
methods; the default performs a removal and insertion.
|
417
417
|
|
418
418
|
Args:
|
@@ -10,7 +10,7 @@ This particular graph store abstraction makes a few key assumptions:
|
|
10
10
|
support dynamic modification of edge indices once they have been inserted
|
11
11
|
into the graph store.
|
12
12
|
|
13
|
-
It is the job of a graph store
|
13
|
+
It is the job of a graph store implementer class to handle these assumptions
|
14
14
|
properly. For example, a simple in-memory graph store implementation may
|
15
15
|
concatenate all metadata values with an edge index and use this as a unique
|
16
16
|
index in a KV store. More complicated implementations may choose to partition
|
@@ -487,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
487
487
|
|
488
488
|
return status
|
489
489
|
|
490
|
+
def connected_components(self) -> List[Self]:
|
491
|
+
r"""Extracts connected components of the heterogeneous graph using
|
492
|
+
a union-find algorithm. The components are returned as a list of
|
493
|
+
:class:`~torch_geometric.data.HeteroData` objects.
|
494
|
+
|
495
|
+
.. code-block::
|
496
|
+
|
497
|
+
data = HeteroData()
|
498
|
+
data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
499
|
+
data["blue"].x = torch.tensor([[5.0], [6.0]])
|
500
|
+
data["red", "to", "red"].edge_index = torch.tensor(
|
501
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
502
|
+
)
|
503
|
+
|
504
|
+
components = data.connected_components()
|
505
|
+
print(len(components))
|
506
|
+
>>> 4
|
507
|
+
|
508
|
+
print(components[0])
|
509
|
+
>>> HeteroData(
|
510
|
+
red={x: tensor([[1.], [2.]])},
|
511
|
+
blue={x: tensor([[]])},
|
512
|
+
red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
|
513
|
+
)
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
List[HeteroData]: A list of connected components.
|
517
|
+
"""
|
518
|
+
# Initialize union-find structures
|
519
|
+
self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
|
520
|
+
self._ranks: Dict[Tuple[str, int], int] = {}
|
521
|
+
|
522
|
+
# Union-Find algorithm to find connected components
|
523
|
+
for edge_type in self.edge_types:
|
524
|
+
src, _, dst = edge_type
|
525
|
+
edge_index = self[edge_type].edge_index
|
526
|
+
for src_node, dst_node in edge_index.t().tolist():
|
527
|
+
self._union((src, src_node), (dst, dst_node))
|
528
|
+
|
529
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
530
|
+
for node_type in self.node_types:
|
531
|
+
for node_index in range(self[node_type].num_nodes):
|
532
|
+
self._find_parent((node_type, node_index))
|
533
|
+
|
534
|
+
# Group nodes by their representative parent
|
535
|
+
components_map = defaultdict(list)
|
536
|
+
for node, parent in self._parents.items():
|
537
|
+
components_map[parent].append(node)
|
538
|
+
del self._parents
|
539
|
+
del self._ranks
|
540
|
+
|
541
|
+
components: List[Self] = []
|
542
|
+
for nodes in components_map.values():
|
543
|
+
# Prefill subset_dict with all node types to ensure all are present
|
544
|
+
subset_dict = {node_type: [] for node_type in self.node_types}
|
545
|
+
|
546
|
+
# Convert the list of (node_type, node_id) tuples to a subset_dict
|
547
|
+
for node_type, node_id in nodes:
|
548
|
+
subset_dict[node_type].append(node_id)
|
549
|
+
|
550
|
+
# Convert lists to tensors
|
551
|
+
for node_type, node_ids in subset_dict.items():
|
552
|
+
subset_dict[node_type] = torch.tensor(node_ids,
|
553
|
+
dtype=torch.long)
|
554
|
+
|
555
|
+
# Use the existing subgraph function to do all the heavy lifting
|
556
|
+
component_data = self.subgraph(subset_dict)
|
557
|
+
components.append(component_data)
|
558
|
+
|
559
|
+
return components
|
560
|
+
|
490
561
|
def debug(self):
|
491
562
|
pass # TODO
|
492
563
|
|
@@ -1148,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
1148
1219
|
|
1149
1220
|
return list(edge_attrs.values())
|
1150
1221
|
|
1222
|
+
# Connected Components Helper Functions ###################################
|
1223
|
+
|
1224
|
+
def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
|
1225
|
+
r"""Finds and returns the representative parent of the given node in a
|
1226
|
+
disjoint-set (union-find) data structure. Implements path compression
|
1227
|
+
to optimize future queries.
|
1228
|
+
|
1229
|
+
Args:
|
1230
|
+
node (tuple[str, int]): The node for which to find the parent.
|
1231
|
+
First element is the node type, second is the node index.
|
1232
|
+
|
1233
|
+
Returns:
|
1234
|
+
tuple[str, int]: The representative parent of the node.
|
1235
|
+
"""
|
1236
|
+
if node not in self._parents:
|
1237
|
+
self._parents[node] = node
|
1238
|
+
self._ranks[node] = 0
|
1239
|
+
if self._parents[node] != node:
|
1240
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
1241
|
+
return self._parents[node]
|
1242
|
+
|
1243
|
+
def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
|
1244
|
+
r"""Merges the node1 and node2 in the disjoint-set data structure.
|
1245
|
+
|
1246
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
1247
|
+
method. If they belong to different sets, updates the parent of
|
1248
|
+
root2 to be root1, effectively merging the two sets.
|
1249
|
+
|
1250
|
+
Args:
|
1251
|
+
node1 (Tuple[str, int]): The first node to union. First element is
|
1252
|
+
the node type, second is the node index.
|
1253
|
+
node2 (Tuple[str, int]): The second node to union. First element is
|
1254
|
+
the node type, second is the node index.
|
1255
|
+
"""
|
1256
|
+
root1 = self._find_parent(node1)
|
1257
|
+
root2 = self._find_parent(node2)
|
1258
|
+
if root1 != root2:
|
1259
|
+
if self._ranks[root1] < self._ranks[root2]:
|
1260
|
+
self._parents[root1] = root2
|
1261
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
1262
|
+
self._parents[root2] = root1
|
1263
|
+
else:
|
1264
|
+
self._parents[root2] = root1
|
1265
|
+
self._ranks[root1] += 1
|
1266
|
+
|
1151
1267
|
|
1152
1268
|
# Helper functions ############################################################
|
1153
1269
|
|
@@ -30,8 +30,8 @@ class AirfRANS(InMemoryDataset):
|
|
30
30
|
divided by the specific mass (one component in meter squared per second
|
31
31
|
squared), the turbulent kinematic viscosity (one component in meter squared
|
32
32
|
per second).
|
33
|
-
|
34
|
-
the airfoil or not.
|
33
|
+
Finally, a boolean is attached to each point to inform if this point lies
|
34
|
+
on the airfoil or not.
|
35
35
|
|
36
36
|
A library for manipulating simulations of the dataset is available `here
|
37
37
|
<https://airfrans.readthedocs.io/en/latest/index.html>`_.
|
@@ -361,7 +361,7 @@ class Partitioner:
|
|
361
361
|
'edge_types': self.edge_types,
|
362
362
|
'node_offset': list(node_offset.values()) if node_offset else None,
|
363
363
|
'is_hetero': self.is_hetero,
|
364
|
-
'is_sorted': True, # Based on
|
364
|
+
'is_sorted': True, # Based on column/destination.
|
365
365
|
}
|
366
366
|
with open(osp.join(self.root, 'META.json'), 'w') as f:
|
367
367
|
json.dump(meta, f)
|
@@ -235,9 +235,9 @@ class ClusterData(torch.utils.data.Dataset):
|
|
235
235
|
class ClusterLoader(torch.utils.data.DataLoader):
|
236
236
|
r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
|
237
237
|
for Training Deep and Large Graph Convolutional Networks"
|
238
|
-
<https://arxiv.org/abs/1905.07953>`_ paper which merges
|
239
|
-
and their between-cluster links from a large-scale graph data
|
240
|
-
form a mini-batch.
|
238
|
+
<https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned
|
239
|
+
subgraphs and their between-cluster links from a large-scale graph data
|
240
|
+
object to form a mini-batch.
|
241
241
|
|
242
242
|
.. note::
|
243
243
|
|
@@ -252,7 +252,7 @@ class ClusterLoader(torch.utils.data.DataLoader):
|
|
252
252
|
|
253
253
|
Args:
|
254
254
|
cluster_data (torch_geometric.loader.ClusterData): The already
|
255
|
-
|
255
|
+
partitioned data object.
|
256
256
|
**kwargs (optional): Additional arguments of
|
257
257
|
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
|
258
258
|
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
|
@@ -64,7 +64,7 @@ class MeshCNNConv(MessagePassing):
|
|
64
64
|
:math:`\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)`
|
65
65
|
|
66
66
|
|
67
|
-
Because of this ordering
|
67
|
+
Because of this ordering constraint, :obj:`MeshCNNConv` **requires
|
68
68
|
that the columns of** :math:`A`
|
69
69
|
**be ordered in the following way**:
|
70
70
|
|
@@ -149,7 +149,7 @@ class MeshCNNConv(MessagePassing):
|
|
149
149
|
|
150
150
|
|
151
151
|
Args:
|
152
|
-
in_channels (int):
|
152
|
+
in_channels (int): Corresponds to :math:`\text{Dim-Out}(k)`
|
153
153
|
in the above overview. This
|
154
154
|
represents the output dimension of the prior layer. For the given
|
155
155
|
input mesh :math:`\mathcal{m} = (V, F)`, the prior layer is
|
@@ -184,7 +184,7 @@ class MeshCNNConv(MessagePassing):
|
|
184
184
|
a vector of dimensions :attr:`out_channels`.
|
185
185
|
|
186
186
|
Discussion:
|
187
|
-
The key difference that
|
187
|
+
The key difference that separates :obj:`MeshCNNConv` from a traditional
|
188
188
|
message passing graph neural network is that :obj:`MeshCNNConv`
|
189
189
|
requires the set of neighbors for a node
|
190
190
|
:math:`\mathcal{N}(u) = (v_1, v_2, ...)`
|
@@ -198,7 +198,7 @@ class MeshCNNConv(MessagePassing):
|
|
198
198
|
:math:`\mathbb{S}_4`. Put more plainly, in tradition message passing
|
199
199
|
GNNs, the network is *unable* to distinguish one neighboring node
|
200
200
|
from another.
|
201
|
-
In
|
201
|
+
In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a
|
202
202
|
"role", either the "a", "b", "c", or "d" neighbor. We encode this fact
|
203
203
|
by requiring that :math:`\mathcal{N}` return the 4-tuple,
|
204
204
|
where the first component is the "a" neighbor, and so on.
|
@@ -444,7 +444,7 @@ class MeshCNNConv(MessagePassing):
|
|
444
444
|
"""
|
445
445
|
assert isinstance(kernels, ModuleList), \
|
446
446
|
f"Parameter 'kernels' must be a \
|
447
|
-
torch.nn.module.ModuleList with 5
|
447
|
+
torch.nn.module.ModuleList with 5 members, but we got \
|
448
448
|
{type(kernels)}."
|
449
449
|
|
450
450
|
assert len(kernels) == 5, "Parameter 'kernels' must be a \
|
@@ -37,7 +37,7 @@ class Polynormer(torch.nn.Module):
|
|
37
37
|
(default: :obj:`True`)
|
38
38
|
pre_ln (bool): Pre layer normalization.
|
39
39
|
(default: :obj:`False`)
|
40
|
-
post_bn (bool): Post batch
|
40
|
+
post_bn (bool): Post batch normalization.
|
41
41
|
(default: :obj:`True`)
|
42
42
|
local_attn (bool): Whether use local attention.
|
43
43
|
(default: :obj:`False`)
|
@@ -196,8 +196,8 @@ class InvertibleModule(torch.nn.Module):
|
|
196
196
|
class GroupAddRev(InvertibleModule):
|
197
197
|
r"""The Grouped Reversible GNN module from the `"Graph Neural Networks with
|
198
198
|
1000 Layers" <https://arxiv.org/abs/2106.07476>`_ paper.
|
199
|
-
This module enables training of
|
200
|
-
independent of the number of layers.
|
199
|
+
This module enables training of arbitrary deep GNNs with a memory
|
200
|
+
complexity independent of the number of layers.
|
201
201
|
|
202
202
|
It does so by partitioning input node features :math:`\mathbf{X}` into
|
203
203
|
:math:`C` groups across the feature dimension. Then, a grouped reversible
|
@@ -159,8 +159,8 @@ def jacobian_l1_agg_per_hop(
|
|
159
159
|
vectorize=vectorize)
|
160
160
|
hop_subsets = k_hop_subsets_exact(node_idx, max_hops, edge_index,
|
161
161
|
num_nodes, influence.device)
|
162
|
-
|
163
|
-
return torch.tensor(
|
162
|
+
single_node_influence_per_hop = [influence[s].sum() for s in hop_subsets]
|
163
|
+
return torch.tensor(single_node_influence_per_hop, device=influence.device)
|
164
164
|
|
165
165
|
|
166
166
|
def avg_total_influence(
|
@@ -169,7 +169,7 @@ def avg_total_influence(
|
|
169
169
|
) -> Tensor:
|
170
170
|
"""Compute the *influence‑weighted receptive field* ``R``."""
|
171
171
|
avg_total_influences = torch.mean(influence_all_nodes, dim=0)
|
172
|
-
if normalize: #
|
172
|
+
if normalize: # normalize by hop_0 (jacobian of the center node feature)
|
173
173
|
avg_total_influences = avg_total_influences / avg_total_influences[0]
|
174
174
|
return avg_total_influences
|
175
175
|
|
File without changes
|
{pyg_nightly-2.7.0.dev20250902.dist-info → pyg_nightly-2.7.0.dev20250904.dist-info}/licenses/LICENSE
RENAMED
File without changes
|