pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -15,6 +15,7 @@ from torch_geometric.distributed.rpc import (
15
15
  rpc_async,
16
16
  rpc_register,
17
17
  )
18
+ from torch_geometric.io import fs
18
19
  from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType
19
20
 
20
21
 
@@ -415,11 +416,11 @@ class LocalFeatureStore(FeatureStore):
415
416
 
416
417
  node_feats: Optional[Dict[str, Any]] = None
417
418
  if osp.exists(osp.join(part_dir, 'node_feats.pt')):
418
- node_feats = torch.load(osp.join(part_dir, 'node_feats.pt'))
419
+ node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt'))
419
420
 
420
421
  edge_feats: Optional[Dict[str, Any]] = None
421
422
  if osp.exists(osp.join(part_dir, 'edge_feats.pt')):
422
- edge_feats = torch.load(osp.join(part_dir, 'edge_feats.pt'))
423
+ edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt'))
423
424
 
424
425
  if not meta['is_hetero'] and node_feats is not None:
425
426
  feat_store.put_global_id(node_feats['global_id'], group_name=None)
@@ -6,6 +6,7 @@ from torch import Tensor
6
6
 
7
7
  from torch_geometric.data import EdgeAttr, GraphStore
8
8
  from torch_geometric.distributed.partition import load_partition_info
9
+ from torch_geometric.io import fs
9
10
  from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType
10
11
  from torch_geometric.utils import sort_edge_index
11
12
 
@@ -185,7 +186,7 @@ class LocalGraphStore(GraphStore):
185
186
  graph_store.edge_pb = edge_pb
186
187
  graph_store.meta = meta
187
188
 
188
- graph_data = torch.load(osp.join(part_dir, 'graph.pt'))
189
+ graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt'))
189
190
  graph_store.is_sorted = meta['is_sorted']
190
191
 
191
192
  if not meta['is_hetero']:
@@ -3,15 +3,16 @@ import logging
3
3
  import os
4
4
  import os.path as osp
5
5
  from collections import defaultdict
6
- from typing import List, Optional, Union
6
+ from typing import Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
 
10
10
  import torch_geometric.distributed as pyg_dist
11
11
  from torch_geometric.data import Data, HeteroData
12
+ from torch_geometric.io import fs
12
13
  from torch_geometric.loader.cluster import ClusterData
13
14
  from torch_geometric.sampler.utils import sort_csc
14
- from torch_geometric.typing import Dict, EdgeType, EdgeTypeStr, NodeType, Tuple
15
+ from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType
15
16
 
16
17
 
17
18
  class Partitioner:
@@ -303,7 +304,7 @@ class Partitioner:
303
304
  elif self.is_node_level_time:
304
305
  node_time = data.time
305
306
 
306
- # Sort by column to avoid keeping track of permuations in
307
+ # Sort by column to avoid keeping track of permutations in
307
308
  # `NeighborSampler` when converting to CSC format:
308
309
  global_row, global_col, perm = sort_csc(
309
310
  global_row, global_col, node_time, edge_time)
@@ -360,7 +361,7 @@ class Partitioner:
360
361
  'edge_types': self.edge_types,
361
362
  'node_offset': list(node_offset.values()) if node_offset else None,
362
363
  'is_hetero': self.is_hetero,
363
- 'is_sorted': True, # Based on colum/destination.
364
+ 'is_sorted': True, # Based on column/destination.
364
365
  }
365
366
  with open(osp.join(self.root, 'META.json'), 'w') as f:
366
367
  json.dump(meta, f)
@@ -380,21 +381,21 @@ def load_partition_info(
380
381
  assert osp.exists(partition_dir)
381
382
 
382
383
  if meta['is_hetero'] is False:
383
- node_pb = torch.load(osp.join(root_dir, 'node_map.pt'))
384
- edge_pb = torch.load(osp.join(root_dir, 'edge_map.pt'))
384
+ node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt'))
385
+ edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt'))
385
386
 
386
387
  return (meta, num_partitions, partition_idx, node_pb, edge_pb)
387
388
  else:
388
389
  node_pb_dict = {}
389
390
  node_pb_dir = osp.join(root_dir, 'node_map')
390
391
  for ntype in meta['node_types']:
391
- node_pb_dict[ntype] = torch.load(
392
+ node_pb_dict[ntype] = fs.torch_load(
392
393
  osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt'))
393
394
 
394
395
  edge_pb_dict = {}
395
396
  edge_pb_dir = osp.join(root_dir, 'edge_map')
396
397
  for etype in meta['edge_types']:
397
- edge_pb_dict[tuple(etype)] = torch.load(
398
+ edge_pb_dict[tuple(etype)] = fs.torch_load(
398
399
  osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt'))
399
400
 
400
401
  return (meta, num_partitions, partition_idx, node_pb_dict,
@@ -92,7 +92,7 @@ def shutdown_rpc(id: str = None, graceful: bool = True,
92
92
  class RPCRouter:
93
93
  r"""A router to get the worker based on the partition ID."""
94
94
  def __init__(self, partition_to_workers: List[List[str]]):
95
- for pid, rpc_worker_list in enumerate(partition_to_workers):
95
+ for rpc_worker_list in partition_to_workers:
96
96
  if len(rpc_worker_list) == 0:
97
97
  raise ValueError('No RPC worker is in worker list')
98
98
  self.partition_to_workers = partition_to_workers
@@ -120,7 +120,7 @@ def rpc_partition_to_workers(
120
120
  partition_to_workers = [[] for _ in range(num_partitions)]
121
121
  gathered_results = global_all_gather(
122
122
  (ctx.role, num_partitions, current_partition_idx))
123
- for worker_name, (role, nparts, idx) in gathered_results.items():
123
+ for worker_name, (_, _, idx) in gathered_results.items():
124
124
  partition_to_workers[idx].append(worker_name)
125
125
  return partition_to_workers
126
126
 
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
144
144
  @rpc_require_initialized
145
145
  def rpc_register(call: RPCCallBase) -> int:
146
146
  r"""Registers a call for RPC requests."""
147
- global _rpc_call_id, _rpc_call_pool
147
+ global _rpc_call_id
148
148
 
149
149
  with _rpc_call_lock:
150
150
  call_id = _rpc_call_id
@@ -17,6 +17,7 @@ from typing import (
17
17
  overload,
18
18
  )
19
19
 
20
+ import numpy as np
20
21
  import torch
21
22
  import torch.utils._pytree as pytree
22
23
  from torch import Tensor
@@ -173,7 +174,7 @@ class EdgeIndex(Tensor):
173
174
  :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
174
175
  lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
175
176
 
176
- This representation ensures for optimal computation in GNN message passing
177
+ This representation ensures optimal computation in GNN message passing
177
178
  schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
178
179
  workflows.
179
180
 
@@ -183,7 +184,7 @@ class EdgeIndex(Tensor):
183
184
 
184
185
  edge_index = EdgeIndex(
185
186
  [[0, 1, 1, 2],
186
- [1, 0, 2, 1]]
187
+ [1, 0, 2, 1]],
187
188
  sparse_size=(3, 3),
188
189
  sort_order='row',
189
190
  is_undirected=True,
@@ -210,7 +211,7 @@ class EdgeIndex(Tensor):
210
211
  assert not edge_index.is_undirected
211
212
 
212
213
  # Sparse-Dense Matrix Multiplication:
213
- out = edge_index.flip(0) @ torch.randn(3, 16)
214
+ out = edge_index.flip(0) @ torch.randn(3, 16)
214
215
  assert out.size() == (3, 16)
215
216
  """
216
217
  # See "https://pytorch.org/docs/stable/notes/extending.html"
@@ -297,8 +298,7 @@ class EdgeIndex(Tensor):
297
298
  indptr = None
298
299
  data = torch.stack([row, col], dim=0)
299
300
 
300
- if (torch_geometric.typing.WITH_PT112
301
- and data.layout == torch.sparse_csc):
301
+ if data.layout == torch.sparse_csc:
302
302
  row = data.row_indices()
303
303
  indptr = data.ccol_indices()
304
304
 
@@ -325,7 +325,7 @@ class EdgeIndex(Tensor):
325
325
  elif sparse_size[0] is None and sparse_size[1] is not None:
326
326
  sparse_size = (sparse_size[1], sparse_size[1])
327
327
 
328
- out = Tensor._make_wrapper_subclass( # type: ignore
328
+ out = Tensor._make_wrapper_subclass(
329
329
  cls,
330
330
  size=data.size(),
331
331
  strides=data.stride(),
@@ -803,7 +803,7 @@ class EdgeIndex(Tensor):
803
803
 
804
804
  size = self.get_sparse_size()
805
805
  if value is not None and value.dim() > 1:
806
- size = size + value.size()[1:] # type: ignore
806
+ size = size + value.size()[1:]
807
807
 
808
808
  out = torch.full(size, fill_value, dtype=dtype, device=self.device)
809
809
  out[self._data[0], self._data[1]] = value if value is not None else 1
@@ -820,19 +820,28 @@ class EdgeIndex(Tensor):
820
820
  :obj:`1.0`. (default: :obj:`None`)
821
821
  """
822
822
  value = self._get_value() if value is None else value
823
- out = torch.sparse_coo_tensor(
823
+
824
+ if not torch_geometric.typing.WITH_PT21:
825
+ out = torch.sparse_coo_tensor(
826
+ indices=self._data,
827
+ values=value,
828
+ size=self.get_sparse_size(),
829
+ device=self.device,
830
+ requires_grad=value.requires_grad,
831
+ )
832
+ if self.is_sorted_by_row:
833
+ out = out._coalesced_(True)
834
+ return out
835
+
836
+ return torch.sparse_coo_tensor(
824
837
  indices=self._data,
825
838
  values=value,
826
839
  size=self.get_sparse_size(),
827
840
  device=self.device,
828
841
  requires_grad=value.requires_grad,
842
+ is_coalesced=True if self.is_sorted_by_row else None,
829
843
  )
830
844
 
831
- if self.is_sorted_by_row:
832
- out = out._coalesced_(True)
833
-
834
- return out
835
-
836
845
  def to_sparse_csr( # type: ignore
837
846
  self,
838
847
  value: Optional[Tensor] = None,
@@ -872,10 +881,6 @@ class EdgeIndex(Tensor):
872
881
  If not specified, non-zero elements will be assigned a value of
873
882
  :obj:`1.0`. (default: :obj:`None`)
874
883
  """
875
- if not torch_geometric.typing.WITH_PT112:
876
- raise NotImplementedError(
877
- "'to_sparse_csc' not supported for PyTorch < 1.12")
878
-
879
884
  (colptr, row), perm = self.get_csc()
880
885
  if value is not None and perm is not None:
881
886
  value = value[perm]
@@ -912,7 +917,7 @@ class EdgeIndex(Tensor):
912
917
  return self.to_sparse_coo(value)
913
918
  if layout == torch.sparse_csr:
914
919
  return self.to_sparse_csr(value)
915
- if torch_geometric.typing.WITH_PT112 and layout == torch.sparse_csc:
920
+ if layout == torch.sparse_csc:
916
921
  return self.to_sparse_csc(value)
917
922
 
918
923
  raise ValueError(f"Unexpected tensor layout (got '{layout}')")
@@ -1181,10 +1186,10 @@ class EdgeIndex(Tensor):
1181
1186
  return edge_index
1182
1187
 
1183
1188
  # Prevent auto-wrapping outputs back into the proper subclass type:
1184
- __torch_function__ = torch._C._disabled_torch_function_impl
1189
+ __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore
1185
1190
 
1186
1191
  @classmethod
1187
- def __torch_dispatch__(
1192
+ def __torch_dispatch__( # type: ignore
1188
1193
  cls: Type,
1189
1194
  func: Callable[..., Any],
1190
1195
  types: Iterable[Type[Any]],
@@ -1237,6 +1242,14 @@ class EdgeIndex(Tensor):
1237
1242
  return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
1238
1243
  indent, force_newline=False)
1239
1244
 
1245
+ def tolist(self) -> List[Any]:
1246
+ """""" # noqa: D419
1247
+ return self._data.tolist()
1248
+
1249
+ def numpy(self, *, force: bool = False) -> np.ndarray:
1250
+ """""" # noqa: D419
1251
+ return self._data.numpy(force=force)
1252
+
1240
1253
  # Helpers #################################################################
1241
1254
 
1242
1255
  def _shallow_copy(self) -> 'EdgeIndex':
@@ -1469,7 +1482,7 @@ def _slice(
1469
1482
  step: int = 1,
1470
1483
  ) -> Union[EdgeIndex, Tensor]:
1471
1484
 
1472
- if ((start is None or start <= 0)
1485
+ if ((start is None or start == 0 or start <= -input.size(dim))
1473
1486
  and (end is None or end > input.size(dim)) and step == 1):
1474
1487
  return input._shallow_copy() # No-op.
1475
1488
 
@@ -1928,7 +1941,7 @@ def _spmm(
1928
1941
  if transpose and not input.is_sorted_by_col:
1929
1942
  cls_name = input.__class__.__name__
1930
1943
  raise ValueError(f"'matmul(..., transpose=True)' requires "
1931
- f"'{cls_name}' to be sorted by colums")
1944
+ f"'{cls_name}' to be sorted by columns")
1932
1945
 
1933
1946
  if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
1934
1947
  and other.is_cuda): # pragma: no cover
@@ -1,13 +1,14 @@
1
1
  import logging
2
- from typing import List, Optional, Union
2
+ from typing import Dict, List, Optional, Union, overload
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from torch_geometric.explain import Explanation
7
+ from torch_geometric.explain import Explanation, HeteroExplanation
8
8
  from torch_geometric.explain.algorithm import ExplainerAlgorithm
9
9
  from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
10
10
  from torch_geometric.nn.conv.message_passing import MessagePassing
11
+ from torch_geometric.typing import EdgeType, NodeType
11
12
 
12
13
 
13
14
  class AttentionExplainer(ExplainerAlgorithm):
@@ -26,7 +27,9 @@ class AttentionExplainer(ExplainerAlgorithm):
26
27
  def __init__(self, reduce: str = 'max'):
27
28
  super().__init__()
28
29
  self.reduce = reduce
30
+ self.is_hetero = False
29
31
 
32
+ @overload
30
33
  def forward(
31
34
  self,
32
35
  model: torch.nn.Module,
@@ -37,65 +40,252 @@ class AttentionExplainer(ExplainerAlgorithm):
37
40
  index: Optional[Union[int, Tensor]] = None,
38
41
  **kwargs,
39
42
  ) -> Explanation:
40
- if isinstance(x, dict) or isinstance(edge_index, dict):
41
- raise ValueError(f"Heterogeneous graphs not yet supported in "
42
- f"'{self.__class__.__name__}'")
43
+ ...
43
44
 
44
- hard_edge_mask = None
45
- if self.model_config.task_level == ModelTaskLevel.node:
46
- # We need to compute the hard edge mask to properly clean up edge
47
- # attributions not involved during message passing:
48
- _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
49
- num_nodes=x.size(0))
45
+ @overload
46
+ def forward(
47
+ self,
48
+ model: torch.nn.Module,
49
+ x: Dict[NodeType, Tensor],
50
+ edge_index: Dict[EdgeType, Tensor],
51
+ *,
52
+ target: Tensor,
53
+ index: Optional[Union[int, Tensor]] = None,
54
+ **kwargs,
55
+ ) -> HeteroExplanation:
56
+ ...
57
+
58
+ def forward(
59
+ self,
60
+ model: torch.nn.Module,
61
+ x: Union[Tensor, Dict[NodeType, Tensor]],
62
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
63
+ *,
64
+ target: Tensor,
65
+ index: Optional[Union[int, Tensor]] = None,
66
+ **kwargs,
67
+ ) -> Union[Explanation, HeteroExplanation]:
68
+ """Generate explanations based on attention coefficients."""
69
+ self.is_hetero = isinstance(x, dict)
70
+
71
+ # Collect attention coefficients
72
+ alphas_dict = self._collect_attention_coefficients(
73
+ model, x, edge_index, **kwargs)
74
+
75
+ # Process attention coefficients
76
+ if self.is_hetero:
77
+ return self._create_hetero_explanation(model, alphas_dict,
78
+ edge_index, index, x)
79
+ else:
80
+ return self._create_homo_explanation(model, alphas_dict,
81
+ edge_index, index, x)
82
+
83
+ @overload
84
+ def _collect_attention_coefficients(
85
+ self,
86
+ model: torch.nn.Module,
87
+ x: Tensor,
88
+ edge_index: Tensor,
89
+ **kwargs,
90
+ ) -> List[Tensor]:
91
+ ...
92
+
93
+ @overload
94
+ def _collect_attention_coefficients(
95
+ self,
96
+ model: torch.nn.Module,
97
+ x: Dict[NodeType, Tensor],
98
+ edge_index: Dict[EdgeType, Tensor],
99
+ **kwargs,
100
+ ) -> Dict[EdgeType, List[Tensor]]:
101
+ ...
102
+
103
+ def _collect_attention_coefficients(
104
+ self,
105
+ model: torch.nn.Module,
106
+ x: Union[Tensor, Dict[NodeType, Tensor]],
107
+ edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
108
+ **kwargs,
109
+ ) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:
110
+ """Collect attention coefficients from model layers."""
111
+ if self.is_hetero:
112
+ # For heterogeneous graphs, store alphas by edge type
113
+ alphas_dict: Dict[EdgeType, List[Tensor]] = {}
114
+
115
+ # Get list of edge types
116
+ edge_types = list(edge_index.keys())
117
+
118
+ # Hook function to capture attention coefficients by edge type
119
+ def hook(module, msg_kwargs, out):
120
+ # Find edge type from the module's full name
121
+ module_name = getattr(module, '_name', None)
122
+ if module_name is None:
123
+ return
50
124
 
51
- alphas: List[Tensor] = []
125
+ edge_type = None
126
+ for edge_tuple in edge_types:
127
+ src_type, edge_name, dst_type = edge_tuple
128
+ # Check if all components appear in the module name in
129
+ # order
130
+ try:
131
+ src_idx = module_name.index(src_type)
132
+ edge_idx = module_name.index(edge_name, src_idx)
133
+ dst_idx = module_name.index(dst_type, edge_idx)
134
+ if src_idx < edge_idx < dst_idx:
135
+ edge_type = edge_tuple
136
+ break
137
+ except ValueError: # Component not found
138
+ continue
139
+
140
+ if edge_type is None:
141
+ return
142
+
143
+ if edge_type not in alphas_dict:
144
+ alphas_dict[edge_type] = []
145
+
146
+ # Extract alpha from message kwargs or module
147
+ if 'alpha' in msg_kwargs[0]:
148
+ alphas_dict[edge_type].append(
149
+ msg_kwargs[0]['alpha'].detach())
150
+ elif getattr(module, '_alpha', None) is not None:
151
+ alphas_dict[edge_type].append(module._alpha.detach())
152
+ else:
153
+ # For homogeneous graphs, store all alphas in a list
154
+ alphas: List[Tensor] = []
52
155
 
53
- def hook(module, msg_kwargs, out):
54
- if 'alpha' in msg_kwargs[0]:
55
- alphas.append(msg_kwargs[0]['alpha'].detach())
56
- elif getattr(module, '_alpha', None) is not None:
57
- alphas.append(module._alpha.detach())
156
+ def hook(module, msg_kwargs, out):
157
+ if 'alpha' in msg_kwargs[0]:
158
+ alphas.append(msg_kwargs[0]['alpha'].detach())
159
+ elif getattr(module, '_alpha', None) is not None:
160
+ alphas.append(module._alpha.detach())
58
161
 
162
+ # Register hooks for all message passing modules
59
163
  hook_handles = []
60
- for module in model.modules(): # Register message forward hooks:
61
- if (isinstance(module, MessagePassing)
62
- and module.explain is not False):
164
+ for name, module in model.named_modules():
165
+ if isinstance(module,
166
+ MessagePassing) and module.explain is not False:
167
+ # Store name for hetero graph lookup in the hook
168
+ if self.is_hetero:
169
+ module._name = name
170
+
63
171
  hook_handles.append(module.register_message_forward_hook(hook))
64
172
 
173
+ # Forward pass to collect attention coefficients.
65
174
  model(x, edge_index, **kwargs)
66
175
 
67
- for handle in hook_handles: # Remove hooks:
176
+ # Remove hooks
177
+ for handle in hook_handles:
68
178
  handle.remove()
69
179
 
70
- if len(alphas) == 0:
71
- raise ValueError("Could not collect any attention coefficients. "
72
- "Please ensure that your model is using "
73
- "attention-based GNN layers.")
180
+ # Check if we collected any attention coefficients.
181
+ if self.is_hetero:
182
+ if not alphas_dict:
183
+ raise ValueError(
184
+ "Could not collect any attention coefficients. "
185
+ "Please ensure that your model is using "
186
+ "attention-based GNN layers.")
187
+ return alphas_dict
188
+ else:
189
+ if not alphas:
190
+ raise ValueError(
191
+ "Could not collect any attention coefficients. "
192
+ "Please ensure that your model is using "
193
+ "attention-based GNN layers.")
194
+ return alphas
74
195
 
196
+ def _process_attention_coefficients(
197
+ self,
198
+ alphas: List[Tensor],
199
+ edge_index_size: int,
200
+ ) -> Tensor:
201
+ """Process collected attention coefficients into a single mask."""
75
202
  for i, alpha in enumerate(alphas):
76
- alpha = alpha[:edge_index.size(1)] # Respect potential self-loops.
203
+ # Ensure alpha doesn't exceed edge_index size
204
+ alpha = alpha[:edge_index_size]
205
+
206
+ # Reduce multi-head attention
77
207
  if alpha.dim() == 2:
78
208
  alpha = getattr(torch, self.reduce)(alpha, dim=-1)
79
- if isinstance(alpha, tuple): # Respect `torch.max`:
209
+ if isinstance(alpha, tuple): # Handle torch.max output
80
210
  alpha = alpha[0]
81
211
  elif alpha.dim() > 2:
82
- raise ValueError(f"Can not reduce attention coefficients of "
212
+ raise ValueError(f"Cannot reduce attention coefficients of "
83
213
  f"shape {list(alpha.size())}")
84
214
  alphas[i] = alpha
85
215
 
216
+ # Combine attention coefficients across layers
86
217
  if len(alphas) > 1:
87
218
  alpha = torch.stack(alphas, dim=-1)
88
219
  alpha = getattr(torch, self.reduce)(alpha, dim=-1)
89
- if isinstance(alpha, tuple): # Respect `torch.max`:
220
+ if isinstance(alpha, tuple): # Handle torch.max output
90
221
  alpha = alpha[0]
91
222
  else:
92
223
  alpha = alphas[0]
93
224
 
225
+ return alpha
226
+
227
+ def _create_homo_explanation(
228
+ self,
229
+ model: torch.nn.Module,
230
+ alphas: List[Tensor],
231
+ edge_index: Tensor,
232
+ index: Optional[Union[int, Tensor]],
233
+ x: Tensor,
234
+ ) -> Explanation:
235
+ """Create explanation for homogeneous graph."""
236
+ # Get hard edge mask for node-level tasks
237
+ hard_edge_mask = None
238
+ if self.model_config.task_level == ModelTaskLevel.node:
239
+ _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
240
+ num_nodes=x.size(0))
241
+
242
+ # Process attention coefficients
243
+ alpha = self._process_attention_coefficients(alphas,
244
+ edge_index.size(1))
245
+
246
+ # Post-process mask with hard edge mask if needed
94
247
  alpha = self._post_process_mask(alpha, hard_edge_mask,
95
248
  apply_sigmoid=False)
96
249
 
97
250
  return Explanation(edge_mask=alpha)
98
251
 
252
+ def _create_hetero_explanation(
253
+ self,
254
+ model: torch.nn.Module,
255
+ alphas_dict: Dict[EdgeType, List[Tensor]],
256
+ edge_index: Dict[EdgeType, Tensor],
257
+ index: Optional[Union[int, Tensor]],
258
+ x: Dict[NodeType, Tensor],
259
+ ) -> HeteroExplanation:
260
+ """Create explanation for heterogeneous graph."""
261
+ edge_masks_dict = {}
262
+
263
+ # Process each edge type separately
264
+ for edge_type, alphas in alphas_dict.items():
265
+ if not alphas:
266
+ continue
267
+
268
+ # Get hard edge mask for node-level tasks
269
+ hard_edge_mask = None
270
+ if self.model_config.task_level == ModelTaskLevel.node:
271
+ src_type, _, dst_type = edge_type
272
+ _, hard_edge_mask = self._get_hard_masks(
273
+ model, index, edge_index[edge_type],
274
+ num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))
275
+
276
+ # Process attention coefficients for this edge type
277
+ alpha = self._process_attention_coefficients(
278
+ alphas, edge_index[edge_type].size(1))
279
+
280
+ # Apply hard mask if available
281
+ edge_masks_dict[edge_type] = self._post_process_mask(
282
+ alpha, hard_edge_mask, apply_sigmoid=False)
283
+
284
+ # Create heterogeneous explanation
285
+ explanation = HeteroExplanation()
286
+ explanation.set_value_dict('edge_mask', edge_masks_dict)
287
+ return explanation
288
+
99
289
  def supports(self) -> bool:
100
290
  explanation_type = self.explainer_config.explanation_type
101
291
  if explanation_type != ExplanationType.model:
@@ -166,7 +166,7 @@ class ExplainerAlgorithm(torch.nn.Module):
166
166
  elif self.model_config.return_type == ModelReturnType.probs:
167
167
  loss_fn = F.binary_cross_entropy
168
168
  else:
169
- assert False
169
+ raise AssertionError()
170
170
 
171
171
  return loss_fn(y_hat.view_as(y), y.float())
172
172
 
@@ -183,7 +183,7 @@ class ExplainerAlgorithm(torch.nn.Module):
183
183
  elif self.model_config.return_type == ModelReturnType.log_probs:
184
184
  loss_fn = F.nll_loss
185
185
  else:
186
- assert False
186
+ raise AssertionError()
187
187
 
188
188
  return loss_fn(y_hat, y)
189
189
 
@@ -190,7 +190,7 @@ def to_captum_input(
190
190
 
191
191
  Args:
192
192
  x (torch.Tensor or Dict[NodeType, torch.Tensor]): The node features.
193
- For heterogeneous graphs this is a dictionary holding node featues
193
+ For heterogeneous graphs this is a dictionary holding node features
194
194
  for each node type.
195
195
  edge_index(torch.Tensor or Dict[EdgeType, torch.Tensor]): The edge
196
196
  indices. For heterogeneous graphs this is a dictionary holding the
@@ -73,7 +73,8 @@ class CaptumExplainer(ExplainerAlgorithm):
73
73
  f"{self.attribution_method_class.__name__}")
74
74
 
75
75
  if kwargs.get('internal_batch_size', 1) != 1:
76
- warnings.warn("Overriding 'internal_batch_size' to 1")
76
+ warnings.warn("Overriding 'internal_batch_size' to 1",
77
+ stacklevel=2)
77
78
 
78
79
  if 'internal_batch_size' in self._get_attribute_parameters():
79
80
  kwargs['internal_batch_size'] = 1