pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,134 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import List, Optional, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from torch import Tensor
|
7
|
+
|
8
|
+
|
9
|
+
class PoolingStrategy(Enum):
|
10
|
+
MEAN = 'mean'
|
11
|
+
LAST = 'last'
|
12
|
+
CLS = 'cls'
|
13
|
+
LAST_HIDDEN_STATE = 'last_hidden_state'
|
14
|
+
|
15
|
+
|
16
|
+
class SentenceTransformer(torch.nn.Module):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
model_name: str,
|
20
|
+
pooling_strategy: Union[PoolingStrategy, str] = 'mean',
|
21
|
+
) -> None:
|
22
|
+
super().__init__()
|
23
|
+
|
24
|
+
self.model_name = model_name
|
25
|
+
self.pooling_strategy = PoolingStrategy(pooling_strategy)
|
26
|
+
|
27
|
+
from transformers import AutoModel, AutoTokenizer
|
28
|
+
|
29
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
30
|
+
self.model = AutoModel.from_pretrained(model_name)
|
31
|
+
if self.tokenizer.pad_token is None:
|
32
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
33
|
+
|
34
|
+
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
35
|
+
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
36
|
+
|
37
|
+
emb = out[0] # First element contains all token embeddings.
|
38
|
+
if self.pooling_strategy == PoolingStrategy.MEAN:
|
39
|
+
emb = mean_pooling(emb, attention_mask)
|
40
|
+
elif self.pooling_strategy == PoolingStrategy.LAST:
|
41
|
+
emb = last_pooling(emb, attention_mask)
|
42
|
+
elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
|
43
|
+
emb = out.last_hidden_state
|
44
|
+
else:
|
45
|
+
assert self.pooling_strategy == PoolingStrategy.CLS
|
46
|
+
emb = emb[:, 0, :]
|
47
|
+
|
48
|
+
emb = F.normalize(emb, p=2, dim=1)
|
49
|
+
return emb
|
50
|
+
|
51
|
+
def get_input_ids(
|
52
|
+
self,
|
53
|
+
text: List[str],
|
54
|
+
batch_size: Optional[int] = None,
|
55
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
56
|
+
) -> Tensor:
|
57
|
+
is_empty = len(text) == 0
|
58
|
+
text = ['dummy'] if is_empty else text
|
59
|
+
|
60
|
+
batch_size = len(text) if batch_size is None else batch_size
|
61
|
+
|
62
|
+
input_ids: List[Tensor] = []
|
63
|
+
attention_masks: List[Tensor] = []
|
64
|
+
for start in range(0, len(text), batch_size):
|
65
|
+
token = self.tokenizer(
|
66
|
+
text[start:start + batch_size],
|
67
|
+
padding=True,
|
68
|
+
truncation=True,
|
69
|
+
return_tensors='pt',
|
70
|
+
)
|
71
|
+
input_ids.append(token.input_ids.to(self.device))
|
72
|
+
attention_masks.append(token.attention_mask.to(self.device))
|
73
|
+
|
74
|
+
def _out(x: List[Tensor]) -> Tensor:
|
75
|
+
out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
|
76
|
+
out = out[:0] if is_empty else out
|
77
|
+
return out.to(output_device)
|
78
|
+
|
79
|
+
return _out(input_ids), _out(attention_masks)
|
80
|
+
|
81
|
+
@property
|
82
|
+
def device(self) -> torch.device:
|
83
|
+
return next(iter(self.model.parameters())).device
|
84
|
+
|
85
|
+
@torch.no_grad()
|
86
|
+
def encode(
|
87
|
+
self,
|
88
|
+
text: List[str],
|
89
|
+
batch_size: Optional[int] = None,
|
90
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
91
|
+
) -> Tensor:
|
92
|
+
is_empty = len(text) == 0
|
93
|
+
text = ['dummy'] if is_empty else text
|
94
|
+
|
95
|
+
batch_size = len(text) if batch_size is None else batch_size
|
96
|
+
|
97
|
+
embs: List[Tensor] = []
|
98
|
+
for start in range(0, len(text), batch_size):
|
99
|
+
token = self.tokenizer(
|
100
|
+
text[start:start + batch_size],
|
101
|
+
padding=True,
|
102
|
+
truncation=True,
|
103
|
+
return_tensors='pt',
|
104
|
+
)
|
105
|
+
|
106
|
+
emb = self(
|
107
|
+
input_ids=token.input_ids.to(self.device),
|
108
|
+
attention_mask=token.attention_mask.to(self.device),
|
109
|
+
).to(output_device)
|
110
|
+
|
111
|
+
embs.append(emb)
|
112
|
+
|
113
|
+
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
|
114
|
+
out = out[:0] if is_empty else out
|
115
|
+
return out
|
116
|
+
|
117
|
+
def __repr__(self) -> str:
|
118
|
+
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
119
|
+
|
120
|
+
|
121
|
+
def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
|
122
|
+
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
|
123
|
+
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
|
124
|
+
|
125
|
+
|
126
|
+
def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
|
127
|
+
# Check whether language model uses left padding,
|
128
|
+
# which is always used for decoder LLMs
|
129
|
+
left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
|
130
|
+
if left_padding:
|
131
|
+
return emb[:, -1]
|
132
|
+
|
133
|
+
seq_indices = attention_mask.sum(dim=1) - 1
|
134
|
+
return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
|
7
|
+
class VisionTransformer(torch.nn.Module):
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
model_name: str,
|
11
|
+
) -> None:
|
12
|
+
super().__init__()
|
13
|
+
self.model_name = model_name
|
14
|
+
|
15
|
+
from transformers import SwinConfig, SwinModel
|
16
|
+
|
17
|
+
self.config = SwinConfig.from_pretrained(model_name)
|
18
|
+
self.model = SwinModel(self.config)
|
19
|
+
|
20
|
+
@torch.no_grad()
|
21
|
+
def forward(
|
22
|
+
self,
|
23
|
+
images: Tensor,
|
24
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
25
|
+
) -> Tensor:
|
26
|
+
return self.model(images).last_hidden_state.to(output_device)
|
27
|
+
|
28
|
+
@property
|
29
|
+
def device(self) -> torch.device:
|
30
|
+
return next(iter(self.model.parameters())).device
|
31
|
+
|
32
|
+
def __repr__(self) -> str:
|
33
|
+
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
@@ -88,7 +88,7 @@ class BatchNorm(torch.nn.Module):
|
|
88
88
|
return self.module(x)
|
89
89
|
|
90
90
|
def __repr__(self):
|
91
|
-
return f'{self.__class__.__name__}({self.module.
|
91
|
+
return f'{self.__class__.__name__}({self.module.extra_repr()})'
|
92
92
|
|
93
93
|
|
94
94
|
class HeteroBatchNorm(torch.nn.Module):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Final, Iterable, Mapping, Optional,
|
1
|
+
from typing import Final, Iterable, Mapping, Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch.nn import Parameter
|
@@ -11,7 +11,7 @@ Key = Union[str, Tuple[str, ...]]
|
|
11
11
|
# internal representation and converts it back to `.` in the external
|
12
12
|
# representation. It also allows passing tuples as keys.
|
13
13
|
class ParameterDict(torch.nn.ParameterDict):
|
14
|
-
CLASS_ATTRS: Final[
|
14
|
+
CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict))
|
15
15
|
|
16
16
|
def __init__(
|
17
17
|
self,
|
@@ -7,18 +7,19 @@ from torch import Tensor
|
|
7
7
|
import torch_geometric.typing
|
8
8
|
from torch_geometric.typing import OptTensor, torch_cluster
|
9
9
|
|
10
|
-
from .asap import ASAPooling
|
11
10
|
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
|
12
|
-
from .edge_pool import EdgePooling
|
13
11
|
from .glob import global_add_pool, global_max_pool, global_mean_pool
|
14
12
|
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
|
15
13
|
ApproxMIPSKNNIndex)
|
16
14
|
from .graclus import graclus
|
17
15
|
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
|
18
|
-
from .mem_pool import MemPooling
|
19
|
-
from .pan_pool import PANPooling
|
20
|
-
from .sag_pool import SAGPooling
|
21
16
|
from .topk_pool import TopKPooling
|
17
|
+
from .sag_pool import SAGPooling
|
18
|
+
from .edge_pool import EdgePooling
|
19
|
+
from .cluster_pool import ClusterPooling
|
20
|
+
from .asap import ASAPooling
|
21
|
+
from .pan_pool import PANPooling
|
22
|
+
from .mem_pool import MemPooling
|
22
23
|
from .voxel_grid import voxel_grid
|
23
24
|
from .approx_knn import approx_knn, approx_knn_graph
|
24
25
|
|
@@ -218,6 +219,13 @@ def radius(
|
|
218
219
|
Automatically calculated if not given. (default: :obj:`None`)
|
219
220
|
|
220
221
|
:rtype: :class:`torch.Tensor`
|
222
|
+
|
223
|
+
.. warning::
|
224
|
+
|
225
|
+
The CPU implementation of :meth:`radius` with :obj:`max_num_neighbors`
|
226
|
+
is biased towards certain quadrants.
|
227
|
+
Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving
|
228
|
+
inputs to GPU before proceeding.
|
221
229
|
"""
|
222
230
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
223
231
|
return torch_cluster.radius(x, y, r, batch_x, batch_y,
|
@@ -268,6 +276,13 @@ def radius_graph(
|
|
268
276
|
Automatically calculated if not given. (default: :obj:`None`)
|
269
277
|
|
270
278
|
:rtype: :class:`torch.Tensor`
|
279
|
+
|
280
|
+
.. warning::
|
281
|
+
|
282
|
+
The CPU implementation of :meth:`radius_graph` with
|
283
|
+
:obj:`max_num_neighbors` is biased towards certain quadrants.
|
284
|
+
Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving
|
285
|
+
inputs to GPU before proceeding.
|
271
286
|
"""
|
272
287
|
if batch is not None and x.device != batch.device:
|
273
288
|
warnings.warn("Input tensor 'x' and 'batch' are on different devices "
|
@@ -330,6 +345,7 @@ __all__ = [
|
|
330
345
|
'TopKPooling',
|
331
346
|
'SAGPooling',
|
332
347
|
'EdgePooling',
|
348
|
+
'ClusterPooling',
|
333
349
|
'ASAPooling',
|
334
350
|
'PANPooling',
|
335
351
|
'MemPooling',
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from typing import NamedTuple, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from torch import Tensor
|
6
|
+
|
7
|
+
from torch_geometric.utils import (
|
8
|
+
dense_to_sparse,
|
9
|
+
one_hot,
|
10
|
+
to_dense_adj,
|
11
|
+
to_scipy_sparse_matrix,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class UnpoolInfo(NamedTuple):
|
16
|
+
edge_index: Tensor
|
17
|
+
cluster: Tensor
|
18
|
+
batch: Tensor
|
19
|
+
|
20
|
+
|
21
|
+
class ClusterPooling(torch.nn.Module):
|
22
|
+
r"""The cluster pooling operator from the `"Edge-Based Graph Component
|
23
|
+
Pooling" <paper url>`_ paper.
|
24
|
+
|
25
|
+
:class:`ClusterPooling` computes a score for each edge.
|
26
|
+
Based on the selected edges, graph clusters are calculated and compressed
|
27
|
+
to one node using the injective :obj:`"sum"` aggregation function.
|
28
|
+
Edges are remapped based on the nodes created by each cluster and the
|
29
|
+
original edges.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
in_channels (int): Size of each input sample.
|
33
|
+
edge_score_method (str, optional): The function to apply
|
34
|
+
to compute the edge score from raw edge scores (:obj:`"tanh"`,
|
35
|
+
:obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
|
36
|
+
dropout (float, optional): The probability with
|
37
|
+
which to drop edge scores during training. (default: :obj:`0.0`)
|
38
|
+
threshold (float, optional): The threshold of edge scores. If set to
|
39
|
+
:obj:`None`, will be automatically inferred depending on
|
40
|
+
:obj:`edge_score_method`. (default: :obj:`None`)
|
41
|
+
"""
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
in_channels: int,
|
45
|
+
edge_score_method: str = 'tanh',
|
46
|
+
dropout: float = 0.0,
|
47
|
+
threshold: Optional[float] = None,
|
48
|
+
):
|
49
|
+
super().__init__()
|
50
|
+
assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
|
51
|
+
|
52
|
+
if threshold is None:
|
53
|
+
threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
|
54
|
+
|
55
|
+
self.in_channels = in_channels
|
56
|
+
self.edge_score_method = edge_score_method
|
57
|
+
self.dropout = dropout
|
58
|
+
self.threshhold = threshold
|
59
|
+
|
60
|
+
self.lin = torch.nn.Linear(2 * in_channels, 1)
|
61
|
+
|
62
|
+
def reset_parameters(self):
|
63
|
+
r"""Resets all learnable parameters of the module."""
|
64
|
+
self.lin.reset_parameters()
|
65
|
+
|
66
|
+
def forward(
|
67
|
+
self,
|
68
|
+
x: Tensor,
|
69
|
+
edge_index: Tensor,
|
70
|
+
batch: Tensor,
|
71
|
+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
|
72
|
+
r"""Forward pass.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
x (torch.Tensor): The node features.
|
76
|
+
edge_index (torch.Tensor): The edge indices.
|
77
|
+
batch (torch.Tensor): Batch vector
|
78
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
79
|
+
each node to a specific example.
|
80
|
+
|
81
|
+
Return types:
|
82
|
+
* **x** *(torch.Tensor)* - The pooled node features.
|
83
|
+
* **edge_index** *(torch.Tensor)* - The coarsened edge indices.
|
84
|
+
* **batch** *(torch.Tensor)* - The coarsened batch vector.
|
85
|
+
* **unpool_info** *(UnpoolInfo)* - Information that can be consumed
|
86
|
+
for unpooling.
|
87
|
+
"""
|
88
|
+
mask = edge_index[0] != edge_index[1]
|
89
|
+
edge_index = edge_index[:, mask]
|
90
|
+
|
91
|
+
edge_attr = torch.cat(
|
92
|
+
[x[edge_index[0]], x[edge_index[1]]],
|
93
|
+
dim=-1,
|
94
|
+
)
|
95
|
+
edge_score = self.lin(edge_attr).view(-1)
|
96
|
+
edge_score = F.dropout(edge_score, p=self.dropout,
|
97
|
+
training=self.training)
|
98
|
+
|
99
|
+
if self.edge_score_method == 'tanh':
|
100
|
+
edge_score = edge_score.tanh()
|
101
|
+
elif self.edge_score_method == 'sigmoid':
|
102
|
+
edge_score = edge_score.sigmoid()
|
103
|
+
else:
|
104
|
+
assert self.edge_score_method == 'log_softmax'
|
105
|
+
edge_score = F.log_softmax(edge_score, dim=0)
|
106
|
+
|
107
|
+
return self._merge_edges(x, edge_index, batch, edge_score)
|
108
|
+
|
109
|
+
def _merge_edges(
|
110
|
+
self,
|
111
|
+
x: Tensor,
|
112
|
+
edge_index: Tensor,
|
113
|
+
batch: Tensor,
|
114
|
+
edge_score: Tensor,
|
115
|
+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
|
116
|
+
|
117
|
+
from scipy.sparse.csgraph import connected_components
|
118
|
+
|
119
|
+
edge_contract = edge_index[:, edge_score > self.threshhold]
|
120
|
+
|
121
|
+
adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
|
122
|
+
_, cluster_np = connected_components(adj, directed=True,
|
123
|
+
connection="weak")
|
124
|
+
|
125
|
+
cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
|
126
|
+
C = one_hot(cluster)
|
127
|
+
A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
|
128
|
+
S = to_dense_adj(edge_index, edge_attr=edge_score,
|
129
|
+
max_num_nodes=x.size(0)).squeeze(0)
|
130
|
+
|
131
|
+
A_contract = to_dense_adj(edge_contract,
|
132
|
+
max_num_nodes=x.size(0)).squeeze(0)
|
133
|
+
nodes_single = ((A_contract.sum(dim=-1) +
|
134
|
+
A_contract.sum(dim=-2)) == 0).nonzero()
|
135
|
+
S[nodes_single, nodes_single] = 1.0
|
136
|
+
|
137
|
+
x_out = (S @ C).t() @ x
|
138
|
+
edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
|
139
|
+
batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
|
140
|
+
unpool_info = UnpoolInfo(edge_index, cluster, batch)
|
141
|
+
|
142
|
+
return x_out, edge_index_out, batch_out, unpool_info
|
143
|
+
|
144
|
+
def __repr__(self) -> str:
|
145
|
+
return f'{self.__class__.__name__}({self.in_channels})'
|
torch_geometric/nn/pool/pool.py
CHANGED
@@ -5,12 +5,18 @@ import torch
|
|
5
5
|
from torch_geometric.utils import coalesce, remove_self_loops, scatter
|
6
6
|
|
7
7
|
|
8
|
-
def pool_edge(
|
8
|
+
def pool_edge(
|
9
|
+
cluster,
|
10
|
+
edge_index,
|
11
|
+
edge_attr: Optional[torch.Tensor] = None,
|
12
|
+
reduce: Optional[str] = 'sum',
|
13
|
+
):
|
9
14
|
num_nodes = cluster.size(0)
|
10
15
|
edge_index = cluster[edge_index.view(-1)].view(2, -1)
|
11
16
|
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
|
12
17
|
if edge_index.numel() > 0:
|
13
|
-
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes
|
18
|
+
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
|
19
|
+
reduce=reduce)
|
14
20
|
return edge_index, edge_attr
|
15
21
|
|
16
22
|
|
@@ -3,11 +3,12 @@ from typing import List, Optional, Union
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
+
import torch_geometric.typing
|
6
7
|
from torch_geometric.utils.repeat import repeat
|
7
8
|
|
8
|
-
|
9
|
+
if torch_geometric.typing.WITH_TORCH_CLUSTER:
|
9
10
|
from torch_cluster import grid_cluster
|
10
|
-
|
11
|
+
else:
|
11
12
|
grid_cluster = None
|
12
13
|
|
13
14
|
|
torch_geometric/nn/resolver.py
CHANGED
@@ -1,35 +1,22 @@
|
|
1
1
|
import typing
|
2
|
-
from typing import *
|
3
2
|
|
4
3
|
import torch
|
5
4
|
from torch import Tensor
|
6
5
|
|
7
6
|
import torch_geometric.typing
|
8
|
-
|
7
|
+
{% for module in modules %}
|
8
|
+
from {{module}} import *
|
9
|
+
{%- endfor %}
|
9
10
|
|
10
11
|
|
11
|
-
|
12
|
-
|
13
|
-
{%- for
|
14
|
-
|
15
|
-
self.{{child.name}}.reset_parameters()
|
12
|
+
def forward(
|
13
|
+
self,
|
14
|
+
{%- for param in signature.param_dict.values() %}
|
15
|
+
{{param.name}}: {{param.type_repr}},
|
16
16
|
{%- endfor %}
|
17
|
+
) -> {{signature.return_type_repr}}:
|
17
18
|
|
18
|
-
def forward(self, {{ input_types|join(', ') }}) -> {{return_type}}:
|
19
19
|
{%- for child in children %}
|
20
|
-
|
20
|
+
{{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})
|
21
21
|
{%- endfor %}
|
22
|
-
|
23
|
-
|
24
|
-
def __getitem__(self, idx: int) -> torch.nn.Module:
|
25
|
-
return getattr(self, self._module_names[idx])
|
26
|
-
|
27
|
-
def __len__(self) -> int:
|
28
|
-
return {{children|length}}
|
29
|
-
|
30
|
-
def __repr__(self) -> str:
|
31
|
-
module_reprs = [
|
32
|
-
f' ({i}) - {self[i]}: {self._module_descs[i]}'
|
33
|
-
for i in range(len(self))
|
34
|
-
]
|
35
|
-
return 'Sequential(\n{}\n)'.format('\n'.join(module_reprs))
|
22
|
+
return {{children[-1].return_names|join(', ')}}
|