pyg-nightly 2.7.0.dev20250607__py3-none-any.whl → 2.7.0.dev20250608__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/METADATA +3 -2
  2. {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/RECORD +79 -79
  3. torch_geometric/__init__.py +5 -4
  4. torch_geometric/_compile.py +3 -2
  5. torch_geometric/contrib/__init__.py +1 -1
  6. torch_geometric/data/data.py +3 -3
  7. torch_geometric/data/database.py +4 -0
  8. torch_geometric/data/dataset.py +9 -6
  9. torch_geometric/data/hetero_data.py +7 -6
  10. torch_geometric/data/hypergraph_data.py +1 -1
  11. torch_geometric/data/in_memory_dataset.py +2 -2
  12. torch_geometric/data/large_graph_indexer.py +1 -1
  13. torch_geometric/data/lightning/datamodule.py +28 -20
  14. torch_geometric/data/storage.py +1 -1
  15. torch_geometric/datasets/dbp15k.py +1 -1
  16. torch_geometric/datasets/molecule_net.py +3 -2
  17. torch_geometric/datasets/tag_dataset.py +1 -1
  18. torch_geometric/datasets/wikics.py +2 -1
  19. torch_geometric/deprecation.py +1 -1
  20. torch_geometric/distributed/rpc.py +2 -2
  21. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  22. torch_geometric/explain/algorithm/graphmask_explainer.py +7 -7
  23. torch_geometric/explain/explainer.py +1 -1
  24. torch_geometric/graphgym/config.py +3 -2
  25. torch_geometric/graphgym/imports.py +4 -2
  26. torch_geometric/graphgym/logger.py +1 -1
  27. torch_geometric/graphgym/models/encoder.py +2 -2
  28. torch_geometric/hash_tensor.py +5 -4
  29. torch_geometric/io/fs.py +5 -4
  30. torch_geometric/loader/ibmb_loader.py +4 -4
  31. torch_geometric/loader/mixin.py +2 -1
  32. torch_geometric/loader/prefetch.py +3 -2
  33. torch_geometric/nn/aggr/fused.py +1 -1
  34. torch_geometric/nn/conv/appnp.py +1 -1
  35. torch_geometric/nn/conv/gen_conv.py +1 -1
  36. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  37. torch_geometric/nn/conv/hetero_conv.py +2 -1
  38. torch_geometric/nn/conv/meshcnn_conv.py +6 -4
  39. torch_geometric/nn/conv/message_passing.py +3 -2
  40. torch_geometric/nn/conv/sg_conv.py +1 -1
  41. torch_geometric/nn/conv/spline_conv.py +2 -1
  42. torch_geometric/nn/conv/ssg_conv.py +1 -1
  43. torch_geometric/nn/data_parallel.py +5 -4
  44. torch_geometric/nn/fx.py +7 -5
  45. torch_geometric/nn/models/attentive_fp.py +1 -1
  46. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  47. torch_geometric/nn/models/glem.py +20 -12
  48. torch_geometric/nn/models/gpse.py +2 -2
  49. torch_geometric/nn/models/graph_unet.py +1 -1
  50. torch_geometric/nn/models/metapath2vec.py +1 -1
  51. torch_geometric/nn/models/mlp.py +4 -2
  52. torch_geometric/nn/models/node2vec.py +1 -1
  53. torch_geometric/nn/models/rev_gnn.py +1 -1
  54. torch_geometric/nn/models/signed_gcn.py +1 -1
  55. torch_geometric/nn/nlp/llm.py +2 -1
  56. torch_geometric/nn/pool/__init__.py +8 -4
  57. torch_geometric/nn/pool/knn.py +13 -10
  58. torch_geometric/nn/to_hetero_module.py +4 -3
  59. torch_geometric/nn/to_hetero_transformer.py +3 -3
  60. torch_geometric/nn/to_hetero_with_bases_transformer.py +3 -3
  61. torch_geometric/sampler/base.py +7 -4
  62. torch_geometric/sampler/hgt_sampler.py +11 -1
  63. torch_geometric/sampler/neighbor_sampler.py +10 -8
  64. torch_geometric/testing/decorators.py +3 -2
  65. torch_geometric/testing/distributed.py +1 -1
  66. torch_geometric/transforms/add_gpse.py +11 -2
  67. torch_geometric/transforms/add_metapaths.py +8 -6
  68. torch_geometric/transforms/base_transform.py +2 -1
  69. torch_geometric/transforms/largest_connected_components.py +1 -1
  70. torch_geometric/transforms/random_link_split.py +1 -1
  71. torch_geometric/typing.py +13 -9
  72. torch_geometric/utils/_scatter.py +8 -6
  73. torch_geometric/utils/_spmm.py +15 -12
  74. torch_geometric/utils/convert.py +2 -2
  75. torch_geometric/utils/embedding.py +5 -3
  76. torch_geometric/utils/geodesic.py +4 -3
  77. torch_geometric/utils/sparse.py +3 -2
  78. {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/WHEEL +0 -0
  79. {pyg_nightly-2.7.0.dev20250607.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/licenses/LICENSE +0 -0
@@ -106,7 +106,7 @@ class DeepGraphInfomax(torch.nn.Module):
106
106
  """
107
107
  from sklearn.linear_model import LogisticRegression
108
108
 
109
- clf = LogisticRegression(solver=solver, *args,
109
+ clf = LogisticRegression(*args, solver=solver,
110
110
  **kwargs).fit(train_z.detach().cpu().numpy(),
111
111
  train_y.detach().cpu().numpy())
112
112
  return clf.score(test_z.detach().cpu().numpy(),
@@ -37,20 +37,28 @@ class GLEM(torch.nn.Module):
37
37
  See `examples/llm_plus_gnn/glem.py` for example usage.
38
38
  """
39
39
  def __init__(
40
- self,
41
- lm_to_use: str = 'prajjwal1/bert-tiny',
42
- gnn_to_use: basic_gnn = GraphSAGE,
43
- out_channels: int = 47,
44
- gnn_loss=nn.CrossEntropyLoss(reduction='mean'),
45
- lm_loss=nn.CrossEntropyLoss(reduction='mean'),
46
- alpha: float = 0.5,
47
- beta: float = 0.5,
48
- lm_dtype: torch.dtype = torch.bfloat16,
49
- lm_use_lora: bool = True,
50
- lora_target_modules: Optional[Union[List[str], str]] = None,
51
- device: Union[str, torch.device] = torch.device('cpu'),
40
+ self,
41
+ lm_to_use: str = 'prajjwal1/bert-tiny',
42
+ gnn_to_use: basic_gnn = GraphSAGE,
43
+ out_channels: int = 47,
44
+ gnn_loss: Optional[nn.Module] = None,
45
+ lm_loss: Optional[nn.Module] = None,
46
+ alpha: float = 0.5,
47
+ beta: float = 0.5,
48
+ lm_dtype: torch.dtype = torch.bfloat16,
49
+ lm_use_lora: bool = True,
50
+ lora_target_modules: Optional[Union[List[str], str]] = None,
51
+ device: Optional[Union[str, torch.device]] = None,
52
52
  ):
53
53
  super().__init__()
54
+
55
+ if gnn_loss is None:
56
+ gnn_loss = nn.CrossEntropyLoss(reduction='mean')
57
+ if lm_loss is None:
58
+ lm_loss = nn.CrossEntropyLoss(reduction='mean')
59
+ if device is None:
60
+ device = torch.device('cpu')
61
+
54
62
  self.device = device
55
63
  self.lm_loss = lm_loss
56
64
  self.gnn = gnn_to_use
@@ -801,7 +801,7 @@ def gpse_process(
801
801
  shuffle=False, pin_memory=True, **kwargs)
802
802
  out_list = []
803
803
  pbar = trange(data.num_nodes, position=2)
804
- for i, batch in enumerate(loader):
804
+ for batch in loader:
805
805
  out, _ = model(batch.to(device))
806
806
  out = out[:batch.batch_size].to("cpu", non_blocking=True)
807
807
  out_list.append(out)
@@ -906,7 +906,7 @@ def gpse_process_batch(
906
906
  shuffle=False, pin_memory=True, **kwargs)
907
907
  out_list = []
908
908
  pbar = trange(batch.num_nodes, position=2)
909
- for i, batch in enumerate(loader):
909
+ for batch in loader:
910
910
  out, _ = model(batch.to(device))
911
911
  out = out[:batch.batch_size].to('cpu', non_blocking=True)
912
912
  out_list.append(out)
@@ -64,7 +64,7 @@ class GraphUNet(torch.nn.Module):
64
64
  in_channels = channels if sum_res else 2 * channels
65
65
 
66
66
  self.up_convs = torch.nn.ModuleList()
67
- for i in range(depth - 1):
67
+ for _ in range(depth - 1):
68
68
  self.up_convs.append(GCNConv(in_channels, channels, improved=True))
69
69
  self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))
70
70
 
@@ -233,7 +233,7 @@ class MetaPath2Vec(torch.nn.Module):
233
233
  """
234
234
  from sklearn.linear_model import LogisticRegression
235
235
 
236
- clf = LogisticRegression(solver=solver, *args,
236
+ clf = LogisticRegression(*args, solver=solver,
237
237
  **kwargs).fit(train_z.detach().cpu().numpy(),
238
238
  train_y.detach().cpu().numpy())
239
239
  return clf.score(test_z.detach().cpu().numpy(),
@@ -99,8 +99,10 @@ class MLP(torch.nn.Module):
99
99
  act_first = act_first or kwargs.get("relu_first", False)
100
100
  batch_norm = kwargs.get("batch_norm", None)
101
101
  if batch_norm is not None and isinstance(batch_norm, bool):
102
- warnings.warn("Argument `batch_norm` is deprecated, "
103
- "please use `norm` to specify normalization layer.")
102
+ warnings.warn(
103
+ "Argument `batch_norm` is deprecated, "
104
+ "please use `norm` to specify normalization layer.",
105
+ stacklevel=2)
104
106
  norm = 'batch_norm' if batch_norm else None
105
107
  batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
106
108
  norm_kwargs = batch_norm_kwargs or {}
@@ -181,7 +181,7 @@ class Node2Vec(torch.nn.Module):
181
181
  """
182
182
  from sklearn.linear_model import LogisticRegression
183
183
 
184
- clf = LogisticRegression(solver=solver, *args,
184
+ clf = LogisticRegression(*args, solver=solver,
185
185
  **kwargs).fit(train_z.detach().cpu().numpy(),
186
186
  train_y.detach().cpu().numpy())
187
187
  return clf.score(test_z.detach().cpu().numpy(),
@@ -249,7 +249,7 @@ class GroupAddRev(InvertibleModule):
249
249
  else:
250
250
  assert num_groups is not None, "Please specific 'num_groups'"
251
251
  self.convs = torch.nn.ModuleList([conv])
252
- for i in range(num_groups - 1):
252
+ for _ in range(num_groups - 1):
253
253
  conv = copy.deepcopy(self.convs[0])
254
254
  if hasattr(conv, 'reset_parameters'):
255
255
  conv.reset_parameters()
@@ -45,7 +45,7 @@ class SignedGCN(torch.nn.Module):
45
45
  self.conv1 = SignedConv(in_channels, hidden_channels // 2,
46
46
  first_aggr=True)
47
47
  self.convs = torch.nn.ModuleList()
48
- for i in range(num_layers - 1):
48
+ for _ in range(num_layers - 1):
49
49
  self.convs.append(
50
50
  SignedConv(hidden_channels // 2, hidden_channels // 2,
51
51
  first_aggr=False))
@@ -94,7 +94,8 @@ class LLM(torch.nn.Module):
94
94
  self.word_embedding = self.llm.model.get_input_embeddings()
95
95
 
96
96
  if 'max_memory' not in kwargs: # Pure CPU:
97
- warnings.warn("LLM is being used on CPU, which may be slow")
97
+ warnings.warn("LLM is being used on CPU, which may be slow",
98
+ stacklevel=2)
98
99
  self.device = torch.device('cpu')
99
100
  self.autocast_context = nullcontext()
100
101
  else:
@@ -163,8 +163,10 @@ def knn_graph(
163
163
  :rtype: :class:`torch.Tensor`
164
164
  """
165
165
  if batch is not None and x.device != batch.device:
166
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
167
- "in 'knn_graph'. Performing blocking device transfer")
166
+ warnings.warn(
167
+ "Input tensor 'x' and 'batch' are on different devices "
168
+ "in 'knn_graph'. Performing blocking device transfer",
169
+ stacklevel=2)
168
170
  batch = batch.to(x.device)
169
171
 
170
172
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -285,8 +287,10 @@ def radius_graph(
285
287
  inputs to GPU before proceeding.
286
288
  """
287
289
  if batch is not None and x.device != batch.device:
288
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
289
- "in 'radius_graph'. Performing blocking device transfer")
290
+ warnings.warn(
291
+ "Input tensor 'x' and 'batch' are on different devices "
292
+ "in 'radius_graph'. Performing blocking device transfer",
293
+ stacklevel=2)
290
294
  batch = batch.to(x.device)
291
295
 
292
296
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -91,9 +91,10 @@ class KNNIndex:
91
91
  if hasattr(self.index, 'reserveMemory'):
92
92
  self.index.reserveMemory(self.reserve)
93
93
  else:
94
- warnings.warn(f"'{self.index.__class__.__name__}' "
95
- f"does not support pre-allocation of "
96
- f"memory")
94
+ warnings.warn(
95
+ f"'{self.index.__class__.__name__}' "
96
+ f"does not support pre-allocation of "
97
+ f"memory", stacklevel=2)
97
98
 
98
99
  self.index.train(emb)
99
100
 
@@ -135,14 +136,16 @@ class KNNIndex:
135
136
  query_k = min(query_k, self.numel)
136
137
 
137
138
  if k > 2048: # `faiss` supports up-to `k=2048`:
138
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
139
- f"(got {k}). This may cause some relevant items to "
140
- f"not be retrieved.")
139
+ warnings.warn(
140
+ f"Capping 'k' to faiss' upper limit of 2048 "
141
+ f"(got {k}). This may cause some relevant items to "
142
+ f"not be retrieved.", stacklevel=2)
141
143
  elif query_k > 2048:
142
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
143
- f"(got {k} which got extended to {query_k} due to "
144
- f"the exclusion of existing links). This may cause "
145
- f"some relevant items to not be retrieved.")
144
+ warnings.warn(
145
+ f"Capping 'k' to faiss' upper limit of 2048 "
146
+ f"(got {k} which got extended to {query_k} due to "
147
+ f"the exclusion of existing links). This may cause "
148
+ f"some relevant items to not be retrieved.", stacklevel=2)
146
149
  query_k = 2048
147
150
 
148
151
  score, index = self.index.search(emb.detach(), query_k)
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
108
108
 
109
109
  if (not hasattr(module, 'reset_parameters')
110
110
  and sum([p.numel() for p in module.parameters()]) > 0):
111
- warnings.warn(f"'{module}' will be duplicated, but its parameters "
112
- f"cannot be reset. To suppress this warning, add a "
113
- f"'reset_parameters()' method to '{module}'")
111
+ warnings.warn(
112
+ f"'{module}' will be duplicated, but its parameters "
113
+ f"cannot be reset. To suppress this warning, add a "
114
+ f"'reset_parameters()' method to '{module}'", stacklevel=2)
114
115
 
115
116
  convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
116
117
  self.hetero_module = HeteroConv(convs, aggr)
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
157
157
  f"There exist node types ({unused_node_types}) whose "
158
158
  f"representations do not get updated during message passing "
159
159
  f"as they do not occur as destination type in any edge type. "
160
- f"This may lead to unexpected behavior.")
160
+ f"This may lead to unexpected behavior.", stacklevel=2)
161
161
 
162
162
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
163
163
  for name in names:
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
166
166
  f"The type '{name}' contains invalid characters which "
167
167
  f"may lead to unexpected behavior. To avoid any issues, "
168
168
  f"ensure that your types only contain letters, numbers "
169
- f"and underscores.")
169
+ f"and underscores.", stacklevel=2)
170
170
 
171
171
  def placeholder(self, node: Node, target: Any, name: str):
172
172
  # Adds a `get` call to the input dictionary for every node-type or
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
379
379
  warnings.warn(
380
380
  f"'{target}' will be duplicated, but its parameters "
381
381
  f"cannot be reset. To suppress this warning, add a "
382
- f"'reset_parameters()' method to '{target}'")
382
+ f"'reset_parameters()' method to '{target}'", stacklevel=2)
383
383
 
384
384
  return module_dict
385
385
 
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
165
165
  f"There exist node types ({unused_node_types}) whose "
166
166
  f"representations do not get updated during message passing "
167
167
  f"as they do not occur as destination type in any edge type. "
168
- f"This may lead to unexpected behavior.")
168
+ f"This may lead to unexpected behavior.", stacklevel=2)
169
169
 
170
170
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
171
171
  for name in names:
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
174
174
  f"The type '{name}' contains invalid characters which "
175
175
  f"may lead to unexpected behavior. To avoid any issues, "
176
176
  f"ensure that your types only contain letters, numbers "
177
- f"and underscores.")
177
+ f"and underscores.", stacklevel=2)
178
178
 
179
179
  def transform(self) -> GraphModule:
180
180
  self._node_offset_dict_initialized = False
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
361
361
  warnings.warn(
362
362
  f"'{conv}' will be duplicated, but its parameters cannot "
363
363
  f"be reset. To suppress this warning, add a "
364
- f"'reset_parameters()' method to '{conv}'")
364
+ f"'reset_parameters()' method to '{conv}'", stacklevel=2)
365
365
  torch.nn.init.xavier_uniform_(conv.edge_type_weight)
366
366
 
367
367
  def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
@@ -1,7 +1,7 @@
1
1
  import copy
2
2
  import math
3
3
  import warnings
4
- from abc import ABC
4
+ from abc import ABC, abstractmethod
5
5
  from collections import defaultdict
6
6
  from dataclasses import dataclass
7
7
  from enum import Enum
@@ -369,9 +369,10 @@ class HeteroSamplerOutput(CastMixin):
369
369
  out.edge[edge_type] = None
370
370
 
371
371
  else:
372
- warnings.warn(f"Cannot convert to bidirectional graph "
373
- f"since the edge type {edge_type} does not "
374
- f"seem to have a reverse edge type")
372
+ warnings.warn(
373
+ f"Cannot convert to bidirectional graph "
374
+ f"since the edge type {edge_type} does not "
375
+ f"seem to have a reverse edge type", stacklevel=2)
375
376
 
376
377
  return out
377
378
 
@@ -622,6 +623,7 @@ class BaseSampler(ABC):
622
623
  As such, it is recommended to limit the amount of information stored in
623
624
  the sampler.
624
625
  """
626
+ @abstractmethod
625
627
  def sample_from_nodes(
626
628
  self,
627
629
  index: NodeSamplerInput,
@@ -642,6 +644,7 @@ class BaseSampler(ABC):
642
644
  """
643
645
  raise NotImplementedError
644
646
 
647
+ @abstractmethod
645
648
  def sample_from_edges(
646
649
  self,
647
650
  index: EdgeSamplerInput,
@@ -1,12 +1,15 @@
1
- from typing import Dict, List, Union
1
+ from typing import Dict, List, Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from torch_geometric.data import Data, HeteroData
6
6
  from torch_geometric.sampler import (
7
7
  BaseSampler,
8
+ EdgeSamplerInput,
8
9
  HeteroSamplerOutput,
10
+ NegativeSampling,
9
11
  NodeSamplerInput,
12
+ SamplerOutput,
10
13
  )
11
14
  from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
12
15
  from torch_geometric.typing import (
@@ -76,6 +79,13 @@ class HGTSampler(BaseSampler):
76
79
  metadata=(inputs.input_id, inputs.time),
77
80
  )
78
81
 
82
+ def sample_from_edges(
83
+ self,
84
+ index: EdgeSamplerInput,
85
+ neg_sampling: Optional[NegativeSampling] = None,
86
+ ) -> Union[HeteroSamplerOutput, SamplerOutput]:
87
+ pass
88
+
79
89
  @property
80
90
  def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
81
91
  return self.perm
@@ -52,16 +52,18 @@ class NeighborSampler(BaseSampler):
52
52
  ):
53
53
  if not directed:
54
54
  subgraph_type = SubgraphType.induced
55
- warnings.warn(f"The usage of the 'directed' argument in "
56
- f"'{self.__class__.__name__}' is deprecated. Use "
57
- f"`subgraph_type='induced'` instead.")
55
+ warnings.warn(
56
+ f"The usage of the 'directed' argument in "
57
+ f"'{self.__class__.__name__}' is deprecated. Use "
58
+ f"`subgraph_type='induced'` instead.", stacklevel=2)
58
59
 
59
60
  if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'
60
61
  and subgraph_type != SubgraphType.induced):
61
- warnings.warn(f"Using '{self.__class__.__name__}' without a "
62
- f"'pyg-lib' installation is deprecated and will be "
63
- f"removed soon. Please install 'pyg-lib' for "
64
- f"accelerated neighborhood sampling")
62
+ warnings.warn(
63
+ f"Using '{self.__class__.__name__}' without a "
64
+ f"'pyg-lib' installation is deprecated and will be "
65
+ f"removed soon. Please install 'pyg-lib' for "
66
+ f"accelerated neighborhood sampling", stacklevel=2)
65
67
 
66
68
  self.data_type = DataType.from_data(data)
67
69
 
@@ -806,7 +808,7 @@ def neg_sample(
806
808
  out = out.view(num_samples, seed.numel())
807
809
  mask = node_time[out] > seed_time # holds all invalid samples.
808
810
  neg_sampling_complete = False
809
- for i in range(5): # pragma: no cover
811
+ for _ in range(5): # pragma: no cover
810
812
  num_invalid = int(mask.sum())
811
813
  if num_invalid == 0:
812
814
  neg_sampling_complete = True
@@ -252,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
252
252
  if device:
253
253
  backend = os.getenv('TORCH_BACKEND')
254
254
  if backend is None:
255
- warnings.warn(f"Please specify the backend via 'TORCH_BACKEND' in"
256
- f"order to test against '{device}'")
255
+ warnings.warn(
256
+ f"Please specify the backend via 'TORCH_BACKEND' in"
257
+ f"order to test against '{device}'", stacklevel=2)
257
258
  else:
258
259
  import_module(backend)
259
260
  devices.append(pytest.param(torch.device(device), id=device))
@@ -73,7 +73,7 @@ def assert_run_mproc(
73
73
  ]
74
74
  results = []
75
75
 
76
- for p, q in zip(procs, queues):
76
+ for p, _ in zip(procs, queues):
77
77
  p.start()
78
78
 
79
79
  for p, q in zip(procs, queues):
@@ -1,3 +1,5 @@
1
+ from typing import Any
2
+
1
3
  from torch.nn import Module
2
4
 
3
5
  from torch_geometric.data import Data
@@ -22,13 +24,20 @@ class AddGPSE(BaseTransform):
22
24
  (default: :obj:`NormalSE`)
23
25
 
24
26
  """
25
- def __init__(self, model: Module, use_vn: bool = True,
26
- rand_type: str = 'NormalSE'):
27
+ def __init__(
28
+ self,
29
+ model: Module,
30
+ use_vn: bool = True,
31
+ rand_type: str = 'NormalSE',
32
+ ):
27
33
  self.model = model
28
34
  self.use_vn = use_vn
29
35
  self.vn = VirtualNode()
30
36
  self.rand_type = rand_type
31
37
 
38
+ def forward(self, data: Data) -> Any:
39
+ pass
40
+
32
41
  def __call__(self, data: Data) -> Data:
33
42
  from torch_geometric.nn.models.gpse import gpse_process
34
43
 
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
108
108
  **kwargs: bool,
109
109
  ) -> None:
110
110
  if 'drop_orig_edges' in kwargs:
111
- warnings.warn("'drop_orig_edges' is deprecated. Use "
112
- "'drop_orig_edge_types' instead")
111
+ warnings.warn(
112
+ "'drop_orig_edges' is deprecated. Use "
113
+ "'drop_orig_edge_types' instead", stacklevel=2)
113
114
  drop_orig_edge_types = kwargs['drop_orig_edges']
114
115
 
115
116
  if 'drop_unconnected_nodes' in kwargs:
116
- warnings.warn("'drop_unconnected_nodes' is deprecated. Use "
117
- "'drop_unconnected_node_types' instead")
117
+ warnings.warn(
118
+ "'drop_unconnected_nodes' is deprecated. Use "
119
+ "'drop_unconnected_node_types' instead", stacklevel=2)
118
120
  drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
119
121
 
120
122
  for path in metapaths:
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
144
146
  if self.max_sample is not None:
145
147
  edge_index, edge_weight = self._sample(edge_index, edge_weight)
146
148
 
147
- for i, edge_type in enumerate(metapath[1:]):
149
+ for edge_type in metapath[1:]:
148
150
  edge_index2, edge_weight2 = self._edge_index(data, edge_type)
149
151
 
150
152
  edge_index, edge_weight = edge_index.matmul(
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
276
278
  row = start = torch.randperm(num_nodes)[:num_starts].repeat(
277
279
  self.walks_per_node[j])
278
280
 
279
- for i, edge_type in enumerate(metapath):
281
+ for edge_type in metapath:
280
282
  edge_index = EdgeIndex(
281
283
  data[edge_type].edge_index,
282
284
  sparse_size=data[edge_type].size(),
@@ -1,5 +1,5 @@
1
1
  import copy
2
- from abc import ABC
2
+ from abc import ABC, abstractmethod
3
3
  from typing import Any
4
4
 
5
5
 
@@ -31,6 +31,7 @@ class BaseTransform(ABC):
31
31
  # Shallow-copy the data so that we prevent in-place data modification.
32
32
  return self.forward(copy.copy(data))
33
33
 
34
+ @abstractmethod
34
35
  def forward(self, data: Any) -> Any:
35
36
  pass
36
37
 
@@ -47,7 +47,7 @@ class LargestConnectedComponents(BaseTransform):
47
47
  return data
48
48
 
49
49
  _, count = np.unique(component, return_counts=True)
50
- subset_np = np.in1d(component, count.argsort()[-self.num_components:])
50
+ subset_np = np.isin(component, count.argsort()[-self.num_components:])
51
51
  subset = torch.from_numpy(subset_np)
52
52
  subset = subset.to(data.edge_index.device, torch.bool)
53
53
 
@@ -245,7 +245,7 @@ class RandomLinkSplit(BaseTransform):
245
245
  warnings.warn(
246
246
  f"There are not enough negative edges to satisfy "
247
247
  "the provided sampling ratio. The ratio will be "
248
- f"adjusted to {ratio:.2f}.")
248
+ f"adjusted to {ratio:.2f}.", stacklevel=2)
249
249
  num_neg_train = int((num_neg_train / num_neg) * num_neg_found)
250
250
  num_neg_val = int((num_neg_val / num_neg) * num_neg_found)
251
251
  num_neg_test = num_neg_found - num_neg_train - num_neg_val
torch_geometric/typing.py CHANGED
@@ -81,8 +81,9 @@ try:
81
81
  WITH_CUDA_HASH_MAP = False
82
82
  except Exception as e:
83
83
  if not isinstance(e, ImportError): # pragma: no cover
84
- warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
85
- f"Disabling its usage. Stacktrace: {e}")
84
+ warnings.warn(
85
+ f"An issue occurred while importing 'pyg-lib'. "
86
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
86
87
  pyg_lib = object
87
88
  WITH_PYG_LIB = False
88
89
  WITH_GMM = False
@@ -125,8 +126,9 @@ try:
125
126
  WITH_TORCH_SCATTER = True
126
127
  except Exception as e:
127
128
  if not isinstance(e, ImportError): # pragma: no cover
128
- warnings.warn(f"An issue occurred while importing 'torch-scatter'. "
129
- f"Disabling its usage. Stacktrace: {e}")
129
+ warnings.warn(
130
+ f"An issue occurred while importing 'torch-scatter'. "
131
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
130
132
  torch_scatter = object
131
133
  WITH_TORCH_SCATTER = False
132
134
 
@@ -136,8 +138,9 @@ try:
136
138
  WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
137
139
  except Exception as e:
138
140
  if not isinstance(e, ImportError): # pragma: no cover
139
- warnings.warn(f"An issue occurred while importing 'torch-cluster'. "
140
- f"Disabling its usage. Stacktrace: {e}")
141
+ warnings.warn(
142
+ f"An issue occurred while importing 'torch-cluster'. "
143
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
141
144
  WITH_TORCH_CLUSTER = False
142
145
  WITH_TORCH_CLUSTER_BATCH_SIZE = False
143
146
 
@@ -154,7 +157,7 @@ except Exception as e:
154
157
  if not isinstance(e, ImportError): # pragma: no cover
155
158
  warnings.warn(
156
159
  f"An issue occurred while importing 'torch-spline-conv'. "
157
- f"Disabling its usage. Stacktrace: {e}")
160
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
158
161
  WITH_TORCH_SPLINE_CONV = False
159
162
 
160
163
  try:
@@ -163,8 +166,9 @@ try:
163
166
  WITH_TORCH_SPARSE = True
164
167
  except Exception as e:
165
168
  if not isinstance(e, ImportError): # pragma: no cover
166
- warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
167
- f"Disabling its usage. Stacktrace: {e}")
169
+ warnings.warn(
170
+ f"An issue occurred while importing 'torch-sparse'. "
171
+ f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
168
172
  WITH_TORCH_SPARSE = False
169
173
 
170
174
  class SparseStorage: # type: ignore
@@ -88,9 +88,10 @@ def scatter(
88
88
 
89
89
  if (src.is_cuda and src.requires_grad and not is_compiling()
90
90
  and not is_in_onnx_export()):
91
- warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
92
- f"can be accelerated via the 'torch-scatter'"
93
- f" package, but it was not found")
91
+ warnings.warn(
92
+ f"The usage of `scatter(reduce='{reduce}')` "
93
+ f"can be accelerated via the 'torch-scatter'"
94
+ f" package, but it was not found", stacklevel=2)
94
95
 
95
96
  index = broadcast(index, src, dim)
96
97
  if not is_in_onnx_export():
@@ -120,9 +121,10 @@ def scatter(
120
121
  or not src.is_cuda):
121
122
 
122
123
  if src.is_cuda and not is_compiling():
123
- warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
124
- f"can be accelerated via the 'torch-scatter'"
125
- f" package, but it was not found")
124
+ warnings.warn(
125
+ f"The usage of `scatter(reduce='{reduce}')` "
126
+ f"can be accelerated via the 'torch-scatter'"
127
+ f" package, but it was not found", stacklevel=2)
126
128
 
127
129
  index = broadcast(index, src, dim)
128
130
  # We initialize with `one` here to match `scatter_mul` output: