pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (205) hide show
  1. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
  2. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
  3. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +26 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +16 -14
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/data.py +13 -8
  12. torch_geometric/data/database.py +15 -7
  13. torch_geometric/data/dataset.py +14 -6
  14. torch_geometric/data/feature_store.py +13 -22
  15. torch_geometric/data/graph_store.py +0 -4
  16. torch_geometric/data/hetero_data.py +4 -4
  17. torch_geometric/data/in_memory_dataset.py +2 -4
  18. torch_geometric/data/large_graph_indexer.py +677 -0
  19. torch_geometric/data/lightning/datamodule.py +4 -4
  20. torch_geometric/data/storage.py +15 -5
  21. torch_geometric/data/summary.py +14 -4
  22. torch_geometric/data/temporal.py +1 -2
  23. torch_geometric/datasets/__init__.py +11 -1
  24. torch_geometric/datasets/actor.py +9 -11
  25. torch_geometric/datasets/airfrans.py +15 -18
  26. torch_geometric/datasets/airports.py +10 -12
  27. torch_geometric/datasets/amazon.py +8 -11
  28. torch_geometric/datasets/amazon_book.py +9 -10
  29. torch_geometric/datasets/amazon_products.py +9 -10
  30. torch_geometric/datasets/aminer.py +8 -9
  31. torch_geometric/datasets/aqsol.py +10 -13
  32. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  33. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  34. torch_geometric/datasets/ba_shapes.py +5 -6
  35. torch_geometric/datasets/bitcoin_otc.py +1 -1
  36. torch_geometric/datasets/brca_tgca.py +1 -1
  37. torch_geometric/datasets/dblp.py +2 -1
  38. torch_geometric/datasets/dbp15k.py +2 -2
  39. torch_geometric/datasets/fake.py +1 -3
  40. torch_geometric/datasets/flickr.py +2 -1
  41. torch_geometric/datasets/freebase.py +1 -1
  42. torch_geometric/datasets/gdelt_lite.py +3 -2
  43. torch_geometric/datasets/ged_dataset.py +3 -2
  44. torch_geometric/datasets/git_mol_dataset.py +263 -0
  45. torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
  46. torch_geometric/datasets/hgb_dataset.py +8 -8
  47. torch_geometric/datasets/imdb.py +2 -1
  48. torch_geometric/datasets/last_fm.py +2 -1
  49. torch_geometric/datasets/linkx_dataset.py +4 -3
  50. torch_geometric/datasets/lrgb.py +3 -5
  51. torch_geometric/datasets/malnet_tiny.py +4 -3
  52. torch_geometric/datasets/mnist_superpixels.py +2 -3
  53. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  54. torch_geometric/datasets/molecule_net.py +7 -1
  55. torch_geometric/datasets/motif_generator/base.py +0 -1
  56. torch_geometric/datasets/neurograph.py +1 -3
  57. torch_geometric/datasets/ogb_mag.py +1 -1
  58. torch_geometric/datasets/opf.py +239 -0
  59. torch_geometric/datasets/ose_gvcs.py +1 -1
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  62. torch_geometric/datasets/pcqm4m.py +2 -1
  63. torch_geometric/datasets/ppi.py +1 -1
  64. torch_geometric/datasets/qm9.py +4 -3
  65. torch_geometric/datasets/reddit.py +2 -1
  66. torch_geometric/datasets/reddit2.py +2 -1
  67. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  68. torch_geometric/datasets/s3dis.py +2 -2
  69. torch_geometric/datasets/shapenet.py +3 -3
  70. torch_geometric/datasets/shrec2016.py +2 -2
  71. torch_geometric/datasets/tag_dataset.py +350 -0
  72. torch_geometric/datasets/upfd.py +2 -1
  73. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  74. torch_geometric/datasets/webkb.py +2 -2
  75. torch_geometric/datasets/wikics.py +1 -1
  76. torch_geometric/datasets/wikidata.py +3 -2
  77. torch_geometric/datasets/wikipedia_network.py +2 -2
  78. torch_geometric/datasets/word_net.py +2 -2
  79. torch_geometric/datasets/yelp.py +2 -1
  80. torch_geometric/datasets/zinc.py +1 -1
  81. torch_geometric/device.py +42 -0
  82. torch_geometric/distributed/local_feature_store.py +3 -2
  83. torch_geometric/distributed/local_graph_store.py +2 -1
  84. torch_geometric/distributed/partition.py +9 -8
  85. torch_geometric/edge_index.py +17 -8
  86. torch_geometric/explain/algorithm/base.py +0 -1
  87. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  88. torch_geometric/explain/explanation.py +2 -2
  89. torch_geometric/graphgym/checkpoint.py +2 -1
  90. torch_geometric/graphgym/logger.py +4 -4
  91. torch_geometric/graphgym/loss.py +1 -1
  92. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  93. torch_geometric/index.py +20 -7
  94. torch_geometric/inspector.py +6 -2
  95. torch_geometric/io/fs.py +28 -2
  96. torch_geometric/io/npz.py +2 -1
  97. torch_geometric/io/off.py +2 -2
  98. torch_geometric/io/sdf.py +2 -2
  99. torch_geometric/io/tu.py +2 -3
  100. torch_geometric/loader/__init__.py +4 -0
  101. torch_geometric/loader/cluster.py +9 -3
  102. torch_geometric/loader/graph_saint.py +2 -1
  103. torch_geometric/loader/ibmb_loader.py +12 -4
  104. torch_geometric/loader/mixin.py +1 -1
  105. torch_geometric/loader/neighbor_loader.py +1 -1
  106. torch_geometric/loader/neighbor_sampler.py +2 -2
  107. torch_geometric/loader/prefetch.py +1 -1
  108. torch_geometric/loader/rag_loader.py +107 -0
  109. torch_geometric/loader/zip_loader.py +10 -0
  110. torch_geometric/metrics/__init__.py +11 -2
  111. torch_geometric/metrics/link_pred.py +159 -34
  112. torch_geometric/nn/aggr/__init__.py +2 -0
  113. torch_geometric/nn/aggr/attention.py +0 -2
  114. torch_geometric/nn/aggr/base.py +2 -4
  115. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  116. torch_geometric/nn/aggr/set_transformer.py +1 -1
  117. torch_geometric/nn/attention/__init__.py +5 -1
  118. torch_geometric/nn/attention/qformer.py +71 -0
  119. torch_geometric/nn/conv/collect.jinja +6 -3
  120. torch_geometric/nn/conv/cugraph/base.py +0 -1
  121. torch_geometric/nn/conv/edge_conv.py +3 -2
  122. torch_geometric/nn/conv/gat_conv.py +35 -7
  123. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  124. torch_geometric/nn/conv/general_conv.py +1 -1
  125. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  126. torch_geometric/nn/conv/hetero_conv.py +3 -3
  127. torch_geometric/nn/conv/hgt_conv.py +1 -1
  128. torch_geometric/nn/conv/message_passing.py +100 -82
  129. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  130. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  131. torch_geometric/nn/conv/spline_conv.py +4 -4
  132. torch_geometric/nn/conv/x_conv.py +3 -2
  133. torch_geometric/nn/dense/linear.py +5 -4
  134. torch_geometric/nn/fx.py +3 -3
  135. torch_geometric/nn/model_hub.py +3 -1
  136. torch_geometric/nn/models/__init__.py +10 -2
  137. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  138. torch_geometric/nn/models/dimenet_utils.py +5 -7
  139. torch_geometric/nn/models/g_retriever.py +230 -0
  140. torch_geometric/nn/models/git_mol.py +336 -0
  141. torch_geometric/nn/models/glem.py +385 -0
  142. torch_geometric/nn/models/gnnff.py +0 -1
  143. torch_geometric/nn/models/graph_unet.py +12 -3
  144. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  145. torch_geometric/nn/models/lightgcn.py +1 -1
  146. torch_geometric/nn/models/metapath2vec.py +3 -4
  147. torch_geometric/nn/models/molecule_gpt.py +222 -0
  148. torch_geometric/nn/models/node2vec.py +1 -2
  149. torch_geometric/nn/models/schnet.py +2 -1
  150. torch_geometric/nn/models/signed_gcn.py +3 -3
  151. torch_geometric/nn/module_dict.py +2 -2
  152. torch_geometric/nn/nlp/__init__.py +9 -0
  153. torch_geometric/nn/nlp/llm.py +322 -0
  154. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  155. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  156. torch_geometric/nn/norm/batch_norm.py +1 -1
  157. torch_geometric/nn/parameter_dict.py +2 -2
  158. torch_geometric/nn/pool/__init__.py +7 -5
  159. torch_geometric/nn/pool/cluster_pool.py +145 -0
  160. torch_geometric/nn/pool/connect/base.py +0 -1
  161. torch_geometric/nn/pool/edge_pool.py +1 -1
  162. torch_geometric/nn/pool/graclus.py +4 -2
  163. torch_geometric/nn/pool/select/base.py +0 -1
  164. torch_geometric/nn/pool/voxel_grid.py +3 -2
  165. torch_geometric/nn/resolver.py +1 -1
  166. torch_geometric/nn/sequential.jinja +10 -23
  167. torch_geometric/nn/sequential.py +203 -77
  168. torch_geometric/nn/summary.py +1 -1
  169. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  170. torch_geometric/profile/__init__.py +2 -0
  171. torch_geometric/profile/nvtx.py +66 -0
  172. torch_geometric/profile/profiler.py +24 -15
  173. torch_geometric/resolver.py +1 -1
  174. torch_geometric/sampler/base.py +34 -13
  175. torch_geometric/sampler/neighbor_sampler.py +11 -10
  176. torch_geometric/testing/decorators.py +17 -22
  177. torch_geometric/transforms/__init__.py +2 -0
  178. torch_geometric/transforms/add_metapaths.py +4 -4
  179. torch_geometric/transforms/add_positional_encoding.py +1 -1
  180. torch_geometric/transforms/delaunay.py +65 -14
  181. torch_geometric/transforms/face_to_edge.py +32 -3
  182. torch_geometric/transforms/gdc.py +7 -6
  183. torch_geometric/transforms/laplacian_lambda_max.py +2 -2
  184. torch_geometric/transforms/mask.py +5 -1
  185. torch_geometric/transforms/node_property_split.py +1 -2
  186. torch_geometric/transforms/pad.py +7 -6
  187. torch_geometric/transforms/random_link_split.py +1 -1
  188. torch_geometric/transforms/remove_self_loops.py +36 -0
  189. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  190. torch_geometric/transforms/virtual_node.py +2 -1
  191. torch_geometric/typing.py +31 -5
  192. torch_geometric/utils/__init__.py +5 -1
  193. torch_geometric/utils/_negative_sampling.py +1 -1
  194. torch_geometric/utils/_normalize_edge_index.py +46 -0
  195. torch_geometric/utils/_scatter.py +37 -12
  196. torch_geometric/utils/_subgraph.py +4 -0
  197. torch_geometric/utils/_tree_decomposition.py +2 -2
  198. torch_geometric/utils/augmentation.py +1 -1
  199. torch_geometric/utils/convert.py +5 -5
  200. torch_geometric/utils/geodesic.py +24 -22
  201. torch_geometric/utils/hetero.py +1 -1
  202. torch_geometric/utils/map.py +1 -1
  203. torch_geometric/utils/smiles.py +66 -28
  204. torch_geometric/utils/sparse.py +25 -10
  205. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,143 @@
1
+ import math
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.experimental import disable_dynamic_shapes
8
+ from torch_geometric.nn.aggr import Aggregation
9
+ from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
10
+ from torch_geometric.nn.encoding import PositionalEncoding
11
+ from torch_geometric.utils import scatter
12
+
13
+
14
+ class PatchTransformerAggregation(Aggregation):
15
+ r"""Performs patch transformer aggregation in which the elements to
16
+ aggregate are processed by multi-head attention blocks across patches, as
17
+ described in the `"Simplifying Temporal Heterogeneous Network for
18
+ Continuous-Time Link Prediction"
19
+ <https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.
20
+
21
+ Args:
22
+ in_channels (int): Size of each input sample.
23
+ out_channels (int): Size of each output sample.
24
+ patch_size (int): Number of elements in a patch.
25
+ hidden_channels (int): Intermediate size of each sample.
26
+ num_transformer_blocks (int, optional): Number of transformer blocks
27
+ (default: :obj:`1`).
28
+ heads (int, optional): Number of multi-head-attentions.
29
+ (default: :obj:`1`)
30
+ dropout (float, optional): Dropout probability of attention weights.
31
+ (default: :obj:`0.0`)
32
+ aggr (str or list[str], optional): The aggregation module, *e.g.*,
33
+ :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
34
+ :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
35
+ """
36
+ def __init__(
37
+ self,
38
+ in_channels: int,
39
+ out_channels: int,
40
+ patch_size: int,
41
+ hidden_channels: int,
42
+ num_transformer_blocks: int = 1,
43
+ heads: int = 1,
44
+ dropout: float = 0.0,
45
+ aggr: Union[str, List[str]] = 'mean',
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ self.in_channels = in_channels
50
+ self.out_channels = out_channels
51
+ self.patch_size = patch_size
52
+ self.aggrs = [aggr] if isinstance(aggr, str) else aggr
53
+
54
+ assert len(self.aggrs) > 0
55
+ for aggr in self.aggrs:
56
+ assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
57
+
58
+ self.lin = torch.nn.Linear(in_channels, hidden_channels)
59
+ self.pad_projector = torch.nn.Linear(
60
+ patch_size * hidden_channels,
61
+ hidden_channels,
62
+ )
63
+ self.pe = PositionalEncoding(hidden_channels)
64
+
65
+ self.blocks = torch.nn.ModuleList([
66
+ MultiheadAttentionBlock(
67
+ channels=hidden_channels,
68
+ heads=heads,
69
+ layer_norm=True,
70
+ dropout=dropout,
71
+ ) for _ in range(num_transformer_blocks)
72
+ ])
73
+
74
+ self.fc = torch.nn.Linear(
75
+ hidden_channels * len(self.aggrs),
76
+ out_channels,
77
+ )
78
+
79
+ def reset_parameters(self) -> None:
80
+ self.lin.reset_parameters()
81
+ self.pad_projector.reset_parameters()
82
+ self.pe.reset_parameters()
83
+ for block in self.blocks:
84
+ block.reset_parameters()
85
+ self.fc.reset_parameters()
86
+
87
+ @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
88
+ def forward(
89
+ self,
90
+ x: Tensor,
91
+ index: Tensor,
92
+ ptr: Optional[Tensor] = None,
93
+ dim_size: Optional[int] = None,
94
+ dim: int = -2,
95
+ max_num_elements: Optional[int] = None,
96
+ ) -> Tensor:
97
+
98
+ if max_num_elements is None:
99
+ if ptr is not None:
100
+ count = ptr.diff()
101
+ else:
102
+ count = scatter(torch.ones_like(index), index, dim=0,
103
+ dim_size=dim_size, reduce='sum')
104
+ max_num_elements = int(count.max()) + 1
105
+
106
+ # Set `max_num_elements` to a multiple of `patch_size`:
107
+ max_num_elements = (math.floor(max_num_elements / self.patch_size) *
108
+ self.patch_size)
109
+
110
+ x = self.lin(x)
111
+
112
+ # TODO If groups are heavily unbalanced, this will create a lot of
113
+ # "empty" patches. Try to figure out a way to fix this.
114
+ # [batch_size, num_patches * patch_size, hidden_channels]
115
+ x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
116
+ max_num_elements=max_num_elements)
117
+
118
+ # [batch_size, num_patches, patch_size * hidden_channels]
119
+ x = x.view(x.size(0), max_num_elements // self.patch_size,
120
+ self.patch_size * x.size(-1))
121
+
122
+ # [batch_size, num_patches, hidden_channels]
123
+ x = self.pad_projector(x)
124
+
125
+ x = x + self.pe(torch.arange(x.size(1), device=x.device))
126
+
127
+ # [batch_size, num_patches, hidden_channels]
128
+ for block in self.blocks:
129
+ x = block(x, x)
130
+
131
+ # [batch_size, hidden_channels]
132
+ outs: List[Tensor] = []
133
+ for aggr in self.aggrs:
134
+ out = getattr(torch, aggr)(x, dim=1)
135
+ outs.append(out[0] if isinstance(out, tuple) else out)
136
+ out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]
137
+
138
+ # [batch_size, out_channels]
139
+ return self.fc(out)
140
+
141
+ def __repr__(self) -> str:
142
+ return (f'{self.__class__.__name__}({self.in_channels}, '
143
+ f'{self.out_channels}, patch_size={self.patch_size})')
@@ -38,7 +38,7 @@ class SetTransformerAggregation(Aggregation):
38
38
  (default: :obj:`1`)
39
39
  concat (bool, optional): If set to :obj:`False`, the seed embeddings
40
40
  are averaged instead of concatenated. (default: :obj:`True`)
41
- norm (str, optional): If set to :obj:`True`, will apply layer
41
+ layer_norm (str, optional): If set to :obj:`True`, will apply layer
42
42
  normalization. (default: :obj:`False`)
43
43
  dropout (float, optional): Dropout probability of attention weights.
44
44
  (default: :obj:`0`)
@@ -1,3 +1,7 @@
1
1
  from .performer import PerformerAttention
2
+ from .qformer import QFormer
2
3
 
3
- __all__ = ['PerformerAttention']
4
+ __all__ = [
5
+ 'PerformerAttention',
6
+ 'QFormer',
7
+ ]
@@ -0,0 +1,71 @@
1
+ from typing import Callable
2
+
3
+ import torch
4
+
5
+
6
+ class QFormer(torch.nn.Module):
7
+ r"""The Querying Transformer (Q-Former) from
8
+ `"BLIP-2: Bootstrapping Language-Image Pre-training
9
+ with Frozen Image Encoders and Large Language Models"
10
+ <https://arxiv.org/pdf/2301.12597>`_ paper.
11
+
12
+ Args:
13
+ input_dim (int): The number of features in the input.
14
+ hidden_dim (int): The dimension of the fnn in the encoder layer.
15
+ output_dim (int): The final output dimension.
16
+ num_heads (int): The number of multi-attention-heads.
17
+ num_layers (int): The number of sub-encoder-layers in the encoder.
18
+ dropout (int): The dropout value in each encoder layer.
19
+
20
+
21
+ .. note::
22
+ This is a simplified version of the original Q-Former implementation.
23
+ """
24
+ def __init__(
25
+ self,
26
+ input_dim: int,
27
+ hidden_dim: int,
28
+ output_dim: int,
29
+ num_heads: int,
30
+ num_layers: int,
31
+ dropout: float = 0.0,
32
+ activation: Callable = torch.nn.ReLU(),
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.num_heads = num_heads
38
+
39
+ self.layer_norm = torch.nn.LayerNorm(input_dim)
40
+ self.encoder_layer = torch.nn.TransformerEncoderLayer(
41
+ d_model=input_dim,
42
+ nhead=num_heads,
43
+ dim_feedforward=hidden_dim,
44
+ dropout=dropout,
45
+ activation=activation,
46
+ batch_first=True,
47
+ )
48
+ self.encoder = torch.nn.TransformerEncoder(
49
+ self.encoder_layer,
50
+ num_layers=num_layers,
51
+ )
52
+ self.project = torch.nn.Linear(input_dim, output_dim)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ r"""Forward pass.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input sequence to the encoder layer.
59
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
60
+ batch-size :math:`B`, sequence length :math:`N`,
61
+ and feature dimension :math:`F`.
62
+ """
63
+ x = self.layer_norm(x)
64
+ x = self.encoder(x)
65
+ out = self.project(x)
66
+ return out
67
+
68
+ def __repr__(self) -> str:
69
+ return (f'{self.__class__.__name__}('
70
+ f'num_heads={self.num_heads}, '
71
+ f'num_layers={self.num_layers})')
@@ -98,13 +98,16 @@ def {{collect_name}}(
98
98
 
99
99
  {%- if 'edge_weight' in collect_param_dict and
100
100
  collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
101
- assert edge_weight is not None
101
+ if torch.jit.is_scripting():
102
+ assert edge_weight is not None
102
103
  {%- elif 'edge_attr' in collect_param_dict and
103
104
  collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
104
- assert edge_attr is not None
105
+ if torch.jit.is_scripting():
106
+ assert edge_attr is not None
105
107
  {%- elif 'edge_type' in collect_param_dict and
106
108
  collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
107
- assert edge_type is not None
109
+ if torch.jit.is_scripting():
110
+ assert edge_type is not None
108
111
  {%- endif %}
109
112
 
110
113
  # Collect user-defined arguments:
@@ -36,7 +36,6 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
36
36
 
37
37
  def reset_parameters(self):
38
38
  r"""Resets all learnable parameters of the module."""
39
- pass
40
39
 
41
40
  def get_cugraph(
42
41
  self,
@@ -3,13 +3,14 @@ from typing import Callable, Optional, Union
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
+ import torch_geometric.typing
6
7
  from torch_geometric.nn.conv import MessagePassing
7
8
  from torch_geometric.nn.inits import reset
8
9
  from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor
9
10
 
10
- try:
11
+ if torch_geometric.typing.WITH_TORCH_CLUSTER:
11
12
  from torch_cluster import knn
12
- except ImportError:
13
+ else:
13
14
  knn = None
14
15
 
15
16
 
@@ -37,9 +37,8 @@ class GATConv(MessagePassing):
37
37
  <https://arxiv.org/abs/1710.10903>`_ paper.
38
38
 
39
39
  .. math::
40
- \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}_{s}\mathbf{x}_{i} +
41
- \sum_{j \in \mathcal{N}(i)}
42
- \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j},
40
+ \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}}
41
+ \alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j},
43
42
 
44
43
  where the attention coefficients :math:`\alpha_{i,j}` are computed as
45
44
 
@@ -108,6 +107,8 @@ class GATConv(MessagePassing):
108
107
  :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
109
108
  bias (bool, optional): If set to :obj:`False`, the layer will not learn
110
109
  an additive bias. (default: :obj:`True`)
110
+ residual (bool, optional): If set to :obj:`True`, the layer will add
111
+ a learnable skip-connection. (default: :obj:`False`)
111
112
  **kwargs (optional): Additional arguments of
112
113
  :class:`torch_geometric.nn.conv.MessagePassing`.
113
114
 
@@ -138,6 +139,7 @@ class GATConv(MessagePassing):
138
139
  edge_dim: Optional[int] = None,
139
140
  fill_value: Union[float, Tensor, str] = 'mean',
140
141
  bias: bool = True,
142
+ residual: bool = False,
141
143
  **kwargs,
142
144
  ):
143
145
  kwargs.setdefault('aggr', 'add')
@@ -152,6 +154,7 @@ class GATConv(MessagePassing):
152
154
  self.add_self_loops = add_self_loops
153
155
  self.edge_dim = edge_dim
154
156
  self.fill_value = fill_value
157
+ self.residual = residual
155
158
 
156
159
  # In case we are operating in bipartite graphs, we apply separate
157
160
  # transformations 'lin_src' and 'lin_dst' to source and target nodes:
@@ -177,10 +180,22 @@ class GATConv(MessagePassing):
177
180
  self.lin_edge = None
178
181
  self.register_parameter('att_edge', None)
179
182
 
180
- if bias and concat:
181
- self.bias = Parameter(torch.empty(heads * out_channels))
182
- elif bias and not concat:
183
- self.bias = Parameter(torch.empty(out_channels))
183
+ # The number of output channels:
184
+ total_out_channels = out_channels * (heads if concat else 1)
185
+
186
+ if residual:
187
+ self.res = Linear(
188
+ in_channels
189
+ if isinstance(in_channels, int) else in_channels[1],
190
+ total_out_channels,
191
+ bias=False,
192
+ weight_initializer='glorot',
193
+ )
194
+ else:
195
+ self.register_parameter('res', None)
196
+
197
+ if bias:
198
+ self.bias = Parameter(torch.empty(total_out_channels))
184
199
  else:
185
200
  self.register_parameter('bias', None)
186
201
 
@@ -196,6 +211,8 @@ class GATConv(MessagePassing):
196
211
  self.lin_dst.reset_parameters()
197
212
  if self.lin_edge is not None:
198
213
  self.lin_edge.reset_parameters()
214
+ if self.res is not None:
215
+ self.res.reset_parameters()
199
216
  glorot(self.att_src)
200
217
  glorot(self.att_dst)
201
218
  glorot(self.att_edge)
@@ -271,11 +288,16 @@ class GATConv(MessagePassing):
271
288
 
272
289
  H, C = self.heads, self.out_channels
273
290
 
291
+ res: Optional[Tensor] = None
292
+
274
293
  # We first transform the input node features. If a tuple is passed, we
275
294
  # transform source and target node features via separate weights:
276
295
  if isinstance(x, Tensor):
277
296
  assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
278
297
 
298
+ if self.res is not None:
299
+ res = self.res(x)
300
+
279
301
  if self.lin is not None:
280
302
  x_src = x_dst = self.lin(x).view(-1, H, C)
281
303
  else:
@@ -289,6 +311,9 @@ class GATConv(MessagePassing):
289
311
  x_src, x_dst = x
290
312
  assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
291
313
 
314
+ if x_dst is not None and self.res is not None:
315
+ res = self.res(x_dst)
316
+
292
317
  if self.lin is not None:
293
318
  # If the module is initialized as non-bipartite, we expect that
294
319
  # source and destination node features have the same shape and
@@ -345,6 +370,9 @@ class GATConv(MessagePassing):
345
370
  else:
346
371
  out = out.mean(dim=1)
347
372
 
373
+ if res is not None:
374
+ out = out + res
375
+
348
376
  if self.bias is not None:
349
377
  out = out + self.bias
350
378
 
@@ -41,8 +41,7 @@ class GATv2Conv(MessagePassing):
41
41
  In contrast, in :class:`GATv2`, every node can attend to any other node.
42
42
 
43
43
  .. math::
44
- \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}_{s}\mathbf{x}_{i} +
45
- \sum_{j \in \mathcal{N}(i)}
44
+ \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}}
46
45
  \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j},
47
46
 
48
47
  where the attention coefficients :math:`\alpha_{i,j}` are computed as
@@ -111,6 +110,8 @@ class GATv2Conv(MessagePassing):
111
110
  will be applied to the source and the target node of every edge,
112
111
  *i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`.
113
112
  (default: :obj:`False`)
113
+ residual (bool, optional): If set to :obj:`True`, the layer will add
114
+ a learnable skip-connection. (default: :obj:`False`)
114
115
  **kwargs (optional): Additional arguments of
115
116
  :class:`torch_geometric.nn.conv.MessagePassing`.
116
117
 
@@ -142,6 +143,7 @@ class GATv2Conv(MessagePassing):
142
143
  fill_value: Union[float, Tensor, str] = 'mean',
143
144
  bias: bool = True,
144
145
  share_weights: bool = False,
146
+ residual: bool = False,
145
147
  **kwargs,
146
148
  ):
147
149
  super().__init__(node_dim=0, **kwargs)
@@ -155,6 +157,7 @@ class GATv2Conv(MessagePassing):
155
157
  self.add_self_loops = add_self_loops
156
158
  self.edge_dim = edge_dim
157
159
  self.fill_value = fill_value
160
+ self.residual = residual
158
161
  self.share_weights = share_weights
159
162
 
160
163
  if isinstance(in_channels, int):
@@ -182,10 +185,22 @@ class GATv2Conv(MessagePassing):
182
185
  else:
183
186
  self.lin_edge = None
184
187
 
185
- if bias and concat:
186
- self.bias = Parameter(torch.empty(heads * out_channels))
187
- elif bias and not concat:
188
- self.bias = Parameter(torch.empty(out_channels))
188
+ # The number of output channels:
189
+ total_out_channels = out_channels * (heads if concat else 1)
190
+
191
+ if residual:
192
+ self.res = Linear(
193
+ in_channels
194
+ if isinstance(in_channels, int) else in_channels[1],
195
+ total_out_channels,
196
+ bias=False,
197
+ weight_initializer='glorot',
198
+ )
199
+ else:
200
+ self.register_parameter('res', None)
201
+
202
+ if bias:
203
+ self.bias = Parameter(torch.empty(total_out_channels))
189
204
  else:
190
205
  self.register_parameter('bias', None)
191
206
 
@@ -197,6 +212,8 @@ class GATv2Conv(MessagePassing):
197
212
  self.lin_r.reset_parameters()
198
213
  if self.lin_edge is not None:
199
214
  self.lin_edge.reset_parameters()
215
+ if self.res is not None:
216
+ self.res.reset_parameters()
200
217
  glorot(self.att)
201
218
  zeros(self.bias)
202
219
 
@@ -256,10 +273,16 @@ class GATv2Conv(MessagePassing):
256
273
  """
257
274
  H, C = self.heads, self.out_channels
258
275
 
276
+ res: Optional[Tensor] = None
277
+
259
278
  x_l: OptTensor = None
260
279
  x_r: OptTensor = None
261
280
  if isinstance(x, Tensor):
262
281
  assert x.dim() == 2
282
+
283
+ if self.res is not None:
284
+ res = self.res(x)
285
+
263
286
  x_l = self.lin_l(x).view(-1, H, C)
264
287
  if self.share_weights:
265
288
  x_r = x_l
@@ -268,6 +291,10 @@ class GATv2Conv(MessagePassing):
268
291
  else:
269
292
  x_l, x_r = x[0], x[1]
270
293
  assert x[0].dim() == 2
294
+
295
+ if x_r is not None and self.res is not None:
296
+ res = self.res(x_r)
297
+
271
298
  x_l = self.lin_l(x_l).view(-1, H, C)
272
299
  if x_r is not None:
273
300
  x_r = self.lin_r(x_r).view(-1, H, C)
@@ -306,6 +333,9 @@ class GATv2Conv(MessagePassing):
306
333
  else:
307
334
  out = out.mean(dim=1)
308
335
 
336
+ if res is not None:
337
+ out = out + res
338
+
309
339
  if self.bias is not None:
310
340
  out = out + self.bias
311
341
 
@@ -70,7 +70,7 @@ class GeneralConv(MessagePassing):
70
70
  self,
71
71
  in_channels: Union[int, Tuple[int, int]],
72
72
  out_channels: Optional[int],
73
- in_edge_channels: int = None,
73
+ in_edge_channels: Optional[int] = None,
74
74
  aggr: str = "add",
75
75
  skip_linear: str = False,
76
76
  directed_msg: bool = True,
@@ -4,14 +4,15 @@ from typing import Optional, Union
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
+ import torch_geometric.typing
7
8
  from torch_geometric.nn.conv import MessagePassing
8
9
  from torch_geometric.nn.dense.linear import Linear
9
10
  from torch_geometric.typing import OptPairTensor # noqa
10
11
  from torch_geometric.typing import OptTensor, PairOptTensor, PairTensor
11
12
 
12
- try:
13
+ if torch_geometric.typing.WITH_TORCH_CLUSTER:
13
14
  from torch_cluster import knn
14
- except ImportError:
15
+ else:
15
16
  knn = None
16
17
 
17
18
 
@@ -70,8 +70,8 @@ class HeteroConv(torch.nn.Module):
70
70
  for edge_type, module in convs.items():
71
71
  check_add_self_loops(module, [edge_type])
72
72
 
73
- src_node_types = set([key[0] for key in convs.keys()])
74
- dst_node_types = set([key[-1] for key in convs.keys()])
73
+ src_node_types = {key[0] for key in convs.keys()}
74
+ dst_node_types = {key[-1] for key in convs.keys()}
75
75
  if len(src_node_types - dst_node_types) > 0:
76
76
  warnings.warn(
77
77
  f"There exist node types ({src_node_types - dst_node_types}) "
@@ -102,7 +102,7 @@ class HeteroConv(torch.nn.Module):
102
102
  individual edge type, either as a :class:`torch.Tensor` of
103
103
  shape :obj:`[2, num_edges]` or a
104
104
  :class:`torch_sparse.SparseTensor`.
105
- *args_dict (optional): Additional forward arguments of invididual
105
+ *args_dict (optional): Additional forward arguments of individual
106
106
  :class:`torch_geometric.nn.conv.MessagePassing` layers.
107
107
  **kwargs_dict (optional): Additional forward arguments of
108
108
  individual :class:`torch_geometric.nn.conv.MessagePassing`
@@ -67,7 +67,7 @@ class HGTConv(MessagePassing):
67
67
  for i, edge_type in enumerate(metadata[1])
68
68
  }
69
69
 
70
- self.dst_node_types = set([key[-1] for key in self.edge_types])
70
+ self.dst_node_types = {key[-1] for key in self.edge_types}
71
71
 
72
72
  self.kqv_lin = HeteroDictLinear(self.in_channels,
73
73
  self.out_channels * 3)