pyg-nightly 2.7.0.dev20250606__py3-none-any.whl → 2.7.0.dev20250608__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/METADATA +3 -2
  2. {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/RECORD +84 -84
  3. torch_geometric/__init__.py +5 -4
  4. torch_geometric/_compile.py +3 -2
  5. torch_geometric/contrib/__init__.py +1 -1
  6. torch_geometric/data/data.py +3 -3
  7. torch_geometric/data/database.py +4 -0
  8. torch_geometric/data/dataset.py +9 -6
  9. torch_geometric/data/hetero_data.py +7 -6
  10. torch_geometric/data/hypergraph_data.py +1 -1
  11. torch_geometric/data/in_memory_dataset.py +2 -2
  12. torch_geometric/data/large_graph_indexer.py +1 -1
  13. torch_geometric/data/lightning/datamodule.py +28 -20
  14. torch_geometric/data/storage.py +1 -1
  15. torch_geometric/datasets/dbp15k.py +1 -1
  16. torch_geometric/datasets/molecule_net.py +3 -2
  17. torch_geometric/datasets/tag_dataset.py +1 -1
  18. torch_geometric/datasets/wikics.py +2 -1
  19. torch_geometric/deprecation.py +1 -1
  20. torch_geometric/distributed/rpc.py +2 -2
  21. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  22. torch_geometric/explain/algorithm/graphmask_explainer.py +7 -7
  23. torch_geometric/explain/explainer.py +1 -1
  24. torch_geometric/graphgym/config.py +3 -2
  25. torch_geometric/graphgym/imports.py +4 -2
  26. torch_geometric/graphgym/logger.py +1 -1
  27. torch_geometric/graphgym/models/encoder.py +2 -2
  28. torch_geometric/graphgym/utils/comp_budget.py +2 -1
  29. torch_geometric/hash_tensor.py +5 -4
  30. torch_geometric/io/fs.py +5 -4
  31. torch_geometric/loader/ibmb_loader.py +4 -4
  32. torch_geometric/loader/mixin.py +2 -1
  33. torch_geometric/loader/prefetch.py +3 -2
  34. torch_geometric/nn/aggr/fused.py +1 -1
  35. torch_geometric/nn/conv/appnp.py +1 -1
  36. torch_geometric/nn/conv/eg_conv.py +7 -7
  37. torch_geometric/nn/conv/gen_conv.py +1 -1
  38. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  39. torch_geometric/nn/conv/hetero_conv.py +2 -1
  40. torch_geometric/nn/conv/meshcnn_conv.py +6 -4
  41. torch_geometric/nn/conv/message_passing.py +3 -2
  42. torch_geometric/nn/conv/sg_conv.py +1 -1
  43. torch_geometric/nn/conv/spline_conv.py +2 -1
  44. torch_geometric/nn/conv/ssg_conv.py +1 -1
  45. torch_geometric/nn/data_parallel.py +5 -4
  46. torch_geometric/nn/fx.py +7 -5
  47. torch_geometric/nn/models/attentive_fp.py +1 -1
  48. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  49. torch_geometric/nn/models/glem.py +20 -12
  50. torch_geometric/nn/models/gpse.py +30 -13
  51. torch_geometric/nn/models/graph_unet.py +1 -1
  52. torch_geometric/nn/models/metapath2vec.py +1 -1
  53. torch_geometric/nn/models/mlp.py +4 -2
  54. torch_geometric/nn/models/node2vec.py +1 -1
  55. torch_geometric/nn/models/rev_gnn.py +1 -1
  56. torch_geometric/nn/models/signed_gcn.py +1 -1
  57. torch_geometric/nn/nlp/llm.py +2 -1
  58. torch_geometric/nn/pool/__init__.py +8 -4
  59. torch_geometric/nn/pool/knn.py +13 -10
  60. torch_geometric/nn/to_hetero_module.py +4 -3
  61. torch_geometric/nn/to_hetero_transformer.py +3 -3
  62. torch_geometric/nn/to_hetero_with_bases_transformer.py +3 -3
  63. torch_geometric/sampler/base.py +7 -4
  64. torch_geometric/sampler/hgt_sampler.py +11 -1
  65. torch_geometric/sampler/neighbor_sampler.py +10 -8
  66. torch_geometric/testing/decorators.py +3 -2
  67. torch_geometric/testing/distributed.py +1 -1
  68. torch_geometric/transforms/add_gpse.py +11 -2
  69. torch_geometric/transforms/add_metapaths.py +8 -6
  70. torch_geometric/transforms/base_transform.py +2 -1
  71. torch_geometric/transforms/gdc.py +7 -8
  72. torch_geometric/transforms/largest_connected_components.py +1 -1
  73. torch_geometric/transforms/normalize_features.py +3 -3
  74. torch_geometric/transforms/random_link_split.py +1 -1
  75. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  76. torch_geometric/typing.py +13 -9
  77. torch_geometric/utils/_scatter.py +8 -6
  78. torch_geometric/utils/_spmm.py +15 -12
  79. torch_geometric/utils/convert.py +2 -2
  80. torch_geometric/utils/embedding.py +5 -3
  81. torch_geometric/utils/geodesic.py +4 -3
  82. torch_geometric/utils/sparse.py +3 -2
  83. {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/WHEEL +0 -0
  84. {pyg_nightly-2.7.0.dev20250606.dist-info → pyg_nightly-2.7.0.dev20250608.dist-info}/licenses/LICENSE +0 -0
@@ -90,7 +90,7 @@ class SGConv(MessagePassing):
90
90
  edge_index, edge_weight, x.size(self.node_dim), False,
91
91
  self.add_self_loops, self.flow, dtype=x.dtype)
92
92
 
93
- for k in range(self.K):
93
+ for _ in range(self.K):
94
94
  # propagate_type: (x: Tensor, edge_weight: OptTensor)
95
95
  x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
96
96
  if self.cached:
@@ -132,7 +132,8 @@ class SplineConv(MessagePassing):
132
132
  if not x[0].is_cuda:
133
133
  warnings.warn(
134
134
  'We do not recommend using the non-optimized CPU version of '
135
- '`SplineConv`. If possible, please move your data to GPU.')
135
+ '`SplineConv`. If possible, please move your data to GPU.',
136
+ stacklevel=2)
136
137
 
137
138
  # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
138
139
  out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
@@ -100,7 +100,7 @@ class SSGConv(MessagePassing):
100
100
  self.add_self_loops, self.flow, dtype=x.dtype)
101
101
 
102
102
  h = x * self.alpha
103
- for k in range(self.K):
103
+ for _ in range(self.K):
104
104
  # propagate_type: (x: Tensor, edge_weight: OptTensor)
105
105
  x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
106
106
  h = h + (1 - self.alpha) / self.K * x
@@ -57,10 +57,11 @@ class DataParallel(torch.nn.DataParallel):
57
57
  follow_batch=None, exclude_keys=None):
58
58
  super().__init__(module, device_ids, output_device)
59
59
 
60
- warnings.warn("'DataParallel' is usually much slower than "
61
- "'DistributedDataParallel' even on a single machine. "
62
- "Please consider switching to 'DistributedDataParallel' "
63
- "for multi-GPU training.")
60
+ warnings.warn(
61
+ "'DataParallel' is usually much slower than "
62
+ "'DistributedDataParallel' even on a single machine. "
63
+ "Please consider switching to 'DistributedDataParallel' "
64
+ "for multi-GPU training.", stacklevel=2)
64
65
 
65
66
  self.src_device = torch.device(f'cuda:{self.device_ids[0]}')
66
67
  self.follow_batch = follow_batch or []
torch_geometric/nn/fx.py CHANGED
@@ -130,11 +130,13 @@ class Transformer:
130
130
  # (node-level, edge-level) by filling `self._state`:
131
131
  for node in list(self.graph.nodes):
132
132
  if node.op == 'call_function' and 'training' in node.kwargs:
133
- warnings.warn(f"Found function '{node.name}' with keyword "
134
- f"argument 'training'. During FX tracing, this "
135
- f"will likely be baked in as a constant value. "
136
- f"Consider replacing this function by a module "
137
- f"to properly encapsulate its training flag.")
133
+ warnings.warn(
134
+ f"Found function '{node.name}' with keyword "
135
+ f"argument 'training'. During FX tracing, this "
136
+ f"will likely be baked in as a constant value. "
137
+ f"Consider replacing this function by a module "
138
+ f"to properly encapsulate its training flag.",
139
+ stacklevel=2)
138
140
 
139
141
  if node.op == 'placeholder':
140
142
  if node.name not in self._state:
@@ -160,7 +160,7 @@ class AttentiveFP(torch.nn.Module):
160
160
  edge_index = torch.stack([row, batch], dim=0)
161
161
 
162
162
  out = global_add_pool(x, batch).relu_()
163
- for t in range(self.num_timesteps):
163
+ for _ in range(self.num_timesteps):
164
164
  h = F.elu_(self.mol_conv((x, out), edge_index))
165
165
  h = F.dropout(h, p=self.dropout, training=self.training)
166
166
  out = self.mol_gru(h, out).relu_()
@@ -106,7 +106,7 @@ class DeepGraphInfomax(torch.nn.Module):
106
106
  """
107
107
  from sklearn.linear_model import LogisticRegression
108
108
 
109
- clf = LogisticRegression(solver=solver, *args,
109
+ clf = LogisticRegression(*args, solver=solver,
110
110
  **kwargs).fit(train_z.detach().cpu().numpy(),
111
111
  train_y.detach().cpu().numpy())
112
112
  return clf.score(test_z.detach().cpu().numpy(),
@@ -37,20 +37,28 @@ class GLEM(torch.nn.Module):
37
37
  See `examples/llm_plus_gnn/glem.py` for example usage.
38
38
  """
39
39
  def __init__(
40
- self,
41
- lm_to_use: str = 'prajjwal1/bert-tiny',
42
- gnn_to_use: basic_gnn = GraphSAGE,
43
- out_channels: int = 47,
44
- gnn_loss=nn.CrossEntropyLoss(reduction='mean'),
45
- lm_loss=nn.CrossEntropyLoss(reduction='mean'),
46
- alpha: float = 0.5,
47
- beta: float = 0.5,
48
- lm_dtype: torch.dtype = torch.bfloat16,
49
- lm_use_lora: bool = True,
50
- lora_target_modules: Optional[Union[List[str], str]] = None,
51
- device: Union[str, torch.device] = torch.device('cpu'),
40
+ self,
41
+ lm_to_use: str = 'prajjwal1/bert-tiny',
42
+ gnn_to_use: basic_gnn = GraphSAGE,
43
+ out_channels: int = 47,
44
+ gnn_loss: Optional[nn.Module] = None,
45
+ lm_loss: Optional[nn.Module] = None,
46
+ alpha: float = 0.5,
47
+ beta: float = 0.5,
48
+ lm_dtype: torch.dtype = torch.bfloat16,
49
+ lm_use_lora: bool = True,
50
+ lora_target_modules: Optional[Union[List[str], str]] = None,
51
+ device: Optional[Union[str, torch.device]] = None,
52
52
  ):
53
53
  super().__init__()
54
+
55
+ if gnn_loss is None:
56
+ gnn_loss = nn.CrossEntropyLoss(reduction='mean')
57
+ if lm_loss is None:
58
+ lm_loss = nn.CrossEntropyLoss(reduction='mean')
59
+ if device is None:
60
+ device = torch.device('cpu')
61
+
54
62
  self.device = device
55
63
  self.lm_loss = lm_loss
56
64
  self.gnn = gnn_to_use
@@ -716,11 +716,18 @@ class GPSENodeEncoder(torch.nn.Module):
716
716
 
717
717
 
718
718
  @torch.no_grad()
719
- def gpse_process(model: Module, data: Data, rand_type: str,
720
- use_vn: bool = True, bernoulli_thresh: float = 0.5,
721
- neighbor_loader: bool = False,
722
- num_neighbors: List[int] = [30, 20, 10], fillval: int = 5,
723
- layers_mp: int = None, **kwargs) -> torch.Tensor:
719
+ def gpse_process(
720
+ model: Module,
721
+ data: Data,
722
+ rand_type: str,
723
+ use_vn: bool = True,
724
+ bernoulli_thresh: float = 0.5,
725
+ neighbor_loader: bool = False,
726
+ num_neighbors: Optional[List[int]] = None,
727
+ fillval: int = 5,
728
+ layers_mp: int = None,
729
+ **kwargs,
730
+ ) -> torch.Tensor:
724
731
  r"""Processes the data using the :class:`GPSE` model to generate and append
725
732
  GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a
726
733
  single :class:`~torch_geometric.data.Dataset` object.
@@ -784,6 +791,8 @@ def gpse_process(model: Module, data: Data, rand_type: str,
784
791
  if layers_mp is None:
785
792
  raise ValueError('Please provide the number of message-passing '
786
793
  'layers as "layers_mp".')
794
+
795
+ num_neighbors = num_neighbors or [30, 20, 10]
787
796
  diff = layers_mp - len(num_neighbors)
788
797
  if fillval > 0 and diff > 0:
789
798
  num_neighbors += [fillval] * diff
@@ -792,7 +801,7 @@ def gpse_process(model: Module, data: Data, rand_type: str,
792
801
  shuffle=False, pin_memory=True, **kwargs)
793
802
  out_list = []
794
803
  pbar = trange(data.num_nodes, position=2)
795
- for i, batch in enumerate(loader):
804
+ for batch in loader:
796
805
  out, _ = model(batch.to(device))
797
806
  out = out[:batch.batch_size].to("cpu", non_blocking=True)
798
807
  out_list.append(out)
@@ -806,12 +815,18 @@ def gpse_process(model: Module, data: Data, rand_type: str,
806
815
 
807
816
 
808
817
  @torch.no_grad()
809
- def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
810
- bernoulli_thresh: float = 0.5,
811
- neighbor_loader: bool = False,
812
- num_neighbors: List[int] = [30, 20, 10],
813
- fillval: int = 5, layers_mp: int = None,
814
- **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
818
+ def gpse_process_batch(
819
+ model: GPSE,
820
+ batch,
821
+ rand_type: str,
822
+ use_vn: bool = True,
823
+ bernoulli_thresh: float = 0.5,
824
+ neighbor_loader: bool = False,
825
+ num_neighbors: Optional[List[int]] = None,
826
+ fillval: int = 5,
827
+ layers_mp: int = None,
828
+ **kwargs,
829
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
815
830
  r"""Process a batch of data using the :class:`GPSE` model to generate and
816
831
  append :class:`GPSE` encodings. Identical to `gpse_process`, but operates
817
832
  on a batch of :class:`~torch_geometric.data.Data` objects.
@@ -881,6 +896,8 @@ def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
881
896
  if layers_mp is None:
882
897
  raise ValueError('Please provide the number of message-passing '
883
898
  'layers as "layers_mp".')
899
+
900
+ num_neighbors = num_neighbors or [30, 20, 10]
884
901
  diff = layers_mp - len(num_neighbors)
885
902
  if fillval > 0 and diff > 0:
886
903
  num_neighbors += [fillval] * diff
@@ -889,7 +906,7 @@ def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
889
906
  shuffle=False, pin_memory=True, **kwargs)
890
907
  out_list = []
891
908
  pbar = trange(batch.num_nodes, position=2)
892
- for i, batch in enumerate(loader):
909
+ for batch in loader:
893
910
  out, _ = model(batch.to(device))
894
911
  out = out[:batch.batch_size].to('cpu', non_blocking=True)
895
912
  out_list.append(out)
@@ -64,7 +64,7 @@ class GraphUNet(torch.nn.Module):
64
64
  in_channels = channels if sum_res else 2 * channels
65
65
 
66
66
  self.up_convs = torch.nn.ModuleList()
67
- for i in range(depth - 1):
67
+ for _ in range(depth - 1):
68
68
  self.up_convs.append(GCNConv(in_channels, channels, improved=True))
69
69
  self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))
70
70
 
@@ -233,7 +233,7 @@ class MetaPath2Vec(torch.nn.Module):
233
233
  """
234
234
  from sklearn.linear_model import LogisticRegression
235
235
 
236
- clf = LogisticRegression(solver=solver, *args,
236
+ clf = LogisticRegression(*args, solver=solver,
237
237
  **kwargs).fit(train_z.detach().cpu().numpy(),
238
238
  train_y.detach().cpu().numpy())
239
239
  return clf.score(test_z.detach().cpu().numpy(),
@@ -99,8 +99,10 @@ class MLP(torch.nn.Module):
99
99
  act_first = act_first or kwargs.get("relu_first", False)
100
100
  batch_norm = kwargs.get("batch_norm", None)
101
101
  if batch_norm is not None and isinstance(batch_norm, bool):
102
- warnings.warn("Argument `batch_norm` is deprecated, "
103
- "please use `norm` to specify normalization layer.")
102
+ warnings.warn(
103
+ "Argument `batch_norm` is deprecated, "
104
+ "please use `norm` to specify normalization layer.",
105
+ stacklevel=2)
104
106
  norm = 'batch_norm' if batch_norm else None
105
107
  batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
106
108
  norm_kwargs = batch_norm_kwargs or {}
@@ -181,7 +181,7 @@ class Node2Vec(torch.nn.Module):
181
181
  """
182
182
  from sklearn.linear_model import LogisticRegression
183
183
 
184
- clf = LogisticRegression(solver=solver, *args,
184
+ clf = LogisticRegression(*args, solver=solver,
185
185
  **kwargs).fit(train_z.detach().cpu().numpy(),
186
186
  train_y.detach().cpu().numpy())
187
187
  return clf.score(test_z.detach().cpu().numpy(),
@@ -249,7 +249,7 @@ class GroupAddRev(InvertibleModule):
249
249
  else:
250
250
  assert num_groups is not None, "Please specific 'num_groups'"
251
251
  self.convs = torch.nn.ModuleList([conv])
252
- for i in range(num_groups - 1):
252
+ for _ in range(num_groups - 1):
253
253
  conv = copy.deepcopy(self.convs[0])
254
254
  if hasattr(conv, 'reset_parameters'):
255
255
  conv.reset_parameters()
@@ -45,7 +45,7 @@ class SignedGCN(torch.nn.Module):
45
45
  self.conv1 = SignedConv(in_channels, hidden_channels // 2,
46
46
  first_aggr=True)
47
47
  self.convs = torch.nn.ModuleList()
48
- for i in range(num_layers - 1):
48
+ for _ in range(num_layers - 1):
49
49
  self.convs.append(
50
50
  SignedConv(hidden_channels // 2, hidden_channels // 2,
51
51
  first_aggr=False))
@@ -94,7 +94,8 @@ class LLM(torch.nn.Module):
94
94
  self.word_embedding = self.llm.model.get_input_embeddings()
95
95
 
96
96
  if 'max_memory' not in kwargs: # Pure CPU:
97
- warnings.warn("LLM is being used on CPU, which may be slow")
97
+ warnings.warn("LLM is being used on CPU, which may be slow",
98
+ stacklevel=2)
98
99
  self.device = torch.device('cpu')
99
100
  self.autocast_context = nullcontext()
100
101
  else:
@@ -163,8 +163,10 @@ def knn_graph(
163
163
  :rtype: :class:`torch.Tensor`
164
164
  """
165
165
  if batch is not None and x.device != batch.device:
166
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
167
- "in 'knn_graph'. Performing blocking device transfer")
166
+ warnings.warn(
167
+ "Input tensor 'x' and 'batch' are on different devices "
168
+ "in 'knn_graph'. Performing blocking device transfer",
169
+ stacklevel=2)
168
170
  batch = batch.to(x.device)
169
171
 
170
172
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -285,8 +287,10 @@ def radius_graph(
285
287
  inputs to GPU before proceeding.
286
288
  """
287
289
  if batch is not None and x.device != batch.device:
288
- warnings.warn("Input tensor 'x' and 'batch' are on different devices "
289
- "in 'radius_graph'. Performing blocking device transfer")
290
+ warnings.warn(
291
+ "Input tensor 'x' and 'batch' are on different devices "
292
+ "in 'radius_graph'. Performing blocking device transfer",
293
+ stacklevel=2)
290
294
  batch = batch.to(x.device)
291
295
 
292
296
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
@@ -91,9 +91,10 @@ class KNNIndex:
91
91
  if hasattr(self.index, 'reserveMemory'):
92
92
  self.index.reserveMemory(self.reserve)
93
93
  else:
94
- warnings.warn(f"'{self.index.__class__.__name__}' "
95
- f"does not support pre-allocation of "
96
- f"memory")
94
+ warnings.warn(
95
+ f"'{self.index.__class__.__name__}' "
96
+ f"does not support pre-allocation of "
97
+ f"memory", stacklevel=2)
97
98
 
98
99
  self.index.train(emb)
99
100
 
@@ -135,14 +136,16 @@ class KNNIndex:
135
136
  query_k = min(query_k, self.numel)
136
137
 
137
138
  if k > 2048: # `faiss` supports up-to `k=2048`:
138
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
139
- f"(got {k}). This may cause some relevant items to "
140
- f"not be retrieved.")
139
+ warnings.warn(
140
+ f"Capping 'k' to faiss' upper limit of 2048 "
141
+ f"(got {k}). This may cause some relevant items to "
142
+ f"not be retrieved.", stacklevel=2)
141
143
  elif query_k > 2048:
142
- warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
143
- f"(got {k} which got extended to {query_k} due to "
144
- f"the exclusion of existing links). This may cause "
145
- f"some relevant items to not be retrieved.")
144
+ warnings.warn(
145
+ f"Capping 'k' to faiss' upper limit of 2048 "
146
+ f"(got {k} which got extended to {query_k} due to "
147
+ f"the exclusion of existing links). This may cause "
148
+ f"some relevant items to not be retrieved.", stacklevel=2)
146
149
  query_k = 2048
147
150
 
148
151
  score, index = self.index.search(emb.detach(), query_k)
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
108
108
 
109
109
  if (not hasattr(module, 'reset_parameters')
110
110
  and sum([p.numel() for p in module.parameters()]) > 0):
111
- warnings.warn(f"'{module}' will be duplicated, but its parameters "
112
- f"cannot be reset. To suppress this warning, add a "
113
- f"'reset_parameters()' method to '{module}'")
111
+ warnings.warn(
112
+ f"'{module}' will be duplicated, but its parameters "
113
+ f"cannot be reset. To suppress this warning, add a "
114
+ f"'reset_parameters()' method to '{module}'", stacklevel=2)
114
115
 
115
116
  convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
116
117
  self.hetero_module = HeteroConv(convs, aggr)
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
157
157
  f"There exist node types ({unused_node_types}) whose "
158
158
  f"representations do not get updated during message passing "
159
159
  f"as they do not occur as destination type in any edge type. "
160
- f"This may lead to unexpected behavior.")
160
+ f"This may lead to unexpected behavior.", stacklevel=2)
161
161
 
162
162
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
163
163
  for name in names:
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
166
166
  f"The type '{name}' contains invalid characters which "
167
167
  f"may lead to unexpected behavior. To avoid any issues, "
168
168
  f"ensure that your types only contain letters, numbers "
169
- f"and underscores.")
169
+ f"and underscores.", stacklevel=2)
170
170
 
171
171
  def placeholder(self, node: Node, target: Any, name: str):
172
172
  # Adds a `get` call to the input dictionary for every node-type or
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
379
379
  warnings.warn(
380
380
  f"'{target}' will be duplicated, but its parameters "
381
381
  f"cannot be reset. To suppress this warning, add a "
382
- f"'reset_parameters()' method to '{target}'")
382
+ f"'reset_parameters()' method to '{target}'", stacklevel=2)
383
383
 
384
384
  return module_dict
385
385
 
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
165
165
  f"There exist node types ({unused_node_types}) whose "
166
166
  f"representations do not get updated during message passing "
167
167
  f"as they do not occur as destination type in any edge type. "
168
- f"This may lead to unexpected behavior.")
168
+ f"This may lead to unexpected behavior.", stacklevel=2)
169
169
 
170
170
  names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
171
171
  for name in names:
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
174
174
  f"The type '{name}' contains invalid characters which "
175
175
  f"may lead to unexpected behavior. To avoid any issues, "
176
176
  f"ensure that your types only contain letters, numbers "
177
- f"and underscores.")
177
+ f"and underscores.", stacklevel=2)
178
178
 
179
179
  def transform(self) -> GraphModule:
180
180
  self._node_offset_dict_initialized = False
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
361
361
  warnings.warn(
362
362
  f"'{conv}' will be duplicated, but its parameters cannot "
363
363
  f"be reset. To suppress this warning, add a "
364
- f"'reset_parameters()' method to '{conv}'")
364
+ f"'reset_parameters()' method to '{conv}'", stacklevel=2)
365
365
  torch.nn.init.xavier_uniform_(conv.edge_type_weight)
366
366
 
367
367
  def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
@@ -1,7 +1,7 @@
1
1
  import copy
2
2
  import math
3
3
  import warnings
4
- from abc import ABC
4
+ from abc import ABC, abstractmethod
5
5
  from collections import defaultdict
6
6
  from dataclasses import dataclass
7
7
  from enum import Enum
@@ -369,9 +369,10 @@ class HeteroSamplerOutput(CastMixin):
369
369
  out.edge[edge_type] = None
370
370
 
371
371
  else:
372
- warnings.warn(f"Cannot convert to bidirectional graph "
373
- f"since the edge type {edge_type} does not "
374
- f"seem to have a reverse edge type")
372
+ warnings.warn(
373
+ f"Cannot convert to bidirectional graph "
374
+ f"since the edge type {edge_type} does not "
375
+ f"seem to have a reverse edge type", stacklevel=2)
375
376
 
376
377
  return out
377
378
 
@@ -622,6 +623,7 @@ class BaseSampler(ABC):
622
623
  As such, it is recommended to limit the amount of information stored in
623
624
  the sampler.
624
625
  """
626
+ @abstractmethod
625
627
  def sample_from_nodes(
626
628
  self,
627
629
  index: NodeSamplerInput,
@@ -642,6 +644,7 @@ class BaseSampler(ABC):
642
644
  """
643
645
  raise NotImplementedError
644
646
 
647
+ @abstractmethod
645
648
  def sample_from_edges(
646
649
  self,
647
650
  index: EdgeSamplerInput,
@@ -1,12 +1,15 @@
1
- from typing import Dict, List, Union
1
+ from typing import Dict, List, Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from torch_geometric.data import Data, HeteroData
6
6
  from torch_geometric.sampler import (
7
7
  BaseSampler,
8
+ EdgeSamplerInput,
8
9
  HeteroSamplerOutput,
10
+ NegativeSampling,
9
11
  NodeSamplerInput,
12
+ SamplerOutput,
10
13
  )
11
14
  from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
12
15
  from torch_geometric.typing import (
@@ -76,6 +79,13 @@ class HGTSampler(BaseSampler):
76
79
  metadata=(inputs.input_id, inputs.time),
77
80
  )
78
81
 
82
+ def sample_from_edges(
83
+ self,
84
+ index: EdgeSamplerInput,
85
+ neg_sampling: Optional[NegativeSampling] = None,
86
+ ) -> Union[HeteroSamplerOutput, SamplerOutput]:
87
+ pass
88
+
79
89
  @property
80
90
  def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
81
91
  return self.perm
@@ -52,16 +52,18 @@ class NeighborSampler(BaseSampler):
52
52
  ):
53
53
  if not directed:
54
54
  subgraph_type = SubgraphType.induced
55
- warnings.warn(f"The usage of the 'directed' argument in "
56
- f"'{self.__class__.__name__}' is deprecated. Use "
57
- f"`subgraph_type='induced'` instead.")
55
+ warnings.warn(
56
+ f"The usage of the 'directed' argument in "
57
+ f"'{self.__class__.__name__}' is deprecated. Use "
58
+ f"`subgraph_type='induced'` instead.", stacklevel=2)
58
59
 
59
60
  if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'
60
61
  and subgraph_type != SubgraphType.induced):
61
- warnings.warn(f"Using '{self.__class__.__name__}' without a "
62
- f"'pyg-lib' installation is deprecated and will be "
63
- f"removed soon. Please install 'pyg-lib' for "
64
- f"accelerated neighborhood sampling")
62
+ warnings.warn(
63
+ f"Using '{self.__class__.__name__}' without a "
64
+ f"'pyg-lib' installation is deprecated and will be "
65
+ f"removed soon. Please install 'pyg-lib' for "
66
+ f"accelerated neighborhood sampling", stacklevel=2)
65
67
 
66
68
  self.data_type = DataType.from_data(data)
67
69
 
@@ -806,7 +808,7 @@ def neg_sample(
806
808
  out = out.view(num_samples, seed.numel())
807
809
  mask = node_time[out] > seed_time # holds all invalid samples.
808
810
  neg_sampling_complete = False
809
- for i in range(5): # pragma: no cover
811
+ for _ in range(5): # pragma: no cover
810
812
  num_invalid = int(mask.sum())
811
813
  if num_invalid == 0:
812
814
  neg_sampling_complete = True
@@ -252,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
252
252
  if device:
253
253
  backend = os.getenv('TORCH_BACKEND')
254
254
  if backend is None:
255
- warnings.warn(f"Please specify the backend via 'TORCH_BACKEND' in"
256
- f"order to test against '{device}'")
255
+ warnings.warn(
256
+ f"Please specify the backend via 'TORCH_BACKEND' in"
257
+ f"order to test against '{device}'", stacklevel=2)
257
258
  else:
258
259
  import_module(backend)
259
260
  devices.append(pytest.param(torch.device(device), id=device))
@@ -73,7 +73,7 @@ def assert_run_mproc(
73
73
  ]
74
74
  results = []
75
75
 
76
- for p, q in zip(procs, queues):
76
+ for p, _ in zip(procs, queues):
77
77
  p.start()
78
78
 
79
79
  for p, q in zip(procs, queues):
@@ -1,3 +1,5 @@
1
+ from typing import Any
2
+
1
3
  from torch.nn import Module
2
4
 
3
5
  from torch_geometric.data import Data
@@ -22,13 +24,20 @@ class AddGPSE(BaseTransform):
22
24
  (default: :obj:`NormalSE`)
23
25
 
24
26
  """
25
- def __init__(self, model: Module, use_vn: bool = True,
26
- rand_type: str = 'NormalSE'):
27
+ def __init__(
28
+ self,
29
+ model: Module,
30
+ use_vn: bool = True,
31
+ rand_type: str = 'NormalSE',
32
+ ):
27
33
  self.model = model
28
34
  self.use_vn = use_vn
29
35
  self.vn = VirtualNode()
30
36
  self.rand_type = rand_type
31
37
 
38
+ def forward(self, data: Data) -> Any:
39
+ pass
40
+
32
41
  def __call__(self, data: Data) -> Data:
33
42
  from torch_geometric.nn.models.gpse import gpse_process
34
43
 
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
108
108
  **kwargs: bool,
109
109
  ) -> None:
110
110
  if 'drop_orig_edges' in kwargs:
111
- warnings.warn("'drop_orig_edges' is deprecated. Use "
112
- "'drop_orig_edge_types' instead")
111
+ warnings.warn(
112
+ "'drop_orig_edges' is deprecated. Use "
113
+ "'drop_orig_edge_types' instead", stacklevel=2)
113
114
  drop_orig_edge_types = kwargs['drop_orig_edges']
114
115
 
115
116
  if 'drop_unconnected_nodes' in kwargs:
116
- warnings.warn("'drop_unconnected_nodes' is deprecated. Use "
117
- "'drop_unconnected_node_types' instead")
117
+ warnings.warn(
118
+ "'drop_unconnected_nodes' is deprecated. Use "
119
+ "'drop_unconnected_node_types' instead", stacklevel=2)
118
120
  drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
119
121
 
120
122
  for path in metapaths:
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
144
146
  if self.max_sample is not None:
145
147
  edge_index, edge_weight = self._sample(edge_index, edge_weight)
146
148
 
147
- for i, edge_type in enumerate(metapath[1:]):
149
+ for edge_type in metapath[1:]:
148
150
  edge_index2, edge_weight2 = self._edge_index(data, edge_type)
149
151
 
150
152
  edge_index, edge_weight = edge_index.matmul(
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
276
278
  row = start = torch.randperm(num_nodes)[:num_starts].repeat(
277
279
  self.walks_per_node[j])
278
280
 
279
- for i, edge_type in enumerate(metapath):
281
+ for edge_type in metapath:
280
282
  edge_index = EdgeIndex(
281
283
  data[edge_type].edge_index,
282
284
  sparse_size=data[edge_type].size(),