pyg-nightly 2.7.0.dev20250214__py3-none-any.whl → 2.7.0.dev20250216__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.7.0.dev20250214.dist-info → pyg_nightly-2.7.0.dev20250216.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250214.dist-info → pyg_nightly-2.7.0.dev20250216.dist-info}/RECORD +10 -7
- torch_geometric/__init__.py +4 -1
- torch_geometric/hash_tensor.py +133 -0
- torch_geometric/nn/attention/__init__.py +2 -0
- torch_geometric/nn/attention/sgformer.py +78 -0
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/sgformer.py +190 -0
- torch_geometric/typing.py +41 -0
- {pyg_nightly-2.7.0.dev20250214.dist-info → pyg_nightly-2.7.0.dev20250216.dist-info}/WHEEL +0 -0
{pyg_nightly-2.7.0.dev20250214.dist-info → pyg_nightly-2.7.0.dev20250216.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250216
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -22,6 +22,7 @@ Requires-Dist: psutil>=5.8.0
|
|
22
22
|
Requires-Dist: pyparsing
|
23
23
|
Requires-Dist: requests
|
24
24
|
Requires-Dist: tqdm
|
25
|
+
Requires-Dist: xxhash
|
25
26
|
Requires-Dist: matplotlib ; extra == "benchmark"
|
26
27
|
Requires-Dist: networkx ; extra == "benchmark"
|
27
28
|
Requires-Dist: pandas ; extra == "benchmark"
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=QlYg3d7vRWC0k_OzPCmAanyTUBiTvikPuzM9feT0REk,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -9,6 +9,7 @@ torch_geometric/deprecation.py,sha256=dWRymDIUkUVI2MeEmBG5WF4R6jObZeseSBV9G6FNfj
|
|
9
9
|
torch_geometric/device.py,sha256=tU5-_lBNVbVHl_kUmWPwiG5mQ1pyapwMF4JkmtNN3MM,1224
|
10
10
|
torch_geometric/edge_index.py,sha256=BsLh5tOZRjjSYDkjqOFAdBuvMaDg7EWaaLELYsUL0Z8,70048
|
11
11
|
torch_geometric/experimental.py,sha256=JbtNNEXjFGI8hZ9raM6-qrZURP6Z5nlDK8QicZUIbz0,4756
|
12
|
+
torch_geometric/hash_tensor.py,sha256=VBuz9n16ouSk2u4DifpCIW5MXiuoH5UGO4rPj-Celjw,4418
|
12
13
|
torch_geometric/home.py,sha256=EV54B4Dmiv61GDbkCwtCfWGWJ4eFGwZ8s3KOgGjwYgY,790
|
13
14
|
torch_geometric/index.py,sha256=9ChzWFCwj2slNcVBOgfV-wQn-KscJe_y7502w-Vf76w,24045
|
14
15
|
torch_geometric/inspector.py,sha256=nKi5o4Mn6xsG0Ex1GudTEQt_EqnF9mcMqGtp7Shh9sQ,19336
|
@@ -18,7 +19,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
|
|
18
19
|
torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
|
19
20
|
torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
|
20
21
|
torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
|
21
|
-
torch_geometric/typing.py,sha256=
|
22
|
+
torch_geometric/typing.py,sha256=mtSM6QhCsrohstnyvqMuxEIajCkhcvkQKOU4uVu-nDs,15596
|
22
23
|
torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
|
23
24
|
torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
|
24
25
|
torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
|
@@ -330,9 +331,10 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
|
|
330
331
|
torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
|
331
332
|
torch_geometric/nn/aggr/utils.py,sha256=CLJ-ZrVWYIOBpdhQBLAz94dj3cMKKKc3qwGr4DFbiCU,8338
|
332
333
|
torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
|
333
|
-
torch_geometric/nn/attention/__init__.py,sha256=
|
334
|
+
torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
|
334
335
|
torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
|
335
336
|
torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
|
337
|
+
torch_geometric/nn/attention/sgformer.py,sha256=QT_3bVOPsqPY0fWVdwyr8E0LL9u9TwC_sv0_2EggdpA,2893
|
336
338
|
torch_geometric/nn/conv/__init__.py,sha256=37zTdt0gfSAUPMtwXjZg5mWx_itojJVFNODYR1h1ch0,3515
|
337
339
|
torch_geometric/nn/conv/agnn_conv.py,sha256=5nEPLx_BBHcDaO6HWzLuHfXc0Yd_reKynAOH0Iq09lU,3077
|
338
340
|
torch_geometric/nn/conv/antisymmetric_conv.py,sha256=dhA6sCETy1jlXReYJZBSyToOcL_mZ1wL10fMIb8Ppuw,4387
|
@@ -424,7 +426,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
424
426
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
425
427
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
426
428
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
427
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
429
|
+
torch_geometric/nn/models/__init__.py,sha256=gmBRXrbjkxLv_g0hfI87bSXMgwFoIo3XgslpGGdMs3g,2197
|
428
430
|
torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
|
429
431
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
430
432
|
torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
|
@@ -456,6 +458,7 @@ torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3W
|
|
456
458
|
torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
|
457
459
|
torch_geometric/nn/models/rev_gnn.py,sha256=1b6wU-6YTuLsWn5p8c5LXQm2KugEAVcEYJKZbWTDvgQ,11796
|
458
460
|
torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O3hDOZM,16656
|
461
|
+
torch_geometric/nn/models/sgformer.py,sha256=QacmFjlTSUgk6T5Dk4NfG9zmxHZqr0ggx1ElNrPMJNc,5878
|
459
462
|
torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6lASrKwUEAI,9841
|
460
463
|
torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
|
461
464
|
torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
|
@@ -630,6 +633,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
630
633
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
631
634
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
632
635
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
633
|
-
pyg_nightly-2.7.0.
|
634
|
-
pyg_nightly-2.7.0.
|
635
|
-
pyg_nightly-2.7.0.
|
636
|
+
pyg_nightly-2.7.0.dev20250216.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
637
|
+
pyg_nightly-2.7.0.dev20250216.dist-info/METADATA,sha256=nrJ478w_yYVSpIyqlfR5lxdYTu6NOIzCv_atkMY5w5w,62999
|
638
|
+
pyg_nightly-2.7.0.dev20250216.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -7,6 +7,7 @@ from ._compile import compile, is_compiling
|
|
7
7
|
from ._onnx import is_in_onnx_export
|
8
8
|
from .index import Index
|
9
9
|
from .edge_index import EdgeIndex
|
10
|
+
from .hash_tensor import HashTensor
|
10
11
|
from .seed import seed_everything
|
11
12
|
from .home import get_home_dir, set_home_dir
|
12
13
|
from .device import is_mps_available, is_xpu_available, device
|
@@ -30,11 +31,12 @@ from .lazy_loader import LazyLoader
|
|
30
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
31
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
32
33
|
|
33
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250216'
|
34
35
|
|
35
36
|
__all__ = [
|
36
37
|
'Index',
|
37
38
|
'EdgeIndex',
|
39
|
+
'HashTensor',
|
38
40
|
'seed_everything',
|
39
41
|
'get_home_dir',
|
40
42
|
'set_home_dir',
|
@@ -67,4 +69,5 @@ if torch_geometric.typing.WITH_PT24:
|
|
67
69
|
EdgeIndex,
|
68
70
|
torch_geometric.edge_index.SortOrder,
|
69
71
|
torch_geometric.edge_index.CatMetadata,
|
72
|
+
HashTensor,
|
70
73
|
])
|
@@ -0,0 +1,133 @@
|
|
1
|
+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.utils._pytree as pytree
|
5
|
+
import xxhash
|
6
|
+
from torch import Tensor
|
7
|
+
|
8
|
+
from torch_geometric.typing import CPUHashMap, CUDAHashMap
|
9
|
+
|
10
|
+
HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
|
11
|
+
|
12
|
+
|
13
|
+
def as_key_tensor(
|
14
|
+
key: Any,
|
15
|
+
*,
|
16
|
+
device: Optional[torch.device] = None,
|
17
|
+
) -> Tensor:
|
18
|
+
try:
|
19
|
+
key = torch.as_tensor(key, device=device)
|
20
|
+
except Exception:
|
21
|
+
key = torch.tensor([
|
22
|
+
xxhash.xxh64(item).intdigest() & 0x7FFFFFFFFFFFFFFF for item in key
|
23
|
+
], dtype=torch.int64, device=device)
|
24
|
+
|
25
|
+
if key.element_size() == 1:
|
26
|
+
key = key.view(torch.uint8)
|
27
|
+
elif key.element_size() == 2:
|
28
|
+
key = key.view(torch.int16)
|
29
|
+
elif key.element_size() == 4:
|
30
|
+
key = key.view(torch.int32)
|
31
|
+
elif key.element_size() == 8:
|
32
|
+
key = key.view(torch.int64)
|
33
|
+
else:
|
34
|
+
raise ValueError(f"Received invalid dtype '{key.dtype}' with "
|
35
|
+
f"{key.element_size()} bytes")
|
36
|
+
|
37
|
+
return key
|
38
|
+
|
39
|
+
|
40
|
+
class HashTensor(Tensor):
|
41
|
+
_map: Union[Tensor, CPUHashMap, CUDAHashMap]
|
42
|
+
_value: Optional[Tensor]
|
43
|
+
|
44
|
+
@staticmethod
|
45
|
+
def __new__(
|
46
|
+
cls: Type,
|
47
|
+
key: Any,
|
48
|
+
value: Optional[Any] = None,
|
49
|
+
*,
|
50
|
+
dtype: Optional[torch.dtype] = None,
|
51
|
+
device: Optional[torch.device] = None,
|
52
|
+
) -> 'HashTensor':
|
53
|
+
|
54
|
+
if value is not None:
|
55
|
+
value = torch.as_tensor(value, dtype=dtype, device=device)
|
56
|
+
device = value.device
|
57
|
+
|
58
|
+
key = as_key_tensor(key, device=device)
|
59
|
+
device = key.device
|
60
|
+
|
61
|
+
if key.dim() != 1:
|
62
|
+
raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
|
63
|
+
f"one-dimensional (got {key.dim()} dimensions)")
|
64
|
+
|
65
|
+
if not key.is_contiguous():
|
66
|
+
raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
|
67
|
+
f"contiguous")
|
68
|
+
|
69
|
+
if value is not None:
|
70
|
+
if key.device != value.device:
|
71
|
+
raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
|
72
|
+
f"are expected to be on the same device (got "
|
73
|
+
f"'{key.device}' and '{value.device}')")
|
74
|
+
|
75
|
+
if key.numel() != value.size(0):
|
76
|
+
raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
|
77
|
+
f"are expected to have the same size in the "
|
78
|
+
f"first dimension (got {key.size(0)} and "
|
79
|
+
f"{value.size(0)})")
|
80
|
+
|
81
|
+
dtype = value.dtype
|
82
|
+
size = value.size()
|
83
|
+
stride = value.stride()
|
84
|
+
layout = value.layout
|
85
|
+
requires_grad = value.requires_grad
|
86
|
+
else:
|
87
|
+
dtype = dtype or torch.int64
|
88
|
+
size = torch.Size([key.numel()])
|
89
|
+
stride = (1, )
|
90
|
+
layout = torch.strided
|
91
|
+
requires_grad = False
|
92
|
+
|
93
|
+
out = Tensor._make_wrapper_subclass( # type: ignore
|
94
|
+
cls,
|
95
|
+
size=size,
|
96
|
+
strides=stride,
|
97
|
+
dtype=dtype,
|
98
|
+
device=device,
|
99
|
+
layout=layout,
|
100
|
+
requires_grad=requires_grad,
|
101
|
+
)
|
102
|
+
|
103
|
+
out._value = value
|
104
|
+
|
105
|
+
return out
|
106
|
+
|
107
|
+
def as_tensor(self) -> Tensor:
|
108
|
+
r"""Zero-copies the :class:`HashTensor` representation back to a
|
109
|
+
:class:`torch.Tensor` representation.
|
110
|
+
"""
|
111
|
+
if self._value is not None:
|
112
|
+
return self._value
|
113
|
+
return torch.arange(self.size(0), dtype=self.dtype, device=self.device)
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def __torch_dispatch__(
|
117
|
+
cls: Type,
|
118
|
+
func: Callable[..., Any],
|
119
|
+
types: Iterable[Type[Any]],
|
120
|
+
args: Iterable[Tuple[Any, ...]] = (),
|
121
|
+
kwargs: Optional[Dict[Any, Any]] = None,
|
122
|
+
) -> Any:
|
123
|
+
# Hold a number of `HANDLED_FUNCTIONS` that implement specific
|
124
|
+
# functions for valid `HashTensor` routines.
|
125
|
+
if func in HANDLED_FUNCTIONS:
|
126
|
+
return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
|
127
|
+
|
128
|
+
# For all other PyTorch functions, we treat them as vanilla tensors.
|
129
|
+
args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args)
|
130
|
+
if kwargs is not None:
|
131
|
+
kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(),
|
132
|
+
kwargs)
|
133
|
+
return func(*args, **(kwargs or {}))
|
@@ -0,0 +1,78 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
|
5
|
+
class SGFormerAttention(torch.nn.Module):
|
6
|
+
r"""The simple global attention mechanism from the
|
7
|
+
`"SGFormer: Simplifying and Empowering Transformers for
|
8
|
+
Large-Graph Representations"
|
9
|
+
<https://arxiv.org/abs/2306.10759>`_ paper.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
channels (int): Size of each input sample.
|
13
|
+
heads (int, optional): Number of parallel attention heads.
|
14
|
+
(default: :obj:`1.`)
|
15
|
+
head_channels (int, optional): Size of each attention head.
|
16
|
+
(default: :obj:`64.`)
|
17
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
18
|
+
and value in the self attention. (default: :obj:`False`)
|
19
|
+
"""
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
channels: int,
|
23
|
+
heads: int = 1,
|
24
|
+
head_channels: int = 64,
|
25
|
+
qkv_bias: bool = False,
|
26
|
+
) -> None:
|
27
|
+
super().__init__()
|
28
|
+
assert channels % heads == 0
|
29
|
+
assert heads == 1, 'The number of heads are fixed as 1.'
|
30
|
+
if head_channels is None:
|
31
|
+
head_channels = channels // heads
|
32
|
+
|
33
|
+
self.heads = heads
|
34
|
+
self.head_channels = head_channels
|
35
|
+
|
36
|
+
inner_channels = head_channels * heads
|
37
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
38
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
39
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
40
|
+
|
41
|
+
def forward(self, x: Tensor) -> Tensor:
|
42
|
+
# feature transformation
|
43
|
+
qs = self.q(x).reshape(-1, self.heads, self.head_channels)
|
44
|
+
ks = self.k(x).reshape(-1, self.heads, self.head_channels)
|
45
|
+
vs = self.v(x).reshape(-1, self.heads, self.head_channels)
|
46
|
+
|
47
|
+
# normalize input
|
48
|
+
qs = qs / torch.norm(qs, p=2) # [N, H, M]
|
49
|
+
ks = ks / torch.norm(ks, p=2) # [L, H, M]
|
50
|
+
N = qs.shape[0]
|
51
|
+
|
52
|
+
# numerator
|
53
|
+
kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
|
54
|
+
attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
|
55
|
+
attention_num += N * vs
|
56
|
+
|
57
|
+
# denominator
|
58
|
+
all_ones = torch.ones([ks.shape[0]]).to(ks.device)
|
59
|
+
ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
|
60
|
+
attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H]
|
61
|
+
|
62
|
+
# attentive aggregated results
|
63
|
+
attention_normalizer = torch.unsqueeze(
|
64
|
+
attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1]
|
65
|
+
attention_normalizer += torch.ones_like(attention_normalizer) * N
|
66
|
+
attn_output = attention_num / attention_normalizer # [N, H, D]
|
67
|
+
|
68
|
+
return attn_output.mean(dim=1)
|
69
|
+
|
70
|
+
def reset_parameters(self):
|
71
|
+
self.q.reset_parameters()
|
72
|
+
self.k.reset_parameters()
|
73
|
+
self.v.reset_parameters()
|
74
|
+
|
75
|
+
def __repr__(self) -> str:
|
76
|
+
return (f'{self.__class__.__name__}('
|
77
|
+
f'heads={self.heads}, '
|
78
|
+
f'head_channels={self.head_channels})')
|
@@ -32,6 +32,7 @@ from .g_retriever import GRetriever
|
|
32
32
|
from .git_mol import GITMol
|
33
33
|
from .molecule_gpt import MoleculeGPT
|
34
34
|
from .glem import GLEM
|
35
|
+
from .sgformer import SGFormer
|
35
36
|
# Deprecated:
|
36
37
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
37
38
|
captum_output_to_dicts)
|
@@ -82,4 +83,5 @@ __all__ = classes = [
|
|
82
83
|
'GITMol',
|
83
84
|
'MoleculeGPT',
|
84
85
|
'GLEM',
|
86
|
+
'SGFormer',
|
85
87
|
]
|
@@ -0,0 +1,190 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
from torch_geometric.nn.attention import SGFormerAttention
|
5
|
+
from torch_geometric.nn.conv import GCNConv
|
6
|
+
|
7
|
+
|
8
|
+
class GraphModule(torch.nn.Module):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
in_channels,
|
12
|
+
hidden_channels,
|
13
|
+
num_layers=2,
|
14
|
+
dropout=0.5,
|
15
|
+
):
|
16
|
+
super().__init__()
|
17
|
+
|
18
|
+
self.convs = torch.nn.ModuleList()
|
19
|
+
self.fcs = torch.nn.ModuleList()
|
20
|
+
self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
|
21
|
+
|
22
|
+
self.bns = torch.nn.ModuleList()
|
23
|
+
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
|
24
|
+
for _ in range(num_layers):
|
25
|
+
self.convs.append(GCNConv(hidden_channels, hidden_channels))
|
26
|
+
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
|
27
|
+
|
28
|
+
self.dropout = dropout
|
29
|
+
self.activation = F.relu
|
30
|
+
|
31
|
+
def reset_parameters(self):
|
32
|
+
for conv in self.convs:
|
33
|
+
conv.reset_parameters()
|
34
|
+
for bn in self.bns:
|
35
|
+
bn.reset_parameters()
|
36
|
+
for fc in self.fcs:
|
37
|
+
fc.reset_parameters()
|
38
|
+
|
39
|
+
def forward(self, x, edge_index):
|
40
|
+
x = self.fcs[0](x)
|
41
|
+
x = self.bns[0](x)
|
42
|
+
x = self.activation(x)
|
43
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
44
|
+
last_x = x
|
45
|
+
|
46
|
+
for i, conv in enumerate(self.convs):
|
47
|
+
x = conv(x, edge_index)
|
48
|
+
x = self.bns[i + 1](x)
|
49
|
+
x = self.activation(x)
|
50
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
51
|
+
x = x + last_x
|
52
|
+
return x
|
53
|
+
|
54
|
+
|
55
|
+
class SGModule(torch.nn.Module):
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
in_channels,
|
59
|
+
hidden_channels,
|
60
|
+
num_layers=2,
|
61
|
+
num_heads=1,
|
62
|
+
dropout=0.5,
|
63
|
+
):
|
64
|
+
super().__init__()
|
65
|
+
|
66
|
+
self.attns = torch.nn.ModuleList()
|
67
|
+
self.fcs = torch.nn.ModuleList()
|
68
|
+
self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
|
69
|
+
self.bns = torch.nn.ModuleList()
|
70
|
+
self.bns.append(torch.nn.LayerNorm(hidden_channels))
|
71
|
+
for _ in range(num_layers):
|
72
|
+
self.attns.append(
|
73
|
+
SGFormerAttention(hidden_channels, num_heads, hidden_channels))
|
74
|
+
self.bns.append(torch.nn.LayerNorm(hidden_channels))
|
75
|
+
|
76
|
+
self.dropout = dropout
|
77
|
+
self.activation = F.relu
|
78
|
+
|
79
|
+
def reset_parameters(self):
|
80
|
+
for attn in self.attns:
|
81
|
+
attn.reset_parameters()
|
82
|
+
for bn in self.bns:
|
83
|
+
bn.reset_parameters()
|
84
|
+
for fc in self.fcs:
|
85
|
+
fc.reset_parameters()
|
86
|
+
|
87
|
+
def forward(self, x):
|
88
|
+
layer_ = []
|
89
|
+
|
90
|
+
# input MLP layer
|
91
|
+
x = self.fcs[0](x)
|
92
|
+
x = self.bns[0](x)
|
93
|
+
x = self.activation(x)
|
94
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
95
|
+
|
96
|
+
# store as residual link
|
97
|
+
layer_.append(x)
|
98
|
+
|
99
|
+
for i, attn in enumerate(self.attns):
|
100
|
+
x = attn(x)
|
101
|
+
x = (x + layer_[i]) / 2.
|
102
|
+
x = self.bns[i + 1](x)
|
103
|
+
x = self.activation(x)
|
104
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
105
|
+
layer_.append(x)
|
106
|
+
|
107
|
+
return x
|
108
|
+
|
109
|
+
|
110
|
+
class SGFormer(torch.nn.Module):
|
111
|
+
r"""The sgformer module from the
|
112
|
+
`"SGFormer: Simplifying and Empowering Transformers for
|
113
|
+
Large-Graph Representations"
|
114
|
+
<https://arxiv.org/abs/2306.10759>`_ paper.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
in_channels (int): Input channels.
|
118
|
+
hidden_channels (int): Hidden channels.
|
119
|
+
out_channels (int): Output channels.
|
120
|
+
trans_num_layers (int): The number of layers for all-pair attention.
|
121
|
+
(default: :obj:`2`)
|
122
|
+
trans_num_heads (int): The number of heads for attention.
|
123
|
+
(default: :obj:`1`)
|
124
|
+
trans_dropout (float): Global dropout rate.
|
125
|
+
(default: :obj:`0.5`)
|
126
|
+
gnn_num_layers (int): The number of layers for GNN.
|
127
|
+
(default: :obj:`3`)
|
128
|
+
gnn_dropout (float): GNN dropout rate.
|
129
|
+
(default: :obj:`0.5`)
|
130
|
+
graph_weight (float): The weight balance global and gnn module.
|
131
|
+
(default: :obj:`0.5`)
|
132
|
+
aggregate (str): Aggregate type.
|
133
|
+
(default: :obj:`add`)
|
134
|
+
"""
|
135
|
+
def __init__(
|
136
|
+
self,
|
137
|
+
in_channels: int,
|
138
|
+
hidden_channels: int,
|
139
|
+
out_channels: int,
|
140
|
+
trans_num_layers: int = 2,
|
141
|
+
trans_num_heads: int = 1,
|
142
|
+
trans_dropout: float = 0.5,
|
143
|
+
gnn_num_layers: int = 3,
|
144
|
+
gnn_dropout: float = 0.5,
|
145
|
+
graph_weight: float = 0.5,
|
146
|
+
aggregate: str = 'add',
|
147
|
+
):
|
148
|
+
super().__init__()
|
149
|
+
self.trans_conv = SGModule(
|
150
|
+
in_channels,
|
151
|
+
hidden_channels,
|
152
|
+
trans_num_layers,
|
153
|
+
trans_num_heads,
|
154
|
+
trans_dropout,
|
155
|
+
)
|
156
|
+
self.graph_conv = GraphModule(
|
157
|
+
in_channels,
|
158
|
+
hidden_channels,
|
159
|
+
gnn_num_layers,
|
160
|
+
gnn_dropout,
|
161
|
+
)
|
162
|
+
self.graph_weight = graph_weight
|
163
|
+
|
164
|
+
self.aggregate = aggregate
|
165
|
+
|
166
|
+
if aggregate == 'add':
|
167
|
+
self.fc = torch.nn.Linear(hidden_channels, out_channels)
|
168
|
+
elif aggregate == 'cat':
|
169
|
+
self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)
|
170
|
+
else:
|
171
|
+
raise ValueError(f'Invalid aggregate type:{aggregate}')
|
172
|
+
|
173
|
+
self.params1 = list(self.trans_conv.parameters())
|
174
|
+
self.params2 = list(self.graph_conv.parameters())
|
175
|
+
self.params2.extend(list(self.fc.parameters()))
|
176
|
+
|
177
|
+
def reset_parameters(self) -> None:
|
178
|
+
self.trans_conv.reset_parameters()
|
179
|
+
self.graph_conv.reset_parameters()
|
180
|
+
self.fc.reset_parameters()
|
181
|
+
|
182
|
+
def forward(self, x, edge_index):
|
183
|
+
x1 = self.trans_conv(x)
|
184
|
+
x2 = self.graph_conv(x, edge_index)
|
185
|
+
if self.aggregate == 'add':
|
186
|
+
x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
|
187
|
+
else:
|
188
|
+
x = torch.cat((x1, x2), dim=1)
|
189
|
+
x = self.fc(x)
|
190
|
+
return F.log_softmax(x, dim=-1)
|
torch_geometric/typing.py
CHANGED
@@ -9,6 +9,11 @@ import numpy as np
|
|
9
9
|
import torch
|
10
10
|
from torch import Tensor
|
11
11
|
|
12
|
+
try:
|
13
|
+
from typing import TypeAlias # type: ignore
|
14
|
+
except ImportError:
|
15
|
+
from typing_extensions import TypeAlias
|
16
|
+
|
12
17
|
WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
|
13
18
|
WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
|
14
19
|
WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
|
@@ -64,6 +69,16 @@ try:
|
|
64
69
|
pyg_lib.sampler.neighbor_sample).parameters)
|
65
70
|
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
|
66
71
|
pyg_lib.sampler.neighbor_sample).parameters)
|
72
|
+
try:
|
73
|
+
torch.classes.pyg.CPUHashMap
|
74
|
+
WITH_CPU_HASH_MAP = True
|
75
|
+
except Exception:
|
76
|
+
WITH_CPU_HASH_MAP = False
|
77
|
+
try:
|
78
|
+
torch.classes.pyg.CUDAHashMap
|
79
|
+
WITH_CUDA_HASH_MAP = True
|
80
|
+
except Exception:
|
81
|
+
WITH_CUDA_HASH_MAP = False
|
67
82
|
except Exception as e:
|
68
83
|
if not isinstance(e, ImportError): # pragma: no cover
|
69
84
|
warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
|
@@ -78,6 +93,32 @@ except Exception as e:
|
|
78
93
|
WITH_METIS = False
|
79
94
|
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
|
80
95
|
WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
|
96
|
+
WITH_CPU_HASH_MAP = False
|
97
|
+
WITH_CUDA_HASH_MAP = False
|
98
|
+
|
99
|
+
if WITH_CPU_HASH_MAP:
|
100
|
+
CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap
|
101
|
+
else:
|
102
|
+
|
103
|
+
class CPUHashMap: # type: ignore
|
104
|
+
def __init__(self, key: Tensor) -> None:
|
105
|
+
raise ImportError("'CPUHashMap' requires 'pyg-lib'")
|
106
|
+
|
107
|
+
def get(self, query: Tensor) -> Tensor:
|
108
|
+
raise ImportError("'CPUHashMap' requires 'pyg-lib'")
|
109
|
+
|
110
|
+
|
111
|
+
if WITH_CUDA_HASH_MAP:
|
112
|
+
CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap
|
113
|
+
else:
|
114
|
+
|
115
|
+
class CUDAHashMap: # type: ignore
|
116
|
+
def __init__(self, key: Tensor) -> None:
|
117
|
+
raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
|
118
|
+
|
119
|
+
def get(self, query: Tensor) -> Tensor:
|
120
|
+
raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
|
121
|
+
|
81
122
|
|
82
123
|
try:
|
83
124
|
import torch_scatter # noqa
|
File without changes
|