pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -30,11 +30,13 @@ class LayerNorm(torch.nn.Module):
30
30
  affine (bool, optional): If set to :obj:`True`, this module has
31
31
  learnable affine parameters :math:`\gamma` and :math:`\beta`.
32
32
  (default: :obj:`True`)
33
- mode (str, optinal): The normalization mode to use for layer
33
+ mode (str, optional): The normalization mode to use for layer
34
34
  normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"`
35
35
  is used, each graph will be considered as an element to be
36
36
  normalized. If `"node"` is used, each node will be considered as
37
37
  an element to be normalized. (default: :obj:`"graph"`)
38
+ device (torch.device, optional): The device to use for the module.
39
+ (default: :obj:`None`)
38
40
  """
39
41
  def __init__(
40
42
  self,
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
42
44
  eps: float = 1e-5,
43
45
  affine: bool = True,
44
46
  mode: str = 'graph',
47
+ device: Optional[torch.device] = None,
45
48
  ):
46
49
  super().__init__()
47
50
 
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
51
54
  self.mode = mode
52
55
 
53
56
  if affine:
54
- self.weight = Parameter(torch.empty(in_channels))
55
- self.bias = Parameter(torch.empty(in_channels))
57
+ self.weight = Parameter(torch.empty(in_channels, device=device))
58
+ self.bias = Parameter(torch.empty(in_channels, device=device))
56
59
  else:
57
60
  self.register_parameter('weight', None)
58
61
  self.register_parameter('bias', None)
@@ -108,7 +111,7 @@ class LayerNorm(torch.nn.Module):
108
111
  return F.layer_norm(x, (self.in_channels, ), self.weight,
109
112
  self.bias, self.eps)
110
113
 
111
- raise ValueError(f"Unknow normalization mode: {self.mode}")
114
+ raise ValueError(f"Unknownn normalization mode: {self.mode}")
112
115
 
113
116
  def __repr__(self):
114
117
  return (f'{self.__class__.__name__}({self.in_channels}, '
@@ -130,10 +133,12 @@ class HeteroLayerNorm(torch.nn.Module):
130
133
  affine (bool, optional): If set to :obj:`True`, this module has
131
134
  learnable affine parameters :math:`\gamma` and :math:`\beta`.
132
135
  (default: :obj:`True`)
133
- mode (str, optinal): The normalization mode to use for layer
136
+ mode (str, optional): The normalization mode to use for layer
134
137
  normalization (:obj:`"node"`). If `"node"` is used, each node will
135
138
  be considered as an element to be normalized.
136
139
  (default: :obj:`"node"`)
140
+ device (torch.device, optional): The device to use for the module.
141
+ (default: :obj:`None`)
137
142
  """
138
143
  def __init__(
139
144
  self,
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
142
147
  eps: float = 1e-5,
143
148
  affine: bool = True,
144
149
  mode: str = 'node',
150
+ device: Optional[torch.device] = None,
145
151
  ):
146
152
  super().__init__()
147
153
  assert mode == 'node'
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
152
158
  self.affine = affine
153
159
 
154
160
  if affine:
155
- self.weight = Parameter(torch.empty(num_types, in_channels))
156
- self.bias = Parameter(torch.empty(num_types, in_channels))
161
+ self.weight = Parameter(
162
+ torch.empty(num_types, in_channels, device=device))
163
+ self.bias = Parameter(
164
+ torch.empty(num_types, in_channels, device=device))
157
165
  else:
158
166
  self.register_parameter('weight', None)
159
167
  self.register_parameter('bias', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
  from torch import Tensor
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
19
21
  learn_scale (bool, optional): If set to :obj:`True`, will learn the
20
22
  scaling factor :math:`s` of message normalization.
21
23
  (default: :obj:`False`)
24
+ device (torch.device, optional): The device to use for the module.
25
+ (default: :obj:`None`)
22
26
  """
23
- def __init__(self, learn_scale: bool = False):
27
+ def __init__(self, learn_scale: bool = False,
28
+ device: Optional[torch.device] = None):
24
29
  super().__init__()
25
- self.scale = Parameter(torch.empty(1), requires_grad=learn_scale)
30
+ self.scale = Parameter(torch.empty(1, device=device),
31
+ requires_grad=learn_scale)
26
32
  self.reset_parameters()
27
33
 
28
34
  def reset_parameters(self):
@@ -7,18 +7,19 @@ from torch import Tensor
7
7
  import torch_geometric.typing
8
8
  from torch_geometric.typing import OptTensor, torch_cluster
9
9
 
10
- from .asap import ASAPooling
11
10
  from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
12
- from .edge_pool import EdgePooling
13
11
  from .glob import global_add_pool, global_max_pool, global_mean_pool
14
12
  from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
15
13
  ApproxMIPSKNNIndex)
16
14
  from .graclus import graclus
17
15
  from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
18
- from .mem_pool import MemPooling
19
- from .pan_pool import PANPooling
20
- from .sag_pool import SAGPooling
21
16
  from .topk_pool import TopKPooling
17
+ from .sag_pool import SAGPooling
18
+ from .edge_pool import EdgePooling
19
+ from .cluster_pool import ClusterPooling
20
+ from .asap import ASAPooling
21
+ from .pan_pool import PANPooling
22
+ from .mem_pool import MemPooling
22
23
  from .voxel_grid import voxel_grid
23
24
  from .approx_knn import approx_knn, approx_knn_graph
24
25
 
@@ -162,8 +163,10 @@ def knn_graph(
162
163
  :rtype: :class:`torch.Tensor`
163
164
  """
164
165
  if batch is not None and x.device != batch.device:
165
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
166
- "in 'knn_graph'. Performing blocking device transfer")
166
+ warnings.warn(
167
+ "Input tensor 'x' and 'batch' are on different devices "
168
+ "in 'knn_graph'. Performing blocking device transfer",
169
+ stacklevel=2)
167
170
  batch = batch.to(x.device)
168
171
 
169
172
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -284,8 +287,10 @@ def radius_graph(
284
287
  inputs to GPU before proceeding.
285
288
  """
286
289
  if batch is not None and x.device != batch.device:
287
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
288
- "in 'radius_graph'. Performing blocking device transfer")
290
+ warnings.warn(
291
+ "Input tensor 'x' and 'batch' are on different devices "
292
+ "in 'radius_graph'. Performing blocking device transfer",
293
+ stacklevel=2)
289
294
  batch = batch.to(x.device)
290
295
 
291
296
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -344,6 +349,7 @@ __all__ = [
344
349
  'TopKPooling',
345
350
  'SAGPooling',
346
351
  'EdgePooling',
352
+ 'ClusterPooling',
347
353
  'ASAPooling',
348
354
  'PANPooling',
349
355
  'MemPooling',
@@ -0,0 +1,144 @@
1
+ from typing import NamedTuple, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.utils import (
8
+ dense_to_sparse,
9
+ one_hot,
10
+ to_dense_adj,
11
+ to_scipy_sparse_matrix,
12
+ )
13
+
14
+
15
+ class UnpoolInfo(NamedTuple):
16
+ edge_index: Tensor
17
+ cluster: Tensor
18
+ batch: Tensor
19
+
20
+
21
+ class ClusterPooling(torch.nn.Module):
22
+ r"""The cluster pooling operator from the `"Edge-Based Graph Component
23
+ Pooling" <https://arxiv.org/abs/2409.11856>`_ paper.
24
+ :class:`ClusterPooling` computes a score for each edge.
25
+ Based on the selected edges, graph clusters are calculated and compressed
26
+ to one node using the injective :obj:`"sum"` aggregation function.
27
+ Edges are remapped based on the nodes created by each cluster and the
28
+ original edges.
29
+
30
+ Args:
31
+ in_channels (int): Size of each input sample.
32
+ edge_score_method (str, optional): The function to apply
33
+ to compute the edge score from raw edge scores (:obj:`"tanh"`,
34
+ :obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
35
+ dropout (float, optional): The probability with
36
+ which to drop edge scores during training. (default: :obj:`0.0`)
37
+ threshold (float, optional): The threshold of edge scores. If set to
38
+ :obj:`None`, will be automatically inferred depending on
39
+ :obj:`edge_score_method`. (default: :obj:`None`)
40
+ """
41
+ def __init__(
42
+ self,
43
+ in_channels: int,
44
+ edge_score_method: str = 'tanh',
45
+ dropout: float = 0.0,
46
+ threshold: Optional[float] = None,
47
+ ):
48
+ super().__init__()
49
+ assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
50
+
51
+ if threshold is None:
52
+ threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
53
+
54
+ self.in_channels = in_channels
55
+ self.edge_score_method = edge_score_method
56
+ self.dropout = dropout
57
+ self.threshold = threshold
58
+
59
+ self.lin = torch.nn.Linear(2 * in_channels, 1)
60
+
61
+ def reset_parameters(self):
62
+ r"""Resets all learnable parameters of the module."""
63
+ self.lin.reset_parameters()
64
+
65
+ def forward(
66
+ self,
67
+ x: Tensor,
68
+ edge_index: Tensor,
69
+ batch: Tensor,
70
+ ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
71
+ r"""Forward pass.
72
+
73
+ Args:
74
+ x (torch.Tensor): The node features.
75
+ edge_index (torch.Tensor): The edge indices.
76
+ batch (torch.Tensor): Batch vector
77
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
78
+ each node to a specific example.
79
+
80
+ Return types:
81
+ * **x** *(torch.Tensor)* - The pooled node features.
82
+ * **edge_index** *(torch.Tensor)* - The coarsened edge indices.
83
+ * **batch** *(torch.Tensor)* - The coarsened batch vector.
84
+ * **unpool_info** *(UnpoolInfo)* - Information that can be consumed
85
+ for unpooling.
86
+ """
87
+ mask = edge_index[0] != edge_index[1]
88
+ edge_index = edge_index[:, mask]
89
+
90
+ edge_attr = torch.cat(
91
+ [x[edge_index[0]], x[edge_index[1]]],
92
+ dim=-1,
93
+ )
94
+ edge_score = self.lin(edge_attr).view(-1)
95
+ edge_score = F.dropout(edge_score, p=self.dropout,
96
+ training=self.training)
97
+
98
+ if self.edge_score_method == 'tanh':
99
+ edge_score = edge_score.tanh()
100
+ elif self.edge_score_method == 'sigmoid':
101
+ edge_score = edge_score.sigmoid()
102
+ else:
103
+ assert self.edge_score_method == 'log_softmax'
104
+ edge_score = F.log_softmax(edge_score, dim=0)
105
+
106
+ return self._merge_edges(x, edge_index, batch, edge_score)
107
+
108
+ def _merge_edges(
109
+ self,
110
+ x: Tensor,
111
+ edge_index: Tensor,
112
+ batch: Tensor,
113
+ edge_score: Tensor,
114
+ ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
115
+
116
+ from scipy.sparse.csgraph import connected_components
117
+
118
+ edge_contract = edge_index[:, edge_score > self.threshold]
119
+
120
+ adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
121
+ _, cluster_np = connected_components(adj, directed=True,
122
+ connection="weak")
123
+
124
+ cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
125
+ C = one_hot(cluster)
126
+ A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
127
+ S = to_dense_adj(edge_index, edge_attr=edge_score,
128
+ max_num_nodes=x.size(0)).squeeze(0)
129
+
130
+ A_contract = to_dense_adj(edge_contract,
131
+ max_num_nodes=x.size(0)).squeeze(0)
132
+ nodes_single = ((A_contract.sum(dim=-1) +
133
+ A_contract.sum(dim=-2)) == 0).nonzero()
134
+ S[nodes_single, nodes_single] = 1.0
135
+
136
+ x_out = (S @ C).t() @ x
137
+ edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
138
+ batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
139
+ unpool_info = UnpoolInfo(edge_index, cluster, batch)
140
+
141
+ return x_out, edge_index_out, batch_out, unpool_info
142
+
143
+ def __repr__(self) -> str:
144
+ return f'{self.__class__.__name__}({self.in_channels})'
@@ -4,7 +4,6 @@ from typing import Optional
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- import torch_geometric.typing
8
7
  from torch_geometric.nn.pool.select import SelectOutput
9
8
 
10
9
 
@@ -49,8 +48,7 @@ class ConnectOutput:
49
48
  self.batch = batch
50
49
 
51
50
 
52
- if torch_geometric.typing.WITH_PT113:
53
- ConnectOutput = torch.jit.script(ConnectOutput)
51
+ ConnectOutput = torch.jit.script(ConnectOutput)
54
52
 
55
53
 
56
54
  class Connect(torch.nn.Module):
@@ -58,7 +58,7 @@ class EdgePooling(torch.nn.Module):
58
58
  self,
59
59
  in_channels: int,
60
60
  edge_score_method: Optional[Callable] = None,
61
- dropout: Optional[float] = 0.0,
61
+ dropout: float = 0.0,
62
62
  add_to_edge_score: float = 0.5,
63
63
  ):
64
64
  super().__init__()
@@ -91,9 +91,10 @@ class KNNIndex:
91
91
  if hasattr(self.index, 'reserveMemory'):
92
92
  self.index.reserveMemory(self.reserve)
93
93
  else:
94
- warnings.warn(f"'{self.index.__class__.__name__}' "
95
- f"does not support pre-allocation of "
96
- f"memory")
94
+ warnings.warn(
95
+ f"'{self.index.__class__.__name__}' "
96
+ f"does not support pre-allocation of "
97
+ f"memory", stacklevel=2)
97
98
 
98
99
  self.index.train(emb)
99
100
 
@@ -135,14 +136,16 @@ class KNNIndex:
135
136
  query_k = min(query_k, self.numel)
136
137
 
137
138
  if k > 2048: # `faiss` supports up-to `k=2048`:
138
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
139
- f"(got {k}). This may cause some relevant items to "
140
- f"not be retrieved.")
139
+ warnings.warn(
140
+ f"Capping 'k' to faiss' upper limit of 2048 "
141
+ f"(got {k}). This may cause some relevant items to "
142
+ f"not be retrieved.", stacklevel=2)
141
143
  elif query_k > 2048:
142
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
143
- f"(got {k} which got extended to {query_k} due to "
144
- f"the exclusion of existing links). This may cause "
145
- f"some relevant items to not be retrieved.")
144
+ warnings.warn(
145
+ f"Capping 'k' to faiss' upper limit of 2048 "
146
+ f"(got {k} which got extended to {query_k} due to "
147
+ f"the exclusion of existing links). This may cause "
148
+ f"some relevant items to not be retrieved.", stacklevel=2)
146
149
  query_k = 2048
147
150
 
148
151
  score, index = self.index.search(emb.detach(), query_k)
@@ -4,8 +4,6 @@ from typing import Optional
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- import torch_geometric.typing
8
-
9
7
 
10
8
  @dataclass(init=False)
11
9
  class SelectOutput:
@@ -64,8 +62,7 @@ class SelectOutput:
64
62
  self.weight = weight
65
63
 
66
64
 
67
- if torch_geometric.typing.WITH_PT113:
68
- SelectOutput = torch.jit.script(SelectOutput)
65
+ SelectOutput = torch.jit.script(SelectOutput)
69
66
 
70
67
 
71
68
  class Select(torch.nn.Module):
@@ -141,7 +141,7 @@ def get_shape(inputs: Any) -> str:
141
141
  def postprocess(info_list: List[dict]) -> List[dict]:
142
142
  for idx, info in enumerate(info_list):
143
143
  depth = info['depth']
144
- if idx > 0: # root module (0) is exclued
144
+ if idx > 0: # root module (0) is excluded
145
145
  if depth == 1:
146
146
  prefix = '├─'
147
147
  else:
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
108
108
 
109
109
  if (not hasattr(module, 'reset_parameters')
110
110
  and sum([p.numel() for p in module.parameters()]) > 0):
111
- warnings.warn(f"'{module}' will be duplicated, but its parameters "
112
- f"cannot be reset. To suppress this warning, add a "
113
- f"'reset_parameters()' method to '{module}'")
111
+ warnings.warn(
112
+ f"'{module}' will be duplicated, but its parameters "
113
+ f"cannot be reset. To suppress this warning, add a "
114
+ f"'reset_parameters()' method to '{module}'", stacklevel=2)
114
115
 
115
116
  convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
116
117
  self.hetero_module = HeteroConv(convs, aggr)
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
157
157
  f"There exist node types ({unused_node_types}) whose "
158
158
  f"representations do not get updated during message passing "
159
159
  f"as they do not occur as destination type in any edge type. "
160
- f"This may lead to unexpected behavior.")
160
+ f"This may lead to unexpected behavior.", stacklevel=2)
161
161
 
162
162
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
163
163
  for name in names:
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
166
166
  f"The type '{name}' contains invalid characters which "
167
167
  f"may lead to unexpected behavior. To avoid any issues, "
168
168
  f"ensure that your types only contain letters, numbers "
169
- f"and underscores.")
169
+ f"and underscores.", stacklevel=2)
170
170
 
171
171
  def placeholder(self, node: Node, target: Any, name: str):
172
172
  # Adds a `get` call to the input dictionary for every node-type or
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
379
379
  warnings.warn(
380
380
  f"'{target}' will be duplicated, but its parameters "
381
381
  f"cannot be reset. To suppress this warning, add a "
382
- f"'reset_parameters()' method to '{target}'")
382
+ f"'reset_parameters()' method to '{target}'", stacklevel=2)
383
383
 
384
384
  return module_dict
385
385
 
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
165
165
  f"There exist node types ({unused_node_types}) whose "
166
166
  f"representations do not get updated during message passing "
167
167
  f"as they do not occur as destination type in any edge type. "
168
- f"This may lead to unexpected behavior.")
168
+ f"This may lead to unexpected behavior.", stacklevel=2)
169
169
 
170
170
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
171
171
  for name in names:
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
174
174
  f"The type '{name}' contains invalid characters which "
175
175
  f"may lead to unexpected behavior. To avoid any issues, "
176
176
  f"ensure that your types only contain letters, numbers "
177
- f"and underscores.")
177
+ f"and underscores.", stacklevel=2)
178
178
 
179
179
  def transform(self) -> GraphModule:
180
180
  self._node_offset_dict_initialized = False
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
361
361
  warnings.warn(
362
362
  f"'{conv}' will be duplicated, but its parameters cannot "
363
363
  f"be reset. To suppress this warning, add a "
364
- f"'reset_parameters()' method to '{conv}'")
364
+ f"'reset_parameters()' method to '{conv}'", stacklevel=2)
365
365
  torch.nn.init.xavier_uniform_(conv.edge_type_weight)
366
366
 
367
367
  def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
@@ -380,7 +380,7 @@ class HeteroBasisConv(torch.nn.Module):
380
380
 
381
381
 
382
382
  class LinearAlign(torch.nn.Module):
383
- # Aligns representions to the same dimensionality. Note that this will
383
+ # Aligns representations to the same dimensionality. Note that this will
384
384
  # create lazy modules, and as such requires a forward pass in order to
385
385
  # initialize parameters.
386
386
  def __init__(self, keys: List[Union[NodeType, EdgeType]],
@@ -468,7 +468,7 @@ def get_edge_type(
468
468
  ###############################################################################
469
469
 
470
470
  # These methods are used to group the individual type-wise components into a
471
- # unfied single representation.
471
+ # unified single representation.
472
472
 
473
473
 
474
474
  def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
@@ -20,6 +20,7 @@ from .utils import (
20
20
  get_gpu_memory_from_nvidia_smi,
21
21
  get_model_size,
22
22
  )
23
+ from .nvtx import nvtxit
23
24
 
24
25
  __all__ = [
25
26
  'profileit',
@@ -38,6 +39,7 @@ __all__ = [
38
39
  'get_gpu_memory_from_nvidia_smi',
39
40
  'get_gpu_memory_from_ipex',
40
41
  'benchmark',
42
+ 'nvtxit',
41
43
  ]
42
44
 
43
45
  classes = __all__
@@ -0,0 +1,66 @@
1
+ from functools import wraps
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ CUDA_PROFILE_STARTED = False
7
+
8
+
9
+ def begin_cuda_profile():
10
+ global CUDA_PROFILE_STARTED
11
+ prev_state = CUDA_PROFILE_STARTED
12
+ if prev_state is False:
13
+ CUDA_PROFILE_STARTED = True
14
+ torch.cuda.cudart().cudaProfilerStart()
15
+ return prev_state
16
+
17
+
18
+ def end_cuda_profile(prev_state: bool):
19
+ global CUDA_PROFILE_STARTED
20
+ CUDA_PROFILE_STARTED = prev_state
21
+ if prev_state is False:
22
+ torch.cuda.cudart().cudaProfilerStop()
23
+
24
+
25
+ def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
26
+ n_iters: Optional[int] = None):
27
+ """Enables NVTX profiling for a function.
28
+
29
+ Args:
30
+ name (Optional[str], optional): Name to give the reference frame for
31
+ the function being wrapped. Defaults to the name of the
32
+ function in code.
33
+ n_warmups (int, optional): Number of iters to call that function
34
+ before starting. Defaults to 0.
35
+ n_iters (Optional[int], optional): Number of iters of that function to
36
+ record. Defaults to all of them.
37
+ """
38
+ def nvtx(func):
39
+
40
+ nonlocal name
41
+ iters_so_far = 0
42
+ if name is None:
43
+ name = func.__name__
44
+
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ nonlocal iters_so_far
48
+ if not torch.cuda.is_available():
49
+ return func(*args, **kwargs)
50
+ elif iters_so_far < n_warmups:
51
+ iters_so_far += 1
52
+ return func(*args, **kwargs)
53
+ elif n_iters is None or iters_so_far < n_iters + n_warmups:
54
+ prev_state = begin_cuda_profile()
55
+ torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
56
+ result = func(*args, **kwargs)
57
+ torch.cuda.nvtx.range_pop()
58
+ end_cuda_profile(prev_state)
59
+ iters_so_far += 1
60
+ return result
61
+ else:
62
+ return func(*args, **kwargs)
63
+
64
+ return wrapper
65
+
66
+ return nvtx
@@ -5,6 +5,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
5
5
  import torch
6
6
  import torch.profiler as torch_profiler
7
7
 
8
+ import torch_geometric.typing
9
+
8
10
  # predefined namedtuple for variable setting (global template)
9
11
  Trace = namedtuple('Trace', ['path', 'leaf', 'module'])
10
12
 
@@ -325,6 +327,8 @@ def _flatten_tree(t, depth=0):
325
327
 
326
328
 
327
329
  def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
330
+ device_str = 'device' if torch_geometric.typing.WITH_PT24 else 'cuda'
331
+
328
332
  # memory profiling supported in torch >= 1.6
329
333
  self_cpu_memory = None
330
334
  has_self_cpu_memory = any(
@@ -339,29 +343,34 @@ def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
339
343
  [getattr(e, "cpu_memory_usage", 0) or 0 for e in events])
340
344
  self_cuda_memory = None
341
345
  has_self_cuda_memory = any(
342
- hasattr(e, "self_cuda_memory_usage") for e in events)
346
+ hasattr(e, f"self_{device_str}_memory_usage") for e in events)
343
347
  if has_self_cuda_memory:
344
- self_cuda_memory = sum(
345
- [getattr(e, "self_cuda_memory_usage", 0) or 0 for e in events])
348
+ self_cuda_memory = sum([
349
+ getattr(e, f"self_{device_str}_memory_usage", 0) or 0
350
+ for e in events
351
+ ])
346
352
  cuda_memory = None
347
- has_cuda_memory = any(hasattr(e, "cuda_memory_usage") for e in events)
353
+ has_cuda_memory = any(
354
+ hasattr(e, f"{device_str}_memory_usage") for e in events)
348
355
  if has_cuda_memory:
349
356
  cuda_memory = sum(
350
- [getattr(e, "cuda_memory_usage", 0) or 0 for e in events])
357
+ [getattr(e, f"{device_str}_memory_usage", 0) or 0 for e in events])
351
358
 
352
359
  # self CUDA time supported in torch >= 1.7
353
360
  self_cuda_total = None
354
361
  has_self_cuda_time = any(
355
- hasattr(e, "self_cuda_time_total") for e in events)
362
+ hasattr(e, f"self_{device_str}_time_total") for e in events)
356
363
  if has_self_cuda_time:
357
- self_cuda_total = sum(
358
- [getattr(e, "self_cuda_time_total", 0) or 0 for e in events])
364
+ self_cuda_total = sum([
365
+ getattr(e, f"self_{device_str}_time_total", 0) or 0 for e in events
366
+ ])
359
367
 
360
368
  return Measure(
361
369
  self_cpu_total=sum([e.self_cpu_time_total or 0 for e in events]),
362
370
  cpu_total=sum([e.cpu_time_total or 0 for e in events]),
363
371
  self_cuda_total=self_cuda_total,
364
- cuda_total=sum([e.cuda_time_total or 0 for e in events]),
372
+ cuda_total=sum(
373
+ [getattr(e, f"{device_str}_time_total") or 0 for e in events]),
365
374
  self_cpu_memory=self_cpu_memory,
366
375
  cpu_memory=cpu_memory,
367
376
  self_cuda_memory=self_cuda_memory,