pyg-nightly 2.7.0.dev20250606__py3-none-any.whl → 2.7.0.dev20250608__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/METADATA +3 -2
- {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/RECORD +84 -84
- torch_geometric/__init__.py +5 -4
- torch_geometric/_compile.py +3 -2
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/data/data.py +3 -3
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +9 -6
- torch_geometric/data/hetero_data.py +7 -6
- torch_geometric/data/hypergraph_data.py +1 -1
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/large_graph_indexer.py +1 -1
- torch_geometric/data/lightning/datamodule.py +28 -20
- torch_geometric/data/storage.py +1 -1
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/tag_dataset.py +1 -1
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/rpc.py +2 -2
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +7 -7
- torch_geometric/explain/explainer.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +4 -2
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/utils/comp_budget.py +2 -1
- torch_geometric/hash_tensor.py +5 -4
- torch_geometric/io/fs.py +5 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/mixin.py +2 -1
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +6 -4
- torch_geometric/nn/conv/message_passing.py +3 -2
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/fx.py +7 -5
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/glem.py +20 -12
- torch_geometric/nn/models/gpse.py +30 -13
- torch_geometric/nn/models/graph_unet.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/rev_gnn.py +1 -1
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/nlp/llm.py +2 -1
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +3 -3
- torch_geometric/sampler/base.py +7 -4
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +10 -8
- torch_geometric/testing/decorators.py +3 -2
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/add_gpse.py +11 -2
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/typing.py +13 -9
- torch_geometric/utils/_scatter.py +8 -6
- torch_geometric/utils/_spmm.py +15 -12
- torch_geometric/utils/convert.py +2 -2
- torch_geometric/utils/embedding.py +5 -3
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/sparse.py +3 -2
- {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/licenses/LICENSE +0 -0
@@ -90,7 +90,7 @@ class SGConv(MessagePassing):
|
|
90
90
|
edge_index, edge_weight, x.size(self.node_dim), False,
|
91
91
|
self.add_self_loops, self.flow, dtype=x.dtype)
|
92
92
|
|
93
|
-
for
|
93
|
+
for _ in range(self.K):
|
94
94
|
# propagate_type: (x: Tensor, edge_weight: OptTensor)
|
95
95
|
x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
|
96
96
|
if self.cached:
|
@@ -132,7 +132,8 @@ class SplineConv(MessagePassing):
|
|
132
132
|
if not x[0].is_cuda:
|
133
133
|
warnings.warn(
|
134
134
|
'We do not recommend using the non-optimized CPU version of '
|
135
|
-
'`SplineConv`. If possible, please move your data to GPU.'
|
135
|
+
'`SplineConv`. If possible, please move your data to GPU.',
|
136
|
+
stacklevel=2)
|
136
137
|
|
137
138
|
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
|
138
139
|
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
|
@@ -100,7 +100,7 @@ class SSGConv(MessagePassing):
|
|
100
100
|
self.add_self_loops, self.flow, dtype=x.dtype)
|
101
101
|
|
102
102
|
h = x * self.alpha
|
103
|
-
for
|
103
|
+
for _ in range(self.K):
|
104
104
|
# propagate_type: (x: Tensor, edge_weight: OptTensor)
|
105
105
|
x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
|
106
106
|
h = h + (1 - self.alpha) / self.K * x
|
@@ -57,10 +57,11 @@ class DataParallel(torch.nn.DataParallel):
|
|
57
57
|
follow_batch=None, exclude_keys=None):
|
58
58
|
super().__init__(module, device_ids, output_device)
|
59
59
|
|
60
|
-
warnings.warn(
|
61
|
-
|
62
|
-
|
63
|
-
|
60
|
+
warnings.warn(
|
61
|
+
"'DataParallel' is usually much slower than "
|
62
|
+
"'DistributedDataParallel' even on a single machine. "
|
63
|
+
"Please consider switching to 'DistributedDataParallel' "
|
64
|
+
"for multi-GPU training.", stacklevel=2)
|
64
65
|
|
65
66
|
self.src_device = torch.device(f'cuda:{self.device_ids[0]}')
|
66
67
|
self.follow_batch = follow_batch or []
|
torch_geometric/nn/fx.py
CHANGED
@@ -130,11 +130,13 @@ class Transformer:
|
|
130
130
|
# (node-level, edge-level) by filling `self._state`:
|
131
131
|
for node in list(self.graph.nodes):
|
132
132
|
if node.op == 'call_function' and 'training' in node.kwargs:
|
133
|
-
warnings.warn(
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
133
|
+
warnings.warn(
|
134
|
+
f"Found function '{node.name}' with keyword "
|
135
|
+
f"argument 'training'. During FX tracing, this "
|
136
|
+
f"will likely be baked in as a constant value. "
|
137
|
+
f"Consider replacing this function by a module "
|
138
|
+
f"to properly encapsulate its training flag.",
|
139
|
+
stacklevel=2)
|
138
140
|
|
139
141
|
if node.op == 'placeholder':
|
140
142
|
if node.name not in self._state:
|
@@ -160,7 +160,7 @@ class AttentiveFP(torch.nn.Module):
|
|
160
160
|
edge_index = torch.stack([row, batch], dim=0)
|
161
161
|
|
162
162
|
out = global_add_pool(x, batch).relu_()
|
163
|
-
for
|
163
|
+
for _ in range(self.num_timesteps):
|
164
164
|
h = F.elu_(self.mol_conv((x, out), edge_index))
|
165
165
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
166
166
|
out = self.mol_gru(h, out).relu_()
|
@@ -106,7 +106,7 @@ class DeepGraphInfomax(torch.nn.Module):
|
|
106
106
|
"""
|
107
107
|
from sklearn.linear_model import LogisticRegression
|
108
108
|
|
109
|
-
clf = LogisticRegression(solver=solver,
|
109
|
+
clf = LogisticRegression(*args, solver=solver,
|
110
110
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
111
111
|
train_y.detach().cpu().numpy())
|
112
112
|
return clf.score(test_z.detach().cpu().numpy(),
|
@@ -37,20 +37,28 @@ class GLEM(torch.nn.Module):
|
|
37
37
|
See `examples/llm_plus_gnn/glem.py` for example usage.
|
38
38
|
"""
|
39
39
|
def __init__(
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
40
|
+
self,
|
41
|
+
lm_to_use: str = 'prajjwal1/bert-tiny',
|
42
|
+
gnn_to_use: basic_gnn = GraphSAGE,
|
43
|
+
out_channels: int = 47,
|
44
|
+
gnn_loss: Optional[nn.Module] = None,
|
45
|
+
lm_loss: Optional[nn.Module] = None,
|
46
|
+
alpha: float = 0.5,
|
47
|
+
beta: float = 0.5,
|
48
|
+
lm_dtype: torch.dtype = torch.bfloat16,
|
49
|
+
lm_use_lora: bool = True,
|
50
|
+
lora_target_modules: Optional[Union[List[str], str]] = None,
|
51
|
+
device: Optional[Union[str, torch.device]] = None,
|
52
52
|
):
|
53
53
|
super().__init__()
|
54
|
+
|
55
|
+
if gnn_loss is None:
|
56
|
+
gnn_loss = nn.CrossEntropyLoss(reduction='mean')
|
57
|
+
if lm_loss is None:
|
58
|
+
lm_loss = nn.CrossEntropyLoss(reduction='mean')
|
59
|
+
if device is None:
|
60
|
+
device = torch.device('cpu')
|
61
|
+
|
54
62
|
self.device = device
|
55
63
|
self.lm_loss = lm_loss
|
56
64
|
self.gnn = gnn_to_use
|
@@ -716,11 +716,18 @@ class GPSENodeEncoder(torch.nn.Module):
|
|
716
716
|
|
717
717
|
|
718
718
|
@torch.no_grad()
|
719
|
-
def gpse_process(
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
719
|
+
def gpse_process(
|
720
|
+
model: Module,
|
721
|
+
data: Data,
|
722
|
+
rand_type: str,
|
723
|
+
use_vn: bool = True,
|
724
|
+
bernoulli_thresh: float = 0.5,
|
725
|
+
neighbor_loader: bool = False,
|
726
|
+
num_neighbors: Optional[List[int]] = None,
|
727
|
+
fillval: int = 5,
|
728
|
+
layers_mp: int = None,
|
729
|
+
**kwargs,
|
730
|
+
) -> torch.Tensor:
|
724
731
|
r"""Processes the data using the :class:`GPSE` model to generate and append
|
725
732
|
GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a
|
726
733
|
single :class:`~torch_geometric.data.Dataset` object.
|
@@ -784,6 +791,8 @@ def gpse_process(model: Module, data: Data, rand_type: str,
|
|
784
791
|
if layers_mp is None:
|
785
792
|
raise ValueError('Please provide the number of message-passing '
|
786
793
|
'layers as "layers_mp".')
|
794
|
+
|
795
|
+
num_neighbors = num_neighbors or [30, 20, 10]
|
787
796
|
diff = layers_mp - len(num_neighbors)
|
788
797
|
if fillval > 0 and diff > 0:
|
789
798
|
num_neighbors += [fillval] * diff
|
@@ -792,7 +801,7 @@ def gpse_process(model: Module, data: Data, rand_type: str,
|
|
792
801
|
shuffle=False, pin_memory=True, **kwargs)
|
793
802
|
out_list = []
|
794
803
|
pbar = trange(data.num_nodes, position=2)
|
795
|
-
for
|
804
|
+
for batch in loader:
|
796
805
|
out, _ = model(batch.to(device))
|
797
806
|
out = out[:batch.batch_size].to("cpu", non_blocking=True)
|
798
807
|
out_list.append(out)
|
@@ -806,12 +815,18 @@ def gpse_process(model: Module, data: Data, rand_type: str,
|
|
806
815
|
|
807
816
|
|
808
817
|
@torch.no_grad()
|
809
|
-
def gpse_process_batch(
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
818
|
+
def gpse_process_batch(
|
819
|
+
model: GPSE,
|
820
|
+
batch,
|
821
|
+
rand_type: str,
|
822
|
+
use_vn: bool = True,
|
823
|
+
bernoulli_thresh: float = 0.5,
|
824
|
+
neighbor_loader: bool = False,
|
825
|
+
num_neighbors: Optional[List[int]] = None,
|
826
|
+
fillval: int = 5,
|
827
|
+
layers_mp: int = None,
|
828
|
+
**kwargs,
|
829
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
815
830
|
r"""Process a batch of data using the :class:`GPSE` model to generate and
|
816
831
|
append :class:`GPSE` encodings. Identical to `gpse_process`, but operates
|
817
832
|
on a batch of :class:`~torch_geometric.data.Data` objects.
|
@@ -881,6 +896,8 @@ def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
|
|
881
896
|
if layers_mp is None:
|
882
897
|
raise ValueError('Please provide the number of message-passing '
|
883
898
|
'layers as "layers_mp".')
|
899
|
+
|
900
|
+
num_neighbors = num_neighbors or [30, 20, 10]
|
884
901
|
diff = layers_mp - len(num_neighbors)
|
885
902
|
if fillval > 0 and diff > 0:
|
886
903
|
num_neighbors += [fillval] * diff
|
@@ -889,7 +906,7 @@ def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
|
|
889
906
|
shuffle=False, pin_memory=True, **kwargs)
|
890
907
|
out_list = []
|
891
908
|
pbar = trange(batch.num_nodes, position=2)
|
892
|
-
for
|
909
|
+
for batch in loader:
|
893
910
|
out, _ = model(batch.to(device))
|
894
911
|
out = out[:batch.batch_size].to('cpu', non_blocking=True)
|
895
912
|
out_list.append(out)
|
@@ -64,7 +64,7 @@ class GraphUNet(torch.nn.Module):
|
|
64
64
|
in_channels = channels if sum_res else 2 * channels
|
65
65
|
|
66
66
|
self.up_convs = torch.nn.ModuleList()
|
67
|
-
for
|
67
|
+
for _ in range(depth - 1):
|
68
68
|
self.up_convs.append(GCNConv(in_channels, channels, improved=True))
|
69
69
|
self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))
|
70
70
|
|
@@ -233,7 +233,7 @@ class MetaPath2Vec(torch.nn.Module):
|
|
233
233
|
"""
|
234
234
|
from sklearn.linear_model import LogisticRegression
|
235
235
|
|
236
|
-
clf = LogisticRegression(solver=solver,
|
236
|
+
clf = LogisticRegression(*args, solver=solver,
|
237
237
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
238
238
|
train_y.detach().cpu().numpy())
|
239
239
|
return clf.score(test_z.detach().cpu().numpy(),
|
torch_geometric/nn/models/mlp.py
CHANGED
@@ -99,8 +99,10 @@ class MLP(torch.nn.Module):
|
|
99
99
|
act_first = act_first or kwargs.get("relu_first", False)
|
100
100
|
batch_norm = kwargs.get("batch_norm", None)
|
101
101
|
if batch_norm is not None and isinstance(batch_norm, bool):
|
102
|
-
warnings.warn(
|
103
|
-
|
102
|
+
warnings.warn(
|
103
|
+
"Argument `batch_norm` is deprecated, "
|
104
|
+
"please use `norm` to specify normalization layer.",
|
105
|
+
stacklevel=2)
|
104
106
|
norm = 'batch_norm' if batch_norm else None
|
105
107
|
batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
|
106
108
|
norm_kwargs = batch_norm_kwargs or {}
|
@@ -181,7 +181,7 @@ class Node2Vec(torch.nn.Module):
|
|
181
181
|
"""
|
182
182
|
from sklearn.linear_model import LogisticRegression
|
183
183
|
|
184
|
-
clf = LogisticRegression(solver=solver,
|
184
|
+
clf = LogisticRegression(*args, solver=solver,
|
185
185
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
186
186
|
train_y.detach().cpu().numpy())
|
187
187
|
return clf.score(test_z.detach().cpu().numpy(),
|
@@ -249,7 +249,7 @@ class GroupAddRev(InvertibleModule):
|
|
249
249
|
else:
|
250
250
|
assert num_groups is not None, "Please specific 'num_groups'"
|
251
251
|
self.convs = torch.nn.ModuleList([conv])
|
252
|
-
for
|
252
|
+
for _ in range(num_groups - 1):
|
253
253
|
conv = copy.deepcopy(self.convs[0])
|
254
254
|
if hasattr(conv, 'reset_parameters'):
|
255
255
|
conv.reset_parameters()
|
@@ -45,7 +45,7 @@ class SignedGCN(torch.nn.Module):
|
|
45
45
|
self.conv1 = SignedConv(in_channels, hidden_channels // 2,
|
46
46
|
first_aggr=True)
|
47
47
|
self.convs = torch.nn.ModuleList()
|
48
|
-
for
|
48
|
+
for _ in range(num_layers - 1):
|
49
49
|
self.convs.append(
|
50
50
|
SignedConv(hidden_channels // 2, hidden_channels // 2,
|
51
51
|
first_aggr=False))
|
torch_geometric/nn/nlp/llm.py
CHANGED
@@ -94,7 +94,8 @@ class LLM(torch.nn.Module):
|
|
94
94
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
95
95
|
|
96
96
|
if 'max_memory' not in kwargs: # Pure CPU:
|
97
|
-
warnings.warn("LLM is being used on CPU, which may be slow"
|
97
|
+
warnings.warn("LLM is being used on CPU, which may be slow",
|
98
|
+
stacklevel=2)
|
98
99
|
self.device = torch.device('cpu')
|
99
100
|
self.autocast_context = nullcontext()
|
100
101
|
else:
|
@@ -163,8 +163,10 @@ def knn_graph(
|
|
163
163
|
:rtype: :class:`torch.Tensor`
|
164
164
|
"""
|
165
165
|
if batch is not None and x.device != batch.device:
|
166
|
-
warnings.warn(
|
167
|
-
|
166
|
+
warnings.warn(
|
167
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
168
|
+
"in 'knn_graph'. Performing blocking device transfer",
|
169
|
+
stacklevel=2)
|
168
170
|
batch = batch.to(x.device)
|
169
171
|
|
170
172
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
@@ -285,8 +287,10 @@ def radius_graph(
|
|
285
287
|
inputs to GPU before proceeding.
|
286
288
|
"""
|
287
289
|
if batch is not None and x.device != batch.device:
|
288
|
-
warnings.warn(
|
289
|
-
|
290
|
+
warnings.warn(
|
291
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
292
|
+
"in 'radius_graph'. Performing blocking device transfer",
|
293
|
+
stacklevel=2)
|
290
294
|
batch = batch.to(x.device)
|
291
295
|
|
292
296
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
torch_geometric/nn/pool/knn.py
CHANGED
@@ -91,9 +91,10 @@ class KNNIndex:
|
|
91
91
|
if hasattr(self.index, 'reserveMemory'):
|
92
92
|
self.index.reserveMemory(self.reserve)
|
93
93
|
else:
|
94
|
-
warnings.warn(
|
95
|
-
|
96
|
-
|
94
|
+
warnings.warn(
|
95
|
+
f"'{self.index.__class__.__name__}' "
|
96
|
+
f"does not support pre-allocation of "
|
97
|
+
f"memory", stacklevel=2)
|
97
98
|
|
98
99
|
self.index.train(emb)
|
99
100
|
|
@@ -135,14 +136,16 @@ class KNNIndex:
|
|
135
136
|
query_k = min(query_k, self.numel)
|
136
137
|
|
137
138
|
if k > 2048: # `faiss` supports up-to `k=2048`:
|
138
|
-
warnings.warn(
|
139
|
-
|
140
|
-
|
139
|
+
warnings.warn(
|
140
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
141
|
+
f"(got {k}). This may cause some relevant items to "
|
142
|
+
f"not be retrieved.", stacklevel=2)
|
141
143
|
elif query_k > 2048:
|
142
|
-
warnings.warn(
|
143
|
-
|
144
|
-
|
145
|
-
|
144
|
+
warnings.warn(
|
145
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
146
|
+
f"(got {k} which got extended to {query_k} due to "
|
147
|
+
f"the exclusion of existing links). This may cause "
|
148
|
+
f"some relevant items to not be retrieved.", stacklevel=2)
|
146
149
|
query_k = 2048
|
147
150
|
|
148
151
|
score, index = self.index.search(emb.detach(), query_k)
|
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
|
|
108
108
|
|
109
109
|
if (not hasattr(module, 'reset_parameters')
|
110
110
|
and sum([p.numel() for p in module.parameters()]) > 0):
|
111
|
-
warnings.warn(
|
112
|
-
|
113
|
-
|
111
|
+
warnings.warn(
|
112
|
+
f"'{module}' will be duplicated, but its parameters "
|
113
|
+
f"cannot be reset. To suppress this warning, add a "
|
114
|
+
f"'reset_parameters()' method to '{module}'", stacklevel=2)
|
114
115
|
|
115
116
|
convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
|
116
117
|
self.hetero_module = HeteroConv(convs, aggr)
|
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
|
|
157
157
|
f"There exist node types ({unused_node_types}) whose "
|
158
158
|
f"representations do not get updated during message passing "
|
159
159
|
f"as they do not occur as destination type in any edge type. "
|
160
|
-
f"This may lead to unexpected behavior.")
|
160
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
161
161
|
|
162
162
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
163
163
|
for name in names:
|
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
|
|
166
166
|
f"The type '{name}' contains invalid characters which "
|
167
167
|
f"may lead to unexpected behavior. To avoid any issues, "
|
168
168
|
f"ensure that your types only contain letters, numbers "
|
169
|
-
f"and underscores.")
|
169
|
+
f"and underscores.", stacklevel=2)
|
170
170
|
|
171
171
|
def placeholder(self, node: Node, target: Any, name: str):
|
172
172
|
# Adds a `get` call to the input dictionary for every node-type or
|
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
|
|
379
379
|
warnings.warn(
|
380
380
|
f"'{target}' will be duplicated, but its parameters "
|
381
381
|
f"cannot be reset. To suppress this warning, add a "
|
382
|
-
f"'reset_parameters()' method to '{target}'")
|
382
|
+
f"'reset_parameters()' method to '{target}'", stacklevel=2)
|
383
383
|
|
384
384
|
return module_dict
|
385
385
|
|
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
165
165
|
f"There exist node types ({unused_node_types}) whose "
|
166
166
|
f"representations do not get updated during message passing "
|
167
167
|
f"as they do not occur as destination type in any edge type. "
|
168
|
-
f"This may lead to unexpected behavior.")
|
168
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
169
169
|
|
170
170
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
171
171
|
for name in names:
|
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
174
174
|
f"The type '{name}' contains invalid characters which "
|
175
175
|
f"may lead to unexpected behavior. To avoid any issues, "
|
176
176
|
f"ensure that your types only contain letters, numbers "
|
177
|
-
f"and underscores.")
|
177
|
+
f"and underscores.", stacklevel=2)
|
178
178
|
|
179
179
|
def transform(self) -> GraphModule:
|
180
180
|
self._node_offset_dict_initialized = False
|
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
|
|
361
361
|
warnings.warn(
|
362
362
|
f"'{conv}' will be duplicated, but its parameters cannot "
|
363
363
|
f"be reset. To suppress this warning, add a "
|
364
|
-
f"'reset_parameters()' method to '{conv}'")
|
364
|
+
f"'reset_parameters()' method to '{conv}'", stacklevel=2)
|
365
365
|
torch.nn.init.xavier_uniform_(conv.edge_type_weight)
|
366
366
|
|
367
367
|
def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
|
torch_geometric/sampler/base.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import copy
|
2
2
|
import math
|
3
3
|
import warnings
|
4
|
-
from abc import ABC
|
4
|
+
from abc import ABC, abstractmethod
|
5
5
|
from collections import defaultdict
|
6
6
|
from dataclasses import dataclass
|
7
7
|
from enum import Enum
|
@@ -369,9 +369,10 @@ class HeteroSamplerOutput(CastMixin):
|
|
369
369
|
out.edge[edge_type] = None
|
370
370
|
|
371
371
|
else:
|
372
|
-
warnings.warn(
|
373
|
-
|
374
|
-
|
372
|
+
warnings.warn(
|
373
|
+
f"Cannot convert to bidirectional graph "
|
374
|
+
f"since the edge type {edge_type} does not "
|
375
|
+
f"seem to have a reverse edge type", stacklevel=2)
|
375
376
|
|
376
377
|
return out
|
377
378
|
|
@@ -622,6 +623,7 @@ class BaseSampler(ABC):
|
|
622
623
|
As such, it is recommended to limit the amount of information stored in
|
623
624
|
the sampler.
|
624
625
|
"""
|
626
|
+
@abstractmethod
|
625
627
|
def sample_from_nodes(
|
626
628
|
self,
|
627
629
|
index: NodeSamplerInput,
|
@@ -642,6 +644,7 @@ class BaseSampler(ABC):
|
|
642
644
|
"""
|
643
645
|
raise NotImplementedError
|
644
646
|
|
647
|
+
@abstractmethod
|
645
648
|
def sample_from_edges(
|
646
649
|
self,
|
647
650
|
index: EdgeSamplerInput,
|
@@ -1,12 +1,15 @@
|
|
1
|
-
from typing import Dict, List, Union
|
1
|
+
from typing import Dict, List, Optional, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from torch_geometric.data import Data, HeteroData
|
6
6
|
from torch_geometric.sampler import (
|
7
7
|
BaseSampler,
|
8
|
+
EdgeSamplerInput,
|
8
9
|
HeteroSamplerOutput,
|
10
|
+
NegativeSampling,
|
9
11
|
NodeSamplerInput,
|
12
|
+
SamplerOutput,
|
10
13
|
)
|
11
14
|
from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
|
12
15
|
from torch_geometric.typing import (
|
@@ -76,6 +79,13 @@ class HGTSampler(BaseSampler):
|
|
76
79
|
metadata=(inputs.input_id, inputs.time),
|
77
80
|
)
|
78
81
|
|
82
|
+
def sample_from_edges(
|
83
|
+
self,
|
84
|
+
index: EdgeSamplerInput,
|
85
|
+
neg_sampling: Optional[NegativeSampling] = None,
|
86
|
+
) -> Union[HeteroSamplerOutput, SamplerOutput]:
|
87
|
+
pass
|
88
|
+
|
79
89
|
@property
|
80
90
|
def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
|
81
91
|
return self.perm
|
@@ -52,16 +52,18 @@ class NeighborSampler(BaseSampler):
|
|
52
52
|
):
|
53
53
|
if not directed:
|
54
54
|
subgraph_type = SubgraphType.induced
|
55
|
-
warnings.warn(
|
56
|
-
|
57
|
-
|
55
|
+
warnings.warn(
|
56
|
+
f"The usage of the 'directed' argument in "
|
57
|
+
f"'{self.__class__.__name__}' is deprecated. Use "
|
58
|
+
f"`subgraph_type='induced'` instead.", stacklevel=2)
|
58
59
|
|
59
60
|
if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'
|
60
61
|
and subgraph_type != SubgraphType.induced):
|
61
|
-
warnings.warn(
|
62
|
-
|
63
|
-
|
64
|
-
|
62
|
+
warnings.warn(
|
63
|
+
f"Using '{self.__class__.__name__}' without a "
|
64
|
+
f"'pyg-lib' installation is deprecated and will be "
|
65
|
+
f"removed soon. Please install 'pyg-lib' for "
|
66
|
+
f"accelerated neighborhood sampling", stacklevel=2)
|
65
67
|
|
66
68
|
self.data_type = DataType.from_data(data)
|
67
69
|
|
@@ -806,7 +808,7 @@ def neg_sample(
|
|
806
808
|
out = out.view(num_samples, seed.numel())
|
807
809
|
mask = node_time[out] > seed_time # holds all invalid samples.
|
808
810
|
neg_sampling_complete = False
|
809
|
-
for
|
811
|
+
for _ in range(5): # pragma: no cover
|
810
812
|
num_invalid = int(mask.sum())
|
811
813
|
if num_invalid == 0:
|
812
814
|
neg_sampling_complete = True
|
@@ -252,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
|
|
252
252
|
if device:
|
253
253
|
backend = os.getenv('TORCH_BACKEND')
|
254
254
|
if backend is None:
|
255
|
-
warnings.warn(
|
256
|
-
|
255
|
+
warnings.warn(
|
256
|
+
f"Please specify the backend via 'TORCH_BACKEND' in"
|
257
|
+
f"order to test against '{device}'", stacklevel=2)
|
257
258
|
else:
|
258
259
|
import_module(backend)
|
259
260
|
devices.append(pytest.param(torch.device(device), id=device))
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
1
3
|
from torch.nn import Module
|
2
4
|
|
3
5
|
from torch_geometric.data import Data
|
@@ -22,13 +24,20 @@ class AddGPSE(BaseTransform):
|
|
22
24
|
(default: :obj:`NormalSE`)
|
23
25
|
|
24
26
|
"""
|
25
|
-
def __init__(
|
26
|
-
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
model: Module,
|
30
|
+
use_vn: bool = True,
|
31
|
+
rand_type: str = 'NormalSE',
|
32
|
+
):
|
27
33
|
self.model = model
|
28
34
|
self.use_vn = use_vn
|
29
35
|
self.vn = VirtualNode()
|
30
36
|
self.rand_type = rand_type
|
31
37
|
|
38
|
+
def forward(self, data: Data) -> Any:
|
39
|
+
pass
|
40
|
+
|
32
41
|
def __call__(self, data: Data) -> Data:
|
33
42
|
from torch_geometric.nn.models.gpse import gpse_process
|
34
43
|
|
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
|
|
108
108
|
**kwargs: bool,
|
109
109
|
) -> None:
|
110
110
|
if 'drop_orig_edges' in kwargs:
|
111
|
-
warnings.warn(
|
112
|
-
|
111
|
+
warnings.warn(
|
112
|
+
"'drop_orig_edges' is deprecated. Use "
|
113
|
+
"'drop_orig_edge_types' instead", stacklevel=2)
|
113
114
|
drop_orig_edge_types = kwargs['drop_orig_edges']
|
114
115
|
|
115
116
|
if 'drop_unconnected_nodes' in kwargs:
|
116
|
-
warnings.warn(
|
117
|
-
|
117
|
+
warnings.warn(
|
118
|
+
"'drop_unconnected_nodes' is deprecated. Use "
|
119
|
+
"'drop_unconnected_node_types' instead", stacklevel=2)
|
118
120
|
drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
|
119
121
|
|
120
122
|
for path in metapaths:
|
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
|
|
144
146
|
if self.max_sample is not None:
|
145
147
|
edge_index, edge_weight = self._sample(edge_index, edge_weight)
|
146
148
|
|
147
|
-
for
|
149
|
+
for edge_type in metapath[1:]:
|
148
150
|
edge_index2, edge_weight2 = self._edge_index(data, edge_type)
|
149
151
|
|
150
152
|
edge_index, edge_weight = edge_index.matmul(
|
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
|
|
276
278
|
row = start = torch.randperm(num_nodes)[:num_starts].repeat(
|
277
279
|
self.walks_per_node[j])
|
278
280
|
|
279
|
-
for
|
281
|
+
for edge_type in metapath:
|
280
282
|
edge_index = EdgeIndex(
|
281
283
|
data[edge_type].edge_index,
|
282
284
|
sparse_size=data[edge_type].size(),
|