pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -1,9 +1,11 @@
1
1
  import warnings
2
- from typing import Any, List
2
+ from typing import Any, Dict, List, Optional, Type
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
+ from torch_geometric.typing import NodeType
8
+
7
9
 
8
10
  def get_embeddings(
9
11
  model: torch.nn.Module,
@@ -40,7 +42,8 @@ def get_embeddings(
40
42
  hook_handles.append(module.register_forward_hook(hook))
41
43
 
42
44
  if len(hook_handles) == 0:
43
- warnings.warn("The 'model' does not have any 'MessagePassing' layers")
45
+ warnings.warn("The 'model' does not have any 'MessagePassing' layers",
46
+ stacklevel=2)
44
47
 
45
48
  training = model.training
46
49
  model.eval()
@@ -52,3 +55,89 @@ def get_embeddings(
52
55
  handle.remove()
53
56
 
54
57
  return embeddings
58
+
59
+
60
+ def get_embeddings_hetero(
61
+ model: torch.nn.Module,
62
+ supported_models: Optional[List[Type[torch.nn.Module]]] = None,
63
+ *args: Any,
64
+ **kwargs: Any,
65
+ ) -> Dict[NodeType, List[Tensor]]:
66
+ """Returns the output embeddings of all
67
+ :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
68
+ :obj:`model`, organized by edge type.
69
+
70
+ Internally, this method registers forward hooks on all modules that process
71
+ heterogeneous graphs in the model and runs the forward pass of the model.
72
+ For heterogeneous models, the output is a dictionary where each key is a
73
+ node type and each value is a list of embeddings from different layers.
74
+
75
+ Args:
76
+ model (torch.nn.Module): The heterogeneous GNN model.
77
+ supported_models (List[Type[torch.nn.Module]], optional): A list of
78
+ supported model classes. If not provided, defaults to
79
+ [HGTConv, HANConv, HeteroConv].
80
+ *args: Arguments passed to the model.
81
+ **kwargs (optional): Additional keyword arguments passed to the model.
82
+
83
+ Returns:
84
+ Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to
85
+ a list of embeddings from different layers.
86
+ """
87
+ from torch_geometric.nn import HANConv, HeteroConv, HGTConv
88
+ if not supported_models:
89
+ supported_models = [HGTConv, HANConv, HeteroConv]
90
+
91
+ # Dictionary to store node embeddings by type
92
+ node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}
93
+
94
+ # Hook function to capture node embeddings
95
+ def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
96
+ # Check if the outputs is a dictionary mapping node types to embeddings
97
+ if isinstance(outputs, dict) and outputs:
98
+ # Store embeddings for each node type
99
+ for node_type, embedding in outputs.items():
100
+ # Made sure that the outputs are a dictionary mapping node
101
+ # types to embeddings and remove the false positives.
102
+ if node_type not in node_embeddings_dict:
103
+ node_embeddings_dict[node_type] = []
104
+ node_embeddings_dict[node_type].append(embedding.clone())
105
+
106
+ # List to store hook handles
107
+ hook_handles = []
108
+
109
+ # Find ModuleDict objects in the model
110
+ for _, module in model.named_modules():
111
+ # Handle the native heterogenous models, e.g. HGTConv, HANConv
112
+ # and HeteroConv, etc.
113
+ if isinstance(module, tuple(supported_models)):
114
+ hook_handles.append(module.register_forward_hook(hook))
115
+ else:
116
+ # Handle the heterogenous models that are generated by calling
117
+ # to_hetero() on the homogeneous models.
118
+ submodules = list(module.children())
119
+ submodules_contains_module_dict = any([
120
+ isinstance(submodule, torch.nn.ModuleDict)
121
+ for submodule in submodules
122
+ ])
123
+ if submodules_contains_module_dict:
124
+ hook_handles.append(module.register_forward_hook(hook))
125
+
126
+ if len(hook_handles) == 0:
127
+ warnings.warn(
128
+ "The 'model' does not have any heterogenous "
129
+ "'MessagePassing' layers", stacklevel=2)
130
+
131
+ # Run the model forward pass
132
+ training = model.training
133
+ model.eval()
134
+
135
+ with torch.no_grad():
136
+ model(*args, **kwargs)
137
+ model.train(training)
138
+
139
+ # Clean up hooks
140
+ for handle in hook_handles:
141
+ handle.remove()
142
+
143
+ return node_embeddings_dict
@@ -2,6 +2,7 @@ import multiprocessing as mp
2
2
  import warnings
3
3
  from typing import Optional
4
4
 
5
+ import numpy as np
5
6
  import torch
6
7
  from torch import Tensor
7
8
 
@@ -65,9 +66,10 @@ def geodesic_distance( # noqa: D417
65
66
 
66
67
  if 'dest' in kwargs:
67
68
  dst = kwargs['dest']
68
- warnings.warn("'dest' attribute in 'geodesic_distance' is deprecated "
69
- "and will be removed in a future release. Use the 'dst' "
70
- "argument instead.")
69
+ warnings.warn(
70
+ "'dest' attribute in 'geodesic_distance' is deprecated "
71
+ "and will be removed in a future release. Use the 'dst' "
72
+ "argument instead.", stacklevel=2)
71
73
 
72
74
  max_distance = float('inf') if max_distance is None else max_distance
73
75
 
@@ -82,54 +84,55 @@ def geodesic_distance( # noqa: D417
82
84
 
83
85
  dtype = pos.dtype
84
86
 
85
- pos = pos.detach().cpu().to(torch.double).numpy()
86
- face = face.detach().t().cpu().to(torch.int).numpy()
87
+ pos_np = pos.detach().cpu().to(torch.double).numpy()
88
+ face_np = face.detach().t().cpu().to(torch.int).numpy()
87
89
 
88
90
  if src is None and dst is None:
89
- out = gdist.local_gdist_matrix(pos, face,
90
- max_distance * scale).toarray() / scale
91
+ out = gdist.local_gdist_matrix(
92
+ pos_np,
93
+ face_np,
94
+ max_distance * scale,
95
+ ).toarray() / scale
91
96
  return torch.from_numpy(out).to(dtype)
92
97
 
93
98
  if src is None:
94
- src = torch.arange(pos.shape[0], dtype=torch.int).numpy()
99
+ src_np = torch.arange(pos.size(0), dtype=torch.int).numpy()
95
100
  else:
96
- src = src.detach().cpu().to(torch.int).numpy()
97
- assert src is not None
101
+ src_np = src.detach().cpu().to(torch.int).numpy()
98
102
 
99
- dst = None if dst is None else dst.detach().cpu().to(torch.int).numpy()
103
+ dst_np = None if dst is None else dst.detach().cpu().to(torch.int).numpy()
100
104
 
101
105
  def _parallel_loop(
102
- pos: Tensor,
103
- face: Tensor,
104
- src: Tensor,
105
- dst: Optional[Tensor],
106
+ pos_np: np.ndarray,
107
+ face_np: np.ndarray,
108
+ src_np: np.ndarray,
109
+ dst_np: Optional[np.ndarray],
106
110
  max_distance: float,
107
111
  scale: float,
108
112
  i: int,
109
113
  dtype: torch.dtype,
110
114
  ) -> Tensor:
111
- s = src[i:i + 1]
112
- d = None if dst is None else dst[i:i + 1]
113
- out = gdist.compute_gdist(pos, face, s, d, max_distance * scale)
115
+ s = src_np[i:i + 1]
116
+ d = None if dst_np is None else dst_np[i:i + 1]
117
+ out = gdist.compute_gdist(pos_np, face_np, s, d, max_distance * scale)
114
118
  out = out / scale
115
119
  return torch.from_numpy(out).to(dtype)
116
120
 
117
121
  num_workers = mp.cpu_count() if num_workers <= -1 else num_workers
118
122
  if num_workers > 0:
119
123
  with mp.Pool(num_workers) as pool:
120
- outs = pool.starmap(
121
- _parallel_loop,
122
- [(pos, face, src, dst, max_distance, scale, i, dtype)
123
- for i in range(len(src))])
124
+ data = [(pos_np, face_np, src_np, dst_np, max_distance, scale, i,
125
+ dtype) for i in range(len(src_np))]
126
+ outs = pool.starmap(_parallel_loop, data)
124
127
  else:
125
128
  outs = [
126
- _parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype)
127
- for i in range(len(src))
129
+ _parallel_loop(pos_np, face_np, src_np, dst_np, max_distance,
130
+ scale, i, dtype) for i in range(len(src_np))
128
131
  ]
129
132
 
130
133
  out = torch.cat(outs, dim=0)
131
134
 
132
135
  if dst is None:
133
- out = out.view(-1, pos.shape[0])
136
+ out = out.view(-1, pos.size(0))
134
137
 
135
138
  return out
@@ -0,0 +1,279 @@
1
+ from typing import List, Tuple, Union, cast
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.autograd.functional import jacobian
6
+ from tqdm.auto import tqdm
7
+
8
+ from torch_geometric.data import Data
9
+ from torch_geometric.utils import k_hop_subgraph
10
+
11
+
12
+ def k_hop_subsets_rough(
13
+ node_idx: int,
14
+ num_hops: int,
15
+ edge_index: Tensor,
16
+ num_nodes: int,
17
+ ) -> List[Tensor]:
18
+ r"""Return *rough* (possibly overlapping) *k*-hop node subsets.
19
+
20
+ This is a thin wrapper around
21
+ :pyfunc:`torch_geometric.utils.k_hop_subgraph` that *additionally* returns
22
+ **all** intermediate hop subsets rather than the full union only.
23
+
24
+ Parameters
25
+ ----------
26
+ node_idx: int
27
+ Index or indices of the central node(s).
28
+ num_hops: int
29
+ Number of hops *k*.
30
+ edge_index: Tensor
31
+ Edge index in COO format with shape :math:`[2, \text{num_edges}]`.
32
+ num_nodes: int
33
+ Total number of nodes in the graph. Required to allocate the masks.
34
+
35
+ Returns:
36
+ -------
37
+ List[Tensor]
38
+ A list ``[H₀, H₁, …, H_k]`` where ``H₀`` contains the seed node(s) and
39
+ ``H_i`` (for *i*>0) contains **all** nodes that are exactly *i* hops
40
+ away in the *expanded* neighbourhood (i.e. overlaps are *not*
41
+ removed).
42
+ """
43
+ col, row = edge_index
44
+
45
+ node_mask = row.new_empty(num_nodes, dtype=torch.bool)
46
+ edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
47
+
48
+ node_idx_ = torch.tensor([node_idx], device=row.device)
49
+
50
+ subsets = [node_idx_]
51
+ for _ in range(num_hops):
52
+ node_mask.zero_()
53
+ node_mask[subsets[-1]] = True
54
+ torch.index_select(node_mask, 0, row, out=edge_mask)
55
+ subsets.append(col[edge_mask])
56
+
57
+ return subsets
58
+
59
+
60
+ def k_hop_subsets_exact(
61
+ node_idx: int,
62
+ num_hops: int,
63
+ edge_index: Tensor,
64
+ num_nodes: int,
65
+ device: Union[torch.device, str],
66
+ ) -> List[Tensor]:
67
+ """Return **disjoint** *k*-hop subsets.
68
+
69
+ This function refines :pyfunc:`k_hop_subsets_rough` by removing nodes that
70
+ have already appeared in previous hops, ensuring that each subset contains
71
+ nodes *exactly* *i* hops away from the seed.
72
+ """
73
+ rough_subsets = k_hop_subsets_rough(node_idx, num_hops, edge_index,
74
+ num_nodes)
75
+
76
+ exact_subsets: List[List[int]] = [rough_subsets[0].tolist()]
77
+ visited: set[int] = set(exact_subsets[0])
78
+
79
+ for hop_subset in rough_subsets[1:]:
80
+ fresh = set(hop_subset.tolist()) - visited
81
+ visited |= fresh
82
+ exact_subsets.append(list(fresh))
83
+
84
+ return [
85
+ torch.tensor(s, device=device, dtype=edge_index.dtype)
86
+ for s in exact_subsets
87
+ ]
88
+
89
+
90
+ def jacobian_l1(
91
+ model: torch.nn.Module,
92
+ data: Data,
93
+ max_hops: int,
94
+ node_idx: int,
95
+ device: Union[torch.device, str],
96
+ *,
97
+ vectorize: bool = True,
98
+ ) -> Tensor:
99
+ """Compute the **L1 norm** of the Jacobian for a given node.
100
+
101
+ The Jacobian is evaluated w.r.t. the node features of the *k*-hop induced
102
+ sub‑graph centred at ``node_idx``. The result is *folded back* onto the
103
+ **original** node index space so that the returned tensor has length
104
+ ``data.num_nodes``, where the influence score will be zero for nodes
105
+ outside the *k*-hop subgraph.
106
+
107
+ Notes:
108
+ -----
109
+ * The function assumes that the model *and* ``data.x`` share the same
110
+ floating‑point precision (e.g. both ``float32`` or both ``float16``).
111
+
112
+ """
113
+ # Build the induced *k*-hop sub‑graph (with node re‑labelling).
114
+ edge_index = cast(Tensor, data.edge_index)
115
+ x = cast(Tensor, data.x)
116
+ k_hop_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
117
+ node_idx, max_hops, edge_index, relabel_nodes=True)
118
+ # get the location of the *center* node inside the sub‑graph
119
+ root_pos = cast(int, mapping[0])
120
+
121
+ # Move tensors & model to the correct device
122
+ device = torch.device(device)
123
+ sub_x = x[k_hop_nodes].to(device)
124
+ sub_edge_index = sub_edge_index.to(device)
125
+ model = model.to(device)
126
+
127
+ # Jacobian evaluation
128
+ def _forward(x: Tensor) -> Tensor:
129
+ return model(x, sub_edge_index)[root_pos]
130
+
131
+ jac = jacobian(_forward, sub_x, vectorize=vectorize)
132
+ influence_sub = jac.abs().sum(dim=(0, 2)) # Sum of L1 norm
133
+ num_nodes = cast(int, data.num_nodes)
134
+ # Scatter the influence scores back to the *global* node space
135
+ influence_full = torch.zeros(num_nodes, dtype=influence_sub.dtype,
136
+ device=device)
137
+ influence_full[k_hop_nodes] = influence_sub
138
+
139
+ return influence_full
140
+
141
+
142
+ def jacobian_l1_agg_per_hop(
143
+ model: torch.nn.Module,
144
+ data: Data,
145
+ max_hops: int,
146
+ node_idx: int,
147
+ device: Union[torch.device, str],
148
+ vectorize: bool = True,
149
+ ) -> Tensor:
150
+ """Aggregate Jacobian L1 norms **per hop** for node_idx.
151
+
152
+ Returns a vector ``[I_0, I_1, …, I_k]`` where ``I_i`` is the *total*
153
+ influence exerted by nodes that are exactly *i* hops away from
154
+ ``node_idx``.
155
+ """
156
+ num_nodes = cast(int, data.num_nodes)
157
+ edge_index = cast(Tensor, data.edge_index)
158
+ influence = jacobian_l1(model, data, max_hops, node_idx, device,
159
+ vectorize=vectorize)
160
+ hop_subsets = k_hop_subsets_exact(node_idx, max_hops, edge_index,
161
+ num_nodes, influence.device)
162
+ single_node_influence_per_hop = [influence[s].sum() for s in hop_subsets]
163
+ return torch.tensor(single_node_influence_per_hop, device=influence.device)
164
+
165
+
166
+ def avg_total_influence(
167
+ influence_all_nodes: Tensor,
168
+ normalize: bool = True,
169
+ ) -> Tensor:
170
+ """Compute the *influence‑weighted receptive field* ``R``."""
171
+ avg_total_influences = torch.mean(influence_all_nodes, dim=0)
172
+ if normalize: # normalize by hop_0 (jacobian of the center node feature)
173
+ avg_total_influences = avg_total_influences / avg_total_influences[0]
174
+ return avg_total_influences
175
+
176
+
177
+ def influence_weighted_receptive_field(T: Tensor) -> float:
178
+ """Compute the *influence‑weighted receptive field* ``R``.
179
+
180
+ Given an influence matrix ``T`` of shape ``[N, k+1]`` (i‑th row contains
181
+ the per‑hop influences of node *i*), the receptive field breadth *R* is
182
+ defined as the expected hop distance when weighting by influence.
183
+
184
+ A larger *R* indicates that, on average, influence comes from **farther**
185
+ hops.
186
+ """
187
+ normalised = T / torch.sum(T, dim=1, keepdim=True)
188
+ hops = torch.arange(T.shape[1]).float() # 0 … k
189
+ breadth = normalised @ hops # shape (N,)
190
+ return breadth.mean().item()
191
+
192
+
193
+ def total_influence(
194
+ model: torch.nn.Module,
195
+ data: Data,
196
+ max_hops: int,
197
+ num_samples: Union[int, None] = None,
198
+ normalize: bool = True,
199
+ average: bool = True,
200
+ device: Union[torch.device, str] = "cpu",
201
+ vectorize: bool = True,
202
+ ) -> Tuple[Tensor, float]:
203
+ r"""Compute Jacobian‑based influence aggregates for *multiple* seed nodes,
204
+ as introduced in the
205
+ `"Towards Quantifying Long-Range Interactions in Graph Machine Learning:
206
+ a Large Graph Dataset and a Measurement"
207
+ <https://arxiv.org/abs/2503.09008>`_ paper.
208
+ This measurement quantifies how a GNN model's output at a node is
209
+ influenced by features of other nodes at increasing hop distances.
210
+
211
+ Specifically, for every sampled node :math:`v`, this method
212
+
213
+ 1. evaluates the **L1‑norm** of the Jacobian of the model output at
214
+ :math:`v` w.r.t. the node features of its *k*-hop induced sub‑graph;
215
+ 2. sums these scores **per hop** to obtain the influence vector
216
+ :math:`(I_{0}, I_{1}, \dots, I_{k})`;
217
+ 3. optionally averages those vectors over all sampled nodes and
218
+ optionally normalises them by :math:`I_{0}`.
219
+
220
+ Please refer to Section 4 of the paper for a more detailed definition.
221
+
222
+ Args:
223
+ model (torch.nn.Module): A PyTorch Geometric‑compatible model with
224
+ forward signature ``model(x, edge_index) -> Tensor``.
225
+ data (torch_geometric.data.Data): Graph data object providing at least
226
+ :obj:`x` (node features) and :obj:`edge_index` (connectivity).
227
+ max_hops (int): Maximum hop distance :math:`k`.
228
+ num_samples (int, optional): Number of random seed nodes to evaluate.
229
+ If :obj:`None`, all nodes are used. (default: :obj:`None`)
230
+ normalize (bool, optional): If :obj:`True`, normalize each hop‑wise
231
+ influence by the influence of hop 0. (default: :obj:`True`)
232
+ average (bool, optional): If :obj:`True`, return the hop‑wise **mean**
233
+ over all seed nodes (shape ``[k+1]``).
234
+ If :obj:`False`, return the full influence matrix of shape
235
+ ``[N, k+1]``. (default: :obj:`True`)
236
+ device (torch.device or str, optional): Device on which to perform the
237
+ computation. (default: :obj:`"cpu"`)
238
+ vectorize (bool, optional): Forwarded to
239
+ :func:`torch.autograd.functional.jacobian`. Keeping this
240
+ :obj:`True` is often faster but increases memory usage.
241
+ (default: :obj:`True`)
242
+
243
+ Returns:
244
+ Tuple[Tensor, float]:
245
+ * **avg_influence** (*Tensor*):
246
+ shape ``[k+1]`` if :obj:`average=True`;
247
+ shape ``[N, k+1]`` otherwise.
248
+ * **R** (*float*): Influence‑weighted receptive‑field breadth
249
+ returned by :func:`influence_weighted_receptive_field`.
250
+
251
+ Example::
252
+ >>> avg_I, R = total_influence(model, data, max_hops=3,
253
+ ... num_samples=1000)
254
+ >>> avg_I
255
+ tensor([1.0000, 0.1273, 0.0142, 0.0019])
256
+ >>> R
257
+ 0.216
258
+ """
259
+ num_samples = data.num_nodes if num_samples is None else num_samples
260
+ num_nodes = cast(int, data.num_nodes)
261
+ nodes = torch.randperm(num_nodes)[:num_samples].tolist()
262
+
263
+ influence_all_nodes: List[Tensor] = [
264
+ jacobian_l1_agg_per_hop(model, data, max_hops, n, device,
265
+ vectorize=vectorize)
266
+ for n in tqdm(nodes, desc="Influence")
267
+ ]
268
+ allnodes = torch.vstack(influence_all_nodes).detach().cpu()
269
+
270
+ # Average total influence at each hop
271
+ if average:
272
+ avg_influence = avg_total_influence(allnodes, normalize=normalize)
273
+ else:
274
+ avg_influence = allnodes
275
+
276
+ # Influence‑weighted receptive field
277
+ R = influence_weighted_receptive_field(allnodes)
278
+
279
+ return avg_influence, R
@@ -1,4 +1,3 @@
1
- import warnings
2
1
  from typing import Optional, Tuple, Union
3
2
 
4
3
  import numpy as np
@@ -6,6 +5,10 @@ import torch
6
5
  from torch import Tensor
7
6
  from torch.utils.dlpack import from_dlpack
8
7
 
8
+ from torch_geometric.warnings import WarningCache
9
+
10
+ _warning_cache = WarningCache()
11
+
9
12
 
10
13
  def map_index(
11
14
  src: Tensor,
@@ -14,7 +17,7 @@ def map_index(
14
17
  inclusive: bool = False,
15
18
  ) -> Tuple[Tensor, Optional[Tensor]]:
16
19
  r"""Maps indices in :obj:`src` to the positional value of their
17
- corresponding occurence in :obj:`index`.
20
+ corresponding occurrence in :obj:`index`.
18
21
  Indices must be strictly positive.
19
22
 
20
23
  Args:
@@ -93,10 +96,10 @@ def map_index(
93
96
  WITH_CUDF = True
94
97
  except ImportError:
95
98
  import pandas as pd
96
- warnings.warn("Using CPU-based processing within 'map_index' "
97
- "which may cause slowdowns and device "
98
- "synchronization. Consider installing 'cudf' to "
99
- "accelerate computation")
99
+ _warning_cache.warn("Using CPU-based processing within "
100
+ "'map_index' which may cause slowdowns and "
101
+ "device synchronization. Consider installing "
102
+ "'cudf' to accelerate computation")
100
103
  else:
101
104
  import pandas as pd
102
105
 
@@ -148,10 +151,11 @@ def map_index(
148
151
  if inclusive:
149
152
  try:
150
153
  out = from_dlpack(result['right_ser'].to_dlpack())
151
- except ValueError:
152
- raise ValueError("Found invalid entries in 'src' that do not "
153
- "have a corresponding entry in 'index'. Set "
154
- "`inclusive=False` to ignore these entries.")
154
+ except ValueError as e:
155
+ raise ValueError(
156
+ "Found invalid entries in 'src' that do not "
157
+ "have a corresponding entry in 'index'. Set "
158
+ "`inclusive=False` to ignore these entries.") from e
155
159
  else:
156
160
  out = from_dlpack(result['right_ser'].fillna(-1).to_dlpack())
157
161
 
@@ -43,7 +43,7 @@ def to_nested_tensor(
43
43
  xs = [x]
44
44
 
45
45
  # This currently copies the data, although `x` is already contiguous.
46
- # Sadly, there does not exist any (public) API to preven this :(
46
+ # Sadly, there does not exist any (public) API to prevent this :(
47
47
  return torch.nested.as_nested_tensor(xs)
48
48
 
49
49
 
@@ -91,7 +91,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
91
91
  assert isinstance(mol, Chem.Mol)
92
92
 
93
93
  xs: List[List[int]] = []
94
- for atom in mol.GetAtoms(): # type: ignore
94
+ for atom in mol.GetAtoms():
95
95
  row: List[int] = []
96
96
  row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
97
97
  row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
@@ -108,7 +108,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
108
108
  x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
109
109
 
110
110
  edge_indices, edge_attrs = [], []
111
- for bond in mol.GetBonds(): # type: ignore
111
+ for bond in mol.GetBonds():
112
112
  i = bond.GetBeginAtomIdx()
113
113
  j = bond.GetEndAtomIdx()
114
114
 
@@ -148,7 +148,7 @@ def from_smiles(
148
148
  """
149
149
  from rdkit import Chem, RDLogger
150
150
 
151
- RDLogger.DisableLog('rdApp.*') # type: ignore
151
+ RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
152
152
 
153
153
  mol = Chem.MolFromSmiles(smiles)
154
154