pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +13 -7
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +317 -65
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +3 -5
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +329 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +56 -22
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -70,8 +70,8 @@ class HeteroConv(torch.nn.Module):
70
70
  for edge_type, module in convs.items():
71
71
  check_add_self_loops(module, [edge_type])
72
72
 
73
- src_node_types = set([key[0] for key in convs.keys()])
74
- dst_node_types = set([key[-1] for key in convs.keys()])
73
+ src_node_types = {key[0] for key in convs.keys()}
74
+ dst_node_types = {key[-1] for key in convs.keys()}
75
75
  if len(src_node_types - dst_node_types) > 0:
76
76
  warnings.warn(
77
77
  f"There exist node types ({src_node_types - dst_node_types}) "
@@ -102,7 +102,7 @@ class HeteroConv(torch.nn.Module):
102
102
  individual edge type, either as a :class:`torch.Tensor` of
103
103
  shape :obj:`[2, num_edges]` or a
104
104
  :class:`torch_sparse.SparseTensor`.
105
- *args_dict (optional): Additional forward arguments of invididual
105
+ *args_dict (optional): Additional forward arguments of individual
106
106
  :class:`torch_geometric.nn.conv.MessagePassing` layers.
107
107
  **kwargs_dict (optional): Additional forward arguments of
108
108
  individual :class:`torch_geometric.nn.conv.MessagePassing`
@@ -67,7 +67,7 @@ class HGTConv(MessagePassing):
67
67
  for i, edge_type in enumerate(metadata[1])
68
68
  }
69
69
 
70
- self.dst_node_types = set([key[-1] for key in self.edge_types])
70
+ self.dst_node_types = {key[-1] for key in self.edge_types}
71
71
 
72
72
  self.kqv_lin = HeteroDictLinear(self.in_channels,
73
73
  self.out_channels * 3)
@@ -6,6 +6,7 @@ from typing import (
6
6
  Any,
7
7
  Callable,
8
8
  Dict,
9
+ Final,
9
10
  List,
10
11
  Optional,
11
12
  OrderedDict,
@@ -19,6 +20,7 @@ from torch import Tensor
19
20
  from torch.utils.hooks import RemovableHandle
20
21
 
21
22
  from torch_geometric import EdgeIndex, is_compiling
23
+ from torch_geometric.index import ptr2index
22
24
  from torch_geometric.inspector import Inspector, Signature
23
25
  from torch_geometric.nn.aggr import Aggregation
24
26
  from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
@@ -29,7 +31,6 @@ from torch_geometric.utils import (
29
31
  is_torch_sparse_tensor,
30
32
  to_edge_index,
31
33
  )
32
- from torch_geometric.utils.sparse import ptr2index
33
34
 
34
35
  FUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'}
35
36
  HookDict = OrderedDict[int, Callable]
@@ -102,6 +103,10 @@ class MessagePassing(torch.nn.Module):
102
103
  'size_i', 'size_j', 'ptr', 'index', 'dim_size'
103
104
  }
104
105
 
106
+ # Supports `message_and_aggregate` via `EdgeIndex`.
107
+ # TODO Remove once migration is finished.
108
+ SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = False
109
+
105
110
  def __init__(
106
111
  self,
107
112
  aggr: Optional[Union[str, List[str], Aggregation]] = 'sum',
@@ -162,63 +167,8 @@ class MessagePassing(torch.nn.Module):
162
167
  self._edge_update_forward_pre_hooks: HookDict = OrderedDict()
163
168
  self._edge_update_forward_hooks: HookDict = OrderedDict()
164
169
 
165
- root_dir = osp.dirname(osp.realpath(__file__))
166
- jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
167
- # Optimize `propagate()` via `*.jinja` templates:
168
- if not self.propagate.__module__.startswith(jinja_prefix):
169
- try:
170
- module = module_from_template(
171
- module_name=f'{jinja_prefix}_propagate',
172
- template_path=osp.join(root_dir, 'propagate.jinja'),
173
- tmp_dirname='message_passing',
174
- # Keyword arguments:
175
- modules=self.inspector._modules,
176
- collect_name='collect',
177
- signature=self._get_propagate_signature(),
178
- collect_param_dict=self.inspector.get_flat_param_dict(
179
- ['message', 'aggregate', 'update']),
180
- message_args=self.inspector.get_param_names('message'),
181
- aggregate_args=self.inspector.get_param_names('aggregate'),
182
- message_and_aggregate_args=self.inspector.get_param_names(
183
- 'message_and_aggregate'),
184
- update_args=self.inspector.get_param_names('update'),
185
- fuse=self.fuse,
186
- )
187
-
188
- self.__class__._orig_propagate = self.__class__.propagate
189
- self.__class__._jinja_propagate = module.propagate
190
-
191
- self.__class__.propagate = module.propagate
192
- self.__class__.collect = module.collect
193
- except Exception: # pragma: no cover
194
- self.__class__._orig_propagate = self.__class__.propagate
195
- self.__class__._jinja_propagate = self.__class__.propagate
196
-
197
- # Optimize `edge_updater()` via `*.jinja` templates (if implemented):
198
- if (self.inspector.implements('edge_update')
199
- and not self.edge_updater.__module__.startswith(jinja_prefix)):
200
- try:
201
- module = module_from_template(
202
- module_name=f'{jinja_prefix}_edge_updater',
203
- template_path=osp.join(root_dir, 'edge_updater.jinja'),
204
- tmp_dirname='message_passing',
205
- # Keyword arguments:
206
- modules=self.inspector._modules,
207
- collect_name='edge_collect',
208
- signature=self._get_edge_updater_signature(),
209
- collect_param_dict=self.inspector.get_param_dict(
210
- 'edge_update'),
211
- )
212
-
213
- self.__class__._orig_edge_updater = self.__class__.edge_updater
214
- self.__class__._jinja_edge_updater = module.edge_updater
215
-
216
- self.__class__.edge_updater = module.edge_updater
217
- self.__class__.edge_collect = module.edge_collect
218
- except Exception: # pragma: no cover
219
- self.__class__._orig_edge_updater = self.__class__.edge_updater
220
- self.__class__._jinja_edge_updater = (
221
- self.__class__.edge_updater)
170
+ # Set jittable `propagate` and `edge_updater` function templates:
171
+ self._set_jittable_templates()
222
172
 
223
173
  # Explainability:
224
174
  self._explain: Optional[bool] = None
@@ -227,6 +177,7 @@ class MessagePassing(torch.nn.Module):
227
177
  self._apply_sigmoid: bool = True
228
178
 
229
179
  # Inference Decomposition:
180
+ self._decomposed_layers = 1
230
181
  self.decomposed_layers = decomposed_layers
231
182
 
232
183
  def reset_parameters(self) -> None:
@@ -234,6 +185,12 @@ class MessagePassing(torch.nn.Module):
234
185
  if self.aggr_module is not None:
235
186
  self.aggr_module.reset_parameters()
236
187
 
188
+ def __setstate__(self, data: Dict[str, Any]) -> None:
189
+ self.inspector = data['inspector']
190
+ self.fuse = data['fuse']
191
+ self._set_jittable_templates()
192
+ super().__setstate__(data)
193
+
237
194
  def __repr__(self) -> str:
238
195
  channels_repr = ''
239
196
  if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'):
@@ -247,7 +204,7 @@ class MessagePassing(torch.nn.Module):
247
204
  def _check_input(
248
205
  self,
249
206
  edge_index: Union[Tensor, SparseTensor],
250
- size: Optional[Tuple[int, int]],
207
+ size: Optional[Tuple[Optional[int], Optional[int]]],
251
208
  ) -> List[Optional[int]]:
252
209
 
253
210
  if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
@@ -256,19 +213,20 @@ class MessagePassing(torch.nn.Module):
256
213
  if is_sparse(edge_index):
257
214
  if self.flow == 'target_to_source':
258
215
  raise ValueError(
259
- ('Flow direction "target_to_source" is invalid for '
260
- 'message propagation via `torch_sparse.SparseTensor` '
261
- 'or `torch.sparse.Tensor`. If you really want to make '
262
- 'use of a reverse message passing flow, pass in the '
263
- 'transposed sparse tensor to the message passing module, '
264
- 'e.g., `adj_t.t()`.'))
216
+ 'Flow direction "target_to_source" is invalid for '
217
+ 'message propagation via `torch_sparse.SparseTensor` '
218
+ 'or `torch.sparse.Tensor`. If you really want to make '
219
+ 'use of a reverse message passing flow, pass in the '
220
+ 'transposed sparse tensor to the message passing module, '
221
+ 'e.g., `adj_t.t()`.')
265
222
 
266
223
  if isinstance(edge_index, SparseTensor):
267
224
  return [edge_index.size(1), edge_index.size(0)]
268
225
  return [edge_index.size(1), edge_index.size(0)]
269
226
 
270
227
  elif isinstance(edge_index, Tensor):
271
- int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64)
228
+ int_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32,
229
+ torch.int64)
272
230
 
273
231
  if edge_index.dtype not in int_dtypes:
274
232
  raise ValueError(f"Expected 'edge_index' to be of integer "
@@ -284,9 +242,9 @@ class MessagePassing(torch.nn.Module):
284
242
  return list(size) if size is not None else [None, None]
285
243
 
286
244
  raise ValueError(
287
- ('`MessagePassing.propagate` only supports integer tensors of '
288
- 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
289
- '`torch.sparse.Tensor` for argument `edge_index`.'))
245
+ '`MessagePassing.propagate` only supports integer tensors of '
246
+ 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
247
+ '`torch.sparse.Tensor` for argument `edge_index`.')
290
248
 
291
249
  def _set_size(
292
250
  self,
@@ -299,8 +257,8 @@ class MessagePassing(torch.nn.Module):
299
257
  size[dim] = src.size(self.node_dim)
300
258
  elif the_size != src.size(self.node_dim):
301
259
  raise ValueError(
302
- (f'Encountered tensor with size {src.size(self.node_dim)} in '
303
- f'dimension {self.node_dim}, but expected size {the_size}.'))
260
+ f'Encountered tensor with size {src.size(self.node_dim)} in '
261
+ f'dimension {self.node_dim}, but expected size {the_size}.')
304
262
 
305
263
  def _index_select(self, src: Tensor, index) -> Tensor:
306
264
  if torch.jit.is_scripting() or is_compiling():
@@ -370,9 +328,9 @@ class MessagePassing(torch.nn.Module):
370
328
  return src.index_select(self.node_dim, row)
371
329
 
372
330
  raise ValueError(
373
- ('`MessagePassing.propagate` only supports integer tensors of '
374
- 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
375
- 'or `torch.sparse.Tensor` for argument `edge_index`.'))
331
+ '`MessagePassing.propagate` only supports integer tensors of '
332
+ 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
333
+ 'or `torch.sparse.Tensor` for argument `edge_index`.')
376
334
 
377
335
  def _collect(
378
336
  self,
@@ -459,7 +417,6 @@ class MessagePassing(torch.nn.Module):
459
417
 
460
418
  def forward(self, *args: Any, **kwargs: Any) -> Any:
461
419
  r"""Runs the forward pass of the module."""
462
- pass
463
420
 
464
421
  def propagate(
465
422
  self,
@@ -509,7 +466,17 @@ class MessagePassing(torch.nn.Module):
509
466
  mutable_size = self._check_input(edge_index, size)
510
467
 
511
468
  # Run "fused" message and aggregation (if applicable).
512
- if is_sparse(edge_index) and self.fuse and not self.explain:
469
+ fuse = False
470
+ if self.fuse and not self.explain:
471
+ if is_sparse(edge_index):
472
+ fuse = True
473
+ elif (not torch.jit.is_scripting()
474
+ and isinstance(edge_index, EdgeIndex)):
475
+ if (self.SUPPORTS_FUSED_EDGE_INDEX
476
+ and edge_index.is_sorted_by_col):
477
+ fuse = True
478
+
479
+ if fuse:
513
480
  coll_dict = self._collect(self._fused_user_args, edge_index,
514
481
  mutable_size, kwargs)
515
482
 
@@ -628,7 +595,7 @@ class MessagePassing(torch.nn.Module):
628
595
  dim=self.node_dim)
629
596
 
630
597
  @abstractmethod
631
- def message_and_aggregate(self, adj_t: Adj) -> Tensor:
598
+ def message_and_aggregate(self, edge_index: Adj) -> Tensor:
632
599
  r"""Fuses computations of :func:`message` and :func:`aggregate` into a
633
600
  single function.
634
601
  If applicable, this saves both time and memory since messages do not
@@ -720,16 +687,20 @@ class MessagePassing(torch.nn.Module):
720
687
  raise ValueError("Inference decomposition of message passing "
721
688
  "modules is only supported on the Python module")
722
689
 
690
+ if decomposed_layers == self._decomposed_layers:
691
+ return # Abort early if nothing to do.
692
+
723
693
  self._decomposed_layers = decomposed_layers
724
694
 
725
695
  if decomposed_layers != 1:
726
- self.propagate = self.__class__._orig_propagate.__get__(
727
- self, MessagePassing)
696
+ if hasattr(self.__class__, '_orig_propagate'):
697
+ self.propagate = self.__class__._orig_propagate.__get__(
698
+ self, MessagePassing)
728
699
 
729
- elif ((self.explain is None or self.explain is False)
730
- and not self.propagate.__module__.endswith('_propagate')):
731
- self.propagate = self.__class__._jinja_propagate.__get__(
732
- self, MessagePassing)
700
+ elif self.explain is None or self.explain is False:
701
+ if hasattr(self.__class__, '_jinja_propagate'):
702
+ self.propagate = self.__class__._jinja_propagate.__get__(
703
+ self, MessagePassing)
733
704
 
734
705
  # Explainability ##########################################################
735
706
 
@@ -743,6 +714,9 @@ class MessagePassing(torch.nn.Module):
743
714
  raise ValueError("Explainability of message passing modules "
744
715
  "is only supported on the Python module")
745
716
 
717
+ if explain == self._explain:
718
+ return # Abort early if nothing to do.
719
+
746
720
  self._explain = explain
747
721
 
748
722
  if explain is True:
@@ -753,16 +727,18 @@ class MessagePassing(torch.nn.Module):
753
727
  funcs=['message', 'explain_message', 'aggregate', 'update'],
754
728
  exclude=self.special_args,
755
729
  )
756
- self.propagate = self.__class__._orig_propagate.__get__(
757
- self, MessagePassing)
730
+ if hasattr(self.__class__, '_orig_propagate'):
731
+ self.propagate = self.__class__._orig_propagate.__get__(
732
+ self, MessagePassing)
758
733
  else:
759
734
  self._user_args = self.inspector.get_flat_param_names(
760
735
  funcs=['message', 'aggregate', 'update'],
761
736
  exclude=self.special_args,
762
737
  )
763
738
  if self.decomposed_layers == 1:
764
- self.propagate = self.__class__._jinja_propagate.__get__(
765
- self, MessagePassing)
739
+ if hasattr(self.__class__, '_jinja_propagate'):
740
+ self.propagate = self.__class__._jinja_propagate.__get__(
741
+ self, MessagePassing)
766
742
 
767
743
  def explain_message(
768
744
  self,
@@ -947,6 +923,81 @@ class MessagePassing(torch.nn.Module):
947
923
 
948
924
  # TorchScript Support #####################################################
949
925
 
926
+ def _set_jittable_templates(self, raise_on_error: bool = False) -> None:
927
+ root_dir = osp.dirname(osp.realpath(__file__))
928
+ jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
929
+ # Optimize `propagate()` via `*.jinja` templates:
930
+ if not self.propagate.__module__.startswith(jinja_prefix):
931
+ try:
932
+ if ('propagate' in self.__class__.__dict__
933
+ and self.__class__.__dict__['propagate']
934
+ != MessagePassing.propagate):
935
+ raise ValueError("Cannot compile custom 'propagate' "
936
+ "method")
937
+
938
+ module = module_from_template(
939
+ module_name=f'{jinja_prefix}_propagate',
940
+ template_path=osp.join(root_dir, 'propagate.jinja'),
941
+ tmp_dirname='message_passing',
942
+ # Keyword arguments:
943
+ modules=self.inspector._modules,
944
+ collect_name='collect',
945
+ signature=self._get_propagate_signature(),
946
+ collect_param_dict=self.inspector.get_flat_param_dict(
947
+ ['message', 'aggregate', 'update']),
948
+ message_args=self.inspector.get_param_names('message'),
949
+ aggregate_args=self.inspector.get_param_names('aggregate'),
950
+ message_and_aggregate_args=self.inspector.get_param_names(
951
+ 'message_and_aggregate'),
952
+ update_args=self.inspector.get_param_names('update'),
953
+ fuse=self.fuse,
954
+ )
955
+
956
+ self.__class__._orig_propagate = self.__class__.propagate
957
+ self.__class__._jinja_propagate = module.propagate
958
+
959
+ self.__class__.propagate = module.propagate
960
+ self.__class__.collect = module.collect
961
+ except Exception as e: # pragma: no cover
962
+ if raise_on_error:
963
+ raise e
964
+ self.__class__._orig_propagate = self.__class__.propagate
965
+ self.__class__._jinja_propagate = self.__class__.propagate
966
+
967
+ # Optimize `edge_updater()` via `*.jinja` templates (if implemented):
968
+ if (self.inspector.implements('edge_update')
969
+ and not self.edge_updater.__module__.startswith(jinja_prefix)):
970
+ try:
971
+ if ('edge_updater' in self.__class__.__dict__
972
+ and self.__class__.__dict__['edge_updater']
973
+ != MessagePassing.edge_updater):
974
+ raise ValueError("Cannot compile custom 'edge_updater' "
975
+ "method")
976
+
977
+ module = module_from_template(
978
+ module_name=f'{jinja_prefix}_edge_updater',
979
+ template_path=osp.join(root_dir, 'edge_updater.jinja'),
980
+ tmp_dirname='message_passing',
981
+ # Keyword arguments:
982
+ modules=self.inspector._modules,
983
+ collect_name='edge_collect',
984
+ signature=self._get_edge_updater_signature(),
985
+ collect_param_dict=self.inspector.get_param_dict(
986
+ 'edge_update'),
987
+ )
988
+
989
+ self.__class__._orig_edge_updater = self.__class__.edge_updater
990
+ self.__class__._jinja_edge_updater = module.edge_updater
991
+
992
+ self.__class__.edge_updater = module.edge_updater
993
+ self.__class__.edge_collect = module.edge_collect
994
+ except Exception as e: # pragma: no cover
995
+ if raise_on_error:
996
+ raise e
997
+ self.__class__._orig_edge_updater = self.__class__.edge_updater
998
+ self.__class__._jinja_edge_updater = (
999
+ self.__class__.edge_updater)
1000
+
950
1001
  def _get_propagate_signature(self) -> Signature:
951
1002
  param_dict = self.inspector.get_params_from_method_call(
952
1003
  'propagate', exclude=[0, 'edge_index', 'size'])
@@ -14,7 +14,7 @@ from torch_geometric.utils import spmm
14
14
 
15
15
  class MixHopConv(MessagePassing):
16
16
  r"""The Mix-Hop graph convolutional operator from the
17
- `"MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified
17
+ `"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified
18
18
  Neighborhood Mixing" <https://arxiv.org/abs/1905.00067>`_ paper.
19
19
 
20
20
  .. math::
@@ -42,7 +42,15 @@ def propagate(
42
42
  # End Propagate Forward Pre Hook ###########################################
43
43
 
44
44
  mutable_size = self._check_input(edge_index, size)
45
- fuse = is_sparse(edge_index) and self.fuse
45
+
46
+ # Run "fused" message and aggregation (if applicable).
47
+ fuse = False
48
+ if self.fuse:
49
+ if is_sparse(edge_index):
50
+ fuse = True
51
+ elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
52
+ if self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col:
53
+ fuse = True
46
54
 
47
55
  if fuse:
48
56
 
@@ -3,11 +3,11 @@ from typing import Optional, Tuple, Union
3
3
  import torch
4
4
  from torch import Tensor
5
5
  from torch.nn import Parameter
6
- from torch.nn import Parameter as Param
7
6
 
8
7
  import torch_geometric.backend
9
8
  import torch_geometric.typing
10
9
  from torch_geometric import is_compiling
10
+ from torch_geometric.index import index2ptr
11
11
  from torch_geometric.nn.conv import MessagePassing
12
12
  from torch_geometric.nn.inits import glorot, zeros
13
13
  from torch_geometric.typing import (
@@ -18,7 +18,6 @@ from torch_geometric.typing import (
18
18
  torch_sparse,
19
19
  )
20
20
  from torch_geometric.utils import index_sort, one_hot, scatter, spmm
21
- from torch_geometric.utils.sparse import index2ptr
22
21
 
23
22
 
24
23
  def masked_edge_index(edge_index: Adj, edge_mask: Tensor) -> Adj:
@@ -121,7 +120,8 @@ class RGCNConv(MessagePassing):
121
120
  in_channels = (in_channels, in_channels)
122
121
  self.in_channels_l = in_channels[0]
123
122
 
124
- self._use_segment_matmul_heuristic_output: Optional[bool] = None
123
+ self._use_segment_matmul_heuristic_output: torch.jit.Attribute(
124
+ None, Optional[float])
125
125
 
126
126
  if num_bases is not None:
127
127
  self.weight = Parameter(
@@ -143,12 +143,12 @@ class RGCNConv(MessagePassing):
143
143
  self.register_parameter('comp', None)
144
144
 
145
145
  if root_weight:
146
- self.root = Param(torch.empty(in_channels[1], out_channels))
146
+ self.root = Parameter(torch.empty(in_channels[1], out_channels))
147
147
  else:
148
148
  self.register_parameter('root', None)
149
149
 
150
150
  if bias:
151
- self.bias = Param(torch.empty(out_channels))
151
+ self.bias = Parameter(torch.empty(out_channels))
152
152
  else:
153
153
  self.register_parameter('bias', None)
154
154
 
@@ -5,17 +5,17 @@ import torch
5
5
  from torch import Tensor, nn
6
6
  from torch.nn import Parameter
7
7
 
8
+ import torch_geometric.typing
8
9
  from torch_geometric.nn.conv import MessagePassing
9
10
  from torch_geometric.nn.dense.linear import Linear
10
11
  from torch_geometric.nn.inits import uniform, zeros
11
12
  from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
12
13
  from torch_geometric.utils.repeat import repeat
13
14
 
14
- try:
15
+ if torch_geometric.typing.WITH_TORCH_SPLINE_CONV:
15
16
  from torch_spline_conv import spline_basis, spline_weighting
16
- except (ImportError, OSError): # Fail gracefully on GLIBC errors
17
- spline_basis = None
18
- spline_weighting = None
17
+ else:
18
+ spline_basis = spline_weighting = None
19
19
 
20
20
 
21
21
  class SplineConv(MessagePassing):
@@ -9,12 +9,13 @@ from torch.nn import Conv1d
9
9
  from torch.nn import Linear as L
10
10
  from torch.nn import Sequential as S
11
11
 
12
+ import torch_geometric.typing
12
13
  from torch_geometric.nn import Reshape
13
14
  from torch_geometric.nn.inits import reset
14
15
 
15
- try:
16
+ if torch_geometric.typing.WITH_TORCH_CLUSTER:
16
17
  from torch_cluster import knn_graph
17
- except ImportError:
18
+ else:
18
19
  knn_graph = None
19
20
 
20
21
 
@@ -12,10 +12,10 @@ from torch.nn.parameter import Parameter
12
12
  import torch_geometric.backend
13
13
  import torch_geometric.typing
14
14
  from torch_geometric import is_compiling
15
+ from torch_geometric.index import index2ptr
15
16
  from torch_geometric.nn import inits
16
17
  from torch_geometric.typing import pyg_lib
17
18
  from torch_geometric.utils import index_sort
18
- from torch_geometric.utils.sparse import index2ptr
19
19
 
20
20
 
21
21
  def is_uninitialized_parameter(x: Any) -> bool:
@@ -58,7 +58,7 @@ def reset_bias_(bias: Optional[Tensor], in_channels: int,
58
58
 
59
59
 
60
60
  class Linear(torch.nn.Module):
61
- r"""Applies a linear tranformation to the incoming data.
61
+ r"""Applies a linear transformation to the incoming data.
62
62
 
63
63
  .. math::
64
64
  \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}
@@ -192,7 +192,7 @@ class Linear(torch.nn.Module):
192
192
 
193
193
 
194
194
  class HeteroLinear(torch.nn.Module):
195
- r"""Applies separate linear tranformations to the incoming data according
195
+ r"""Applies separate linear transformations to the incoming data according
196
196
  to types.
197
197
 
198
198
  For type :math:`\kappa`, it computes
@@ -222,6 +222,8 @@ class HeteroLinear(torch.nn.Module):
222
222
  type vector :math:`(*)`
223
223
  - **output:** features :math:`(*, F_{out})`
224
224
  """
225
+ _timing_cache: Dict[int, Tuple[float, float]]
226
+
225
227
  def __init__(
226
228
  self,
227
229
  in_channels: int,
@@ -245,15 +247,17 @@ class HeteroLinear(torch.nn.Module):
245
247
  else:
246
248
  self.weight = torch.nn.Parameter(
247
249
  torch.empty(num_types, in_channels, out_channels))
250
+
248
251
  if kwargs.get('bias', True):
249
252
  self.bias = Parameter(torch.empty(num_types, out_channels))
250
253
  else:
251
254
  self.register_parameter('bias', None)
252
- self.reset_parameters()
253
255
 
254
256
  # Timing cache for benchmarking naive vs. segment matmul usage:
255
257
  self._timing_cache: Dict[int, Tuple[float, float]] = {}
256
258
 
259
+ self.reset_parameters()
260
+
257
261
  def reset_parameters(self):
258
262
  r"""Resets all learnable parameters of the module."""
259
263
  reset_weight_(self.weight, self.in_channels,
@@ -361,7 +365,8 @@ class HeteroLinear(torch.nn.Module):
361
365
 
362
366
 
363
367
  class HeteroDictLinear(torch.nn.Module):
364
- r"""Applies separate linear tranformations to the incoming data dictionary.
368
+ r"""Applies separate linear transformations to the incoming data
369
+ dictionary.
365
370
 
366
371
  For key :math:`\kappa`, it computes
367
372
 
@@ -475,7 +480,7 @@ class HeteroDictLinear(torch.nn.Module):
475
480
  lin = self.lins[key]
476
481
  if is_uninitialized_parameter(lin.weight):
477
482
  self.lins[key].initialize_parameters(None, x)
478
- self.reset_parameters()
483
+ self.lins[key].reset_parameters()
479
484
  self._hook.remove()
480
485
  self.in_channels = {key: x.size(-1) for key, x in input[0].items()}
481
486
  delattr(self, '_hook')
torch_geometric/nn/fx.py CHANGED
@@ -18,8 +18,8 @@ class Transformer:
18
18
  :class:`~torch.nn.Module`.
19
19
  :class:`Transformer` works entirely symbolically.
20
20
 
21
- Methods in the :class:`Transformer` class can be overriden to customize the
22
- behavior of transformation.
21
+ Methods in the :class:`Transformer` class can be overridden to customize
22
+ the behavior of transformation.
23
23
 
24
24
  .. code-block:: none
25
25
 
@@ -283,7 +283,7 @@ def symbolic_trace(
283
283
  # TODO We currently only trace top-level modules.
284
284
  return not isinstance(module, torch.nn.Sequential)
285
285
 
286
- # Note: This is a hack around the fact that `Aggregaton.__call__`
286
+ # Note: This is a hack around the fact that `Aggregation.__call__`
287
287
  # is not patched by the base implementation of `trace`.
288
288
  # see https://github.com/pyg-team/pytorch_geometric/pull/5021 for
289
289
  # details on the rationale
@@ -4,6 +4,8 @@ from typing import Any, Dict, Optional, Union
4
4
 
5
5
  import torch
6
6
 
7
+ from torch_geometric.io import fs
8
+
7
9
  try:
8
10
  from huggingface_hub import ModelHubMixin, hf_hub_download
9
11
  except ImportError:
@@ -175,7 +177,7 @@ class PyGModelHubMixin(ModelHubMixin):
175
177
 
176
178
  model = cls(dataset_name, model_name, model_kwargs)
177
179
 
178
- state_dict = torch.load(model_file, map_location=map_location)
180
+ state_dict = fs.torch_load(model_file, map_location=map_location)
179
181
  model.load_state_dict(state_dict, strict=strict)
180
182
  model.eval()
181
183
 
@@ -2,7 +2,7 @@ r"""Model package."""
2
2
 
3
3
  from .mlp import MLP
4
4
  from .basic_gnn import GCN, GraphSAGE, GIN, GAT, PNA, EdgeCNN
5
- from .jumping_knowledge import JumpingKnowledge
5
+ from .jumping_knowledge import JumpingKnowledge, HeteroJumpingKnowledge
6
6
  from .meta import MetaLayer
7
7
  from .node2vec import Node2Vec
8
8
  from .deep_graph_infomax import DeepGraphInfomax
@@ -28,7 +28,10 @@ from .gnnff import GNNFF
28
28
  from .pmlp import PMLP
29
29
  from .neural_fingerprint import NeuralFingerprint
30
30
  from .visnet import ViSNet
31
-
31
+ from .g_retriever import GRetriever
32
+ from .git_mol import GITMol
33
+ from .molecule_gpt import MoleculeGPT
34
+ from .glem import GLEM
32
35
  # Deprecated:
33
36
  from torch_geometric.explain.algorithm.captum import (to_captum_input,
34
37
  captum_output_to_dicts)
@@ -42,6 +45,7 @@ __all__ = classes = [
42
45
  'PNA',
43
46
  'EdgeCNN',
44
47
  'JumpingKnowledge',
48
+ 'HeteroJumpingKnowledge',
45
49
  'MetaLayer',
46
50
  'Node2Vec',
47
51
  'DeepGraphInfomax',
@@ -74,4 +78,8 @@ __all__ = classes = [
74
78
  'PMLP',
75
79
  'NeuralFingerprint',
76
80
  'ViSNet',
81
+ 'GRetriever',
82
+ 'GITMol',
83
+ 'MoleculeGPT',
84
+ 'GLEM',
77
85
  ]