pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +13 -7
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +317 -65
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +3 -5
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +329 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +56 -22
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,134 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+
9
+ class PoolingStrategy(Enum):
10
+ MEAN = 'mean'
11
+ LAST = 'last'
12
+ CLS = 'cls'
13
+ LAST_HIDDEN_STATE = 'last_hidden_state'
14
+
15
+
16
+ class SentenceTransformer(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ pooling_strategy: Union[PoolingStrategy, str] = 'mean',
21
+ ) -> None:
22
+ super().__init__()
23
+
24
+ self.model_name = model_name
25
+ self.pooling_strategy = PoolingStrategy(pooling_strategy)
26
+
27
+ from transformers import AutoModel, AutoTokenizer
28
+
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ self.model = AutoModel.from_pretrained(model_name)
31
+ if self.tokenizer.pad_token is None:
32
+ self.tokenizer.pad_token = self.tokenizer.eos_token
33
+
34
+ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
35
+ out = self.model(input_ids=input_ids, attention_mask=attention_mask)
36
+
37
+ emb = out[0] # First element contains all token embeddings.
38
+ if self.pooling_strategy == PoolingStrategy.MEAN:
39
+ emb = mean_pooling(emb, attention_mask)
40
+ elif self.pooling_strategy == PoolingStrategy.LAST:
41
+ emb = last_pooling(emb, attention_mask)
42
+ elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
43
+ emb = out.last_hidden_state
44
+ else:
45
+ assert self.pooling_strategy == PoolingStrategy.CLS
46
+ emb = emb[:, 0, :]
47
+
48
+ emb = F.normalize(emb, p=2, dim=1)
49
+ return emb
50
+
51
+ def get_input_ids(
52
+ self,
53
+ text: List[str],
54
+ batch_size: Optional[int] = None,
55
+ output_device: Optional[Union[torch.device, str]] = None,
56
+ ) -> Tensor:
57
+ is_empty = len(text) == 0
58
+ text = ['dummy'] if is_empty else text
59
+
60
+ batch_size = len(text) if batch_size is None else batch_size
61
+
62
+ input_ids: List[Tensor] = []
63
+ attention_masks: List[Tensor] = []
64
+ for start in range(0, len(text), batch_size):
65
+ token = self.tokenizer(
66
+ text[start:start + batch_size],
67
+ padding=True,
68
+ truncation=True,
69
+ return_tensors='pt',
70
+ )
71
+ input_ids.append(token.input_ids.to(self.device))
72
+ attention_masks.append(token.attention_mask.to(self.device))
73
+
74
+ def _out(x: List[Tensor]) -> Tensor:
75
+ out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
76
+ out = out[:0] if is_empty else out
77
+ return out.to(output_device)
78
+
79
+ return _out(input_ids), _out(attention_masks)
80
+
81
+ @property
82
+ def device(self) -> torch.device:
83
+ return next(iter(self.model.parameters())).device
84
+
85
+ @torch.no_grad()
86
+ def encode(
87
+ self,
88
+ text: List[str],
89
+ batch_size: Optional[int] = None,
90
+ output_device: Optional[Union[torch.device, str]] = None,
91
+ ) -> Tensor:
92
+ is_empty = len(text) == 0
93
+ text = ['dummy'] if is_empty else text
94
+
95
+ batch_size = len(text) if batch_size is None else batch_size
96
+
97
+ embs: List[Tensor] = []
98
+ for start in range(0, len(text), batch_size):
99
+ token = self.tokenizer(
100
+ text[start:start + batch_size],
101
+ padding=True,
102
+ truncation=True,
103
+ return_tensors='pt',
104
+ )
105
+
106
+ emb = self(
107
+ input_ids=token.input_ids.to(self.device),
108
+ attention_mask=token.attention_mask.to(self.device),
109
+ ).to(output_device)
110
+
111
+ embs.append(emb)
112
+
113
+ out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
114
+ out = out[:0] if is_empty else out
115
+ return out
116
+
117
+ def __repr__(self) -> str:
118
+ return f'{self.__class__.__name__}(model_name={self.model_name})'
119
+
120
+
121
+ def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
122
+ mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
123
+ return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
124
+
125
+
126
+ def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
127
+ # Check whether language model uses left padding,
128
+ # which is always used for decoder LLMs
129
+ left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
130
+ if left_padding:
131
+ return emb[:, -1]
132
+
133
+ seq_indices = attention_mask.sum(dim=1) - 1
134
+ return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]
@@ -0,0 +1,33 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class VisionTransformer(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ model_name: str,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.model_name = model_name
14
+
15
+ from transformers import SwinConfig, SwinModel
16
+
17
+ self.config = SwinConfig.from_pretrained(model_name)
18
+ self.model = SwinModel(self.config)
19
+
20
+ @torch.no_grad()
21
+ def forward(
22
+ self,
23
+ images: Tensor,
24
+ output_device: Optional[Union[torch.device, str]] = None,
25
+ ) -> Tensor:
26
+ return self.model(images).last_hidden_state.to(output_device)
27
+
28
+ @property
29
+ def device(self) -> torch.device:
30
+ return next(iter(self.model.parameters())).device
31
+
32
+ def __repr__(self) -> str:
33
+ return f'{self.__class__.__name__}(model_name={self.model_name})'
@@ -88,7 +88,7 @@ class BatchNorm(torch.nn.Module):
88
88
  return self.module(x)
89
89
 
90
90
  def __repr__(self):
91
- return f'{self.__class__.__name__}({self.module.num_features})'
91
+ return f'{self.__class__.__name__}({self.module.extra_repr()})'
92
92
 
93
93
 
94
94
  class HeteroBatchNorm(torch.nn.Module):
@@ -1,4 +1,4 @@
1
- from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union
1
+ from typing import Final, Iterable, Mapping, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  from torch.nn import Parameter
@@ -11,7 +11,7 @@ Key = Union[str, Tuple[str, ...]]
11
11
  # internal representation and converts it back to `.` in the external
12
12
  # representation. It also allows passing tuples as keys.
13
13
  class ParameterDict(torch.nn.ParameterDict):
14
- CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict))
14
+ CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict))
15
15
 
16
16
  def __init__(
17
17
  self,
@@ -7,18 +7,19 @@ from torch import Tensor
7
7
  import torch_geometric.typing
8
8
  from torch_geometric.typing import OptTensor, torch_cluster
9
9
 
10
- from .asap import ASAPooling
11
10
  from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
12
- from .edge_pool import EdgePooling
13
11
  from .glob import global_add_pool, global_max_pool, global_mean_pool
14
12
  from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
15
13
  ApproxMIPSKNNIndex)
16
14
  from .graclus import graclus
17
15
  from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
18
- from .mem_pool import MemPooling
19
- from .pan_pool import PANPooling
20
- from .sag_pool import SAGPooling
21
16
  from .topk_pool import TopKPooling
17
+ from .sag_pool import SAGPooling
18
+ from .edge_pool import EdgePooling
19
+ from .cluster_pool import ClusterPooling
20
+ from .asap import ASAPooling
21
+ from .pan_pool import PANPooling
22
+ from .mem_pool import MemPooling
22
23
  from .voxel_grid import voxel_grid
23
24
  from .approx_knn import approx_knn, approx_knn_graph
24
25
 
@@ -218,6 +219,13 @@ def radius(
218
219
  Automatically calculated if not given. (default: :obj:`None`)
219
220
 
220
221
  :rtype: :class:`torch.Tensor`
222
+
223
+ .. warning::
224
+
225
+ The CPU implementation of :meth:`radius` with :obj:`max_num_neighbors`
226
+ is biased towards certain quadrants.
227
+ Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving
228
+ inputs to GPU before proceeding.
221
229
  """
222
230
  if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
223
231
  return torch_cluster.radius(x, y, r, batch_x, batch_y,
@@ -268,6 +276,13 @@ def radius_graph(
268
276
  Automatically calculated if not given. (default: :obj:`None`)
269
277
 
270
278
  :rtype: :class:`torch.Tensor`
279
+
280
+ .. warning::
281
+
282
+ The CPU implementation of :meth:`radius_graph` with
283
+ :obj:`max_num_neighbors` is biased towards certain quadrants.
284
+ Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving
285
+ inputs to GPU before proceeding.
271
286
  """
272
287
  if batch is not None and x.device != batch.device:
273
288
  warnings.warn("Input tensor 'x' and 'batch' are on different devices "
@@ -330,6 +345,7 @@ __all__ = [
330
345
  'TopKPooling',
331
346
  'SAGPooling',
332
347
  'EdgePooling',
348
+ 'ClusterPooling',
333
349
  'ASAPooling',
334
350
  'PANPooling',
335
351
  'MemPooling',
@@ -0,0 +1,145 @@
1
+ from typing import NamedTuple, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.utils import (
8
+ dense_to_sparse,
9
+ one_hot,
10
+ to_dense_adj,
11
+ to_scipy_sparse_matrix,
12
+ )
13
+
14
+
15
+ class UnpoolInfo(NamedTuple):
16
+ edge_index: Tensor
17
+ cluster: Tensor
18
+ batch: Tensor
19
+
20
+
21
+ class ClusterPooling(torch.nn.Module):
22
+ r"""The cluster pooling operator from the `"Edge-Based Graph Component
23
+ Pooling" <paper url>`_ paper.
24
+
25
+ :class:`ClusterPooling` computes a score for each edge.
26
+ Based on the selected edges, graph clusters are calculated and compressed
27
+ to one node using the injective :obj:`"sum"` aggregation function.
28
+ Edges are remapped based on the nodes created by each cluster and the
29
+ original edges.
30
+
31
+ Args:
32
+ in_channels (int): Size of each input sample.
33
+ edge_score_method (str, optional): The function to apply
34
+ to compute the edge score from raw edge scores (:obj:`"tanh"`,
35
+ :obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
36
+ dropout (float, optional): The probability with
37
+ which to drop edge scores during training. (default: :obj:`0.0`)
38
+ threshold (float, optional): The threshold of edge scores. If set to
39
+ :obj:`None`, will be automatically inferred depending on
40
+ :obj:`edge_score_method`. (default: :obj:`None`)
41
+ """
42
+ def __init__(
43
+ self,
44
+ in_channels: int,
45
+ edge_score_method: str = 'tanh',
46
+ dropout: float = 0.0,
47
+ threshold: Optional[float] = None,
48
+ ):
49
+ super().__init__()
50
+ assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
51
+
52
+ if threshold is None:
53
+ threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
54
+
55
+ self.in_channels = in_channels
56
+ self.edge_score_method = edge_score_method
57
+ self.dropout = dropout
58
+ self.threshhold = threshold
59
+
60
+ self.lin = torch.nn.Linear(2 * in_channels, 1)
61
+
62
+ def reset_parameters(self):
63
+ r"""Resets all learnable parameters of the module."""
64
+ self.lin.reset_parameters()
65
+
66
+ def forward(
67
+ self,
68
+ x: Tensor,
69
+ edge_index: Tensor,
70
+ batch: Tensor,
71
+ ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
72
+ r"""Forward pass.
73
+
74
+ Args:
75
+ x (torch.Tensor): The node features.
76
+ edge_index (torch.Tensor): The edge indices.
77
+ batch (torch.Tensor): Batch vector
78
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
79
+ each node to a specific example.
80
+
81
+ Return types:
82
+ * **x** *(torch.Tensor)* - The pooled node features.
83
+ * **edge_index** *(torch.Tensor)* - The coarsened edge indices.
84
+ * **batch** *(torch.Tensor)* - The coarsened batch vector.
85
+ * **unpool_info** *(UnpoolInfo)* - Information that can be consumed
86
+ for unpooling.
87
+ """
88
+ mask = edge_index[0] != edge_index[1]
89
+ edge_index = edge_index[:, mask]
90
+
91
+ edge_attr = torch.cat(
92
+ [x[edge_index[0]], x[edge_index[1]]],
93
+ dim=-1,
94
+ )
95
+ edge_score = self.lin(edge_attr).view(-1)
96
+ edge_score = F.dropout(edge_score, p=self.dropout,
97
+ training=self.training)
98
+
99
+ if self.edge_score_method == 'tanh':
100
+ edge_score = edge_score.tanh()
101
+ elif self.edge_score_method == 'sigmoid':
102
+ edge_score = edge_score.sigmoid()
103
+ else:
104
+ assert self.edge_score_method == 'log_softmax'
105
+ edge_score = F.log_softmax(edge_score, dim=0)
106
+
107
+ return self._merge_edges(x, edge_index, batch, edge_score)
108
+
109
+ def _merge_edges(
110
+ self,
111
+ x: Tensor,
112
+ edge_index: Tensor,
113
+ batch: Tensor,
114
+ edge_score: Tensor,
115
+ ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
116
+
117
+ from scipy.sparse.csgraph import connected_components
118
+
119
+ edge_contract = edge_index[:, edge_score > self.threshhold]
120
+
121
+ adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
122
+ _, cluster_np = connected_components(adj, directed=True,
123
+ connection="weak")
124
+
125
+ cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
126
+ C = one_hot(cluster)
127
+ A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
128
+ S = to_dense_adj(edge_index, edge_attr=edge_score,
129
+ max_num_nodes=x.size(0)).squeeze(0)
130
+
131
+ A_contract = to_dense_adj(edge_contract,
132
+ max_num_nodes=x.size(0)).squeeze(0)
133
+ nodes_single = ((A_contract.sum(dim=-1) +
134
+ A_contract.sum(dim=-2)) == 0).nonzero()
135
+ S[nodes_single, nodes_single] = 1.0
136
+
137
+ x_out = (S @ C).t() @ x
138
+ edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
139
+ batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
140
+ unpool_info = UnpoolInfo(edge_index, cluster, batch)
141
+
142
+ return x_out, edge_index_out, batch_out, unpool_info
143
+
144
+ def __repr__(self) -> str:
145
+ return f'{self.__class__.__name__}({self.in_channels})'
@@ -66,7 +66,6 @@ class Connect(torch.nn.Module):
66
66
  """
67
67
  def reset_parameters(self):
68
68
  r"""Resets all learnable parameters of the module."""
69
- pass
70
69
 
71
70
  def forward(
72
71
  self,
@@ -58,7 +58,7 @@ class EdgePooling(torch.nn.Module):
58
58
  self,
59
59
  in_channels: int,
60
60
  edge_score_method: Optional[Callable] = None,
61
- dropout: Optional[float] = 0.0,
61
+ dropout: float = 0.0,
62
62
  add_to_edge_score: float = 0.5,
63
63
  ):
64
64
  super().__init__()
@@ -2,9 +2,11 @@ from typing import Optional
2
2
 
3
3
  from torch import Tensor
4
4
 
5
- try:
5
+ import torch_geometric.typing
6
+
7
+ if torch_geometric.typing.WITH_TORCH_CLUSTER:
6
8
  from torch_cluster import graclus_cluster
7
- except ImportError:
9
+ else:
8
10
  graclus_cluster = None
9
11
 
10
12
 
@@ -5,12 +5,18 @@ import torch
5
5
  from torch_geometric.utils import coalesce, remove_self_loops, scatter
6
6
 
7
7
 
8
- def pool_edge(cluster, edge_index, edge_attr: Optional[torch.Tensor] = None):
8
+ def pool_edge(
9
+ cluster,
10
+ edge_index,
11
+ edge_attr: Optional[torch.Tensor] = None,
12
+ reduce: Optional[str] = 'sum',
13
+ ):
9
14
  num_nodes = cluster.size(0)
10
15
  edge_index = cluster[edge_index.view(-1)].view(2, -1)
11
16
  edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
12
17
  if edge_index.numel() > 0:
13
- edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes)
18
+ edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
19
+ reduce=reduce)
14
20
  return edge_index, edge_attr
15
21
 
16
22
 
@@ -80,7 +80,6 @@ class Select(torch.nn.Module):
80
80
  """
81
81
  def reset_parameters(self):
82
82
  r"""Resets all learnable parameters of the module."""
83
- pass
84
83
 
85
84
  def forward(self, *args, **kwargs) -> SelectOutput:
86
85
  raise NotImplementedError
@@ -3,11 +3,12 @@ from typing import List, Optional, Union
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
+ import torch_geometric.typing
6
7
  from torch_geometric.utils.repeat import repeat
7
8
 
8
- try:
9
+ if torch_geometric.typing.WITH_TORCH_CLUSTER:
9
10
  from torch_cluster import grid_cluster
10
- except ImportError:
11
+ else:
11
12
  grid_cluster = None
12
13
 
13
14
 
@@ -166,5 +166,5 @@ def lr_scheduler_resolver(
166
166
  return obj
167
167
  return cls
168
168
 
169
- choices = set(cls.__name__ for cls in classes)
169
+ choices = {cls.__name__ for cls in classes}
170
170
  raise ValueError(f"Could not resolve '{query}' among choices {choices}")
@@ -1,35 +1,22 @@
1
1
  import typing
2
- from typing import *
3
2
 
4
3
  import torch
5
4
  from torch import Tensor
6
5
 
7
6
  import torch_geometric.typing
8
- from torch_geometric.typing import *
7
+ {% for module in modules %}
8
+ from {{module}} import *
9
+ {%- endfor %}
9
10
 
10
11
 
11
- class Sequential(torch.nn.Module):
12
- def reset_parameters(self) -> None:
13
- {%- for child in children %}
14
- if hasattr(self.{{child.name}}, 'reset_parameters'):
15
- self.{{child.name}}.reset_parameters()
12
+ def forward(
13
+ self,
14
+ {%- for param in signature.param_dict.values() %}
15
+ {{param.name}}: {{param.type_repr}},
16
16
  {%- endfor %}
17
+ ) -> {{signature.return_type_repr}}:
17
18
 
18
- def forward(self, {{ input_types|join(', ') }}) -> {{return_type}}:
19
19
  {%- for child in children %}
20
- {{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})
20
+ {{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})
21
21
  {%- endfor %}
22
- return {{children[-1].return_names|join(', ')}}
23
-
24
- def __getitem__(self, idx: int) -> torch.nn.Module:
25
- return getattr(self, self._module_names[idx])
26
-
27
- def __len__(self) -> int:
28
- return {{children|length}}
29
-
30
- def __repr__(self) -> str:
31
- module_reprs = [
32
- f' ({i}) - {self[i]}: {self._module_descs[i]}'
33
- for i in range(len(self))
34
- ]
35
- return 'Sequential(\n{}\n)'.format('\n'.join(module_reprs))
22
+ return {{children[-1].return_names|join(', ')}}