pyg-nightly 2.7.0.dev20250902__py3-none-any.whl → 2.7.0.dev20250904__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250902
3
+ Version: 2.7.0.dev20250904
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=NO3jwSl5Hf33qP9GchFcqCCeokVu342RzUIGxdqipDI,2292
1
+ torch_geometric/__init__.py,sha256=kAQ6fnK2P5phSQX95Xx4qdMo6r6k-i2p8cDLRERlxuw,2292
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=ODB_8cwFUiwBUjngXn6-K5HHb7IDul7DDXuuGX7vj_0,8178
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -33,15 +33,15 @@ torch_geometric/contrib/transforms/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uY
33
33
  torch_geometric/data/__init__.py,sha256=D6Iz5A9vEb_2rpf96Zn7uM-lchZ3WpW8X7WdAD1yxKw,4565
34
34
  torch_geometric/data/batch.py,sha256=8X8CN4_1rjrh48R3R2--mZUgfsO7Po9JP-H6SbrBiBA,8740
35
35
  torch_geometric/data/collate.py,sha256=tOUvttXoEo-bOvJx_qMivJq2JqOsB9iDdjovtiyys4o,12644
36
- torch_geometric/data/data.py,sha256=F6DLTd-hJI9hiHemS04Eh-kGZTbg_16-gvRVFsbXfRM,43827
36
+ torch_geometric/data/data.py,sha256=-E6el1knNgSJyapV8KUk2aRRHOfvwEvjUFfe_BapLfc,47490
37
37
  torch_geometric/data/database.py,sha256=K3KLefYVfsBN9HRItgFZNkbUIllfDt4ueauBFxk3Rxk,23106
38
38
  torch_geometric/data/datapipes.py,sha256=9_Cq3j_7LIF4plQFzbLaqyy0LcpKdAic6yiKgMqSX9A,3083
39
39
  torch_geometric/data/dataset.py,sha256=AaJH0N9eZgvxX0ljyTH8cXutKJ0AGFAyE-H4Sw9D51w,16834
40
40
  torch_geometric/data/download.py,sha256=kcesTu6jlgmCeePpOxDQOnVhxB_GuZ9iu9ds72KEORc,1889
41
41
  torch_geometric/data/extract.py,sha256=DMG8_6ps4O5xKfkb7j1gUBX_jlWpFdmz6OLY2jBSEx4,2339
42
- torch_geometric/data/feature_store.py,sha256=BIMgIWpP1y7OCIQxnkdSWcnm8_BFJXuS_zOqfZZQOjI,20045
43
- torch_geometric/data/graph_store.py,sha256=EtIgsyY7RdBHRTCn34VypEBOG8cg8WzsNT_kTFKxJR4,13900
44
- torch_geometric/data/hetero_data.py,sha256=2LV8pSvv-IWkTUk8xg7VeI17YMhikg1RkeQhMwA8tkE,48583
42
+ torch_geometric/data/feature_store.py,sha256=pl2pJL25wqzEZnNZbW8c8Ee_yH0DnE2AK8TioTWZV-g,20045
43
+ torch_geometric/data/graph_store.py,sha256=dSMCcMYlka2elfw-Rof-lG_iGQv6NHX98uPEVcgDn_g,13900
44
+ torch_geometric/data/hetero_data.py,sha256=i4_Wt8MFyEG1ZhsU_zg7IKqA2ZatrYyjtYs_R_Qrf3U,53353
45
45
  torch_geometric/data/hypergraph_data.py,sha256=LfriiuJRx9ZrrSGj_fO5NUsh4kvyXJuRdCOqsWo__vc,8307
46
46
  torch_geometric/data/in_memory_dataset.py,sha256=ilFxjF4pvBILsS4wOqocwRBc2j6toI2S_KMFF19KB1w,13413
47
47
  torch_geometric/data/large_graph_indexer.py,sha256=myXTXhbRHQPxEOHNHPeNHB_pBzXCIQBr1KQt9WwBoi8,25468
@@ -57,7 +57,7 @@ torch_geometric/data/lightning/__init__.py,sha256=w3En1tJfy3kSqe1MycpOyZpHFO3fxB
57
57
  torch_geometric/data/lightning/datamodule.py,sha256=jDv9ibLQV_FyPeq8ncq77oOU8qy-STCf-aaYd0R8JE8,29545
58
58
  torch_geometric/datasets/__init__.py,sha256=rgfUmjd9U3o8renKVl81Brscx4LOtwWmt6qAoaG41C4,6417
59
59
  torch_geometric/datasets/actor.py,sha256=oUxgJIX8bi5hJr1etWNYIFyVQNDDXi1nyVpHGGMEAGQ,4304
60
- torch_geometric/datasets/airfrans.py,sha256=8cCBmHPttrlKY_iwfyr-K-CUX_JEDjrIOg3r9dQSN7o,5439
60
+ torch_geometric/datasets/airfrans.py,sha256=Pc9C7IuEKkKzko_RmFPQ5gzOAGJ3132DoZZ4HaePBT8,5440
61
61
  torch_geometric/datasets/airports.py,sha256=b3gkv3gY2JkUpmGiz36Z-g7EcnSfU8lBG1YsCOWdJ6k,3758
62
62
  torch_geometric/datasets/amazon.py,sha256=zLiAgrd_44LAFb8drsrIphRJNyuWa6TMjZgmoWdf97Y,3005
63
63
  torch_geometric/datasets/amazon_book.py,sha256=I-8kRsKgk9M60D4icYDELajlsRwksMLDaHMyn6sBC1Y,3214
@@ -191,7 +191,7 @@ torch_geometric/distributed/dist_neighbor_sampler.py,sha256=YrL-NMFOJwHJpF189o4k
191
191
  torch_geometric/distributed/event_loop.py,sha256=wr3iwMYEWOGkBlvC5huD2k5YxisaGE9w1Z-8RcQiIQk,3309
192
192
  torch_geometric/distributed/local_feature_store.py,sha256=CLW37RN0ouDufEs2tY9d2nLLvpxubiT6zgW3LIHAU8k,19058
193
193
  torch_geometric/distributed/local_graph_store.py,sha256=wNHXSS824Kk2HynbtWFXx-W4pl97UUBv6qFHAVqO3W4,8445
194
- torch_geometric/distributed/partition.py,sha256=BgjmhDloaooAXM7onGizrcikZs8oRnz5drQZHPDDO_g,14734
194
+ torch_geometric/distributed/partition.py,sha256=VYw_3CdpRKXr1O4C80JRSMm8Od6xrS3t6H2bfmfJlGE,14733
195
195
  torch_geometric/distributed/rpc.py,sha256=j4TZQkk7NB2CIovRrasyvL9l9L4J6_YOq43gpzFMxow,5713
196
196
  torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jdSIUjBeM,6567
197
197
  torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
@@ -271,7 +271,7 @@ torch_geometric/io/txt_array.py,sha256=LDeX2qtlNKW-kVe-wpnskMwAdXQp1jVCGQnrJce7S
271
271
  torch_geometric/loader/__init__.py,sha256=DJrdCD1A5PuBYRSgxFbZU9GTBStNuKngqkUV1oEfQQ4,1971
272
272
  torch_geometric/loader/base.py,sha256=ataIwNEYL0px3CN3LJEgXIVTRylDHB6-yBFXXuX2JN0,1615
273
273
  torch_geometric/loader/cache.py,sha256=S65heO3YTyUPbttqizCNtKPHIoAw5iHRpbvw6KlXmok,2106
274
- torch_geometric/loader/cluster.py,sha256=eMNxVkvZt5oQ_gJRgmWm1NBX7zU2tZI_BPaXeB0wuyk,13465
274
+ torch_geometric/loader/cluster.py,sha256=CbZUy739vzMqOKgof2N73uc-Br4Daw56G3XMzptLUT8,13469
275
275
  torch_geometric/loader/data_list_loader.py,sha256=uLNqeMTkza8EEBjzqZWvgQS5kv5qWa9dyyxt6lIlcUA,1459
276
276
  torch_geometric/loader/dataloader.py,sha256=XzboK_Ygnzvaj2UQ1Q0az-6fdlKsUKlsbjo07sbErrQ,3527
277
277
  torch_geometric/loader/dense_data_loader.py,sha256=GDb_Vu2XyNL5iYzw2zoh1YiurZRr6d7mnT6HF2GWKxM,1685
@@ -377,7 +377,7 @@ torch_geometric/nn/conv/hgt_conv.py,sha256=lUhTWUMovMtn9yR_b2-kLNLqHChGOUl2OtXBY
377
377
  torch_geometric/nn/conv/hypergraph_conv.py,sha256=4BosbbqJyprlI6QjPqIfMxCqnARU_0mUn1zcAQhbw90,8691
378
378
  torch_geometric/nn/conv/le_conv.py,sha256=DonmmYZOKk5wIlTZzzIfNKqBY6MO0MRxYhyr0YtNz-Q,3494
379
379
  torch_geometric/nn/conv/lg_conv.py,sha256=8jMa79iPsOUbXEfBIc3wmbvAD8T3d1j37LeIFTX3Yag,2369
380
- torch_geometric/nn/conv/meshcnn_conv.py,sha256=92zUcgfS0Fwv-MpddF4Ia1a65y7ddPAkazYf7D6kvwg,21951
380
+ torch_geometric/nn/conv/meshcnn_conv.py,sha256=qt9oAlj6krDU2DBkgr6s_dPw1_vtxfish4iW74JZ70g,21951
381
381
  torch_geometric/nn/conv/message_passing.py,sha256=ZuTvSvodGy1GyAW4mHtuoMUuxclam-7opidYNY5IHm8,44377
382
382
  torch_geometric/nn/conv/mf_conv.py,sha256=SkOGMN1tFT9dcqy8xYowsB2ozw6QfkoArgR1BksZZaU,4340
383
383
  torch_geometric/nn/conv/mixhop_conv.py,sha256=qVDPWeWcnO7_eHM0ZnpKtr8SISjb4jp0xjgpoDrwjlk,4555
@@ -463,11 +463,11 @@ torch_geometric/nn/models/molecule_gpt.py,sha256=k-XULH6jaurj-R2EE4sIWTkqlNqa3Cz
463
463
  torch_geometric/nn/models/neural_fingerprint.py,sha256=pTLJgU9Uh2Lnf9bggLj4cKI8YdEFcMF-9MALuubqbuQ,2378
464
464
  torch_geometric/nn/models/node2vec.py,sha256=81Ku4Rp4IwLEAy06KEgJ2fYtXXVL_uv_Hb8lBr6YXrE,7664
465
465
  torch_geometric/nn/models/pmlp.py,sha256=dcAASVSyQMMhItSfEJWPeAFh0R3tNCwAHwdrShwQ8o4,3538
466
- torch_geometric/nn/models/polynormer.py,sha256=mayWdzdolT5PCt_Oo7UGG-JUripMHWB2lUWF1bh6goU,7640
466
+ torch_geometric/nn/models/polynormer.py,sha256=JgUngkF18sgepAAJTO7js9RISmYLWiO04-JEeV4J__8,7641
467
467
  torch_geometric/nn/models/protein_mpnn.py,sha256=SwTgafSbI2KJ-yqzn0trZtVWLmfo0_kPEaWSNJUCt70,12266
468
468
  torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3WAuvITk,8986
469
469
  torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
470
- torch_geometric/nn/models/rev_gnn.py,sha256=Bpme087Zs227lcB0ODOKWsxaly67q96wseaRt6bacjs,11796
470
+ torch_geometric/nn/models/rev_gnn.py,sha256=bkKBAd_vXZ3UDMTJgdVObqleYHOTVsVcisftK7XoDlo,11797
471
471
  torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O3hDOZM,16656
472
472
  torch_geometric/nn/models/sgformer.py,sha256=3NDzkEVRtM1QmeJsXSq7FBhGGchyUvyX1SDPKYg9F70,6875
473
473
  torch_geometric/nn/models/signed_gcn.py,sha256=HEKaXZIWoDnsBRxIytviTpwsjQIFKl44c9glNUpwhlM,9841
@@ -626,7 +626,7 @@ torch_geometric/utils/embedding.py,sha256=Ac_MPSrZGpw-e-gU6Yz-seUioC2WZxBSSzXFec
626
626
  torch_geometric/utils/functions.py,sha256=orQdS_6EpzWSmBHSok3WhxCzLy9neB-cin1aTnlXY-8,703
627
627
  torch_geometric/utils/geodesic.py,sha256=e_XCn7dxqeYJBL-sAc2DfxF3kp_ZUIP0vwqsx1yshmU,4777
628
628
  torch_geometric/utils/hetero.py,sha256=ok4uAAOyMiaeEPmvyS4DNoDwdKnLS2gmgs5WVVklxOo,5539
629
- torch_geometric/utils/influence.py,sha256=7R-NW4myJMJPkbNiwcHTmO_m_B3gPB2IlBbQkB446xc,10348
629
+ torch_geometric/utils/influence.py,sha256=wZqt4iR5s0t5LyB--srDthJnR8h9HQN3IqdVtjzD3Cc,10351
630
630
  torch_geometric/utils/isolated.py,sha256=nUxCfMY3q9IIFjelr4eyAJH4sYG9W3lGdpWidnp3dm4,3588
631
631
  torch_geometric/utils/laplacian.py,sha256=ludDil4yS1A27PEuYOjZtCtE3o-t0lnucJKfiqENhvM,3695
632
632
  torch_geometric/utils/loop.py,sha256=MUWUS7a5GxuxLKlCtRq95U1hc3MndybAhqKD5IAe2RY,23051
@@ -646,7 +646,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
646
646
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
647
647
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
648
648
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
649
- pyg_nightly-2.7.0.dev20250902.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
650
- pyg_nightly-2.7.0.dev20250902.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
651
- pyg_nightly-2.7.0.dev20250902.dist-info/METADATA,sha256=voU97tZI_J6eLyfyPVL0Q-7pwlS9ifAPhaMzuAsqiis,64100
652
- pyg_nightly-2.7.0.dev20250902.dist-info/RECORD,,
649
+ pyg_nightly-2.7.0.dev20250904.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
650
+ pyg_nightly-2.7.0.dev20250904.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
651
+ pyg_nightly-2.7.0.dev20250904.dist-info/METADATA,sha256=-vA9MGqxs1s0QEIIV_Vj3gZop1fABzvtq8fMNiso7jY,64100
652
+ pyg_nightly-2.7.0.dev20250904.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250902'
34
+ __version__ = '2.7.0.dev20250904'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import warnings
3
+ from collections import defaultdict
3
4
  from collections.abc import Mapping, Sequence
4
5
  from dataclasses import dataclass
5
6
  from itertools import chain
@@ -904,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
904
905
 
905
906
  return data
906
907
 
908
+ def connected_components(self) -> List[Self]:
909
+ r"""Extracts connected components of the graph using a union-find
910
+ algorithm. The components are returned as a list of
911
+ :class:`~torch_geometric.data.Data` objects, where each object
912
+ represents a connected component of the graph.
913
+
914
+ .. code-block::
915
+
916
+ data = Data()
917
+ data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
918
+ data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
919
+ data.edge_index = torch.tensor(
920
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
921
+ )
922
+
923
+ components = data.connected_components()
924
+ print(len(components))
925
+ >>> 2
926
+
927
+ print(components[0].x)
928
+ >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
929
+
930
+ Returns:
931
+ List[Data]: A list of disconnected components.
932
+ """
933
+ # Union-Find algorithm to find connected components
934
+ self._parents: Dict[int, int] = {}
935
+ self._ranks: Dict[int, int] = {}
936
+ for edge in self.edge_index.t().tolist():
937
+ self._union(edge[0], edge[1])
938
+
939
+ # Rerun _find_parent to ensure all nodes are covered correctly
940
+ for node in range(self.num_nodes):
941
+ self._find_parent(node)
942
+
943
+ # Group parents
944
+ grouped_parents = defaultdict(list)
945
+ for node, parent in self._parents.items():
946
+ grouped_parents[parent].append(node)
947
+ del self._ranks
948
+ del self._parents
949
+
950
+ # Create components based on the found parents (roots)
951
+ components: List[Self] = []
952
+ for nodes in grouped_parents.values():
953
+ # Convert the list of node IDs to a tensor
954
+ subset = torch.tensor(nodes, dtype=torch.long)
955
+
956
+ # Use the existing subgraph function
957
+ component_data = self.subgraph(subset)
958
+ components.append(component_data)
959
+
960
+ return components
961
+
907
962
  ###########################################################################
908
963
 
909
964
  @classmethod
@@ -1150,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
1150
1205
 
1151
1206
  return list(edge_attrs.values())
1152
1207
 
1208
+ # Connected Components Helper Functions ###################################
1209
+
1210
+ def _find_parent(self, node: int) -> int:
1211
+ r"""Finds and returns the representative parent of the given node in a
1212
+ disjoint-set (union-find) data structure. Implements path compression
1213
+ to optimize future queries.
1214
+
1215
+ Args:
1216
+ node (int): The node for which to find the representative parent.
1217
+
1218
+ Returns:
1219
+ int: The representative parent of the node.
1220
+ """
1221
+ if node not in self._parents:
1222
+ self._parents[node] = node
1223
+ self._ranks[node] = 0
1224
+ if self._parents[node] != node:
1225
+ self._parents[node] = self._find_parent(self._parents[node])
1226
+ return self._parents[node]
1227
+
1228
+ def _union(self, node1: int, node2: int):
1229
+ r"""Merges the sets containing node1 and node2 in the disjoint-set
1230
+ data structure.
1231
+
1232
+ Finds the root parents of node1 and node2 using the _find_parent
1233
+ method. If they belong to different sets, updates the parent of
1234
+ root2 to be root1, effectively merging the two sets.
1235
+
1236
+ Args:
1237
+ node1 (int): The index of the first node to union.
1238
+ node2 (int): The index of the second node to union.
1239
+ """
1240
+ root1 = self._find_parent(node1)
1241
+ root2 = self._find_parent(node2)
1242
+ if root1 != root2:
1243
+ if self._ranks[root1] < self._ranks[root2]:
1244
+ self._parents[root1] = root2
1245
+ elif self._ranks[root1] > self._ranks[root2]:
1246
+ self._parents[root2] = root1
1247
+ else:
1248
+ self._parents[root2] = root1
1249
+ self._ranks[root1] += 1
1250
+
1153
1251
 
1154
1252
  ###############################################################################
1155
1253
 
@@ -11,7 +11,7 @@ This particular feature store abstraction makes a few key assumptions:
11
11
  * A feature can be uniquely identified from any associated attributes specified
12
12
  in `TensorAttr`.
13
13
 
14
- It is the job of a feature store implementor class to handle these assumptions
14
+ It is the job of a feature store implementer class to handle these assumptions
15
15
  properly. For example, a simple in-memory feature store implementation may
16
16
  concatenate all metadata values with a feature index and use this as a unique
17
17
  index in a KV store. More complicated implementations may choose to partition
@@ -352,7 +352,7 @@ class FeatureStore(ABC):
352
352
 
353
353
  .. note::
354
354
  The default implementation simply iterates over all calls to
355
- :meth:`get_tensor`. Implementor classes that can provide
355
+ :meth:`get_tensor`. Implementer classes that can provide
356
356
  additional, more performant functionality are recommended to
357
357
  to override this method.
358
358
 
@@ -412,7 +412,7 @@ class FeatureStore(ABC):
412
412
  value. Returns whether the update was successful.
413
413
 
414
414
  .. note::
415
- Implementor classes can choose to define more efficient update
415
+ Implementer classes can choose to define more efficient update
416
416
  methods; the default performs a removal and insertion.
417
417
 
418
418
  Args:
@@ -10,7 +10,7 @@ This particular graph store abstraction makes a few key assumptions:
10
10
  support dynamic modification of edge indices once they have been inserted
11
11
  into the graph store.
12
12
 
13
- It is the job of a graph store implementor class to handle these assumptions
13
+ It is the job of a graph store implementer class to handle these assumptions
14
14
  properly. For example, a simple in-memory graph store implementation may
15
15
  concatenate all metadata values with an edge index and use this as a unique
16
16
  index in a KV store. More complicated implementations may choose to partition
@@ -487,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
487
487
 
488
488
  return status
489
489
 
490
+ def connected_components(self) -> List[Self]:
491
+ r"""Extracts connected components of the heterogeneous graph using
492
+ a union-find algorithm. The components are returned as a list of
493
+ :class:`~torch_geometric.data.HeteroData` objects.
494
+
495
+ .. code-block::
496
+
497
+ data = HeteroData()
498
+ data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
499
+ data["blue"].x = torch.tensor([[5.0], [6.0]])
500
+ data["red", "to", "red"].edge_index = torch.tensor(
501
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
502
+ )
503
+
504
+ components = data.connected_components()
505
+ print(len(components))
506
+ >>> 4
507
+
508
+ print(components[0])
509
+ >>> HeteroData(
510
+ red={x: tensor([[1.], [2.]])},
511
+ blue={x: tensor([[]])},
512
+ red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
513
+ )
514
+
515
+ Returns:
516
+ List[HeteroData]: A list of connected components.
517
+ """
518
+ # Initialize union-find structures
519
+ self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
520
+ self._ranks: Dict[Tuple[str, int], int] = {}
521
+
522
+ # Union-Find algorithm to find connected components
523
+ for edge_type in self.edge_types:
524
+ src, _, dst = edge_type
525
+ edge_index = self[edge_type].edge_index
526
+ for src_node, dst_node in edge_index.t().tolist():
527
+ self._union((src, src_node), (dst, dst_node))
528
+
529
+ # Rerun _find_parent to ensure all nodes are covered correctly
530
+ for node_type in self.node_types:
531
+ for node_index in range(self[node_type].num_nodes):
532
+ self._find_parent((node_type, node_index))
533
+
534
+ # Group nodes by their representative parent
535
+ components_map = defaultdict(list)
536
+ for node, parent in self._parents.items():
537
+ components_map[parent].append(node)
538
+ del self._parents
539
+ del self._ranks
540
+
541
+ components: List[Self] = []
542
+ for nodes in components_map.values():
543
+ # Prefill subset_dict with all node types to ensure all are present
544
+ subset_dict = {node_type: [] for node_type in self.node_types}
545
+
546
+ # Convert the list of (node_type, node_id) tuples to a subset_dict
547
+ for node_type, node_id in nodes:
548
+ subset_dict[node_type].append(node_id)
549
+
550
+ # Convert lists to tensors
551
+ for node_type, node_ids in subset_dict.items():
552
+ subset_dict[node_type] = torch.tensor(node_ids,
553
+ dtype=torch.long)
554
+
555
+ # Use the existing subgraph function to do all the heavy lifting
556
+ component_data = self.subgraph(subset_dict)
557
+ components.append(component_data)
558
+
559
+ return components
560
+
490
561
  def debug(self):
491
562
  pass # TODO
492
563
 
@@ -1148,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
1148
1219
 
1149
1220
  return list(edge_attrs.values())
1150
1221
 
1222
+ # Connected Components Helper Functions ###################################
1223
+
1224
+ def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
1225
+ r"""Finds and returns the representative parent of the given node in a
1226
+ disjoint-set (union-find) data structure. Implements path compression
1227
+ to optimize future queries.
1228
+
1229
+ Args:
1230
+ node (tuple[str, int]): The node for which to find the parent.
1231
+ First element is the node type, second is the node index.
1232
+
1233
+ Returns:
1234
+ tuple[str, int]: The representative parent of the node.
1235
+ """
1236
+ if node not in self._parents:
1237
+ self._parents[node] = node
1238
+ self._ranks[node] = 0
1239
+ if self._parents[node] != node:
1240
+ self._parents[node] = self._find_parent(self._parents[node])
1241
+ return self._parents[node]
1242
+
1243
+ def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
1244
+ r"""Merges the node1 and node2 in the disjoint-set data structure.
1245
+
1246
+ Finds the root parents of node1 and node2 using the _find_parent
1247
+ method. If they belong to different sets, updates the parent of
1248
+ root2 to be root1, effectively merging the two sets.
1249
+
1250
+ Args:
1251
+ node1 (Tuple[str, int]): The first node to union. First element is
1252
+ the node type, second is the node index.
1253
+ node2 (Tuple[str, int]): The second node to union. First element is
1254
+ the node type, second is the node index.
1255
+ """
1256
+ root1 = self._find_parent(node1)
1257
+ root2 = self._find_parent(node2)
1258
+ if root1 != root2:
1259
+ if self._ranks[root1] < self._ranks[root2]:
1260
+ self._parents[root1] = root2
1261
+ elif self._ranks[root1] > self._ranks[root2]:
1262
+ self._parents[root2] = root1
1263
+ else:
1264
+ self._parents[root2] = root1
1265
+ self._ranks[root1] += 1
1266
+
1151
1267
 
1152
1268
  # Helper functions ############################################################
1153
1269
 
@@ -30,8 +30,8 @@ class AirfRANS(InMemoryDataset):
30
30
  divided by the specific mass (one component in meter squared per second
31
31
  squared), the turbulent kinematic viscosity (one component in meter squared
32
32
  per second).
33
- Finaly, a boolean is attached to each point to inform if this point lies on
34
- the airfoil or not.
33
+ Finally, a boolean is attached to each point to inform if this point lies
34
+ on the airfoil or not.
35
35
 
36
36
  A library for manipulating simulations of the dataset is available `here
37
37
  <https://airfrans.readthedocs.io/en/latest/index.html>`_.
@@ -361,7 +361,7 @@ class Partitioner:
361
361
  'edge_types': self.edge_types,
362
362
  'node_offset': list(node_offset.values()) if node_offset else None,
363
363
  'is_hetero': self.is_hetero,
364
- 'is_sorted': True, # Based on columnn/destination.
364
+ 'is_sorted': True, # Based on column/destination.
365
365
  }
366
366
  with open(osp.join(self.root, 'META.json'), 'w') as f:
367
367
  json.dump(meta, f)
@@ -235,9 +235,9 @@ class ClusterData(torch.utils.data.Dataset):
235
235
  class ClusterLoader(torch.utils.data.DataLoader):
236
236
  r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
237
237
  for Training Deep and Large Graph Convolutional Networks"
238
- <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
239
- and their between-cluster links from a large-scale graph data object to
240
- form a mini-batch.
238
+ <https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned
239
+ subgraphs and their between-cluster links from a large-scale graph data
240
+ object to form a mini-batch.
241
241
 
242
242
  .. note::
243
243
 
@@ -252,7 +252,7 @@ class ClusterLoader(torch.utils.data.DataLoader):
252
252
 
253
253
  Args:
254
254
  cluster_data (torch_geometric.loader.ClusterData): The already
255
- partioned data object.
255
+ partitioned data object.
256
256
  **kwargs (optional): Additional arguments of
257
257
  :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
258
258
  :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
@@ -64,7 +64,7 @@ class MeshCNNConv(MessagePassing):
64
64
  :math:`\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)`
65
65
 
66
66
 
67
- Because of this ordering constrait, :obj:`MeshCNNConv` **requires
67
+ Because of this ordering constraint, :obj:`MeshCNNConv` **requires
68
68
  that the columns of** :math:`A`
69
69
  **be ordered in the following way**:
70
70
 
@@ -149,7 +149,7 @@ class MeshCNNConv(MessagePassing):
149
149
 
150
150
 
151
151
  Args:
152
- in_channels (int): Corresonds to :math:`\text{Dim-Out}(k)`
152
+ in_channels (int): Corresponds to :math:`\text{Dim-Out}(k)`
153
153
  in the above overview. This
154
154
  represents the output dimension of the prior layer. For the given
155
155
  input mesh :math:`\mathcal{m} = (V, F)`, the prior layer is
@@ -184,7 +184,7 @@ class MeshCNNConv(MessagePassing):
184
184
  a vector of dimensions :attr:`out_channels`.
185
185
 
186
186
  Discussion:
187
- The key difference that seperates :obj:`MeshCNNConv` from a traditional
187
+ The key difference that separates :obj:`MeshCNNConv` from a traditional
188
188
  message passing graph neural network is that :obj:`MeshCNNConv`
189
189
  requires the set of neighbors for a node
190
190
  :math:`\mathcal{N}(u) = (v_1, v_2, ...)`
@@ -198,7 +198,7 @@ class MeshCNNConv(MessagePassing):
198
198
  :math:`\mathbb{S}_4`. Put more plainly, in tradition message passing
199
199
  GNNs, the network is *unable* to distinguish one neighboring node
200
200
  from another.
201
- In constrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a
201
+ In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a
202
202
  "role", either the "a", "b", "c", or "d" neighbor. We encode this fact
203
203
  by requiring that :math:`\mathcal{N}` return the 4-tuple,
204
204
  where the first component is the "a" neighbor, and so on.
@@ -444,7 +444,7 @@ class MeshCNNConv(MessagePassing):
444
444
  """
445
445
  assert isinstance(kernels, ModuleList), \
446
446
  f"Parameter 'kernels' must be a \
447
- torch.nn.module.ModuleList with 5 memebers, but we got \
447
+ torch.nn.module.ModuleList with 5 members, but we got \
448
448
  {type(kernels)}."
449
449
 
450
450
  assert len(kernels) == 5, "Parameter 'kernels' must be a \
@@ -37,7 +37,7 @@ class Polynormer(torch.nn.Module):
37
37
  (default: :obj:`True`)
38
38
  pre_ln (bool): Pre layer normalization.
39
39
  (default: :obj:`False`)
40
- post_bn (bool): Post batch normlization.
40
+ post_bn (bool): Post batch normalization.
41
41
  (default: :obj:`True`)
42
42
  local_attn (bool): Whether use local attention.
43
43
  (default: :obj:`False`)
@@ -196,8 +196,8 @@ class InvertibleModule(torch.nn.Module):
196
196
  class GroupAddRev(InvertibleModule):
197
197
  r"""The Grouped Reversible GNN module from the `"Graph Neural Networks with
198
198
  1000 Layers" <https://arxiv.org/abs/2106.07476>`_ paper.
199
- This module enables training of arbitary deep GNNs with a memory complexity
200
- independent of the number of layers.
199
+ This module enables training of arbitrary deep GNNs with a memory
200
+ complexity independent of the number of layers.
201
201
 
202
202
  It does so by partitioning input node features :math:`\mathbf{X}` into
203
203
  :math:`C` groups across the feature dimension. Then, a grouped reversible
@@ -159,8 +159,8 @@ def jacobian_l1_agg_per_hop(
159
159
  vectorize=vectorize)
160
160
  hop_subsets = k_hop_subsets_exact(node_idx, max_hops, edge_index,
161
161
  num_nodes, influence.device)
162
- sigle_node_influence_per_hop = [influence[s].sum() for s in hop_subsets]
163
- return torch.tensor(sigle_node_influence_per_hop, device=influence.device)
162
+ single_node_influence_per_hop = [influence[s].sum() for s in hop_subsets]
163
+ return torch.tensor(single_node_influence_per_hop, device=influence.device)
164
164
 
165
165
 
166
166
  def avg_total_influence(
@@ -169,7 +169,7 @@ def avg_total_influence(
169
169
  ) -> Tensor:
170
170
  """Compute the *influence‑weighted receptive field* ``R``."""
171
171
  avg_total_influences = torch.mean(influence_all_nodes, dim=0)
172
- if normalize: # nomalize by hop_0 (jacobian of the center node feature)
172
+ if normalize: # normalize by hop_0 (jacobian of the center node feature)
173
173
  avg_total_influences = avg_total_influences / avg_total_influences[0]
174
174
  return avg_total_influences
175
175