pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -30,11 +30,13 @@ class LayerNorm(torch.nn.Module):
30
30
  affine (bool, optional): If set to :obj:`True`, this module has
31
31
  learnable affine parameters :math:`\gamma` and :math:`\beta`.
32
32
  (default: :obj:`True`)
33
- mode (str, optinal): The normalization mode to use for layer
33
+ mode (str, optional): The normalization mode to use for layer
34
34
  normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"`
35
35
  is used, each graph will be considered as an element to be
36
36
  normalized. If `"node"` is used, each node will be considered as
37
37
  an element to be normalized. (default: :obj:`"graph"`)
38
+ device (torch.device, optional): The device to use for the module.
39
+ (default: :obj:`None`)
38
40
  """
39
41
  def __init__(
40
42
  self,
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
42
44
  eps: float = 1e-5,
43
45
  affine: bool = True,
44
46
  mode: str = 'graph',
47
+ device: Optional[torch.device] = None,
45
48
  ):
46
49
  super().__init__()
47
50
 
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
51
54
  self.mode = mode
52
55
 
53
56
  if affine:
54
- self.weight = Parameter(torch.empty(in_channels))
55
- self.bias = Parameter(torch.empty(in_channels))
57
+ self.weight = Parameter(torch.empty(in_channels, device=device))
58
+ self.bias = Parameter(torch.empty(in_channels, device=device))
56
59
  else:
57
60
  self.register_parameter('weight', None)
58
61
  self.register_parameter('bias', None)
@@ -108,7 +111,7 @@ class LayerNorm(torch.nn.Module):
108
111
  return F.layer_norm(x, (self.in_channels, ), self.weight,
109
112
  self.bias, self.eps)
110
113
 
111
- raise ValueError(f"Unknow normalization mode: {self.mode}")
114
+ raise ValueError(f"Unknownn normalization mode: {self.mode}")
112
115
 
113
116
  def __repr__(self):
114
117
  return (f'{self.__class__.__name__}({self.in_channels}, '
@@ -130,10 +133,12 @@ class HeteroLayerNorm(torch.nn.Module):
130
133
  affine (bool, optional): If set to :obj:`True`, this module has
131
134
  learnable affine parameters :math:`\gamma` and :math:`\beta`.
132
135
  (default: :obj:`True`)
133
- mode (str, optinal): The normalization mode to use for layer
136
+ mode (str, optional): The normalization mode to use for layer
134
137
  normalization (:obj:`"node"`). If `"node"` is used, each node will
135
138
  be considered as an element to be normalized.
136
139
  (default: :obj:`"node"`)
140
+ device (torch.device, optional): The device to use for the module.
141
+ (default: :obj:`None`)
137
142
  """
138
143
  def __init__(
139
144
  self,
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
142
147
  eps: float = 1e-5,
143
148
  affine: bool = True,
144
149
  mode: str = 'node',
150
+ device: Optional[torch.device] = None,
145
151
  ):
146
152
  super().__init__()
147
153
  assert mode == 'node'
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
152
158
  self.affine = affine
153
159
 
154
160
  if affine:
155
- self.weight = Parameter(torch.empty(num_types, in_channels))
156
- self.bias = Parameter(torch.empty(num_types, in_channels))
161
+ self.weight = Parameter(
162
+ torch.empty(num_types, in_channels, device=device))
163
+ self.bias = Parameter(
164
+ torch.empty(num_types, in_channels, device=device))
157
165
  else:
158
166
  self.register_parameter('weight', None)
159
167
  self.register_parameter('bias', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
  from torch import Tensor
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
19
21
  learn_scale (bool, optional): If set to :obj:`True`, will learn the
20
22
  scaling factor :math:`s` of message normalization.
21
23
  (default: :obj:`False`)
24
+ device (torch.device, optional): The device to use for the module.
25
+ (default: :obj:`None`)
22
26
  """
23
- def __init__(self, learn_scale: bool = False):
27
+ def __init__(self, learn_scale: bool = False,
28
+ device: Optional[torch.device] = None):
24
29
  super().__init__()
25
- self.scale = Parameter(torch.empty(1), requires_grad=learn_scale)
30
+ self.scale = Parameter(torch.empty(1, device=device),
31
+ requires_grad=learn_scale)
26
32
  self.reset_parameters()
27
33
 
28
34
  def reset_parameters(self):
@@ -163,8 +163,10 @@ def knn_graph(
163
163
  :rtype: :class:`torch.Tensor`
164
164
  """
165
165
  if batch is not None and x.device != batch.device:
166
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
167
- "in 'knn_graph'. Performing blocking device transfer")
166
+ warnings.warn(
167
+ "Input tensor 'x' and 'batch' are on different devices "
168
+ "in 'knn_graph'. Performing blocking device transfer",
169
+ stacklevel=2)
168
170
  batch = batch.to(x.device)
169
171
 
170
172
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -285,8 +287,10 @@ def radius_graph(
285
287
  inputs to GPU before proceeding.
286
288
  """
287
289
  if batch is not None and x.device != batch.device:
288
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
289
- "in 'radius_graph'. Performing blocking device transfer")
290
+ warnings.warn(
291
+ "Input tensor 'x' and 'batch' are on different devices "
292
+ "in 'radius_graph'. Performing blocking device transfer",
293
+ stacklevel=2)
290
294
  batch = batch.to(x.device)
291
295
 
292
296
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -20,8 +20,7 @@ class UnpoolInfo(NamedTuple):
20
20
 
21
21
  class ClusterPooling(torch.nn.Module):
22
22
  r"""The cluster pooling operator from the `"Edge-Based Graph Component
23
- Pooling" <paper url>`_ paper.
24
-
23
+ Pooling" <https://arxiv.org/abs/2409.11856>`_ paper.
25
24
  :class:`ClusterPooling` computes a score for each edge.
26
25
  Based on the selected edges, graph clusters are calculated and compressed
27
26
  to one node using the injective :obj:`"sum"` aggregation function.
@@ -55,7 +54,7 @@ class ClusterPooling(torch.nn.Module):
55
54
  self.in_channels = in_channels
56
55
  self.edge_score_method = edge_score_method
57
56
  self.dropout = dropout
58
- self.threshhold = threshold
57
+ self.threshold = threshold
59
58
 
60
59
  self.lin = torch.nn.Linear(2 * in_channels, 1)
61
60
 
@@ -116,7 +115,7 @@ class ClusterPooling(torch.nn.Module):
116
115
 
117
116
  from scipy.sparse.csgraph import connected_components
118
117
 
119
- edge_contract = edge_index[:, edge_score > self.threshhold]
118
+ edge_contract = edge_index[:, edge_score > self.threshold]
120
119
 
121
120
  adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
122
121
  _, cluster_np = connected_components(adj, directed=True,
@@ -4,7 +4,6 @@ from typing import Optional
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- import torch_geometric.typing
8
7
  from torch_geometric.nn.pool.select import SelectOutput
9
8
 
10
9
 
@@ -49,8 +48,7 @@ class ConnectOutput:
49
48
  self.batch = batch
50
49
 
51
50
 
52
- if torch_geometric.typing.WITH_PT113:
53
- ConnectOutput = torch.jit.script(ConnectOutput)
51
+ ConnectOutput = torch.jit.script(ConnectOutput)
54
52
 
55
53
 
56
54
  class Connect(torch.nn.Module):
@@ -91,9 +91,10 @@ class KNNIndex:
91
91
  if hasattr(self.index, 'reserveMemory'):
92
92
  self.index.reserveMemory(self.reserve)
93
93
  else:
94
- warnings.warn(f"'{self.index.__class__.__name__}' "
95
- f"does not support pre-allocation of "
96
- f"memory")
94
+ warnings.warn(
95
+ f"'{self.index.__class__.__name__}' "
96
+ f"does not support pre-allocation of "
97
+ f"memory", stacklevel=2)
97
98
 
98
99
  self.index.train(emb)
99
100
 
@@ -135,14 +136,16 @@ class KNNIndex:
135
136
  query_k = min(query_k, self.numel)
136
137
 
137
138
  if k > 2048: # `faiss` supports up-to `k=2048`:
138
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
139
- f"(got {k}). This may cause some relevant items to "
140
- f"not be retrieved.")
139
+ warnings.warn(
140
+ f"Capping 'k' to faiss' upper limit of 2048 "
141
+ f"(got {k}). This may cause some relevant items to "
142
+ f"not be retrieved.", stacklevel=2)
141
143
  elif query_k > 2048:
142
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
143
- f"(got {k} which got extended to {query_k} due to "
144
- f"the exclusion of existing links). This may cause "
145
- f"some relevant items to not be retrieved.")
144
+ warnings.warn(
145
+ f"Capping 'k' to faiss' upper limit of 2048 "
146
+ f"(got {k} which got extended to {query_k} due to "
147
+ f"the exclusion of existing links). This may cause "
148
+ f"some relevant items to not be retrieved.", stacklevel=2)
146
149
  query_k = 2048
147
150
 
148
151
  score, index = self.index.search(emb.detach(), query_k)
@@ -4,8 +4,6 @@ from typing import Optional
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- import torch_geometric.typing
8
-
9
7
 
10
8
  @dataclass(init=False)
11
9
  class SelectOutput:
@@ -64,8 +62,7 @@ class SelectOutput:
64
62
  self.weight = weight
65
63
 
66
64
 
67
- if torch_geometric.typing.WITH_PT113:
68
- SelectOutput = torch.jit.script(SelectOutput)
65
+ SelectOutput = torch.jit.script(SelectOutput)
69
66
 
70
67
 
71
68
  class Select(torch.nn.Module):
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
108
108
 
109
109
  if (not hasattr(module, 'reset_parameters')
110
110
  and sum([p.numel() for p in module.parameters()]) > 0):
111
- warnings.warn(f"'{module}' will be duplicated, but its parameters "
112
- f"cannot be reset. To suppress this warning, add a "
113
- f"'reset_parameters()' method to '{module}'")
111
+ warnings.warn(
112
+ f"'{module}' will be duplicated, but its parameters "
113
+ f"cannot be reset. To suppress this warning, add a "
114
+ f"'reset_parameters()' method to '{module}'", stacklevel=2)
114
115
 
115
116
  convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
116
117
  self.hetero_module = HeteroConv(convs, aggr)
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
157
157
  f"There exist node types ({unused_node_types}) whose "
158
158
  f"representations do not get updated during message passing "
159
159
  f"as they do not occur as destination type in any edge type. "
160
- f"This may lead to unexpected behavior.")
160
+ f"This may lead to unexpected behavior.", stacklevel=2)
161
161
 
162
162
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
163
163
  for name in names:
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
166
166
  f"The type '{name}' contains invalid characters which "
167
167
  f"may lead to unexpected behavior. To avoid any issues, "
168
168
  f"ensure that your types only contain letters, numbers "
169
- f"and underscores.")
169
+ f"and underscores.", stacklevel=2)
170
170
 
171
171
  def placeholder(self, node: Node, target: Any, name: str):
172
172
  # Adds a `get` call to the input dictionary for every node-type or
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
379
379
  warnings.warn(
380
380
  f"'{target}' will be duplicated, but its parameters "
381
381
  f"cannot be reset. To suppress this warning, add a "
382
- f"'reset_parameters()' method to '{target}'")
382
+ f"'reset_parameters()' method to '{target}'", stacklevel=2)
383
383
 
384
384
  return module_dict
385
385
 
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
165
165
  f"There exist node types ({unused_node_types}) whose "
166
166
  f"representations do not get updated during message passing "
167
167
  f"as they do not occur as destination type in any edge type. "
168
- f"This may lead to unexpected behavior.")
168
+ f"This may lead to unexpected behavior.", stacklevel=2)
169
169
 
170
170
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
171
171
  for name in names:
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
174
174
  f"The type '{name}' contains invalid characters which "
175
175
  f"may lead to unexpected behavior. To avoid any issues, "
176
176
  f"ensure that your types only contain letters, numbers "
177
- f"and underscores.")
177
+ f"and underscores.", stacklevel=2)
178
178
 
179
179
  def transform(self) -> GraphModule:
180
180
  self._node_offset_dict_initialized = False
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
361
361
  warnings.warn(
362
362
  f"'{conv}' will be duplicated, but its parameters cannot "
363
363
  f"be reset. To suppress this warning, add a "
364
- f"'reset_parameters()' method to '{conv}'")
364
+ f"'reset_parameters()' method to '{conv}'", stacklevel=2)
365
365
  torch.nn.init.xavier_uniform_(conv.edge_type_weight)
366
366
 
367
367
  def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
@@ -380,7 +380,7 @@ class HeteroBasisConv(torch.nn.Module):
380
380
 
381
381
 
382
382
  class LinearAlign(torch.nn.Module):
383
- # Aligns representions to the same dimensionality. Note that this will
383
+ # Aligns representations to the same dimensionality. Note that this will
384
384
  # create lazy modules, and as such requires a forward pass in order to
385
385
  # initialize parameters.
386
386
  def __init__(self, keys: List[Union[NodeType, EdgeType]],
@@ -20,6 +20,7 @@ from .utils import (
20
20
  get_gpu_memory_from_nvidia_smi,
21
21
  get_model_size,
22
22
  )
23
+ from .nvtx import nvtxit
23
24
 
24
25
  __all__ = [
25
26
  'profileit',
@@ -38,6 +39,7 @@ __all__ = [
38
39
  'get_gpu_memory_from_nvidia_smi',
39
40
  'get_gpu_memory_from_ipex',
40
41
  'benchmark',
42
+ 'nvtxit',
41
43
  ]
42
44
 
43
45
  classes = __all__
@@ -0,0 +1,66 @@
1
+ from functools import wraps
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ CUDA_PROFILE_STARTED = False
7
+
8
+
9
+ def begin_cuda_profile():
10
+ global CUDA_PROFILE_STARTED
11
+ prev_state = CUDA_PROFILE_STARTED
12
+ if prev_state is False:
13
+ CUDA_PROFILE_STARTED = True
14
+ torch.cuda.cudart().cudaProfilerStart()
15
+ return prev_state
16
+
17
+
18
+ def end_cuda_profile(prev_state: bool):
19
+ global CUDA_PROFILE_STARTED
20
+ CUDA_PROFILE_STARTED = prev_state
21
+ if prev_state is False:
22
+ torch.cuda.cudart().cudaProfilerStop()
23
+
24
+
25
+ def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
26
+ n_iters: Optional[int] = None):
27
+ """Enables NVTX profiling for a function.
28
+
29
+ Args:
30
+ name (Optional[str], optional): Name to give the reference frame for
31
+ the function being wrapped. Defaults to the name of the
32
+ function in code.
33
+ n_warmups (int, optional): Number of iters to call that function
34
+ before starting. Defaults to 0.
35
+ n_iters (Optional[int], optional): Number of iters of that function to
36
+ record. Defaults to all of them.
37
+ """
38
+ def nvtx(func):
39
+
40
+ nonlocal name
41
+ iters_so_far = 0
42
+ if name is None:
43
+ name = func.__name__
44
+
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ nonlocal iters_so_far
48
+ if not torch.cuda.is_available():
49
+ return func(*args, **kwargs)
50
+ elif iters_so_far < n_warmups:
51
+ iters_so_far += 1
52
+ return func(*args, **kwargs)
53
+ elif n_iters is None or iters_so_far < n_iters + n_warmups:
54
+ prev_state = begin_cuda_profile()
55
+ torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
56
+ result = func(*args, **kwargs)
57
+ torch.cuda.nvtx.range_pop()
58
+ end_cuda_profile(prev_state)
59
+ iters_so_far += 1
60
+ return result
61
+ else:
62
+ return func(*args, **kwargs)
63
+
64
+ return wrapper
65
+
66
+ return nvtx
@@ -119,19 +119,34 @@ def get_gpu_memory_from_nvidia_smi( # pragma: no cover
119
119
  digits (int): The number of decimals to use for megabytes.
120
120
  (default: :obj:`2`)
121
121
  """
122
+ def parse_memory(output: str) -> list:
123
+ lines = output.decode('utf-8').split('\n')[1:-1]
124
+ mem_list = []
125
+ for line in lines:
126
+ try:
127
+ mem_list.append(int(line.split()[0]))
128
+ except (TypeError, ValueError):
129
+ mem_list.append(None)
130
+ return mem_list
131
+
132
+ def get_gpu_memory(out_device, digits):
133
+ if out_device is None:
134
+ return 0
135
+
136
+ return medibyte_to_megabyte(out_device, digits)
137
+
122
138
  CMD = 'nvidia-smi --query-gpu=memory.free --format=csv'
123
- free_out = sp.check_output(CMD.split()).decode('utf-8').split('\n')[1:-1]
139
+ free_out = parse_memory(sp.check_output(CMD.split()))
124
140
 
125
141
  CMD = 'nvidia-smi --query-gpu=memory.used --format=csv'
126
- used_out = sp.check_output(CMD.split()).decode('utf-8').split('\n')[1:-1]
142
+ used_out = parse_memory(sp.check_output(CMD.split()))
127
143
 
128
144
  if device < 0 or device >= len(free_out):
129
145
  raise AttributeError(
130
146
  f'GPU {device} not available (found {len(free_out)} GPUs)')
131
147
 
132
- free_mem = medibyte_to_megabyte(int(free_out[device].split()[0]), digits)
133
- used_mem = medibyte_to_megabyte(int(used_out[device].split()[0]), digits)
134
-
148
+ free_mem = get_gpu_memory(free_out[device], digits)
149
+ used_mem = get_gpu_memory(used_out[device], digits)
135
150
  return free_mem, used_mem
136
151
 
137
152
 
@@ -3,7 +3,7 @@ r"""Graph sampler package."""
3
3
  from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,
4
4
  SamplerOutput, HeteroSamplerOutput, NegativeSampling,
5
5
  NumNeighbors)
6
- from .neighbor_sampler import NeighborSampler
6
+ from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler
7
7
  from .hgt_sampler import HGTSampler
8
8
 
9
9
  __all__ = classes = [
@@ -15,5 +15,6 @@ __all__ = classes = [
15
15
  'NumNeighbors',
16
16
  'NegativeSampling',
17
17
  'NeighborSampler',
18
+ 'BidirectionalNeighborSampler',
18
19
  'HGTSampler',
19
20
  ]