pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -251,13 +251,13 @@ def from_networkx(
251
251
  if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):
252
252
  group_edge_attrs = edge_attrs
253
253
 
254
- for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
254
+ for _, feat_dict in G.nodes(data=True):
255
255
  if set(feat_dict.keys()) != set(node_attrs):
256
256
  raise ValueError('Not all nodes contain the same attributes')
257
257
  for key, value in feat_dict.items():
258
258
  data_dict[str(key)].append(value)
259
259
 
260
- for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
260
+ for _, _, feat_dict in G.edges(data=True):
261
261
  if set(feat_dict.keys()) != set(edge_attrs):
262
262
  raise ValueError('Not all edges contain the same attributes')
263
263
  for key, value in feat_dict.items():
@@ -452,15 +452,22 @@ def to_cugraph(
452
452
  g = cugraph.Graph(directed=directed)
453
453
  df = cudf.from_dlpack(to_dlpack(edge_index.t()))
454
454
 
455
+ df = cudf.DataFrame({
456
+ 'source':
457
+ cudf.from_dlpack(to_dlpack(edge_index[0])),
458
+ 'destination':
459
+ cudf.from_dlpack(to_dlpack(edge_index[1])),
460
+ })
461
+
455
462
  if edge_weight is not None:
456
463
  assert edge_weight.dim() == 1
457
- df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))
464
+ df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
458
465
 
459
466
  g.from_cudf_edgelist(
460
467
  df,
461
- source=0,
462
- destination=1,
463
- edge_attr='2' if edge_weight is not None else None,
468
+ source='source',
469
+ destination='destination',
470
+ edge_attr='weight' if edge_weight is not None else None,
464
471
  renumber=relabel_nodes,
465
472
  )
466
473
 
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
476
483
  """
477
484
  df = g.view_edge_list()
478
485
 
479
- src = from_dlpack(df[0].to_dlpack()).long()
480
- dst = from_dlpack(df[1].to_dlpack()).long()
486
+ src = from_dlpack(df[g.source_columns].to_dlpack()).long()
487
+ dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
481
488
  edge_index = torch.stack([src, dst], dim=0)
482
489
 
483
490
  edge_weight = None
484
- if '2' in df:
485
- edge_weight = from_dlpack(df['2'].to_dlpack())
491
+ if g.weight_column is not None:
492
+ edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
486
493
 
487
494
  return edge_index, edge_weight
488
495
 
@@ -18,30 +18,51 @@ class SparseCrossEntropy(torch.autograd.Function):
18
18
  ) -> Tensor:
19
19
  assert inputs.dim() == 2
20
20
 
21
- logsumexp = inputs.logsumexp(dim=-1)
22
- ctx.save_for_backward(inputs, edge_label_index, edge_label_weight,
23
- logsumexp)
21
+ # Support for both positive and negative weights:
22
+ # Positive weights scale the logits *after* softmax.
23
+ # Negative weights scale the denominator *before* softmax:
24
+ pos_y = edge_label_index
25
+ neg_y = pos_weight = neg_weight = None
24
26
 
25
- out = inputs[edge_label_index[0], edge_label_index[1]]
26
- out.neg_().add_(logsumexp[edge_label_index[0]])
27
27
  if edge_label_weight is not None:
28
- out *= edge_label_weight
28
+ pos_mask = edge_label_weight >= 0
29
+ pos_y = edge_label_index[:, pos_mask]
30
+ pos_weight = edge_label_weight[pos_mask]
31
+
32
+ if pos_y.size(1) < edge_label_index.size(1):
33
+ neg_mask = ~pos_mask
34
+ neg_y = edge_label_index[:, neg_mask]
35
+ neg_weight = edge_label_weight[neg_mask]
36
+
37
+ if neg_y is not None and neg_weight is not None:
38
+ inputs = inputs.clone()
39
+ inputs[
40
+ neg_y[0],
41
+ neg_y[1],
42
+ ] += neg_weight.abs().log().clamp(min=1e-12)
43
+
44
+ logsumexp = inputs.logsumexp(dim=-1)
45
+ ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)
46
+
47
+ out = inputs[pos_y[0], pos_y[1]]
48
+ out.neg_().add_(logsumexp[pos_y[0]])
49
+ if pos_weight is not None:
50
+ out *= pos_weight
29
51
 
30
52
  return out.sum() / inputs.size(0)
31
53
 
32
54
  @staticmethod
33
55
  @torch.autograd.function.once_differentiable
34
56
  def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:
35
- inputs, edge_label_index, edge_label_weight, logsumexp = (
36
- ctx.saved_tensors)
57
+ inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors
37
58
 
38
59
  grad_out = grad_out / inputs.size(0)
39
- grad_out = grad_out.expand(edge_label_index.size(1))
60
+ grad_out = grad_out.expand(pos_y.size(1))
40
61
 
41
- if edge_label_weight is not None:
42
- grad_out = grad_out * edge_label_weight
62
+ if pos_weight is not None:
63
+ grad_out = grad_out * pos_weight
43
64
 
44
- grad_logsumexp = scatter(grad_out, edge_label_index[0], dim=0,
65
+ grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,
45
66
  dim_size=inputs.size(0), reduce='sum')
46
67
 
47
68
  # Gradient computation of `logsumexp`: `grad * (self - result).exp()`
@@ -49,7 +70,7 @@ class SparseCrossEntropy(torch.autograd.Function):
49
70
  grad_input.exp_()
50
71
  grad_input.mul_(grad_logsumexp.view(-1, 1))
51
72
 
52
- grad_input[edge_label_index[0], edge_label_index[1]] -= grad_out
73
+ grad_input[pos_y[0], pos_y[1]] -= grad_out
53
74
 
54
75
  return grad_input, None, None
55
76
 
@@ -1,9 +1,11 @@
1
1
  import warnings
2
- from typing import Any, List
2
+ from typing import Any, Dict, List, Optional, Type
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
+ from torch_geometric.typing import NodeType
8
+
7
9
 
8
10
  def get_embeddings(
9
11
  model: torch.nn.Module,
@@ -40,7 +42,8 @@ def get_embeddings(
40
42
  hook_handles.append(module.register_forward_hook(hook))
41
43
 
42
44
  if len(hook_handles) == 0:
43
- warnings.warn("The 'model' does not have any 'MessagePassing' layers")
45
+ warnings.warn("The 'model' does not have any 'MessagePassing' layers",
46
+ stacklevel=2)
44
47
 
45
48
  training = model.training
46
49
  model.eval()
@@ -52,3 +55,89 @@ def get_embeddings(
52
55
  handle.remove()
53
56
 
54
57
  return embeddings
58
+
59
+
60
+ def get_embeddings_hetero(
61
+ model: torch.nn.Module,
62
+ supported_models: Optional[List[Type[torch.nn.Module]]] = None,
63
+ *args: Any,
64
+ **kwargs: Any,
65
+ ) -> Dict[NodeType, List[Tensor]]:
66
+ """Returns the output embeddings of all
67
+ :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
68
+ :obj:`model`, organized by edge type.
69
+
70
+ Internally, this method registers forward hooks on all modules that process
71
+ heterogeneous graphs in the model and runs the forward pass of the model.
72
+ For heterogeneous models, the output is a dictionary where each key is a
73
+ node type and each value is a list of embeddings from different layers.
74
+
75
+ Args:
76
+ model (torch.nn.Module): The heterogeneous GNN model.
77
+ supported_models (List[Type[torch.nn.Module]], optional): A list of
78
+ supported model classes. If not provided, defaults to
79
+ [HGTConv, HANConv, HeteroConv].
80
+ *args: Arguments passed to the model.
81
+ **kwargs (optional): Additional keyword arguments passed to the model.
82
+
83
+ Returns:
84
+ Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to
85
+ a list of embeddings from different layers.
86
+ """
87
+ from torch_geometric.nn import HANConv, HeteroConv, HGTConv
88
+ if not supported_models:
89
+ supported_models = [HGTConv, HANConv, HeteroConv]
90
+
91
+ # Dictionary to store node embeddings by type
92
+ node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}
93
+
94
+ # Hook function to capture node embeddings
95
+ def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
96
+ # Check if the outputs is a dictionary mapping node types to embeddings
97
+ if isinstance(outputs, dict) and outputs:
98
+ # Store embeddings for each node type
99
+ for node_type, embedding in outputs.items():
100
+ # Made sure that the outputs are a dictionary mapping node
101
+ # types to embeddings and remove the false positives.
102
+ if node_type not in node_embeddings_dict:
103
+ node_embeddings_dict[node_type] = []
104
+ node_embeddings_dict[node_type].append(embedding.clone())
105
+
106
+ # List to store hook handles
107
+ hook_handles = []
108
+
109
+ # Find ModuleDict objects in the model
110
+ for _, module in model.named_modules():
111
+ # Handle the native heterogenous models, e.g. HGTConv, HANConv
112
+ # and HeteroConv, etc.
113
+ if isinstance(module, tuple(supported_models)):
114
+ hook_handles.append(module.register_forward_hook(hook))
115
+ else:
116
+ # Handle the heterogenous models that are generated by calling
117
+ # to_hetero() on the homogeneous models.
118
+ submodules = list(module.children())
119
+ submodules_contains_module_dict = any([
120
+ isinstance(submodule, torch.nn.ModuleDict)
121
+ for submodule in submodules
122
+ ])
123
+ if submodules_contains_module_dict:
124
+ hook_handles.append(module.register_forward_hook(hook))
125
+
126
+ if len(hook_handles) == 0:
127
+ warnings.warn(
128
+ "The 'model' does not have any heterogenous "
129
+ "'MessagePassing' layers", stacklevel=2)
130
+
131
+ # Run the model forward pass
132
+ training = model.training
133
+ model.eval()
134
+
135
+ with torch.no_grad():
136
+ model(*args, **kwargs)
137
+ model.train(training)
138
+
139
+ # Clean up hooks
140
+ for handle in hook_handles:
141
+ handle.remove()
142
+
143
+ return node_embeddings_dict
@@ -66,9 +66,10 @@ def geodesic_distance( # noqa: D417
66
66
 
67
67
  if 'dest' in kwargs:
68
68
  dst = kwargs['dest']
69
- warnings.warn("'dest' attribute in 'geodesic_distance' is deprecated "
70
- "and will be removed in a future release. Use the 'dst' "
71
- "argument instead.")
69
+ warnings.warn(
70
+ "'dest' attribute in 'geodesic_distance' is deprecated "
71
+ "and will be removed in a future release. Use the 'dst' "
72
+ "argument instead.", stacklevel=2)
72
73
 
73
74
  max_distance = float('inf') if max_distance is None else max_distance
74
75
 
@@ -0,0 +1,279 @@
1
+ from typing import List, Tuple, Union, cast
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.autograd.functional import jacobian
6
+ from tqdm.auto import tqdm
7
+
8
+ from torch_geometric.data import Data
9
+ from torch_geometric.utils import k_hop_subgraph
10
+
11
+
12
+ def k_hop_subsets_rough(
13
+ node_idx: int,
14
+ num_hops: int,
15
+ edge_index: Tensor,
16
+ num_nodes: int,
17
+ ) -> List[Tensor]:
18
+ r"""Return *rough* (possibly overlapping) *k*-hop node subsets.
19
+
20
+ This is a thin wrapper around
21
+ :pyfunc:`torch_geometric.utils.k_hop_subgraph` that *additionally* returns
22
+ **all** intermediate hop subsets rather than the full union only.
23
+
24
+ Parameters
25
+ ----------
26
+ node_idx: int
27
+ Index or indices of the central node(s).
28
+ num_hops: int
29
+ Number of hops *k*.
30
+ edge_index: Tensor
31
+ Edge index in COO format with shape :math:`[2, \text{num_edges}]`.
32
+ num_nodes: int
33
+ Total number of nodes in the graph. Required to allocate the masks.
34
+
35
+ Returns:
36
+ -------
37
+ List[Tensor]
38
+ A list ``[H₀, H₁, …, H_k]`` where ``H₀`` contains the seed node(s) and
39
+ ``H_i`` (for *i*>0) contains **all** nodes that are exactly *i* hops
40
+ away in the *expanded* neighbourhood (i.e. overlaps are *not*
41
+ removed).
42
+ """
43
+ col, row = edge_index
44
+
45
+ node_mask = row.new_empty(num_nodes, dtype=torch.bool)
46
+ edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
47
+
48
+ node_idx_ = torch.tensor([node_idx], device=row.device)
49
+
50
+ subsets = [node_idx_]
51
+ for _ in range(num_hops):
52
+ node_mask.zero_()
53
+ node_mask[subsets[-1]] = True
54
+ torch.index_select(node_mask, 0, row, out=edge_mask)
55
+ subsets.append(col[edge_mask])
56
+
57
+ return subsets
58
+
59
+
60
+ def k_hop_subsets_exact(
61
+ node_idx: int,
62
+ num_hops: int,
63
+ edge_index: Tensor,
64
+ num_nodes: int,
65
+ device: Union[torch.device, str],
66
+ ) -> List[Tensor]:
67
+ """Return **disjoint** *k*-hop subsets.
68
+
69
+ This function refines :pyfunc:`k_hop_subsets_rough` by removing nodes that
70
+ have already appeared in previous hops, ensuring that each subset contains
71
+ nodes *exactly* *i* hops away from the seed.
72
+ """
73
+ rough_subsets = k_hop_subsets_rough(node_idx, num_hops, edge_index,
74
+ num_nodes)
75
+
76
+ exact_subsets: List[List[int]] = [rough_subsets[0].tolist()]
77
+ visited: set[int] = set(exact_subsets[0])
78
+
79
+ for hop_subset in rough_subsets[1:]:
80
+ fresh = set(hop_subset.tolist()) - visited
81
+ visited |= fresh
82
+ exact_subsets.append(list(fresh))
83
+
84
+ return [
85
+ torch.tensor(s, device=device, dtype=edge_index.dtype)
86
+ for s in exact_subsets
87
+ ]
88
+
89
+
90
+ def jacobian_l1(
91
+ model: torch.nn.Module,
92
+ data: Data,
93
+ max_hops: int,
94
+ node_idx: int,
95
+ device: Union[torch.device, str],
96
+ *,
97
+ vectorize: bool = True,
98
+ ) -> Tensor:
99
+ """Compute the **L1 norm** of the Jacobian for a given node.
100
+
101
+ The Jacobian is evaluated w.r.t. the node features of the *k*-hop induced
102
+ sub‑graph centred at ``node_idx``. The result is *folded back* onto the
103
+ **original** node index space so that the returned tensor has length
104
+ ``data.num_nodes``, where the influence score will be zero for nodes
105
+ outside the *k*-hop subgraph.
106
+
107
+ Notes:
108
+ -----
109
+ * The function assumes that the model *and* ``data.x`` share the same
110
+ floating‑point precision (e.g. both ``float32`` or both ``float16``).
111
+
112
+ """
113
+ # Build the induced *k*-hop sub‑graph (with node re‑labelling).
114
+ edge_index = cast(Tensor, data.edge_index)
115
+ x = cast(Tensor, data.x)
116
+ k_hop_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
117
+ node_idx, max_hops, edge_index, relabel_nodes=True)
118
+ # get the location of the *center* node inside the sub‑graph
119
+ root_pos = cast(int, mapping[0])
120
+
121
+ # Move tensors & model to the correct device
122
+ device = torch.device(device)
123
+ sub_x = x[k_hop_nodes].to(device)
124
+ sub_edge_index = sub_edge_index.to(device)
125
+ model = model.to(device)
126
+
127
+ # Jacobian evaluation
128
+ def _forward(x: Tensor) -> Tensor:
129
+ return model(x, sub_edge_index)[root_pos]
130
+
131
+ jac = jacobian(_forward, sub_x, vectorize=vectorize)
132
+ influence_sub = jac.abs().sum(dim=(0, 2)) # Sum of L1 norm
133
+ num_nodes = cast(int, data.num_nodes)
134
+ # Scatter the influence scores back to the *global* node space
135
+ influence_full = torch.zeros(num_nodes, dtype=influence_sub.dtype,
136
+ device=device)
137
+ influence_full[k_hop_nodes] = influence_sub
138
+
139
+ return influence_full
140
+
141
+
142
+ def jacobian_l1_agg_per_hop(
143
+ model: torch.nn.Module,
144
+ data: Data,
145
+ max_hops: int,
146
+ node_idx: int,
147
+ device: Union[torch.device, str],
148
+ vectorize: bool = True,
149
+ ) -> Tensor:
150
+ """Aggregate Jacobian L1 norms **per hop** for node_idx.
151
+
152
+ Returns a vector ``[I_0, I_1, …, I_k]`` where ``I_i`` is the *total*
153
+ influence exerted by nodes that are exactly *i* hops away from
154
+ ``node_idx``.
155
+ """
156
+ num_nodes = cast(int, data.num_nodes)
157
+ edge_index = cast(Tensor, data.edge_index)
158
+ influence = jacobian_l1(model, data, max_hops, node_idx, device,
159
+ vectorize=vectorize)
160
+ hop_subsets = k_hop_subsets_exact(node_idx, max_hops, edge_index,
161
+ num_nodes, influence.device)
162
+ single_node_influence_per_hop = [influence[s].sum() for s in hop_subsets]
163
+ return torch.tensor(single_node_influence_per_hop, device=influence.device)
164
+
165
+
166
+ def avg_total_influence(
167
+ influence_all_nodes: Tensor,
168
+ normalize: bool = True,
169
+ ) -> Tensor:
170
+ """Compute the *influence‑weighted receptive field* ``R``."""
171
+ avg_total_influences = torch.mean(influence_all_nodes, dim=0)
172
+ if normalize: # normalize by hop_0 (jacobian of the center node feature)
173
+ avg_total_influences = avg_total_influences / avg_total_influences[0]
174
+ return avg_total_influences
175
+
176
+
177
+ def influence_weighted_receptive_field(T: Tensor) -> float:
178
+ """Compute the *influence‑weighted receptive field* ``R``.
179
+
180
+ Given an influence matrix ``T`` of shape ``[N, k+1]`` (i‑th row contains
181
+ the per‑hop influences of node *i*), the receptive field breadth *R* is
182
+ defined as the expected hop distance when weighting by influence.
183
+
184
+ A larger *R* indicates that, on average, influence comes from **farther**
185
+ hops.
186
+ """
187
+ normalised = T / torch.sum(T, dim=1, keepdim=True)
188
+ hops = torch.arange(T.shape[1]).float() # 0 … k
189
+ breadth = normalised @ hops # shape (N,)
190
+ return breadth.mean().item()
191
+
192
+
193
+ def total_influence(
194
+ model: torch.nn.Module,
195
+ data: Data,
196
+ max_hops: int,
197
+ num_samples: Union[int, None] = None,
198
+ normalize: bool = True,
199
+ average: bool = True,
200
+ device: Union[torch.device, str] = "cpu",
201
+ vectorize: bool = True,
202
+ ) -> Tuple[Tensor, float]:
203
+ r"""Compute Jacobian‑based influence aggregates for *multiple* seed nodes,
204
+ as introduced in the
205
+ `"Towards Quantifying Long-Range Interactions in Graph Machine Learning:
206
+ a Large Graph Dataset and a Measurement"
207
+ <https://arxiv.org/abs/2503.09008>`_ paper.
208
+ This measurement quantifies how a GNN model's output at a node is
209
+ influenced by features of other nodes at increasing hop distances.
210
+
211
+ Specifically, for every sampled node :math:`v`, this method
212
+
213
+ 1. evaluates the **L1‑norm** of the Jacobian of the model output at
214
+ :math:`v` w.r.t. the node features of its *k*-hop induced sub‑graph;
215
+ 2. sums these scores **per hop** to obtain the influence vector
216
+ :math:`(I_{0}, I_{1}, \dots, I_{k})`;
217
+ 3. optionally averages those vectors over all sampled nodes and
218
+ optionally normalises them by :math:`I_{0}`.
219
+
220
+ Please refer to Section 4 of the paper for a more detailed definition.
221
+
222
+ Args:
223
+ model (torch.nn.Module): A PyTorch Geometric‑compatible model with
224
+ forward signature ``model(x, edge_index) -> Tensor``.
225
+ data (torch_geometric.data.Data): Graph data object providing at least
226
+ :obj:`x` (node features) and :obj:`edge_index` (connectivity).
227
+ max_hops (int): Maximum hop distance :math:`k`.
228
+ num_samples (int, optional): Number of random seed nodes to evaluate.
229
+ If :obj:`None`, all nodes are used. (default: :obj:`None`)
230
+ normalize (bool, optional): If :obj:`True`, normalize each hop‑wise
231
+ influence by the influence of hop 0. (default: :obj:`True`)
232
+ average (bool, optional): If :obj:`True`, return the hop‑wise **mean**
233
+ over all seed nodes (shape ``[k+1]``).
234
+ If :obj:`False`, return the full influence matrix of shape
235
+ ``[N, k+1]``. (default: :obj:`True`)
236
+ device (torch.device or str, optional): Device on which to perform the
237
+ computation. (default: :obj:`"cpu"`)
238
+ vectorize (bool, optional): Forwarded to
239
+ :func:`torch.autograd.functional.jacobian`. Keeping this
240
+ :obj:`True` is often faster but increases memory usage.
241
+ (default: :obj:`True`)
242
+
243
+ Returns:
244
+ Tuple[Tensor, float]:
245
+ * **avg_influence** (*Tensor*):
246
+ shape ``[k+1]`` if :obj:`average=True`;
247
+ shape ``[N, k+1]`` otherwise.
248
+ * **R** (*float*): Influence‑weighted receptive‑field breadth
249
+ returned by :func:`influence_weighted_receptive_field`.
250
+
251
+ Example::
252
+ >>> avg_I, R = total_influence(model, data, max_hops=3,
253
+ ... num_samples=1000)
254
+ >>> avg_I
255
+ tensor([1.0000, 0.1273, 0.0142, 0.0019])
256
+ >>> R
257
+ 0.216
258
+ """
259
+ num_samples = data.num_nodes if num_samples is None else num_samples
260
+ num_nodes = cast(int, data.num_nodes)
261
+ nodes = torch.randperm(num_nodes)[:num_samples].tolist()
262
+
263
+ influence_all_nodes: List[Tensor] = [
264
+ jacobian_l1_agg_per_hop(model, data, max_hops, n, device,
265
+ vectorize=vectorize)
266
+ for n in tqdm(nodes, desc="Influence")
267
+ ]
268
+ allnodes = torch.vstack(influence_all_nodes).detach().cpu()
269
+
270
+ # Average total influence at each hop
271
+ if average:
272
+ avg_influence = avg_total_influence(allnodes, normalize=normalize)
273
+ else:
274
+ avg_influence = allnodes
275
+
276
+ # Influence‑weighted receptive field
277
+ R = influence_weighted_receptive_field(allnodes)
278
+
279
+ return avg_influence, R
@@ -1,4 +1,3 @@
1
- import warnings
2
1
  from typing import Optional, Tuple, Union
3
2
 
4
3
  import numpy as np
@@ -6,6 +5,10 @@ import torch
6
5
  from torch import Tensor
7
6
  from torch.utils.dlpack import from_dlpack
8
7
 
8
+ from torch_geometric.warnings import WarningCache
9
+
10
+ _warning_cache = WarningCache()
11
+
9
12
 
10
13
  def map_index(
11
14
  src: Tensor,
@@ -93,10 +96,10 @@ def map_index(
93
96
  WITH_CUDF = True
94
97
  except ImportError:
95
98
  import pandas as pd
96
- warnings.warn("Using CPU-based processing within 'map_index' "
97
- "which may cause slowdowns and device "
98
- "synchronization. Consider installing 'cudf' to "
99
- "accelerate computation")
99
+ _warning_cache.warn("Using CPU-based processing within "
100
+ "'map_index' which may cause slowdowns and "
101
+ "device synchronization. Consider installing "
102
+ "'cudf' to accelerate computation")
100
103
  else:
101
104
  import pandas as pd
102
105
 
@@ -148,10 +151,11 @@ def map_index(
148
151
  if inclusive:
149
152
  try:
150
153
  out = from_dlpack(result['right_ser'].to_dlpack())
151
- except ValueError:
152
- raise ValueError("Found invalid entries in 'src' that do not "
153
- "have a corresponding entry in 'index'. Set "
154
- "`inclusive=False` to ignore these entries.")
154
+ except ValueError as e:
155
+ raise ValueError(
156
+ "Found invalid entries in 'src' that do not "
157
+ "have a corresponding entry in 'index'. Set "
158
+ "`inclusive=False` to ignore these entries.") from e
155
159
  else:
156
160
  out = from_dlpack(result['right_ser'].fillna(-1).to_dlpack())
157
161
 
@@ -43,7 +43,7 @@ def to_nested_tensor(
43
43
  xs = [x]
44
44
 
45
45
  # This currently copies the data, although `x` is already contiguous.
46
- # Sadly, there does not exist any (public) API to preven this :(
46
+ # Sadly, there does not exist any (public) API to prevent this :(
47
47
  return torch.nested.as_nested_tensor(xs)
48
48
 
49
49
 
@@ -91,7 +91,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
91
91
  assert isinstance(mol, Chem.Mol)
92
92
 
93
93
  xs: List[List[int]] = []
94
- for atom in mol.GetAtoms(): # type: ignore
94
+ for atom in mol.GetAtoms():
95
95
  row: List[int] = []
96
96
  row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
97
97
  row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
@@ -108,7 +108,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
108
108
  x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
109
109
 
110
110
  edge_indices, edge_attrs = [], []
111
- for bond in mol.GetBonds(): # type: ignore
111
+ for bond in mol.GetBonds():
112
112
  i = bond.GetBeginAtomIdx()
113
113
  j = bond.GetEndAtomIdx()
114
114
 
@@ -148,7 +148,7 @@ def from_smiles(
148
148
  """
149
149
  from rdkit import Chem, RDLogger
150
150
 
151
- RDLogger.DisableLog('rdApp.*') # type: ignore
151
+ RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
152
152
 
153
153
  mol = Chem.MolFromSmiles(smiles)
154
154