pyg-nightly 2.7.0.dev20250214__py3-none-any.whl → 2.7.0.dev20250216__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250214
3
+ Version: 2.7.0.dev20250216
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -22,6 +22,7 @@ Requires-Dist: psutil>=5.8.0
22
22
  Requires-Dist: pyparsing
23
23
  Requires-Dist: requests
24
24
  Requires-Dist: tqdm
25
+ Requires-Dist: xxhash
25
26
  Requires-Dist: matplotlib ; extra == "benchmark"
26
27
  Requires-Dist: networkx ; extra == "benchmark"
27
28
  Requires-Dist: pandas ; extra == "benchmark"
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=wEIwjfiE7YtSKe1hUelXwhBp2yBtwx2u2gRooK6Gb_s,1904
1
+ torch_geometric/__init__.py,sha256=QlYg3d7vRWC0k_OzPCmAanyTUBiTvikPuzM9feT0REk,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -9,6 +9,7 @@ torch_geometric/deprecation.py,sha256=dWRymDIUkUVI2MeEmBG5WF4R6jObZeseSBV9G6FNfj
9
9
  torch_geometric/device.py,sha256=tU5-_lBNVbVHl_kUmWPwiG5mQ1pyapwMF4JkmtNN3MM,1224
10
10
  torch_geometric/edge_index.py,sha256=BsLh5tOZRjjSYDkjqOFAdBuvMaDg7EWaaLELYsUL0Z8,70048
11
11
  torch_geometric/experimental.py,sha256=JbtNNEXjFGI8hZ9raM6-qrZURP6Z5nlDK8QicZUIbz0,4756
12
+ torch_geometric/hash_tensor.py,sha256=VBuz9n16ouSk2u4DifpCIW5MXiuoH5UGO4rPj-Celjw,4418
12
13
  torch_geometric/home.py,sha256=EV54B4Dmiv61GDbkCwtCfWGWJ4eFGwZ8s3KOgGjwYgY,790
13
14
  torch_geometric/index.py,sha256=9ChzWFCwj2slNcVBOgfV-wQn-KscJe_y7502w-Vf76w,24045
14
15
  torch_geometric/inspector.py,sha256=nKi5o4Mn6xsG0Ex1GudTEQt_EqnF9mcMqGtp7Shh9sQ,19336
@@ -18,7 +19,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
18
19
  torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
19
20
  torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
20
21
  torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
21
- torch_geometric/typing.py,sha256=SzuZPdeYLt7_lFUHHcAlaggxqLA0VZ_kx8s0iy_tMIw,14429
22
+ torch_geometric/typing.py,sha256=mtSM6QhCsrohstnyvqMuxEIajCkhcvkQKOU4uVu-nDs,15596
22
23
  torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
23
24
  torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
24
25
  torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
@@ -330,9 +331,10 @@ torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQ
330
331
  torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
331
332
  torch_geometric/nn/aggr/utils.py,sha256=CLJ-ZrVWYIOBpdhQBLAz94dj3cMKKKc3qwGr4DFbiCU,8338
332
333
  torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
333
- torch_geometric/nn/attention/__init__.py,sha256=1lCB7zh7uM6FkpW81S9U4CvxTwpCkz59KatPTIE9UmA,127
334
+ torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
334
335
  torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
335
336
  torch_geometric/nn/attention/qformer.py,sha256=7J-pWm_vpumK38IC-iCBz4oqL-BEIofEIxJ0wfjWq9A,2338
337
+ torch_geometric/nn/attention/sgformer.py,sha256=QT_3bVOPsqPY0fWVdwyr8E0LL9u9TwC_sv0_2EggdpA,2893
336
338
  torch_geometric/nn/conv/__init__.py,sha256=37zTdt0gfSAUPMtwXjZg5mWx_itojJVFNODYR1h1ch0,3515
337
339
  torch_geometric/nn/conv/agnn_conv.py,sha256=5nEPLx_BBHcDaO6HWzLuHfXc0Yd_reKynAOH0Iq09lU,3077
338
340
  torch_geometric/nn/conv/antisymmetric_conv.py,sha256=dhA6sCETy1jlXReYJZBSyToOcL_mZ1wL10fMIb8Ppuw,4387
@@ -424,7 +426,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
424
426
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
425
427
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
426
428
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
427
- torch_geometric/nn/models/__init__.py,sha256=vWMKzGBVxA1Fm0uGDLnH4jzYgfhK34CQTRJ-xi5pf5k,2150
429
+ torch_geometric/nn/models/__init__.py,sha256=gmBRXrbjkxLv_g0hfI87bSXMgwFoIo3XgslpGGdMs3g,2197
428
430
  torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
429
431
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
430
432
  torch_geometric/nn/models/basic_gnn.py,sha256=PGa0RUMyvrNy_5yRI2jX_zwPsmZXwOQWfsWvxOiHsSk,31225
@@ -456,6 +458,7 @@ torch_geometric/nn/models/re_net.py,sha256=pz66q5b5BoGDNVQvpEGS2RGoeKvpjkYAv9r3W
456
458
  torch_geometric/nn/models/rect.py,sha256=2F3XyyvHTAEuqfJpiNB5M8pSGy738LhPiom5I-SDWqM,2808
457
459
  torch_geometric/nn/models/rev_gnn.py,sha256=1b6wU-6YTuLsWn5p8c5LXQm2KugEAVcEYJKZbWTDvgQ,11796
458
460
  torch_geometric/nn/models/schnet.py,sha256=0aaHrVtxApdvn3RHCGLQJW1MbIb--CSYUrx9O3hDOZM,16656
461
+ torch_geometric/nn/models/sgformer.py,sha256=QacmFjlTSUgk6T5Dk4NfG9zmxHZqr0ggx1ElNrPMJNc,5878
459
462
  torch_geometric/nn/models/signed_gcn.py,sha256=J40CnedFIqtKI1LhW1ITSEFRbA_XiJZL6lASrKwUEAI,9841
460
463
  torch_geometric/nn/models/tgn.py,sha256=kEGdfLJybkbMT4UMoAh2nCzfX3_nDjfm1cicuPHEwAM,11878
461
464
  torch_geometric/nn/models/visnet.py,sha256=97OFMCsPDEI5BCSi7RhoRcU2CNRp7zck2tEzrltFZj4,43192
@@ -630,6 +633,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
630
633
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
631
634
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
632
635
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
633
- pyg_nightly-2.7.0.dev20250214.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
634
- pyg_nightly-2.7.0.dev20250214.dist-info/METADATA,sha256=7oAMP5u6qz3XuWwE8lms2zjkk_lGmt3Q7-3FswFCbUk,62977
635
- pyg_nightly-2.7.0.dev20250214.dist-info/RECORD,,
636
+ pyg_nightly-2.7.0.dev20250216.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
637
+ pyg_nightly-2.7.0.dev20250216.dist-info/METADATA,sha256=nrJ478w_yYVSpIyqlfR5lxdYTu6NOIzCv_atkMY5w5w,62999
638
+ pyg_nightly-2.7.0.dev20250216.dist-info/RECORD,,
@@ -7,6 +7,7 @@ from ._compile import compile, is_compiling
7
7
  from ._onnx import is_in_onnx_export
8
8
  from .index import Index
9
9
  from .edge_index import EdgeIndex
10
+ from .hash_tensor import HashTensor
10
11
  from .seed import seed_everything
11
12
  from .home import get_home_dir, set_home_dir
12
13
  from .device import is_mps_available, is_xpu_available, device
@@ -30,11 +31,12 @@ from .lazy_loader import LazyLoader
30
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
33
 
33
- __version__ = '2.7.0.dev20250214'
34
+ __version__ = '2.7.0.dev20250216'
34
35
 
35
36
  __all__ = [
36
37
  'Index',
37
38
  'EdgeIndex',
39
+ 'HashTensor',
38
40
  'seed_everything',
39
41
  'get_home_dir',
40
42
  'set_home_dir',
@@ -67,4 +69,5 @@ if torch_geometric.typing.WITH_PT24:
67
69
  EdgeIndex,
68
70
  torch_geometric.edge_index.SortOrder,
69
71
  torch_geometric.edge_index.CatMetadata,
72
+ HashTensor,
70
73
  ])
@@ -0,0 +1,133 @@
1
+ from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
2
+
3
+ import torch
4
+ import torch.utils._pytree as pytree
5
+ import xxhash
6
+ from torch import Tensor
7
+
8
+ from torch_geometric.typing import CPUHashMap, CUDAHashMap
9
+
10
+ HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
11
+
12
+
13
+ def as_key_tensor(
14
+ key: Any,
15
+ *,
16
+ device: Optional[torch.device] = None,
17
+ ) -> Tensor:
18
+ try:
19
+ key = torch.as_tensor(key, device=device)
20
+ except Exception:
21
+ key = torch.tensor([
22
+ xxhash.xxh64(item).intdigest() & 0x7FFFFFFFFFFFFFFF for item in key
23
+ ], dtype=torch.int64, device=device)
24
+
25
+ if key.element_size() == 1:
26
+ key = key.view(torch.uint8)
27
+ elif key.element_size() == 2:
28
+ key = key.view(torch.int16)
29
+ elif key.element_size() == 4:
30
+ key = key.view(torch.int32)
31
+ elif key.element_size() == 8:
32
+ key = key.view(torch.int64)
33
+ else:
34
+ raise ValueError(f"Received invalid dtype '{key.dtype}' with "
35
+ f"{key.element_size()} bytes")
36
+
37
+ return key
38
+
39
+
40
+ class HashTensor(Tensor):
41
+ _map: Union[Tensor, CPUHashMap, CUDAHashMap]
42
+ _value: Optional[Tensor]
43
+
44
+ @staticmethod
45
+ def __new__(
46
+ cls: Type,
47
+ key: Any,
48
+ value: Optional[Any] = None,
49
+ *,
50
+ dtype: Optional[torch.dtype] = None,
51
+ device: Optional[torch.device] = None,
52
+ ) -> 'HashTensor':
53
+
54
+ if value is not None:
55
+ value = torch.as_tensor(value, dtype=dtype, device=device)
56
+ device = value.device
57
+
58
+ key = as_key_tensor(key, device=device)
59
+ device = key.device
60
+
61
+ if key.dim() != 1:
62
+ raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
63
+ f"one-dimensional (got {key.dim()} dimensions)")
64
+
65
+ if not key.is_contiguous():
66
+ raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
67
+ f"contiguous")
68
+
69
+ if value is not None:
70
+ if key.device != value.device:
71
+ raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
72
+ f"are expected to be on the same device (got "
73
+ f"'{key.device}' and '{value.device}')")
74
+
75
+ if key.numel() != value.size(0):
76
+ raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
77
+ f"are expected to have the same size in the "
78
+ f"first dimension (got {key.size(0)} and "
79
+ f"{value.size(0)})")
80
+
81
+ dtype = value.dtype
82
+ size = value.size()
83
+ stride = value.stride()
84
+ layout = value.layout
85
+ requires_grad = value.requires_grad
86
+ else:
87
+ dtype = dtype or torch.int64
88
+ size = torch.Size([key.numel()])
89
+ stride = (1, )
90
+ layout = torch.strided
91
+ requires_grad = False
92
+
93
+ out = Tensor._make_wrapper_subclass( # type: ignore
94
+ cls,
95
+ size=size,
96
+ strides=stride,
97
+ dtype=dtype,
98
+ device=device,
99
+ layout=layout,
100
+ requires_grad=requires_grad,
101
+ )
102
+
103
+ out._value = value
104
+
105
+ return out
106
+
107
+ def as_tensor(self) -> Tensor:
108
+ r"""Zero-copies the :class:`HashTensor` representation back to a
109
+ :class:`torch.Tensor` representation.
110
+ """
111
+ if self._value is not None:
112
+ return self._value
113
+ return torch.arange(self.size(0), dtype=self.dtype, device=self.device)
114
+
115
+ @classmethod
116
+ def __torch_dispatch__(
117
+ cls: Type,
118
+ func: Callable[..., Any],
119
+ types: Iterable[Type[Any]],
120
+ args: Iterable[Tuple[Any, ...]] = (),
121
+ kwargs: Optional[Dict[Any, Any]] = None,
122
+ ) -> Any:
123
+ # Hold a number of `HANDLED_FUNCTIONS` that implement specific
124
+ # functions for valid `HashTensor` routines.
125
+ if func in HANDLED_FUNCTIONS:
126
+ return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
127
+
128
+ # For all other PyTorch functions, we treat them as vanilla tensors.
129
+ args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args)
130
+ if kwargs is not None:
131
+ kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(),
132
+ kwargs)
133
+ return func(*args, **(kwargs or {}))
@@ -1,7 +1,9 @@
1
1
  from .performer import PerformerAttention
2
2
  from .qformer import QFormer
3
+ from .sgformer import SGFormerAttention
3
4
 
4
5
  __all__ = [
5
6
  'PerformerAttention',
6
7
  'QFormer',
8
+ 'SGFormerAttention',
7
9
  ]
@@ -0,0 +1,78 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ class SGFormerAttention(torch.nn.Module):
6
+ r"""The simple global attention mechanism from the
7
+ `"SGFormer: Simplifying and Empowering Transformers for
8
+ Large-Graph Representations"
9
+ <https://arxiv.org/abs/2306.10759>`_ paper.
10
+
11
+ Args:
12
+ channels (int): Size of each input sample.
13
+ heads (int, optional): Number of parallel attention heads.
14
+ (default: :obj:`1.`)
15
+ head_channels (int, optional): Size of each attention head.
16
+ (default: :obj:`64.`)
17
+ qkv_bias (bool, optional): If specified, add bias to query, key
18
+ and value in the self attention. (default: :obj:`False`)
19
+ """
20
+ def __init__(
21
+ self,
22
+ channels: int,
23
+ heads: int = 1,
24
+ head_channels: int = 64,
25
+ qkv_bias: bool = False,
26
+ ) -> None:
27
+ super().__init__()
28
+ assert channels % heads == 0
29
+ assert heads == 1, 'The number of heads are fixed as 1.'
30
+ if head_channels is None:
31
+ head_channels = channels // heads
32
+
33
+ self.heads = heads
34
+ self.head_channels = head_channels
35
+
36
+ inner_channels = head_channels * heads
37
+ self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
38
+ self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
39
+ self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
40
+
41
+ def forward(self, x: Tensor) -> Tensor:
42
+ # feature transformation
43
+ qs = self.q(x).reshape(-1, self.heads, self.head_channels)
44
+ ks = self.k(x).reshape(-1, self.heads, self.head_channels)
45
+ vs = self.v(x).reshape(-1, self.heads, self.head_channels)
46
+
47
+ # normalize input
48
+ qs = qs / torch.norm(qs, p=2) # [N, H, M]
49
+ ks = ks / torch.norm(ks, p=2) # [L, H, M]
50
+ N = qs.shape[0]
51
+
52
+ # numerator
53
+ kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
54
+ attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
55
+ attention_num += N * vs
56
+
57
+ # denominator
58
+ all_ones = torch.ones([ks.shape[0]]).to(ks.device)
59
+ ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
60
+ attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H]
61
+
62
+ # attentive aggregated results
63
+ attention_normalizer = torch.unsqueeze(
64
+ attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1]
65
+ attention_normalizer += torch.ones_like(attention_normalizer) * N
66
+ attn_output = attention_num / attention_normalizer # [N, H, D]
67
+
68
+ return attn_output.mean(dim=1)
69
+
70
+ def reset_parameters(self):
71
+ self.q.reset_parameters()
72
+ self.k.reset_parameters()
73
+ self.v.reset_parameters()
74
+
75
+ def __repr__(self) -> str:
76
+ return (f'{self.__class__.__name__}('
77
+ f'heads={self.heads}, '
78
+ f'head_channels={self.head_channels})')
@@ -32,6 +32,7 @@ from .g_retriever import GRetriever
32
32
  from .git_mol import GITMol
33
33
  from .molecule_gpt import MoleculeGPT
34
34
  from .glem import GLEM
35
+ from .sgformer import SGFormer
35
36
  # Deprecated:
36
37
  from torch_geometric.explain.algorithm.captum import (to_captum_input,
37
38
  captum_output_to_dicts)
@@ -82,4 +83,5 @@ __all__ = classes = [
82
83
  'GITMol',
83
84
  'MoleculeGPT',
84
85
  'GLEM',
86
+ 'SGFormer',
85
87
  ]
@@ -0,0 +1,190 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from torch_geometric.nn.attention import SGFormerAttention
5
+ from torch_geometric.nn.conv import GCNConv
6
+
7
+
8
+ class GraphModule(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels,
12
+ hidden_channels,
13
+ num_layers=2,
14
+ dropout=0.5,
15
+ ):
16
+ super().__init__()
17
+
18
+ self.convs = torch.nn.ModuleList()
19
+ self.fcs = torch.nn.ModuleList()
20
+ self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
21
+
22
+ self.bns = torch.nn.ModuleList()
23
+ self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
24
+ for _ in range(num_layers):
25
+ self.convs.append(GCNConv(hidden_channels, hidden_channels))
26
+ self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
27
+
28
+ self.dropout = dropout
29
+ self.activation = F.relu
30
+
31
+ def reset_parameters(self):
32
+ for conv in self.convs:
33
+ conv.reset_parameters()
34
+ for bn in self.bns:
35
+ bn.reset_parameters()
36
+ for fc in self.fcs:
37
+ fc.reset_parameters()
38
+
39
+ def forward(self, x, edge_index):
40
+ x = self.fcs[0](x)
41
+ x = self.bns[0](x)
42
+ x = self.activation(x)
43
+ x = F.dropout(x, p=self.dropout, training=self.training)
44
+ last_x = x
45
+
46
+ for i, conv in enumerate(self.convs):
47
+ x = conv(x, edge_index)
48
+ x = self.bns[i + 1](x)
49
+ x = self.activation(x)
50
+ x = F.dropout(x, p=self.dropout, training=self.training)
51
+ x = x + last_x
52
+ return x
53
+
54
+
55
+ class SGModule(torch.nn.Module):
56
+ def __init__(
57
+ self,
58
+ in_channels,
59
+ hidden_channels,
60
+ num_layers=2,
61
+ num_heads=1,
62
+ dropout=0.5,
63
+ ):
64
+ super().__init__()
65
+
66
+ self.attns = torch.nn.ModuleList()
67
+ self.fcs = torch.nn.ModuleList()
68
+ self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
69
+ self.bns = torch.nn.ModuleList()
70
+ self.bns.append(torch.nn.LayerNorm(hidden_channels))
71
+ for _ in range(num_layers):
72
+ self.attns.append(
73
+ SGFormerAttention(hidden_channels, num_heads, hidden_channels))
74
+ self.bns.append(torch.nn.LayerNorm(hidden_channels))
75
+
76
+ self.dropout = dropout
77
+ self.activation = F.relu
78
+
79
+ def reset_parameters(self):
80
+ for attn in self.attns:
81
+ attn.reset_parameters()
82
+ for bn in self.bns:
83
+ bn.reset_parameters()
84
+ for fc in self.fcs:
85
+ fc.reset_parameters()
86
+
87
+ def forward(self, x):
88
+ layer_ = []
89
+
90
+ # input MLP layer
91
+ x = self.fcs[0](x)
92
+ x = self.bns[0](x)
93
+ x = self.activation(x)
94
+ x = F.dropout(x, p=self.dropout, training=self.training)
95
+
96
+ # store as residual link
97
+ layer_.append(x)
98
+
99
+ for i, attn in enumerate(self.attns):
100
+ x = attn(x)
101
+ x = (x + layer_[i]) / 2.
102
+ x = self.bns[i + 1](x)
103
+ x = self.activation(x)
104
+ x = F.dropout(x, p=self.dropout, training=self.training)
105
+ layer_.append(x)
106
+
107
+ return x
108
+
109
+
110
+ class SGFormer(torch.nn.Module):
111
+ r"""The sgformer module from the
112
+ `"SGFormer: Simplifying and Empowering Transformers for
113
+ Large-Graph Representations"
114
+ <https://arxiv.org/abs/2306.10759>`_ paper.
115
+
116
+ Args:
117
+ in_channels (int): Input channels.
118
+ hidden_channels (int): Hidden channels.
119
+ out_channels (int): Output channels.
120
+ trans_num_layers (int): The number of layers for all-pair attention.
121
+ (default: :obj:`2`)
122
+ trans_num_heads (int): The number of heads for attention.
123
+ (default: :obj:`1`)
124
+ trans_dropout (float): Global dropout rate.
125
+ (default: :obj:`0.5`)
126
+ gnn_num_layers (int): The number of layers for GNN.
127
+ (default: :obj:`3`)
128
+ gnn_dropout (float): GNN dropout rate.
129
+ (default: :obj:`0.5`)
130
+ graph_weight (float): The weight balance global and gnn module.
131
+ (default: :obj:`0.5`)
132
+ aggregate (str): Aggregate type.
133
+ (default: :obj:`add`)
134
+ """
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ hidden_channels: int,
139
+ out_channels: int,
140
+ trans_num_layers: int = 2,
141
+ trans_num_heads: int = 1,
142
+ trans_dropout: float = 0.5,
143
+ gnn_num_layers: int = 3,
144
+ gnn_dropout: float = 0.5,
145
+ graph_weight: float = 0.5,
146
+ aggregate: str = 'add',
147
+ ):
148
+ super().__init__()
149
+ self.trans_conv = SGModule(
150
+ in_channels,
151
+ hidden_channels,
152
+ trans_num_layers,
153
+ trans_num_heads,
154
+ trans_dropout,
155
+ )
156
+ self.graph_conv = GraphModule(
157
+ in_channels,
158
+ hidden_channels,
159
+ gnn_num_layers,
160
+ gnn_dropout,
161
+ )
162
+ self.graph_weight = graph_weight
163
+
164
+ self.aggregate = aggregate
165
+
166
+ if aggregate == 'add':
167
+ self.fc = torch.nn.Linear(hidden_channels, out_channels)
168
+ elif aggregate == 'cat':
169
+ self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)
170
+ else:
171
+ raise ValueError(f'Invalid aggregate type:{aggregate}')
172
+
173
+ self.params1 = list(self.trans_conv.parameters())
174
+ self.params2 = list(self.graph_conv.parameters())
175
+ self.params2.extend(list(self.fc.parameters()))
176
+
177
+ def reset_parameters(self) -> None:
178
+ self.trans_conv.reset_parameters()
179
+ self.graph_conv.reset_parameters()
180
+ self.fc.reset_parameters()
181
+
182
+ def forward(self, x, edge_index):
183
+ x1 = self.trans_conv(x)
184
+ x2 = self.graph_conv(x, edge_index)
185
+ if self.aggregate == 'add':
186
+ x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
187
+ else:
188
+ x = torch.cat((x1, x2), dim=1)
189
+ x = self.fc(x)
190
+ return F.log_softmax(x, dim=-1)
torch_geometric/typing.py CHANGED
@@ -9,6 +9,11 @@ import numpy as np
9
9
  import torch
10
10
  from torch import Tensor
11
11
 
12
+ try:
13
+ from typing import TypeAlias # type: ignore
14
+ except ImportError:
15
+ from typing_extensions import TypeAlias
16
+
12
17
  WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
13
18
  WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
14
19
  WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
@@ -64,6 +69,16 @@ try:
64
69
  pyg_lib.sampler.neighbor_sample).parameters)
65
70
  WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
66
71
  pyg_lib.sampler.neighbor_sample).parameters)
72
+ try:
73
+ torch.classes.pyg.CPUHashMap
74
+ WITH_CPU_HASH_MAP = True
75
+ except Exception:
76
+ WITH_CPU_HASH_MAP = False
77
+ try:
78
+ torch.classes.pyg.CUDAHashMap
79
+ WITH_CUDA_HASH_MAP = True
80
+ except Exception:
81
+ WITH_CUDA_HASH_MAP = False
67
82
  except Exception as e:
68
83
  if not isinstance(e, ImportError): # pragma: no cover
69
84
  warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
@@ -78,6 +93,32 @@ except Exception as e:
78
93
  WITH_METIS = False
79
94
  WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
80
95
  WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
96
+ WITH_CPU_HASH_MAP = False
97
+ WITH_CUDA_HASH_MAP = False
98
+
99
+ if WITH_CPU_HASH_MAP:
100
+ CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap
101
+ else:
102
+
103
+ class CPUHashMap: # type: ignore
104
+ def __init__(self, key: Tensor) -> None:
105
+ raise ImportError("'CPUHashMap' requires 'pyg-lib'")
106
+
107
+ def get(self, query: Tensor) -> Tensor:
108
+ raise ImportError("'CPUHashMap' requires 'pyg-lib'")
109
+
110
+
111
+ if WITH_CUDA_HASH_MAP:
112
+ CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap
113
+ else:
114
+
115
+ class CUDAHashMap: # type: ignore
116
+ def __init__(self, key: Tensor) -> None:
117
+ raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
118
+
119
+ def get(self, query: Tensor) -> Tensor:
120
+ raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
121
+
81
122
 
82
123
  try:
83
124
  import torch_scatter # noqa