pyg-nightly 2.7.0.dev20250607__py3-none-any.whl → 2.7.0.dev20250609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250609.dist-info}/METADATA +3 -2
- {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250609.dist-info}/RECORD +79 -79
- torch_geometric/__init__.py +5 -4
- torch_geometric/_compile.py +3 -2
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/data/data.py +3 -3
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +9 -6
- torch_geometric/data/hetero_data.py +7 -6
- torch_geometric/data/hypergraph_data.py +1 -1
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/large_graph_indexer.py +1 -1
- torch_geometric/data/lightning/datamodule.py +28 -20
- torch_geometric/data/storage.py +1 -1
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/tag_dataset.py +1 -1
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/rpc.py +2 -2
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +7 -7
- torch_geometric/explain/explainer.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +4 -2
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/hash_tensor.py +5 -4
- torch_geometric/io/fs.py +5 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/mixin.py +2 -1
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +6 -4
- torch_geometric/nn/conv/message_passing.py +3 -2
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/fx.py +7 -5
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/glem.py +20 -12
- torch_geometric/nn/models/gpse.py +2 -2
- torch_geometric/nn/models/graph_unet.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/rev_gnn.py +1 -1
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/nlp/llm.py +2 -1
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +3 -3
- torch_geometric/sampler/base.py +7 -4
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +10 -8
- torch_geometric/testing/decorators.py +3 -2
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/add_gpse.py +11 -2
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/typing.py +13 -9
- torch_geometric/utils/_scatter.py +8 -6
- torch_geometric/utils/_spmm.py +15 -12
- torch_geometric/utils/convert.py +2 -2
- torch_geometric/utils/embedding.py +5 -3
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/sparse.py +3 -2
- {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250609.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250609.dist-info}/licenses/LICENSE +0 -0
@@ -106,7 +106,7 @@ class DeepGraphInfomax(torch.nn.Module):
|
|
106
106
|
"""
|
107
107
|
from sklearn.linear_model import LogisticRegression
|
108
108
|
|
109
|
-
clf = LogisticRegression(solver=solver,
|
109
|
+
clf = LogisticRegression(*args, solver=solver,
|
110
110
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
111
111
|
train_y.detach().cpu().numpy())
|
112
112
|
return clf.score(test_z.detach().cpu().numpy(),
|
@@ -37,20 +37,28 @@ class GLEM(torch.nn.Module):
|
|
37
37
|
See `examples/llm_plus_gnn/glem.py` for example usage.
|
38
38
|
"""
|
39
39
|
def __init__(
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
40
|
+
self,
|
41
|
+
lm_to_use: str = 'prajjwal1/bert-tiny',
|
42
|
+
gnn_to_use: basic_gnn = GraphSAGE,
|
43
|
+
out_channels: int = 47,
|
44
|
+
gnn_loss: Optional[nn.Module] = None,
|
45
|
+
lm_loss: Optional[nn.Module] = None,
|
46
|
+
alpha: float = 0.5,
|
47
|
+
beta: float = 0.5,
|
48
|
+
lm_dtype: torch.dtype = torch.bfloat16,
|
49
|
+
lm_use_lora: bool = True,
|
50
|
+
lora_target_modules: Optional[Union[List[str], str]] = None,
|
51
|
+
device: Optional[Union[str, torch.device]] = None,
|
52
52
|
):
|
53
53
|
super().__init__()
|
54
|
+
|
55
|
+
if gnn_loss is None:
|
56
|
+
gnn_loss = nn.CrossEntropyLoss(reduction='mean')
|
57
|
+
if lm_loss is None:
|
58
|
+
lm_loss = nn.CrossEntropyLoss(reduction='mean')
|
59
|
+
if device is None:
|
60
|
+
device = torch.device('cpu')
|
61
|
+
|
54
62
|
self.device = device
|
55
63
|
self.lm_loss = lm_loss
|
56
64
|
self.gnn = gnn_to_use
|
@@ -801,7 +801,7 @@ def gpse_process(
|
|
801
801
|
shuffle=False, pin_memory=True, **kwargs)
|
802
802
|
out_list = []
|
803
803
|
pbar = trange(data.num_nodes, position=2)
|
804
|
-
for
|
804
|
+
for batch in loader:
|
805
805
|
out, _ = model(batch.to(device))
|
806
806
|
out = out[:batch.batch_size].to("cpu", non_blocking=True)
|
807
807
|
out_list.append(out)
|
@@ -906,7 +906,7 @@ def gpse_process_batch(
|
|
906
906
|
shuffle=False, pin_memory=True, **kwargs)
|
907
907
|
out_list = []
|
908
908
|
pbar = trange(batch.num_nodes, position=2)
|
909
|
-
for
|
909
|
+
for batch in loader:
|
910
910
|
out, _ = model(batch.to(device))
|
911
911
|
out = out[:batch.batch_size].to('cpu', non_blocking=True)
|
912
912
|
out_list.append(out)
|
@@ -64,7 +64,7 @@ class GraphUNet(torch.nn.Module):
|
|
64
64
|
in_channels = channels if sum_res else 2 * channels
|
65
65
|
|
66
66
|
self.up_convs = torch.nn.ModuleList()
|
67
|
-
for
|
67
|
+
for _ in range(depth - 1):
|
68
68
|
self.up_convs.append(GCNConv(in_channels, channels, improved=True))
|
69
69
|
self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))
|
70
70
|
|
@@ -233,7 +233,7 @@ class MetaPath2Vec(torch.nn.Module):
|
|
233
233
|
"""
|
234
234
|
from sklearn.linear_model import LogisticRegression
|
235
235
|
|
236
|
-
clf = LogisticRegression(solver=solver,
|
236
|
+
clf = LogisticRegression(*args, solver=solver,
|
237
237
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
238
238
|
train_y.detach().cpu().numpy())
|
239
239
|
return clf.score(test_z.detach().cpu().numpy(),
|
torch_geometric/nn/models/mlp.py
CHANGED
@@ -99,8 +99,10 @@ class MLP(torch.nn.Module):
|
|
99
99
|
act_first = act_first or kwargs.get("relu_first", False)
|
100
100
|
batch_norm = kwargs.get("batch_norm", None)
|
101
101
|
if batch_norm is not None and isinstance(batch_norm, bool):
|
102
|
-
warnings.warn(
|
103
|
-
|
102
|
+
warnings.warn(
|
103
|
+
"Argument `batch_norm` is deprecated, "
|
104
|
+
"please use `norm` to specify normalization layer.",
|
105
|
+
stacklevel=2)
|
104
106
|
norm = 'batch_norm' if batch_norm else None
|
105
107
|
batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
|
106
108
|
norm_kwargs = batch_norm_kwargs or {}
|
@@ -181,7 +181,7 @@ class Node2Vec(torch.nn.Module):
|
|
181
181
|
"""
|
182
182
|
from sklearn.linear_model import LogisticRegression
|
183
183
|
|
184
|
-
clf = LogisticRegression(solver=solver,
|
184
|
+
clf = LogisticRegression(*args, solver=solver,
|
185
185
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
186
186
|
train_y.detach().cpu().numpy())
|
187
187
|
return clf.score(test_z.detach().cpu().numpy(),
|
@@ -249,7 +249,7 @@ class GroupAddRev(InvertibleModule):
|
|
249
249
|
else:
|
250
250
|
assert num_groups is not None, "Please specific 'num_groups'"
|
251
251
|
self.convs = torch.nn.ModuleList([conv])
|
252
|
-
for
|
252
|
+
for _ in range(num_groups - 1):
|
253
253
|
conv = copy.deepcopy(self.convs[0])
|
254
254
|
if hasattr(conv, 'reset_parameters'):
|
255
255
|
conv.reset_parameters()
|
@@ -45,7 +45,7 @@ class SignedGCN(torch.nn.Module):
|
|
45
45
|
self.conv1 = SignedConv(in_channels, hidden_channels // 2,
|
46
46
|
first_aggr=True)
|
47
47
|
self.convs = torch.nn.ModuleList()
|
48
|
-
for
|
48
|
+
for _ in range(num_layers - 1):
|
49
49
|
self.convs.append(
|
50
50
|
SignedConv(hidden_channels // 2, hidden_channels // 2,
|
51
51
|
first_aggr=False))
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -94,7 +94,8 @@ class LLM(torch.nn.Module):
|
|
94
94
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
95
95
|
|
96
96
|
if 'max_memory' not in kwargs: # Pure CPU:
|
97
|
-
warnings.warn("LLM is being used on CPU, which may be slow"
|
97
|
+
warnings.warn("LLM is being used on CPU, which may be slow",
|
98
|
+
stacklevel=2)
|
98
99
|
self.device = torch.device('cpu')
|
99
100
|
self.autocast_context = nullcontext()
|
100
101
|
else:
|
@@ -163,8 +163,10 @@ def knn_graph(
|
|
163
163
|
:rtype: :class:`torch.Tensor`
|
164
164
|
"""
|
165
165
|
if batch is not None and x.device != batch.device:
|
166
|
-
warnings.warn(
|
167
|
-
|
166
|
+
warnings.warn(
|
167
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
168
|
+
"in 'knn_graph'. Performing blocking device transfer",
|
169
|
+
stacklevel=2)
|
168
170
|
batch = batch.to(x.device)
|
169
171
|
|
170
172
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
@@ -285,8 +287,10 @@ def radius_graph(
|
|
285
287
|
inputs to GPU before proceeding.
|
286
288
|
"""
|
287
289
|
if batch is not None and x.device != batch.device:
|
288
|
-
warnings.warn(
|
289
|
-
|
290
|
+
warnings.warn(
|
291
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
292
|
+
"in 'radius_graph'. Performing blocking device transfer",
|
293
|
+
stacklevel=2)
|
290
294
|
batch = batch.to(x.device)
|
291
295
|
|
292
296
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
torch_geometric/nn/pool/knn.py
CHANGED
@@ -91,9 +91,10 @@ class KNNIndex:
|
|
91
91
|
if hasattr(self.index, 'reserveMemory'):
|
92
92
|
self.index.reserveMemory(self.reserve)
|
93
93
|
else:
|
94
|
-
warnings.warn(
|
95
|
-
|
96
|
-
|
94
|
+
warnings.warn(
|
95
|
+
f"'{self.index.__class__.__name__}' "
|
96
|
+
f"does not support pre-allocation of "
|
97
|
+
f"memory", stacklevel=2)
|
97
98
|
|
98
99
|
self.index.train(emb)
|
99
100
|
|
@@ -135,14 +136,16 @@ class KNNIndex:
|
|
135
136
|
query_k = min(query_k, self.numel)
|
136
137
|
|
137
138
|
if k > 2048: # `faiss` supports up-to `k=2048`:
|
138
|
-
warnings.warn(
|
139
|
-
|
140
|
-
|
139
|
+
warnings.warn(
|
140
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
141
|
+
f"(got {k}). This may cause some relevant items to "
|
142
|
+
f"not be retrieved.", stacklevel=2)
|
141
143
|
elif query_k > 2048:
|
142
|
-
warnings.warn(
|
143
|
-
|
144
|
-
|
145
|
-
|
144
|
+
warnings.warn(
|
145
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
146
|
+
f"(got {k} which got extended to {query_k} due to "
|
147
|
+
f"the exclusion of existing links). This may cause "
|
148
|
+
f"some relevant items to not be retrieved.", stacklevel=2)
|
146
149
|
query_k = 2048
|
147
150
|
|
148
151
|
score, index = self.index.search(emb.detach(), query_k)
|
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
|
|
108
108
|
|
109
109
|
if (not hasattr(module, 'reset_parameters')
|
110
110
|
and sum([p.numel() for p in module.parameters()]) > 0):
|
111
|
-
warnings.warn(
|
112
|
-
|
113
|
-
|
111
|
+
warnings.warn(
|
112
|
+
f"'{module}' will be duplicated, but its parameters "
|
113
|
+
f"cannot be reset. To suppress this warning, add a "
|
114
|
+
f"'reset_parameters()' method to '{module}'", stacklevel=2)
|
114
115
|
|
115
116
|
convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
|
116
117
|
self.hetero_module = HeteroConv(convs, aggr)
|
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
|
|
157
157
|
f"There exist node types ({unused_node_types}) whose "
|
158
158
|
f"representations do not get updated during message passing "
|
159
159
|
f"as they do not occur as destination type in any edge type. "
|
160
|
-
f"This may lead to unexpected behavior.")
|
160
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
161
161
|
|
162
162
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
163
163
|
for name in names:
|
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
|
|
166
166
|
f"The type '{name}' contains invalid characters which "
|
167
167
|
f"may lead to unexpected behavior. To avoid any issues, "
|
168
168
|
f"ensure that your types only contain letters, numbers "
|
169
|
-
f"and underscores.")
|
169
|
+
f"and underscores.", stacklevel=2)
|
170
170
|
|
171
171
|
def placeholder(self, node: Node, target: Any, name: str):
|
172
172
|
# Adds a `get` call to the input dictionary for every node-type or
|
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
|
|
379
379
|
warnings.warn(
|
380
380
|
f"'{target}' will be duplicated, but its parameters "
|
381
381
|
f"cannot be reset. To suppress this warning, add a "
|
382
|
-
f"'reset_parameters()' method to '{target}'")
|
382
|
+
f"'reset_parameters()' method to '{target}'", stacklevel=2)
|
383
383
|
|
384
384
|
return module_dict
|
385
385
|
|
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
165
165
|
f"There exist node types ({unused_node_types}) whose "
|
166
166
|
f"representations do not get updated during message passing "
|
167
167
|
f"as they do not occur as destination type in any edge type. "
|
168
|
-
f"This may lead to unexpected behavior.")
|
168
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
169
169
|
|
170
170
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
171
171
|
for name in names:
|
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
174
174
|
f"The type '{name}' contains invalid characters which "
|
175
175
|
f"may lead to unexpected behavior. To avoid any issues, "
|
176
176
|
f"ensure that your types only contain letters, numbers "
|
177
|
-
f"and underscores.")
|
177
|
+
f"and underscores.", stacklevel=2)
|
178
178
|
|
179
179
|
def transform(self) -> GraphModule:
|
180
180
|
self._node_offset_dict_initialized = False
|
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
|
|
361
361
|
warnings.warn(
|
362
362
|
f"'{conv}' will be duplicated, but its parameters cannot "
|
363
363
|
f"be reset. To suppress this warning, add a "
|
364
|
-
f"'reset_parameters()' method to '{conv}'")
|
364
|
+
f"'reset_parameters()' method to '{conv}'", stacklevel=2)
|
365
365
|
torch.nn.init.xavier_uniform_(conv.edge_type_weight)
|
366
366
|
|
367
367
|
def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
|
torch_geometric/sampler/base.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import copy
|
2
2
|
import math
|
3
3
|
import warnings
|
4
|
-
from abc import ABC
|
4
|
+
from abc import ABC, abstractmethod
|
5
5
|
from collections import defaultdict
|
6
6
|
from dataclasses import dataclass
|
7
7
|
from enum import Enum
|
@@ -369,9 +369,10 @@ class HeteroSamplerOutput(CastMixin):
|
|
369
369
|
out.edge[edge_type] = None
|
370
370
|
|
371
371
|
else:
|
372
|
-
warnings.warn(
|
373
|
-
|
374
|
-
|
372
|
+
warnings.warn(
|
373
|
+
f"Cannot convert to bidirectional graph "
|
374
|
+
f"since the edge type {edge_type} does not "
|
375
|
+
f"seem to have a reverse edge type", stacklevel=2)
|
375
376
|
|
376
377
|
return out
|
377
378
|
|
@@ -622,6 +623,7 @@ class BaseSampler(ABC):
|
|
622
623
|
As such, it is recommended to limit the amount of information stored in
|
623
624
|
the sampler.
|
624
625
|
"""
|
626
|
+
@abstractmethod
|
625
627
|
def sample_from_nodes(
|
626
628
|
self,
|
627
629
|
index: NodeSamplerInput,
|
@@ -642,6 +644,7 @@ class BaseSampler(ABC):
|
|
642
644
|
"""
|
643
645
|
raise NotImplementedError
|
644
646
|
|
647
|
+
@abstractmethod
|
645
648
|
def sample_from_edges(
|
646
649
|
self,
|
647
650
|
index: EdgeSamplerInput,
|
@@ -1,12 +1,15 @@
|
|
1
|
-
from typing import Dict, List, Union
|
1
|
+
from typing import Dict, List, Optional, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from torch_geometric.data import Data, HeteroData
|
6
6
|
from torch_geometric.sampler import (
|
7
7
|
BaseSampler,
|
8
|
+
EdgeSamplerInput,
|
8
9
|
HeteroSamplerOutput,
|
10
|
+
NegativeSampling,
|
9
11
|
NodeSamplerInput,
|
12
|
+
SamplerOutput,
|
10
13
|
)
|
11
14
|
from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
|
12
15
|
from torch_geometric.typing import (
|
@@ -76,6 +79,13 @@ class HGTSampler(BaseSampler):
|
|
76
79
|
metadata=(inputs.input_id, inputs.time),
|
77
80
|
)
|
78
81
|
|
82
|
+
def sample_from_edges(
|
83
|
+
self,
|
84
|
+
index: EdgeSamplerInput,
|
85
|
+
neg_sampling: Optional[NegativeSampling] = None,
|
86
|
+
) -> Union[HeteroSamplerOutput, SamplerOutput]:
|
87
|
+
pass
|
88
|
+
|
79
89
|
@property
|
80
90
|
def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
|
81
91
|
return self.perm
|
@@ -52,16 +52,18 @@ class NeighborSampler(BaseSampler):
|
|
52
52
|
):
|
53
53
|
if not directed:
|
54
54
|
subgraph_type = SubgraphType.induced
|
55
|
-
warnings.warn(
|
56
|
-
|
57
|
-
|
55
|
+
warnings.warn(
|
56
|
+
f"The usage of the 'directed' argument in "
|
57
|
+
f"'{self.__class__.__name__}' is deprecated. Use "
|
58
|
+
f"`subgraph_type='induced'` instead.", stacklevel=2)
|
58
59
|
|
59
60
|
if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'
|
60
61
|
and subgraph_type != SubgraphType.induced):
|
61
|
-
warnings.warn(
|
62
|
-
|
63
|
-
|
64
|
-
|
62
|
+
warnings.warn(
|
63
|
+
f"Using '{self.__class__.__name__}' without a "
|
64
|
+
f"'pyg-lib' installation is deprecated and will be "
|
65
|
+
f"removed soon. Please install 'pyg-lib' for "
|
66
|
+
f"accelerated neighborhood sampling", stacklevel=2)
|
65
67
|
|
66
68
|
self.data_type = DataType.from_data(data)
|
67
69
|
|
@@ -806,7 +808,7 @@ def neg_sample(
|
|
806
808
|
out = out.view(num_samples, seed.numel())
|
807
809
|
mask = node_time[out] > seed_time # holds all invalid samples.
|
808
810
|
neg_sampling_complete = False
|
809
|
-
for
|
811
|
+
for _ in range(5): # pragma: no cover
|
810
812
|
num_invalid = int(mask.sum())
|
811
813
|
if num_invalid == 0:
|
812
814
|
neg_sampling_complete = True
|
@@ -252,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
|
|
252
252
|
if device:
|
253
253
|
backend = os.getenv('TORCH_BACKEND')
|
254
254
|
if backend is None:
|
255
|
-
warnings.warn(
|
256
|
-
|
255
|
+
warnings.warn(
|
256
|
+
f"Please specify the backend via 'TORCH_BACKEND' in"
|
257
|
+
f"order to test against '{device}'", stacklevel=2)
|
257
258
|
else:
|
258
259
|
import_module(backend)
|
259
260
|
devices.append(pytest.param(torch.device(device), id=device))
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
1
3
|
from torch.nn import Module
|
2
4
|
|
3
5
|
from torch_geometric.data import Data
|
@@ -22,13 +24,20 @@ class AddGPSE(BaseTransform):
|
|
22
24
|
(default: :obj:`NormalSE`)
|
23
25
|
|
24
26
|
"""
|
25
|
-
def __init__(
|
26
|
-
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
model: Module,
|
30
|
+
use_vn: bool = True,
|
31
|
+
rand_type: str = 'NormalSE',
|
32
|
+
):
|
27
33
|
self.model = model
|
28
34
|
self.use_vn = use_vn
|
29
35
|
self.vn = VirtualNode()
|
30
36
|
self.rand_type = rand_type
|
31
37
|
|
38
|
+
def forward(self, data: Data) -> Any:
|
39
|
+
pass
|
40
|
+
|
32
41
|
def __call__(self, data: Data) -> Data:
|
33
42
|
from torch_geometric.nn.models.gpse import gpse_process
|
34
43
|
|
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
|
|
108
108
|
**kwargs: bool,
|
109
109
|
) -> None:
|
110
110
|
if 'drop_orig_edges' in kwargs:
|
111
|
-
warnings.warn(
|
112
|
-
|
111
|
+
warnings.warn(
|
112
|
+
"'drop_orig_edges' is deprecated. Use "
|
113
|
+
"'drop_orig_edge_types' instead", stacklevel=2)
|
113
114
|
drop_orig_edge_types = kwargs['drop_orig_edges']
|
114
115
|
|
115
116
|
if 'drop_unconnected_nodes' in kwargs:
|
116
|
-
warnings.warn(
|
117
|
-
|
117
|
+
warnings.warn(
|
118
|
+
"'drop_unconnected_nodes' is deprecated. Use "
|
119
|
+
"'drop_unconnected_node_types' instead", stacklevel=2)
|
118
120
|
drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
|
119
121
|
|
120
122
|
for path in metapaths:
|
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
|
|
144
146
|
if self.max_sample is not None:
|
145
147
|
edge_index, edge_weight = self._sample(edge_index, edge_weight)
|
146
148
|
|
147
|
-
for
|
149
|
+
for edge_type in metapath[1:]:
|
148
150
|
edge_index2, edge_weight2 = self._edge_index(data, edge_type)
|
149
151
|
|
150
152
|
edge_index, edge_weight = edge_index.matmul(
|
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
|
|
276
278
|
row = start = torch.randperm(num_nodes)[:num_starts].repeat(
|
277
279
|
self.walks_per_node[j])
|
278
280
|
|
279
|
-
for
|
281
|
+
for edge_type in metapath:
|
280
282
|
edge_index = EdgeIndex(
|
281
283
|
data[edge_type].edge_index,
|
282
284
|
sparse_size=data[edge_type].size(),
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import copy
|
2
|
-
from abc import ABC
|
2
|
+
from abc import ABC, abstractmethod
|
3
3
|
from typing import Any
|
4
4
|
|
5
5
|
|
@@ -31,6 +31,7 @@ class BaseTransform(ABC):
|
|
31
31
|
# Shallow-copy the data so that we prevent in-place data modification.
|
32
32
|
return self.forward(copy.copy(data))
|
33
33
|
|
34
|
+
@abstractmethod
|
34
35
|
def forward(self, data: Any) -> Any:
|
35
36
|
pass
|
36
37
|
|
@@ -47,7 +47,7 @@ class LargestConnectedComponents(BaseTransform):
|
|
47
47
|
return data
|
48
48
|
|
49
49
|
_, count = np.unique(component, return_counts=True)
|
50
|
-
subset_np = np.
|
50
|
+
subset_np = np.isin(component, count.argsort()[-self.num_components:])
|
51
51
|
subset = torch.from_numpy(subset_np)
|
52
52
|
subset = subset.to(data.edge_index.device, torch.bool)
|
53
53
|
|
@@ -245,7 +245,7 @@ class RandomLinkSplit(BaseTransform):
|
|
245
245
|
warnings.warn(
|
246
246
|
f"There are not enough negative edges to satisfy "
|
247
247
|
"the provided sampling ratio. The ratio will be "
|
248
|
-
f"adjusted to {ratio:.2f}.")
|
248
|
+
f"adjusted to {ratio:.2f}.", stacklevel=2)
|
249
249
|
num_neg_train = int((num_neg_train / num_neg) * num_neg_found)
|
250
250
|
num_neg_val = int((num_neg_val / num_neg) * num_neg_found)
|
251
251
|
num_neg_test = num_neg_found - num_neg_train - num_neg_val
|
torch_geometric/typing.py
CHANGED
@@ -81,8 +81,9 @@ try:
|
|
81
81
|
WITH_CUDA_HASH_MAP = False
|
82
82
|
except Exception as e:
|
83
83
|
if not isinstance(e, ImportError): # pragma: no cover
|
84
|
-
warnings.warn(
|
85
|
-
|
84
|
+
warnings.warn(
|
85
|
+
f"An issue occurred while importing 'pyg-lib'. "
|
86
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
86
87
|
pyg_lib = object
|
87
88
|
WITH_PYG_LIB = False
|
88
89
|
WITH_GMM = False
|
@@ -125,8 +126,9 @@ try:
|
|
125
126
|
WITH_TORCH_SCATTER = True
|
126
127
|
except Exception as e:
|
127
128
|
if not isinstance(e, ImportError): # pragma: no cover
|
128
|
-
warnings.warn(
|
129
|
-
|
129
|
+
warnings.warn(
|
130
|
+
f"An issue occurred while importing 'torch-scatter'. "
|
131
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
130
132
|
torch_scatter = object
|
131
133
|
WITH_TORCH_SCATTER = False
|
132
134
|
|
@@ -136,8 +138,9 @@ try:
|
|
136
138
|
WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
|
137
139
|
except Exception as e:
|
138
140
|
if not isinstance(e, ImportError): # pragma: no cover
|
139
|
-
warnings.warn(
|
140
|
-
|
141
|
+
warnings.warn(
|
142
|
+
f"An issue occurred while importing 'torch-cluster'. "
|
143
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
141
144
|
WITH_TORCH_CLUSTER = False
|
142
145
|
WITH_TORCH_CLUSTER_BATCH_SIZE = False
|
143
146
|
|
@@ -154,7 +157,7 @@ except Exception as e:
|
|
154
157
|
if not isinstance(e, ImportError): # pragma: no cover
|
155
158
|
warnings.warn(
|
156
159
|
f"An issue occurred while importing 'torch-spline-conv'. "
|
157
|
-
f"Disabling its usage. Stacktrace: {e}")
|
160
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
158
161
|
WITH_TORCH_SPLINE_CONV = False
|
159
162
|
|
160
163
|
try:
|
@@ -163,8 +166,9 @@ try:
|
|
163
166
|
WITH_TORCH_SPARSE = True
|
164
167
|
except Exception as e:
|
165
168
|
if not isinstance(e, ImportError): # pragma: no cover
|
166
|
-
warnings.warn(
|
167
|
-
|
169
|
+
warnings.warn(
|
170
|
+
f"An issue occurred while importing 'torch-sparse'. "
|
171
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
168
172
|
WITH_TORCH_SPARSE = False
|
169
173
|
|
170
174
|
class SparseStorage: # type: ignore
|
@@ -88,9 +88,10 @@ def scatter(
|
|
88
88
|
|
89
89
|
if (src.is_cuda and src.requires_grad and not is_compiling()
|
90
90
|
and not is_in_onnx_export()):
|
91
|
-
warnings.warn(
|
92
|
-
|
93
|
-
|
91
|
+
warnings.warn(
|
92
|
+
f"The usage of `scatter(reduce='{reduce}')` "
|
93
|
+
f"can be accelerated via the 'torch-scatter'"
|
94
|
+
f" package, but it was not found", stacklevel=2)
|
94
95
|
|
95
96
|
index = broadcast(index, src, dim)
|
96
97
|
if not is_in_onnx_export():
|
@@ -120,9 +121,10 @@ def scatter(
|
|
120
121
|
or not src.is_cuda):
|
121
122
|
|
122
123
|
if src.is_cuda and not is_compiling():
|
123
|
-
warnings.warn(
|
124
|
-
|
125
|
-
|
124
|
+
warnings.warn(
|
125
|
+
f"The usage of `scatter(reduce='{reduce}')` "
|
126
|
+
f"can be accelerated via the 'torch-scatter'"
|
127
|
+
f" package, but it was not found", stacklevel=2)
|
126
128
|
|
127
129
|
index = broadcast(index, src, dim)
|
128
130
|
# We initialize with `one` here to match `scatter_mul` output:
|