pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -30,11 +30,13 @@ class LayerNorm(torch.nn.Module):
|
|
|
30
30
|
affine (bool, optional): If set to :obj:`True`, this module has
|
|
31
31
|
learnable affine parameters :math:`\gamma` and :math:`\beta`.
|
|
32
32
|
(default: :obj:`True`)
|
|
33
|
-
mode (str,
|
|
33
|
+
mode (str, optional): The normalization mode to use for layer
|
|
34
34
|
normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"`
|
|
35
35
|
is used, each graph will be considered as an element to be
|
|
36
36
|
normalized. If `"node"` is used, each node will be considered as
|
|
37
37
|
an element to be normalized. (default: :obj:`"graph"`)
|
|
38
|
+
device (torch.device, optional): The device to use for the module.
|
|
39
|
+
(default: :obj:`None`)
|
|
38
40
|
"""
|
|
39
41
|
def __init__(
|
|
40
42
|
self,
|
|
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
42
44
|
eps: float = 1e-5,
|
|
43
45
|
affine: bool = True,
|
|
44
46
|
mode: str = 'graph',
|
|
47
|
+
device: Optional[torch.device] = None,
|
|
45
48
|
):
|
|
46
49
|
super().__init__()
|
|
47
50
|
|
|
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
|
|
|
51
54
|
self.mode = mode
|
|
52
55
|
|
|
53
56
|
if affine:
|
|
54
|
-
self.weight = Parameter(torch.empty(in_channels))
|
|
55
|
-
self.bias = Parameter(torch.empty(in_channels))
|
|
57
|
+
self.weight = Parameter(torch.empty(in_channels, device=device))
|
|
58
|
+
self.bias = Parameter(torch.empty(in_channels, device=device))
|
|
56
59
|
else:
|
|
57
60
|
self.register_parameter('weight', None)
|
|
58
61
|
self.register_parameter('bias', None)
|
|
@@ -108,7 +111,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
108
111
|
return F.layer_norm(x, (self.in_channels, ), self.weight,
|
|
109
112
|
self.bias, self.eps)
|
|
110
113
|
|
|
111
|
-
raise ValueError(f"
|
|
114
|
+
raise ValueError(f"Unknownn normalization mode: {self.mode}")
|
|
112
115
|
|
|
113
116
|
def __repr__(self):
|
|
114
117
|
return (f'{self.__class__.__name__}({self.in_channels}, '
|
|
@@ -130,10 +133,12 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
130
133
|
affine (bool, optional): If set to :obj:`True`, this module has
|
|
131
134
|
learnable affine parameters :math:`\gamma` and :math:`\beta`.
|
|
132
135
|
(default: :obj:`True`)
|
|
133
|
-
mode (str,
|
|
136
|
+
mode (str, optional): The normalization mode to use for layer
|
|
134
137
|
normalization (:obj:`"node"`). If `"node"` is used, each node will
|
|
135
138
|
be considered as an element to be normalized.
|
|
136
139
|
(default: :obj:`"node"`)
|
|
140
|
+
device (torch.device, optional): The device to use for the module.
|
|
141
|
+
(default: :obj:`None`)
|
|
137
142
|
"""
|
|
138
143
|
def __init__(
|
|
139
144
|
self,
|
|
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
142
147
|
eps: float = 1e-5,
|
|
143
148
|
affine: bool = True,
|
|
144
149
|
mode: str = 'node',
|
|
150
|
+
device: Optional[torch.device] = None,
|
|
145
151
|
):
|
|
146
152
|
super().__init__()
|
|
147
153
|
assert mode == 'node'
|
|
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
152
158
|
self.affine = affine
|
|
153
159
|
|
|
154
160
|
if affine:
|
|
155
|
-
self.weight = Parameter(
|
|
156
|
-
|
|
161
|
+
self.weight = Parameter(
|
|
162
|
+
torch.empty(num_types, in_channels, device=device))
|
|
163
|
+
self.bias = Parameter(
|
|
164
|
+
torch.empty(num_types, in_channels, device=device))
|
|
157
165
|
else:
|
|
158
166
|
self.register_parameter('weight', None)
|
|
159
167
|
self.register_parameter('bias', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn.functional as F
|
|
3
5
|
from torch import Tensor
|
|
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
|
|
|
19
21
|
learn_scale (bool, optional): If set to :obj:`True`, will learn the
|
|
20
22
|
scaling factor :math:`s` of message normalization.
|
|
21
23
|
(default: :obj:`False`)
|
|
24
|
+
device (torch.device, optional): The device to use for the module.
|
|
25
|
+
(default: :obj:`None`)
|
|
22
26
|
"""
|
|
23
|
-
def __init__(self, learn_scale: bool = False
|
|
27
|
+
def __init__(self, learn_scale: bool = False,
|
|
28
|
+
device: Optional[torch.device] = None):
|
|
24
29
|
super().__init__()
|
|
25
|
-
self.scale = Parameter(torch.empty(1
|
|
30
|
+
self.scale = Parameter(torch.empty(1, device=device),
|
|
31
|
+
requires_grad=learn_scale)
|
|
26
32
|
self.reset_parameters()
|
|
27
33
|
|
|
28
34
|
def reset_parameters(self):
|
|
@@ -163,8 +163,10 @@ def knn_graph(
|
|
|
163
163
|
:rtype: :class:`torch.Tensor`
|
|
164
164
|
"""
|
|
165
165
|
if batch is not None and x.device != batch.device:
|
|
166
|
-
warnings.warn(
|
|
167
|
-
|
|
166
|
+
warnings.warn(
|
|
167
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
|
168
|
+
"in 'knn_graph'. Performing blocking device transfer",
|
|
169
|
+
stacklevel=2)
|
|
168
170
|
batch = batch.to(x.device)
|
|
169
171
|
|
|
170
172
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
|
@@ -285,8 +287,10 @@ def radius_graph(
|
|
|
285
287
|
inputs to GPU before proceeding.
|
|
286
288
|
"""
|
|
287
289
|
if batch is not None and x.device != batch.device:
|
|
288
|
-
warnings.warn(
|
|
289
|
-
|
|
290
|
+
warnings.warn(
|
|
291
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
|
292
|
+
"in 'radius_graph'. Performing blocking device transfer",
|
|
293
|
+
stacklevel=2)
|
|
290
294
|
batch = batch.to(x.device)
|
|
291
295
|
|
|
292
296
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
|
@@ -20,8 +20,7 @@ class UnpoolInfo(NamedTuple):
|
|
|
20
20
|
|
|
21
21
|
class ClusterPooling(torch.nn.Module):
|
|
22
22
|
r"""The cluster pooling operator from the `"Edge-Based Graph Component
|
|
23
|
-
Pooling" <
|
|
24
|
-
|
|
23
|
+
Pooling" <https://arxiv.org/abs/2409.11856>`_ paper.
|
|
25
24
|
:class:`ClusterPooling` computes a score for each edge.
|
|
26
25
|
Based on the selected edges, graph clusters are calculated and compressed
|
|
27
26
|
to one node using the injective :obj:`"sum"` aggregation function.
|
|
@@ -55,7 +54,7 @@ class ClusterPooling(torch.nn.Module):
|
|
|
55
54
|
self.in_channels = in_channels
|
|
56
55
|
self.edge_score_method = edge_score_method
|
|
57
56
|
self.dropout = dropout
|
|
58
|
-
self.
|
|
57
|
+
self.threshold = threshold
|
|
59
58
|
|
|
60
59
|
self.lin = torch.nn.Linear(2 * in_channels, 1)
|
|
61
60
|
|
|
@@ -116,7 +115,7 @@ class ClusterPooling(torch.nn.Module):
|
|
|
116
115
|
|
|
117
116
|
from scipy.sparse.csgraph import connected_components
|
|
118
117
|
|
|
119
|
-
edge_contract = edge_index[:, edge_score > self.
|
|
118
|
+
edge_contract = edge_index[:, edge_score > self.threshold]
|
|
120
119
|
|
|
121
120
|
adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
|
|
122
121
|
_, cluster_np = connected_components(adj, directed=True,
|
|
@@ -4,7 +4,6 @@ from typing import Optional
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
import torch_geometric.typing
|
|
8
7
|
from torch_geometric.nn.pool.select import SelectOutput
|
|
9
8
|
|
|
10
9
|
|
|
@@ -49,8 +48,7 @@ class ConnectOutput:
|
|
|
49
48
|
self.batch = batch
|
|
50
49
|
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
ConnectOutput = torch.jit.script(ConnectOutput)
|
|
51
|
+
ConnectOutput = torch.jit.script(ConnectOutput)
|
|
54
52
|
|
|
55
53
|
|
|
56
54
|
class Connect(torch.nn.Module):
|
torch_geometric/nn/pool/knn.py
CHANGED
|
@@ -91,9 +91,10 @@ class KNNIndex:
|
|
|
91
91
|
if hasattr(self.index, 'reserveMemory'):
|
|
92
92
|
self.index.reserveMemory(self.reserve)
|
|
93
93
|
else:
|
|
94
|
-
warnings.warn(
|
|
95
|
-
|
|
96
|
-
|
|
94
|
+
warnings.warn(
|
|
95
|
+
f"'{self.index.__class__.__name__}' "
|
|
96
|
+
f"does not support pre-allocation of "
|
|
97
|
+
f"memory", stacklevel=2)
|
|
97
98
|
|
|
98
99
|
self.index.train(emb)
|
|
99
100
|
|
|
@@ -135,14 +136,16 @@ class KNNIndex:
|
|
|
135
136
|
query_k = min(query_k, self.numel)
|
|
136
137
|
|
|
137
138
|
if k > 2048: # `faiss` supports up-to `k=2048`:
|
|
138
|
-
warnings.warn(
|
|
139
|
-
|
|
140
|
-
|
|
139
|
+
warnings.warn(
|
|
140
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
|
141
|
+
f"(got {k}). This may cause some relevant items to "
|
|
142
|
+
f"not be retrieved.", stacklevel=2)
|
|
141
143
|
elif query_k > 2048:
|
|
142
|
-
warnings.warn(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
144
|
+
warnings.warn(
|
|
145
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
|
146
|
+
f"(got {k} which got extended to {query_k} due to "
|
|
147
|
+
f"the exclusion of existing links). This may cause "
|
|
148
|
+
f"some relevant items to not be retrieved.", stacklevel=2)
|
|
146
149
|
query_k = 2048
|
|
147
150
|
|
|
148
151
|
score, index = self.index.search(emb.detach(), query_k)
|
|
@@ -4,8 +4,6 @@ from typing import Optional
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
import torch_geometric.typing
|
|
8
|
-
|
|
9
7
|
|
|
10
8
|
@dataclass(init=False)
|
|
11
9
|
class SelectOutput:
|
|
@@ -64,8 +62,7 @@ class SelectOutput:
|
|
|
64
62
|
self.weight = weight
|
|
65
63
|
|
|
66
64
|
|
|
67
|
-
|
|
68
|
-
SelectOutput = torch.jit.script(SelectOutput)
|
|
65
|
+
SelectOutput = torch.jit.script(SelectOutput)
|
|
69
66
|
|
|
70
67
|
|
|
71
68
|
class Select(torch.nn.Module):
|
|
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
|
|
|
108
108
|
|
|
109
109
|
if (not hasattr(module, 'reset_parameters')
|
|
110
110
|
and sum([p.numel() for p in module.parameters()]) > 0):
|
|
111
|
-
warnings.warn(
|
|
112
|
-
|
|
113
|
-
|
|
111
|
+
warnings.warn(
|
|
112
|
+
f"'{module}' will be duplicated, but its parameters "
|
|
113
|
+
f"cannot be reset. To suppress this warning, add a "
|
|
114
|
+
f"'reset_parameters()' method to '{module}'", stacklevel=2)
|
|
114
115
|
|
|
115
116
|
convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
|
|
116
117
|
self.hetero_module = HeteroConv(convs, aggr)
|
|
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
|
|
|
157
157
|
f"There exist node types ({unused_node_types}) whose "
|
|
158
158
|
f"representations do not get updated during message passing "
|
|
159
159
|
f"as they do not occur as destination type in any edge type. "
|
|
160
|
-
f"This may lead to unexpected behavior.")
|
|
160
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
|
161
161
|
|
|
162
162
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
|
163
163
|
for name in names:
|
|
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
|
|
|
166
166
|
f"The type '{name}' contains invalid characters which "
|
|
167
167
|
f"may lead to unexpected behavior. To avoid any issues, "
|
|
168
168
|
f"ensure that your types only contain letters, numbers "
|
|
169
|
-
f"and underscores.")
|
|
169
|
+
f"and underscores.", stacklevel=2)
|
|
170
170
|
|
|
171
171
|
def placeholder(self, node: Node, target: Any, name: str):
|
|
172
172
|
# Adds a `get` call to the input dictionary for every node-type or
|
|
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
|
|
|
379
379
|
warnings.warn(
|
|
380
380
|
f"'{target}' will be duplicated, but its parameters "
|
|
381
381
|
f"cannot be reset. To suppress this warning, add a "
|
|
382
|
-
f"'reset_parameters()' method to '{target}'")
|
|
382
|
+
f"'reset_parameters()' method to '{target}'", stacklevel=2)
|
|
383
383
|
|
|
384
384
|
return module_dict
|
|
385
385
|
|
|
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
|
165
165
|
f"There exist node types ({unused_node_types}) whose "
|
|
166
166
|
f"representations do not get updated during message passing "
|
|
167
167
|
f"as they do not occur as destination type in any edge type. "
|
|
168
|
-
f"This may lead to unexpected behavior.")
|
|
168
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
|
169
169
|
|
|
170
170
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
|
171
171
|
for name in names:
|
|
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
|
174
174
|
f"The type '{name}' contains invalid characters which "
|
|
175
175
|
f"may lead to unexpected behavior. To avoid any issues, "
|
|
176
176
|
f"ensure that your types only contain letters, numbers "
|
|
177
|
-
f"and underscores.")
|
|
177
|
+
f"and underscores.", stacklevel=2)
|
|
178
178
|
|
|
179
179
|
def transform(self) -> GraphModule:
|
|
180
180
|
self._node_offset_dict_initialized = False
|
|
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
|
|
|
361
361
|
warnings.warn(
|
|
362
362
|
f"'{conv}' will be duplicated, but its parameters cannot "
|
|
363
363
|
f"be reset. To suppress this warning, add a "
|
|
364
|
-
f"'reset_parameters()' method to '{conv}'")
|
|
364
|
+
f"'reset_parameters()' method to '{conv}'", stacklevel=2)
|
|
365
365
|
torch.nn.init.xavier_uniform_(conv.edge_type_weight)
|
|
366
366
|
|
|
367
367
|
def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
|
|
@@ -380,7 +380,7 @@ class HeteroBasisConv(torch.nn.Module):
|
|
|
380
380
|
|
|
381
381
|
|
|
382
382
|
class LinearAlign(torch.nn.Module):
|
|
383
|
-
# Aligns
|
|
383
|
+
# Aligns representations to the same dimensionality. Note that this will
|
|
384
384
|
# create lazy modules, and as such requires a forward pass in order to
|
|
385
385
|
# initialize parameters.
|
|
386
386
|
def __init__(self, keys: List[Union[NodeType, EdgeType]],
|
|
@@ -20,6 +20,7 @@ from .utils import (
|
|
|
20
20
|
get_gpu_memory_from_nvidia_smi,
|
|
21
21
|
get_model_size,
|
|
22
22
|
)
|
|
23
|
+
from .nvtx import nvtxit
|
|
23
24
|
|
|
24
25
|
__all__ = [
|
|
25
26
|
'profileit',
|
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
|
38
39
|
'get_gpu_memory_from_nvidia_smi',
|
|
39
40
|
'get_gpu_memory_from_ipex',
|
|
40
41
|
'benchmark',
|
|
42
|
+
'nvtxit',
|
|
41
43
|
]
|
|
42
44
|
|
|
43
45
|
classes = __all__
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
CUDA_PROFILE_STARTED = False
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def begin_cuda_profile():
|
|
10
|
+
global CUDA_PROFILE_STARTED
|
|
11
|
+
prev_state = CUDA_PROFILE_STARTED
|
|
12
|
+
if prev_state is False:
|
|
13
|
+
CUDA_PROFILE_STARTED = True
|
|
14
|
+
torch.cuda.cudart().cudaProfilerStart()
|
|
15
|
+
return prev_state
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def end_cuda_profile(prev_state: bool):
|
|
19
|
+
global CUDA_PROFILE_STARTED
|
|
20
|
+
CUDA_PROFILE_STARTED = prev_state
|
|
21
|
+
if prev_state is False:
|
|
22
|
+
torch.cuda.cudart().cudaProfilerStop()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
|
|
26
|
+
n_iters: Optional[int] = None):
|
|
27
|
+
"""Enables NVTX profiling for a function.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
name (Optional[str], optional): Name to give the reference frame for
|
|
31
|
+
the function being wrapped. Defaults to the name of the
|
|
32
|
+
function in code.
|
|
33
|
+
n_warmups (int, optional): Number of iters to call that function
|
|
34
|
+
before starting. Defaults to 0.
|
|
35
|
+
n_iters (Optional[int], optional): Number of iters of that function to
|
|
36
|
+
record. Defaults to all of them.
|
|
37
|
+
"""
|
|
38
|
+
def nvtx(func):
|
|
39
|
+
|
|
40
|
+
nonlocal name
|
|
41
|
+
iters_so_far = 0
|
|
42
|
+
if name is None:
|
|
43
|
+
name = func.__name__
|
|
44
|
+
|
|
45
|
+
@wraps(func)
|
|
46
|
+
def wrapper(*args, **kwargs):
|
|
47
|
+
nonlocal iters_so_far
|
|
48
|
+
if not torch.cuda.is_available():
|
|
49
|
+
return func(*args, **kwargs)
|
|
50
|
+
elif iters_so_far < n_warmups:
|
|
51
|
+
iters_so_far += 1
|
|
52
|
+
return func(*args, **kwargs)
|
|
53
|
+
elif n_iters is None or iters_so_far < n_iters + n_warmups:
|
|
54
|
+
prev_state = begin_cuda_profile()
|
|
55
|
+
torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
|
|
56
|
+
result = func(*args, **kwargs)
|
|
57
|
+
torch.cuda.nvtx.range_pop()
|
|
58
|
+
end_cuda_profile(prev_state)
|
|
59
|
+
iters_so_far += 1
|
|
60
|
+
return result
|
|
61
|
+
else:
|
|
62
|
+
return func(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
return wrapper
|
|
65
|
+
|
|
66
|
+
return nvtx
|
torch_geometric/profile/utils.py
CHANGED
|
@@ -119,19 +119,34 @@ def get_gpu_memory_from_nvidia_smi( # pragma: no cover
|
|
|
119
119
|
digits (int): The number of decimals to use for megabytes.
|
|
120
120
|
(default: :obj:`2`)
|
|
121
121
|
"""
|
|
122
|
+
def parse_memory(output: str) -> list:
|
|
123
|
+
lines = output.decode('utf-8').split('\n')[1:-1]
|
|
124
|
+
mem_list = []
|
|
125
|
+
for line in lines:
|
|
126
|
+
try:
|
|
127
|
+
mem_list.append(int(line.split()[0]))
|
|
128
|
+
except (TypeError, ValueError):
|
|
129
|
+
mem_list.append(None)
|
|
130
|
+
return mem_list
|
|
131
|
+
|
|
132
|
+
def get_gpu_memory(out_device, digits):
|
|
133
|
+
if out_device is None:
|
|
134
|
+
return 0
|
|
135
|
+
|
|
136
|
+
return medibyte_to_megabyte(out_device, digits)
|
|
137
|
+
|
|
122
138
|
CMD = 'nvidia-smi --query-gpu=memory.free --format=csv'
|
|
123
|
-
free_out = sp.check_output(CMD.split())
|
|
139
|
+
free_out = parse_memory(sp.check_output(CMD.split()))
|
|
124
140
|
|
|
125
141
|
CMD = 'nvidia-smi --query-gpu=memory.used --format=csv'
|
|
126
|
-
used_out = sp.check_output(CMD.split())
|
|
142
|
+
used_out = parse_memory(sp.check_output(CMD.split()))
|
|
127
143
|
|
|
128
144
|
if device < 0 or device >= len(free_out):
|
|
129
145
|
raise AttributeError(
|
|
130
146
|
f'GPU {device} not available (found {len(free_out)} GPUs)')
|
|
131
147
|
|
|
132
|
-
free_mem =
|
|
133
|
-
used_mem =
|
|
134
|
-
|
|
148
|
+
free_mem = get_gpu_memory(free_out[device], digits)
|
|
149
|
+
used_mem = get_gpu_memory(used_out[device], digits)
|
|
135
150
|
return free_mem, used_mem
|
|
136
151
|
|
|
137
152
|
|
|
@@ -3,7 +3,7 @@ r"""Graph sampler package."""
|
|
|
3
3
|
from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,
|
|
4
4
|
SamplerOutput, HeteroSamplerOutput, NegativeSampling,
|
|
5
5
|
NumNeighbors)
|
|
6
|
-
from .neighbor_sampler import NeighborSampler
|
|
6
|
+
from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler
|
|
7
7
|
from .hgt_sampler import HGTSampler
|
|
8
8
|
|
|
9
9
|
__all__ = classes = [
|
|
@@ -15,5 +15,6 @@ __all__ = classes = [
|
|
|
15
15
|
'NumNeighbors',
|
|
16
16
|
'NegativeSampling',
|
|
17
17
|
'NeighborSampler',
|
|
18
|
+
'BidirectionalNeighborSampler',
|
|
18
19
|
'HGTSampler',
|
|
19
20
|
]
|