pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -1,4 +1,4 @@
1
- import pickle
1
+ import io
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass
@@ -111,13 +111,17 @@ class Database(ABC):
111
111
  for key, value in schema_dict.items()
112
112
  }
113
113
 
114
+ @abstractmethod
114
115
  def connect(self) -> None:
115
116
  r"""Connects to the database.
116
117
  Databases will automatically connect on instantiation.
117
118
  """
119
+ raise NotImplementedError
118
120
 
121
+ @abstractmethod
119
122
  def close(self) -> None:
120
123
  r"""Closes the connection to the database."""
124
+ raise NotImplementedError
121
125
 
122
126
  @abstractmethod
123
127
  def insert(self, index: int, data: Any) -> None:
@@ -496,7 +500,9 @@ class SQLiteDatabase(Database):
496
500
  out.append(col)
497
501
 
498
502
  else:
499
- out.append(pickle.dumps(col))
503
+ buffer = io.BytesIO()
504
+ torch.save(col, buffer)
505
+ out.append(buffer.getvalue())
500
506
 
501
507
  return out
502
508
 
@@ -559,7 +565,10 @@ class SQLiteDatabase(Database):
559
565
  out_dict[key] = value
560
566
 
561
567
  else:
562
- out_dict[key] = pickle.loads(value)
568
+ out_dict[key] = torch.load(
569
+ io.BytesIO(value),
570
+ weights_only=False,
571
+ )
563
572
 
564
573
  # In case `0` exists as integer in the schema, this means that the
565
574
  # schema was passed as either a single entry or a tuple:
@@ -644,7 +653,12 @@ class RocksDatabase(Database):
644
653
  # Ensure that data is not a view of a larger tensor:
645
654
  if isinstance(row, Tensor):
646
655
  row = row.clone()
647
- return pickle.dumps(row)
656
+ buffer = io.BytesIO()
657
+ torch.save(row, buffer)
658
+ return buffer.getvalue()
648
659
 
649
660
  def _deserialize(self, row: bytes) -> Any:
650
- return pickle.loads(row)
661
+ return torch.load(
662
+ io.BytesIO(row),
663
+ weights_only=False,
664
+ )
@@ -166,10 +166,11 @@ class Dataset(torch.utils.data.Dataset):
166
166
  elif y.numel() == y.size(0) and torch.is_floating_point(y):
167
167
  num_classes = torch.unique(y).numel()
168
168
  if num_classes > 2:
169
- warnings.warn("Found floating-point labels while calling "
170
- "`dataset.num_classes`. Returning the number of "
171
- "unique elements. Please make sure that this "
172
- "is expected before proceeding.")
169
+ warnings.warn(
170
+ "Found floating-point labels while calling "
171
+ "`dataset.num_classes`. Returning the number of "
172
+ "unique elements. Please make sure that this "
173
+ "is expected before proceeding.", stacklevel=2)
173
174
  return num_classes
174
175
  else:
175
176
  return y.size(-1)
@@ -235,20 +236,24 @@ class Dataset(torch.utils.data.Dataset):
235
236
 
236
237
  def _process(self):
237
238
  f = osp.join(self.processed_dir, 'pre_transform.pt')
238
- if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
239
+ if not self.force_reload and osp.exists(f) and torch.load(
240
+ f, weights_only=False) != _repr(self.pre_transform):
239
241
  warnings.warn(
240
242
  "The `pre_transform` argument differs from the one used in "
241
243
  "the pre-processed version of this dataset. If you want to "
242
244
  "make use of another pre-processing technique, pass "
243
- "`force_reload=True` explicitly to reload the dataset.")
245
+ "`force_reload=True` explicitly to reload the dataset.",
246
+ stacklevel=2)
244
247
 
245
248
  f = osp.join(self.processed_dir, 'pre_filter.pt')
246
- if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
249
+ if not self.force_reload and osp.exists(f) and torch.load(
250
+ f, weights_only=False) != _repr(self.pre_filter):
247
251
  warnings.warn(
248
252
  "The `pre_filter` argument differs from the one used in "
249
253
  "the pre-processed version of this dataset. If you want to "
250
254
  "make use of another pre-fitering technique, pass "
251
- "`force_reload=True` explicitly to reload the dataset.")
255
+ "`force_reload=True` explicitly to reload the dataset.",
256
+ stacklevel=2)
252
257
 
253
258
  if not self.force_reload and files_exist(self.processed_paths):
254
259
  return
@@ -381,7 +386,7 @@ class Dataset(torch.utils.data.Dataset):
381
386
  r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
382
387
 
383
388
  The returned instance can then be used with :pyg:`PyG's` built-in
384
- :class:`DataPipes` for baching graphs as follows:
389
+ :class:`DataPipes` for batching graphs as follows:
385
390
 
386
391
  .. code-block:: python
387
392
 
@@ -28,7 +28,7 @@ def extract_tar(
28
28
  """
29
29
  maybe_log(path, log)
30
30
  with tarfile.open(path, mode) as f:
31
- f.extractall(folder)
31
+ f.extractall(folder, filter='data')
32
32
 
33
33
 
34
34
  def extract_zip(path: str, folder: str, log: bool = True) -> None:
@@ -11,7 +11,7 @@ This particular feature store abstraction makes a few key assumptions:
11
11
  * A feature can be uniquely identified from any associated attributes specified
12
12
  in `TensorAttr`.
13
13
 
14
- It is the job of a feature store implementor class to handle these assumptions
14
+ It is the job of a feature store implementer class to handle these assumptions
15
15
  properly. For example, a simple in-memory feature store implementation may
16
16
  concatenate all metadata values with a feature index and use this as a unique
17
17
  index in a KV store. More complicated implementations may choose to partition
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
74
74
  r"""Whether the :obj:`TensorAttr` has no unset fields."""
75
75
  return all([self.is_set(key) for key in self.__dataclass_fields__])
76
76
 
77
- def fully_specify(self) -> 'TensorAttr':
78
- r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
79
- for key in self.__dataclass_fields__:
80
- if not self.is_set(key):
81
- setattr(self, key, None)
82
- return self
83
-
84
77
  def update(self, attr: 'TensorAttr') -> 'TensorAttr':
85
78
  r"""Updates an :class:`TensorAttr` with set attributes from another
86
79
  :class:`TensorAttr`.
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
230
223
 
231
224
  store[group_name, attr_name]()
232
225
  """
233
- # Set all UNSET values to None:
234
- out = copy.copy(self)
235
- out._attr.fully_specify()
236
- return out._store.get_tensor(out._attr)
226
+ attr = copy.copy(self._attr)
227
+ for key in attr.__dataclass_fields__: # Set all UNSET values to None.
228
+ if not attr.is_set(key):
229
+ setattr(attr, key, None)
230
+ return self._store.get_tensor(attr)
237
231
 
238
232
  def __copy__(self) -> 'AttrView':
239
233
  out = self.__class__.__new__(self.__class__)
@@ -358,7 +352,7 @@ class FeatureStore(ABC):
358
352
 
359
353
  .. note::
360
354
  The default implementation simply iterates over all calls to
361
- :meth:`get_tensor`. Implementor classes that can provide
355
+ :meth:`get_tensor`. Implementer classes that can provide
362
356
  additional, more performant functionality are recommended to
363
357
  to override this method.
364
358
 
@@ -415,10 +409,10 @@ class FeatureStore(ABC):
415
409
  def update_tensor(self, tensor: FeatureTensorType, *args,
416
410
  **kwargs) -> bool:
417
411
  r"""Updates a :obj:`tensor` in the :class:`FeatureStore` with a new
418
- value. Returns whether the update was succesful.
412
+ value. Returns whether the update was successful.
419
413
 
420
414
  .. note::
421
- Implementor classes can choose to define more efficient update
415
+ Implementer classes can choose to define more efficient update
422
416
  methods; the default performs a removal and insertion.
423
417
 
424
418
  Args:
@@ -479,9 +473,7 @@ class FeatureStore(ABC):
479
473
  # CastMixin will handle the case of key being a tuple or TensorAttr
480
474
  # object:
481
475
  key = self._tensor_attr_cls.cast(key)
482
- # We need to fully-specify the key for __setitem__ as it does not make
483
- # sense to work with a view here:
484
- key.fully_specify()
476
+ assert key.is_fully_specified()
485
477
  self.put_tensor(value, key)
486
478
 
487
479
  def __getitem__(self, key: TensorAttr) -> Any:
@@ -503,13 +495,16 @@ class FeatureStore(ABC):
503
495
  # If the view is not fully-specified, return a :class:`AttrView`:
504
496
  return self.view(attr)
505
497
 
506
- def __delitem__(self, key: TensorAttr):
498
+ def __delitem__(self, attr: TensorAttr):
507
499
  r"""Supports :obj:`del store[tensor_attr]`."""
508
500
  # CastMixin will handle the case of key being a tuple or TensorAttr
509
501
  # object:
510
- key = self._tensor_attr_cls.cast(key)
511
- key.fully_specify()
512
- self.remove_tensor(key)
502
+ attr = self._tensor_attr_cls.cast(attr)
503
+ attr = copy.copy(attr)
504
+ for key in attr.__dataclass_fields__: # Set all UNSET values to None.
505
+ if not attr.is_set(key):
506
+ setattr(attr, key, None)
507
+ self.remove_tensor(attr)
513
508
 
514
509
  def __iter__(self):
515
510
  raise NotImplementedError
@@ -10,7 +10,7 @@ This particular graph store abstraction makes a few key assumptions:
10
10
  support dynamic modification of edge indices once they have been inserted
11
11
  into the graph store.
12
12
 
13
- It is the job of a graph store implementor class to handle these assumptions
13
+ It is the job of a graph store implementer class to handle these assumptions
14
14
  properly. For example, a simple in-memory graph store implementation may
15
15
  concatenate all metadata values with an edge index and use this as a unique
16
16
  index in a KV store. More complicated implementations may choose to partition
@@ -261,7 +261,8 @@ class GraphStore(ABC):
261
261
  col = ptr2index(col)
262
262
 
263
263
  if attr.layout != EdgeLayout.CSR: # COO->CSR
264
- num_rows = attr.size[0] if attr.size else int(row.max()) + 1
264
+ num_rows = attr.size[0] if attr.size is not None else int(
265
+ row.max()) + 1
265
266
  row, perm = index_sort(row, max_value=num_rows)
266
267
  col = col[perm]
267
268
  row = index2ptr(row, num_rows)
@@ -282,6 +282,21 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
282
282
  r"""Returns a list of edge type and edge storage pairs."""
283
283
  return list(self._edge_store_dict.items())
284
284
 
285
+ @property
286
+ def input_type(self) -> Optional[Union[NodeType, EdgeType]]:
287
+ r"""Returns the seed/input node/edge type of the graph in case it
288
+ refers to a sampled subgraph, *e.g.*, obtained via
289
+ :class:`~torch_geometric.loader.NeighborLoader` or
290
+ :class:`~torch_geometric.loader.LinkNeighborLoader`.
291
+ """
292
+ for node_type, store in self.node_items():
293
+ if hasattr(store, 'input_id'):
294
+ return node_type
295
+ for edge_type, store in self.edge_items():
296
+ if hasattr(store, 'input_id'):
297
+ return edge_type
298
+ return None
299
+
285
300
  def to_dict(self) -> Dict[str, Any]:
286
301
  out_dict: Dict[str, Any] = {}
287
302
  out_dict['_global_store'] = self._global_store.to_dict()
@@ -472,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
472
487
 
473
488
  return status
474
489
 
490
+ def connected_components(self) -> List[Self]:
491
+ r"""Extracts connected components of the heterogeneous graph using
492
+ a union-find algorithm. The components are returned as a list of
493
+ :class:`~torch_geometric.data.HeteroData` objects.
494
+
495
+ .. code-block::
496
+
497
+ data = HeteroData()
498
+ data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
499
+ data["blue"].x = torch.tensor([[5.0], [6.0]])
500
+ data["red", "to", "red"].edge_index = torch.tensor(
501
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
502
+ )
503
+
504
+ components = data.connected_components()
505
+ print(len(components))
506
+ >>> 4
507
+
508
+ print(components[0])
509
+ >>> HeteroData(
510
+ red={x: tensor([[1.], [2.]])},
511
+ blue={x: tensor([[]])},
512
+ red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
513
+ )
514
+
515
+ Returns:
516
+ List[HeteroData]: A list of connected components.
517
+ """
518
+ # Initialize union-find structures
519
+ self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
520
+ self._ranks: Dict[Tuple[str, int], int] = {}
521
+
522
+ # Union-Find algorithm to find connected components
523
+ for edge_type in self.edge_types:
524
+ src, _, dst = edge_type
525
+ edge_index = self[edge_type].edge_index
526
+ for src_node, dst_node in edge_index.t().tolist():
527
+ self._union((src, src_node), (dst, dst_node))
528
+
529
+ # Rerun _find_parent to ensure all nodes are covered correctly
530
+ for node_type in self.node_types:
531
+ for node_index in range(self[node_type].num_nodes):
532
+ self._find_parent((node_type, node_index))
533
+
534
+ # Group nodes by their representative parent
535
+ components_map = defaultdict(list)
536
+ for node, parent in self._parents.items():
537
+ components_map[parent].append(node)
538
+ del self._parents
539
+ del self._ranks
540
+
541
+ components: List[Self] = []
542
+ for nodes in components_map.values():
543
+ # Prefill subset_dict with all node types to ensure all are present
544
+ subset_dict = {node_type: [] for node_type in self.node_types}
545
+
546
+ # Convert the list of (node_type, node_id) tuples to a subset_dict
547
+ for node_type, node_id in nodes:
548
+ subset_dict[node_type].append(node_id)
549
+
550
+ # Convert lists to tensors
551
+ for node_type, node_ids in subset_dict.items():
552
+ subset_dict[node_type] = torch.tensor(node_ids,
553
+ dtype=torch.long)
554
+
555
+ # Use the existing subgraph function to do all the heavy lifting
556
+ component_data = self.subgraph(subset_dict)
557
+ components.append(component_data)
558
+
559
+ return components
560
+
475
561
  def debug(self):
476
562
  pass # TODO
477
563
 
@@ -551,7 +637,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
551
637
  This is equivalent to writing :obj:`data.x_dict`.
552
638
 
553
639
  Args:
554
- key (str): The attribute to collect from all node and ege types.
640
+ key (str): The attribute to collect from all node and edge types.
555
641
  allow_empty (bool, optional): If set to :obj:`True`, will not raise
556
642
  an error in case the attribute does not exit in any node or
557
643
  edge type. (default: :obj:`False`)
@@ -570,12 +656,13 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
570
656
  global _DISPLAYED_TYPE_NAME_WARNING
571
657
  if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:
572
658
  _DISPLAYED_TYPE_NAME_WARNING = True
573
- warnings.warn(f"There exist type names in the "
574
- f"'{self.__class__.__name__}' object that contain "
575
- f"double underscores '__' (e.g., '{name}'). This "
576
- f"may lead to unexpected behavior. To avoid any "
577
- f"issues, ensure that your type names only contain "
578
- f"single underscores.")
659
+ warnings.warn(
660
+ f"There exist type names in the "
661
+ f"'{self.__class__.__name__}' object that contain "
662
+ f"double underscores '__' (e.g., '{name}'). This "
663
+ f"may lead to unexpected behavior. To avoid any "
664
+ f"issues, ensure that your type names only contain "
665
+ f"single underscores.", stacklevel=2)
579
666
 
580
667
  def get_node_store(self, key: NodeType) -> NodeStorage:
581
668
  r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object
@@ -1132,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
1132
1219
 
1133
1220
  return list(edge_attrs.values())
1134
1221
 
1222
+ # Connected Components Helper Functions ###################################
1223
+
1224
+ def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
1225
+ r"""Finds and returns the representative parent of the given node in a
1226
+ disjoint-set (union-find) data structure. Implements path compression
1227
+ to optimize future queries.
1228
+
1229
+ Args:
1230
+ node (tuple[str, int]): The node for which to find the parent.
1231
+ First element is the node type, second is the node index.
1232
+
1233
+ Returns:
1234
+ tuple[str, int]: The representative parent of the node.
1235
+ """
1236
+ if node not in self._parents:
1237
+ self._parents[node] = node
1238
+ self._ranks[node] = 0
1239
+ if self._parents[node] != node:
1240
+ self._parents[node] = self._find_parent(self._parents[node])
1241
+ return self._parents[node]
1242
+
1243
+ def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
1244
+ r"""Merges the node1 and node2 in the disjoint-set data structure.
1245
+
1246
+ Finds the root parents of node1 and node2 using the _find_parent
1247
+ method. If they belong to different sets, updates the parent of
1248
+ root2 to be root1, effectively merging the two sets.
1249
+
1250
+ Args:
1251
+ node1 (Tuple[str, int]): The first node to union. First element is
1252
+ the node type, second is the node index.
1253
+ node2 (Tuple[str, int]): The second node to union. First element is
1254
+ the node type, second is the node index.
1255
+ """
1256
+ root1 = self._find_parent(node1)
1257
+ root2 = self._find_parent(node2)
1258
+ if root1 != root2:
1259
+ if self._ranks[root1] < self._ranks[root2]:
1260
+ self._parents[root1] = root2
1261
+ elif self._ranks[root1] > self._ranks[root2]:
1262
+ self._parents[root2] = root1
1263
+ else:
1264
+ self._parents[root2] = root1
1265
+ self._ranks[root1] += 1
1266
+
1135
1267
 
1136
1268
  # Helper functions ############################################################
1137
1269
 
@@ -39,7 +39,7 @@ class HyperGraphData(Data):
39
39
  edge_index (LongTensor, optional): Hyperedge tensor
40
40
  with shape :obj:`[2, num_edges*num_nodes_per_edge]`.
41
41
  Where `edge_index[1]` denotes the hyperedge index and
42
- `edge_index[0]` denotes the node indicies that are connected
42
+ `edge_index[0]` denotes the node indices that are connected
43
43
  by the hyperedge. (default: :obj:`None`)
44
44
  (default: :obj:`None`)
45
45
  edge_attr (torch.Tensor, optional): Edge feature matrix with shape
@@ -223,4 +223,4 @@ def warn_or_raise(msg: str, raise_on_error: bool = True) -> None:
223
223
  if raise_on_error:
224
224
  raise ValueError(msg)
225
225
  else:
226
- warnings.warn(msg)
226
+ warnings.warn(msg, stacklevel=2)
@@ -297,7 +297,7 @@ class InMemoryDataset(Dataset):
297
297
  self._data_list = None
298
298
  msg += f' {msg4}'
299
299
 
300
- warnings.warn(msg)
300
+ warnings.warn(msg, stacklevel=2)
301
301
 
302
302
  return self._data
303
303
 
@@ -346,7 +346,7 @@ class InMemoryDataset(Dataset):
346
346
 
347
347
  def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
348
348
  if isinstance(node, Mapping):
349
- for key, value in node.items():
349
+ for value in node.values():
350
350
  yield from nested_iter(value)
351
351
  elif isinstance(node, Sequence):
352
352
  yield from enumerate(node)
@@ -11,21 +11,27 @@ from torch_geometric.sampler import BaseSampler, NeighborSampler
11
11
  from torch_geometric.typing import InputEdges, InputNodes, OptTensor
12
12
 
13
13
  try:
14
- from pytorch_lightning import LightningDataModule as PLLightningDataModule
15
- no_pytorch_lightning = False
14
+ from lightning.pytorch import LightningDataModule as _LightningDataModule
15
+ _pl_is_available = True
16
16
  except ImportError:
17
- PLLightningDataModule = object # type: ignore
18
- no_pytorch_lightning = True
17
+ try:
18
+ from pytorch_lightning import \
19
+ LightningDataModule as _LightningDataModule
20
+ _pl_is_available = True
21
+ except ImportError:
22
+ _pl_is_available = False
23
+ _LightningDataModule = object
19
24
 
20
25
 
21
- class LightningDataModule(PLLightningDataModule):
26
+ class LightningDataModule(_LightningDataModule):
22
27
  def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None:
23
28
  super().__init__()
24
29
 
25
- if no_pytorch_lightning:
30
+ if not _pl_is_available:
26
31
  raise ModuleNotFoundError(
27
- "No module named 'pytorch_lightning' found on this machine. "
28
- "Run 'pip install pytorch_lightning' to install the library.")
32
+ "No module named 'pytorch_lightning' (or 'lightning') found "
33
+ "in your Python environment. Run 'pip install "
34
+ "pytorch_lightning' or 'pip install lightning'")
29
35
 
30
36
  if not has_val:
31
37
  self.val_dataloader = None # type: ignore
@@ -40,9 +46,11 @@ class LightningDataModule(PLLightningDataModule):
40
46
  kwargs.get('num_workers', 0) > 0)
41
47
 
42
48
  if 'shuffle' in kwargs:
43
- warnings.warn(f"The 'shuffle={kwargs['shuffle']}' option is "
44
- f"ignored in '{self.__class__.__name__}'. Remove it "
45
- f"from the argument list to disable this warning")
49
+ warnings.warn(
50
+ f"The 'shuffle={kwargs['shuffle']}' option is "
51
+ f"ignored in '{self.__class__.__name__}'. Remove it "
52
+ f"from the argument list to disable this warning",
53
+ stacklevel=2)
46
54
  del kwargs['shuffle']
47
55
 
48
56
  self.kwargs = kwargs
@@ -74,34 +82,39 @@ class LightningData(LightningDataModule):
74
82
  raise ValueError(f"Undefined 'loader' option (got '{loader}')")
75
83
 
76
84
  if loader == 'full' and kwargs['batch_size'] != 1:
77
- warnings.warn(f"Re-setting 'batch_size' to 1 in "
78
- f"'{self.__class__.__name__}' for loader='full' "
79
- f"(got '{kwargs['batch_size']}')")
85
+ warnings.warn(
86
+ f"Re-setting 'batch_size' to 1 in "
87
+ f"'{self.__class__.__name__}' for loader='full' "
88
+ f"(got '{kwargs['batch_size']}')", stacklevel=2)
80
89
  kwargs['batch_size'] = 1
81
90
 
82
91
  if loader == 'full' and kwargs['num_workers'] != 0:
83
- warnings.warn(f"Re-setting 'num_workers' to 0 in "
84
- f"'{self.__class__.__name__}' for loader='full' "
85
- f"(got '{kwargs['num_workers']}')")
92
+ warnings.warn(
93
+ f"Re-setting 'num_workers' to 0 in "
94
+ f"'{self.__class__.__name__}' for loader='full' "
95
+ f"(got '{kwargs['num_workers']}')", stacklevel=2)
86
96
  kwargs['num_workers'] = 0
87
97
 
88
98
  if loader == 'full' and kwargs.get('sampler') is not None:
89
- warnings.warn("'sampler' option is not supported for "
90
- "loader='full'")
99
+ warnings.warn(
100
+ "'sampler' option is not supported for "
101
+ "loader='full'", stacklevel=2)
91
102
  kwargs.pop('sampler', None)
92
103
 
93
104
  if loader == 'full' and kwargs.get('batch_sampler') is not None:
94
- warnings.warn("'batch_sampler' option is not supported for "
95
- "loader='full'")
105
+ warnings.warn(
106
+ "'batch_sampler' option is not supported for "
107
+ "loader='full'", stacklevel=2)
96
108
  kwargs.pop('batch_sampler', None)
97
109
 
98
110
  super().__init__(has_val, has_test, **kwargs)
99
111
 
100
112
  if loader == 'full':
101
113
  if kwargs.get('pin_memory', False):
102
- warnings.warn(f"Re-setting 'pin_memory' to 'False' in "
103
- f"'{self.__class__.__name__}' for loader='full' "
104
- f"(got 'True')")
114
+ warnings.warn(
115
+ f"Re-setting 'pin_memory' to 'False' in "
116
+ f"'{self.__class__.__name__}' for loader='full' "
117
+ f"(got 'True')", stacklevel=2)
105
118
  self.kwargs['pin_memory'] = False
106
119
 
107
120
  self.data = data
@@ -127,10 +140,11 @@ class LightningData(LightningDataModule):
127
140
  graph_sampler.__class__,
128
141
  )
129
142
  if len(sampler_kwargs) > 0:
130
- warnings.warn(f"Ignoring the arguments "
131
- f"{list(sampler_kwargs.keys())} in "
132
- f"'{self.__class__.__name__}' since a custom "
133
- f"'graph_sampler' was passed")
143
+ warnings.warn(
144
+ f"Ignoring the arguments "
145
+ f"{list(sampler_kwargs.keys())} in "
146
+ f"'{self.__class__.__name__}' since a custom "
147
+ f"'graph_sampler' was passed", stacklevel=2)
134
148
  self.graph_sampler = graph_sampler
135
149
 
136
150
  else:
@@ -454,7 +454,7 @@ class NodeStorage(BaseStorage):
454
454
  f"'{set(self.keys())}'. Please explicitly set 'num_nodes' as an "
455
455
  f"attribute of " +
456
456
  ("'data'" if self._key is None else f"'data[{self._key}]'") +
457
- " to suppress this warning")
457
+ " to suppress this warning", stacklevel=2)
458
458
  if 'edge_index' in self and isinstance(self.edge_index, Tensor):
459
459
  if self.edge_index.numel() > 0:
460
460
  return int(self.edge_index.max()) + 1
@@ -806,6 +806,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
806
806
  return False
807
807
 
808
808
  cat_dim = self._parent().__cat_dim__(key, value, self)
809
+
810
+ if not isinstance(cat_dim, int):
811
+ return False
812
+
809
813
  num_nodes, num_edges = self.num_nodes, self.num_edges
810
814
 
811
815
  if value.shape[cat_dim] != num_nodes:
@@ -852,6 +856,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
852
856
  return False
853
857
 
854
858
  cat_dim = self._parent().__cat_dim__(key, value, self)
859
+
860
+ if not isinstance(cat_dim, int):
861
+ return False
862
+
855
863
  num_nodes, num_edges = self.num_nodes, self.num_edges
856
864
 
857
865
  if value.shape[cat_dim] != num_edges:
@@ -30,6 +30,7 @@ from .faust import FAUST
30
30
  from .dynamic_faust import DynamicFAUST
31
31
  from .shapenet import ShapeNet
32
32
  from .modelnet import ModelNet
33
+ from .medshapenet import MedShapeNet
33
34
  from .coma import CoMA
34
35
  from .shrec2016 import SHREC2016
35
36
  from .tosca import TOSCA
@@ -61,7 +62,6 @@ from .gemsec import GemsecDeezer
61
62
  from .twitch import Twitch
62
63
  from .airports import Airports
63
64
  from .lrgb import LRGBDataset
64
- from .neurograph import NeuroGraphDataset
65
65
  from .malnet_tiny import MalNetTiny
66
66
  from .omdb import OMDB
67
67
  from .polblogs import PolBlogs
@@ -76,6 +76,15 @@ from .jodie import JODIEDataset
76
76
  from .wikidata import Wikidata5M
77
77
  from .myket import MyketDataset
78
78
  from .brca_tgca import BrcaTcga
79
+ from .neurograph import NeuroGraphDataset
80
+ from .web_qsp_dataset import WebQSPDataset, CWQDataset
81
+ from .git_mol_dataset import GitMolDataset
82
+ from .molecule_gpt_dataset import MoleculeGPTDataset
83
+ from .instruct_mol_dataset import InstructMolDataset
84
+ from .protein_mpnn_dataset import ProteinMPNNDataset
85
+ from .tag_dataset import TAGDataset
86
+ from .city import CityNetwork
87
+ from .teeth3ds import Teeth3DS
79
88
 
80
89
  from .dbp15k import DBP15K
81
90
  from .aminer import AMiner
@@ -141,6 +150,7 @@ homo_datasets = [
141
150
  'DynamicFAUST',
142
151
  'ShapeNet',
143
152
  'ModelNet',
153
+ 'MedShapeNet',
144
154
  'CoMA',
145
155
  'SHREC2016',
146
156
  'TOSCA',
@@ -188,6 +198,15 @@ homo_datasets = [
188
198
  'MyketDataset',
189
199
  'BrcaTcga',
190
200
  'NeuroGraphDataset',
201
+ 'WebQSPDataset',
202
+ 'CWQDataset',
203
+ 'GitMolDataset',
204
+ 'MoleculeGPTDataset',
205
+ 'InstructMolDataset',
206
+ 'ProteinMPNNDataset',
207
+ 'TAGDataset',
208
+ 'CityNetwork',
209
+ 'Teeth3DS',
191
210
  ]
192
211
 
193
212
  hetero_datasets = [