hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,230 @@
|
|
1
|
+
from torch_geometric.nn import MessagePassing
|
2
|
+
from torch_geometric.utils import degree
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
from .tetra import (
|
9
|
+
# get_tetra_update,
|
10
|
+
TETRA_UPDATE_DICT
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
class GCNConv(MessagePassing):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
# args,
|
18
|
+
hidden_size,
|
19
|
+
tetra,
|
20
|
+
message
|
21
|
+
):
|
22
|
+
super(GCNConv, self).__init__(aggr='add')
|
23
|
+
self.linear = nn.Linear(hidden_size, hidden_size)
|
24
|
+
self.batch_norm = nn.BatchNorm1d(hidden_size)
|
25
|
+
self.tetra = tetra # bool
|
26
|
+
if self.tetra:
|
27
|
+
# self.tetra_update = get_tetra_update(args)
|
28
|
+
self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
self,
|
32
|
+
x,
|
33
|
+
edge_index,
|
34
|
+
edge_attr,
|
35
|
+
parity_atoms
|
36
|
+
):
|
37
|
+
|
38
|
+
# no edge updates
|
39
|
+
x = self.linear(x)
|
40
|
+
|
41
|
+
# Compute normalization
|
42
|
+
row, col = edge_index
|
43
|
+
deg = degree(col, x.size(0), dtype=x.dtype) + 1
|
44
|
+
deg_inv_sqrt = deg.pow(-0.5)
|
45
|
+
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
|
46
|
+
x_new = self.propagate(edge_index, x=x, edge_attr=edge_attr, norm=norm)
|
47
|
+
|
48
|
+
if self.tetra:
|
49
|
+
tetra_ids = parity_atoms.nonzero().squeeze(1)
|
50
|
+
if tetra_ids.nelement() != 0:
|
51
|
+
x_new[tetra_ids] = self.tetra_message(x, edge_index, edge_attr, tetra_ids, parity_atoms)
|
52
|
+
x = x_new + F.relu(x)
|
53
|
+
|
54
|
+
return self.batch_norm(x), edge_attr
|
55
|
+
|
56
|
+
def message(self, x_j, edge_attr, norm):
|
57
|
+
return norm.view(-1, 1) * F.relu(x_j + edge_attr)
|
58
|
+
|
59
|
+
def tetra_message(self, x, edge_index, edge_attr, tetra_ids, parity_atoms):
|
60
|
+
|
61
|
+
row, col = edge_index
|
62
|
+
tetra_nei_ids = torch.cat([row[col == i].unsqueeze(0) for i in range(x.size(0)) if i in tetra_ids])
|
63
|
+
|
64
|
+
# calculate pseudo tetra degree aligned with GCN method
|
65
|
+
deg = degree(col, x.size(0), dtype=x.dtype)
|
66
|
+
t_deg = deg[tetra_nei_ids]
|
67
|
+
t_deg_inv_sqrt = t_deg.pow(-0.5)
|
68
|
+
t_norm = 0.5 * t_deg_inv_sqrt.mean(dim=1)
|
69
|
+
|
70
|
+
# switch entries for -1 rdkit labels
|
71
|
+
ccw_mask = parity_atoms[tetra_ids] == -1
|
72
|
+
tetra_nei_ids[ccw_mask] = tetra_nei_ids.clone()[ccw_mask][:, [1, 0, 2, 3]]
|
73
|
+
|
74
|
+
# calculate reps
|
75
|
+
edge_ids = torch.cat([tetra_nei_ids.view(1, -1), tetra_ids.repeat_interleave(4).unsqueeze(0)], dim=0)
|
76
|
+
# dense_edge_attr = to_dense_adj(edge_index, batch=None, edge_attr=edge_attr).squeeze(0)
|
77
|
+
# edge_reps = dense_edge_attr[edge_ids[0], edge_ids[1], :].view(tetra_nei_ids.size(0), 4, -1)
|
78
|
+
attr_ids = [torch.where((a == edge_index.t()).all(dim=1))[0] for a in edge_ids.t()]
|
79
|
+
edge_reps = edge_attr[attr_ids, :].view(tetra_nei_ids.size(0), 4, -1)
|
80
|
+
reps = x[tetra_nei_ids] + edge_reps
|
81
|
+
|
82
|
+
return t_norm.unsqueeze(-1) * self.tetra_update(reps)
|
83
|
+
|
84
|
+
|
85
|
+
class GINEConv(MessagePassing):
|
86
|
+
def __init__(
|
87
|
+
self,
|
88
|
+
# args,
|
89
|
+
hidden_size,
|
90
|
+
tetra,
|
91
|
+
message
|
92
|
+
):
|
93
|
+
super(GINEConv, self).__init__(aggr="add")
|
94
|
+
self.eps = nn.Parameter(torch.Tensor([0]))
|
95
|
+
self.mlp = nn.Sequential(nn.Linear(hidden_size, 2 * hidden_size),
|
96
|
+
nn.BatchNorm1d(2 * hidden_size),
|
97
|
+
nn.ReLU(),
|
98
|
+
nn.Linear(2 * hidden_size, hidden_size))
|
99
|
+
self.batch_norm = nn.BatchNorm1d(hidden_size)
|
100
|
+
self.tetra = tetra
|
101
|
+
if self.tetra:
|
102
|
+
# self.tetra_update = get_tetra_update(args)
|
103
|
+
self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
|
104
|
+
|
105
|
+
def forward(self, x, edge_index, edge_attr, parity_atoms):
|
106
|
+
# no edge updates
|
107
|
+
x_new = self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
108
|
+
|
109
|
+
if self.tetra:
|
110
|
+
tetra_ids = parity_atoms.nonzero().squeeze(1)
|
111
|
+
if tetra_ids.nelement() != 0:
|
112
|
+
x_new[tetra_ids] = self.tetra_message(x, edge_index, edge_attr, tetra_ids, parity_atoms)
|
113
|
+
|
114
|
+
x = self.mlp((1 + self.eps) * x + x_new)
|
115
|
+
return self.batch_norm(x), edge_attr
|
116
|
+
|
117
|
+
def message(self, x_j, edge_attr):
|
118
|
+
return F.relu(x_j + edge_attr)
|
119
|
+
|
120
|
+
def tetra_message(self, x, edge_index, edge_attr, tetra_ids, parity_atoms):
|
121
|
+
|
122
|
+
row, col = edge_index
|
123
|
+
tetra_nei_ids = torch.cat([row[col == i].unsqueeze(0) for i in range(x.size(0)) if i in tetra_ids])
|
124
|
+
|
125
|
+
# switch entries for -1 rdkit labels
|
126
|
+
ccw_mask = parity_atoms[tetra_ids] == -1
|
127
|
+
tetra_nei_ids[ccw_mask] = tetra_nei_ids.clone()[ccw_mask][:, [1, 0, 2, 3]]
|
128
|
+
|
129
|
+
# calculate reps
|
130
|
+
edge_ids = torch.cat([tetra_nei_ids.view(1, -1), tetra_ids.repeat_interleave(4).unsqueeze(0)], dim=0)
|
131
|
+
# dense_edge_attr = to_dense_adj(edge_index, batch=None, edge_attr=edge_attr).squeeze(0)
|
132
|
+
# edge_reps = dense_edge_attr[edge_ids[0], edge_ids[1], :].view(tetra_nei_ids.size(0), 4, -1)
|
133
|
+
attr_ids = [torch.where((a == edge_index.t()).all(dim=1))[0] for a in edge_ids.t()]
|
134
|
+
edge_reps = edge_attr[attr_ids, :].view(tetra_nei_ids.size(0), 4, -1)
|
135
|
+
reps = x[tetra_nei_ids] + edge_reps
|
136
|
+
|
137
|
+
return self.tetra_update(reps)
|
138
|
+
|
139
|
+
|
140
|
+
class DMPNNConv(MessagePassing):
|
141
|
+
def __init__(
|
142
|
+
self,
|
143
|
+
# args,
|
144
|
+
hidden_size,
|
145
|
+
tetra,
|
146
|
+
message
|
147
|
+
):
|
148
|
+
super(DMPNNConv, self).__init__(aggr='add')
|
149
|
+
self.lin = nn.Linear(hidden_size, hidden_size)
|
150
|
+
self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size),
|
151
|
+
nn.BatchNorm1d(hidden_size),
|
152
|
+
nn.ReLU())
|
153
|
+
self.tetra = tetra
|
154
|
+
if self.tetra:
|
155
|
+
# self.tetra_update = get_tetra_update(args)
|
156
|
+
self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
|
157
|
+
|
158
|
+
def forward(self, x, edge_index, edge_attr, parity_atoms, parity_bond_index):
|
159
|
+
row, col = edge_index
|
160
|
+
a_message = self.propagate(edge_index, x=None, edge_attr=edge_attr)
|
161
|
+
|
162
|
+
if self.tetra:
|
163
|
+
tetra_ids = parity_atoms.nonzero().squeeze(1)
|
164
|
+
if tetra_ids.nelement() != 0:
|
165
|
+
a_message[tetra_ids] = self.tetra_message(x, edge_index, edge_attr, tetra_ids, parity_atoms, parity_bond_index)
|
166
|
+
|
167
|
+
rev_message = torch.flip(edge_attr.view(edge_attr.size(0) // 2, 2, -1), dims=[1]).view(edge_attr.size(0), -1)
|
168
|
+
return a_message, self.mlp(a_message[row] - rev_message)
|
169
|
+
|
170
|
+
def message(self, x_j, edge_attr):
|
171
|
+
return F.relu(self.lin(edge_attr))
|
172
|
+
|
173
|
+
def tetra_message(self, x, edge_index, edge_attr, tetra_ids, parity_atoms, parity_bond_index):
|
174
|
+
edge_reps = edge_attr[parity_bond_index, :].view(parity_bond_index.size(0)//4, 4, -1)
|
175
|
+
|
176
|
+
return self.tetra_update(edge_reps)
|
177
|
+
# print('1')
|
178
|
+
row, col = edge_index
|
179
|
+
|
180
|
+
col_ids = torch.cat(
|
181
|
+
[(col == i).nonzero() for i in tetra_ids]
|
182
|
+
).squeeze().unsqueeze(0)
|
183
|
+
tetra_nei_ids = row[col_ids].reshape(-1, 4)
|
184
|
+
|
185
|
+
# tetra_nei_ids = torch.cat([
|
186
|
+
# row[col == i].unsqueeze(0)
|
187
|
+
# for i in tetra_ids
|
188
|
+
# ])
|
189
|
+
|
190
|
+
# print('2')
|
191
|
+
# switch entries for -1 rdkit labels
|
192
|
+
ccw_mask = parity_atoms[tetra_ids] == -1
|
193
|
+
tetra_nei_ids[ccw_mask] = tetra_nei_ids.clone()[ccw_mask][:, [1, 0, 2, 3]]
|
194
|
+
|
195
|
+
# calculate reps
|
196
|
+
edge_ids = torch.cat([tetra_nei_ids.view(1, -1), tetra_ids.repeat_interleave(4).unsqueeze(0)], dim=0)
|
197
|
+
# dense_edge_attr = to_dense_adj(edge_index, batch=None, edge_attr=edge_attr).squeeze(0)
|
198
|
+
# edge_reps = dense_edge_attr[edge_ids[0], edge_ids[1], :].view(tetra_nei_ids.size(0), 4, -1)
|
199
|
+
# edge_index_T = edge_index.t()
|
200
|
+
# edge_ids_T = edge_ids.t()
|
201
|
+
|
202
|
+
# attr_ids = [
|
203
|
+
# torch.where(
|
204
|
+
# (a == edge_index_T).all(dim=1)
|
205
|
+
# )[0]
|
206
|
+
# for a in edge_ids_T
|
207
|
+
# ]
|
208
|
+
# attr_ids = torch.cat([(edge_index_T == i).nonzero() for i in edge_ids_T])[:, 0].unique()
|
209
|
+
|
210
|
+
edge_index_T = edge_index.t()
|
211
|
+
edge_ids_T = edge_ids.t()
|
212
|
+
|
213
|
+
c0 = torch.cartesian_prod(
|
214
|
+
edge_index_T[:, 0], edge_ids_T[:, 0]
|
215
|
+
)
|
216
|
+
c1 = torch.cartesian_prod(
|
217
|
+
edge_index_T[:, 1], edge_ids_T[:, 1]
|
218
|
+
)
|
219
|
+
diff = torch.abs(c0[:, 0] - c0[:, 1]) \
|
220
|
+
+ torch.abs(c1[:, 0] - c1[:, 1])
|
221
|
+
|
222
|
+
attr_ids = torch.div(
|
223
|
+
(diff == 0).nonzero(as_tuple=True)[0],
|
224
|
+
edge_ids.size(1),
|
225
|
+
rounding_mode='floor'
|
226
|
+
)
|
227
|
+
|
228
|
+
edge_reps = edge_attr[attr_ids, :].view(tetra_nei_ids.size(0), 4, -1)
|
229
|
+
|
230
|
+
return self.tetra_update(edge_reps)
|
hdl/layers/graph/gcn.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
from torch import nn
|
2
|
+
from torch_geometric.nn import GCNConv
|
3
|
+
|
4
|
+
|
5
|
+
class GraphConv(nn.Module):
|
6
|
+
def __init__(self, num_features, num_out_features):
|
7
|
+
# Init parent
|
8
|
+
super(GraphConv, self).__init__()
|
9
|
+
|
10
|
+
# GCN layers
|
11
|
+
self.conv = GCNConv(num_features, num_out_features)
|
12
|
+
|
13
|
+
def forward(self, x, edge_index):
|
14
|
+
|
15
|
+
hidden = self.conv(x, edge_index)
|
16
|
+
return hidden
|
hdl/layers/graph/gin.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
from torch_geometric.nn import MessagePassing
|
5
|
+
from torch_geometric.utils import add_self_loops
|
6
|
+
|
7
|
+
num_atom_type = 119 # including the extra mask tokens
|
8
|
+
num_chirality_tag = 3
|
9
|
+
|
10
|
+
num_bond_type = 5 # including aromatic and self-loop edge
|
11
|
+
num_bond_direction = 3
|
12
|
+
|
13
|
+
|
14
|
+
class GINEConv(MessagePassing):
|
15
|
+
def __init__(self, emb_dim):
|
16
|
+
super(GINEConv, self).__init__()
|
17
|
+
self.mlp = nn.Sequential(
|
18
|
+
nn.Linear(emb_dim, 2*emb_dim),
|
19
|
+
nn.ReLU(),
|
20
|
+
nn.Linear(2 * emb_dim, emb_dim)
|
21
|
+
)
|
22
|
+
self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
|
23
|
+
self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
|
24
|
+
nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
25
|
+
nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
26
|
+
|
27
|
+
def forward(self, x, edge_index, edge_attr):
|
28
|
+
# add self loops in the edge space
|
29
|
+
edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
|
30
|
+
|
31
|
+
# add features corresponding to self-loop edges.
|
32
|
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
33
|
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
34
|
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
35
|
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
|
36
|
+
|
37
|
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
38
|
+
|
39
|
+
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
|
40
|
+
|
41
|
+
def message(self, x_j, edge_attr):
|
42
|
+
return x_j + edge_attr
|
43
|
+
|
44
|
+
def update(self, aggr_out):
|
45
|
+
return self.mlp(aggr_out)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import copy
|
5
|
+
|
6
|
+
|
7
|
+
class TetraPermuter(nn.Module):
|
8
|
+
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
hidden,
|
12
|
+
# device
|
13
|
+
):
|
14
|
+
super(TetraPermuter, self).__init__()
|
15
|
+
|
16
|
+
self.W_bs = nn.ModuleList([copy.deepcopy(nn.Linear(hidden, hidden)) for _ in range(4)])
|
17
|
+
# self.device = device
|
18
|
+
self.drop = nn.Dropout(p=0.2)
|
19
|
+
self.reset_parameters()
|
20
|
+
self.mlp_out = nn.Sequential(nn.Linear(hidden, hidden),
|
21
|
+
nn.BatchNorm1d(hidden),
|
22
|
+
nn.ReLU(),
|
23
|
+
nn.Linear(hidden, hidden))
|
24
|
+
|
25
|
+
self.tetra_perms = torch.tensor([[0, 1, 2, 3],
|
26
|
+
[0, 2, 3, 1],
|
27
|
+
[0, 3, 1, 2],
|
28
|
+
[1, 0, 3, 2],
|
29
|
+
[1, 2, 0, 3],
|
30
|
+
[1, 3, 2, 0],
|
31
|
+
[2, 0, 1, 3],
|
32
|
+
[2, 1, 3, 0],
|
33
|
+
[2, 3, 0, 1],
|
34
|
+
[3, 0, 2, 1],
|
35
|
+
[3, 1, 0, 2],
|
36
|
+
[3, 2, 1, 0]])
|
37
|
+
|
38
|
+
def reset_parameters(self):
|
39
|
+
gain = 0.5
|
40
|
+
for W_b in self.W_bs:
|
41
|
+
nn.init.xavier_uniform_(W_b.weight, gain=gain)
|
42
|
+
gain += 0.5
|
43
|
+
|
44
|
+
def forward(self, x):
|
45
|
+
|
46
|
+
nei_messages_list = [self.drop(F.tanh(l(t))) for l, t in zip(self.W_bs, torch.split(x[:, self.tetra_perms, :], 1, dim=-2))]
|
47
|
+
nei_messages = torch.sum(self.drop(F.relu(torch.cat(nei_messages_list, dim=-2).sum(dim=-2))), dim=-2)
|
48
|
+
|
49
|
+
return self.mlp_out(nei_messages / 3.)
|
50
|
+
|
51
|
+
|
52
|
+
class ConcatTetraPermuter(nn.Module):
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
hidden,
|
57
|
+
# device
|
58
|
+
):
|
59
|
+
super(ConcatTetraPermuter, self).__init__()
|
60
|
+
|
61
|
+
self.W_bs = nn.Linear(hidden * 4, hidden)
|
62
|
+
torch.nn.init.xavier_normal_(self.W_bs.weight, gain=1.0)
|
63
|
+
self.hidden = hidden
|
64
|
+
# self.device = device
|
65
|
+
self.drop = nn.Dropout(p=0.2)
|
66
|
+
self.mlp_out = nn.Sequential(nn.Linear(hidden, hidden),
|
67
|
+
nn.BatchNorm1d(hidden),
|
68
|
+
nn.ReLU(),
|
69
|
+
nn.Linear(hidden, hidden))
|
70
|
+
|
71
|
+
tetra_perms = torch.tensor([
|
72
|
+
[0, 1, 2, 3],
|
73
|
+
[0, 2, 3, 1],
|
74
|
+
[0, 3, 1, 2],
|
75
|
+
[1, 0, 3, 2],
|
76
|
+
[1, 2, 0, 3],
|
77
|
+
[1, 3, 2, 0],
|
78
|
+
[2, 0, 1, 3],
|
79
|
+
[2, 1, 3, 0],
|
80
|
+
[2, 3, 0, 1],
|
81
|
+
[3, 0, 2, 1],
|
82
|
+
[3, 1, 0, 2],
|
83
|
+
[3, 2, 1, 0]
|
84
|
+
])
|
85
|
+
self.register_buffer('tetra_perms', tetra_perms)
|
86
|
+
|
87
|
+
def forward(self, x):
|
88
|
+
|
89
|
+
nei_messages = self.drop(
|
90
|
+
F.relu(
|
91
|
+
self.W_bs(
|
92
|
+
x[
|
93
|
+
:,
|
94
|
+
self.tetra_perms,
|
95
|
+
:
|
96
|
+
].flatten(start_dim=2)
|
97
|
+
)
|
98
|
+
)
|
99
|
+
)
|
100
|
+
nei_messages_sum = nei_messages.sum(dim=-2) / 3.
|
101
|
+
if nei_messages_sum.size(0) == 1:
|
102
|
+
nei_messages_sum_repeat = torch.repeat_interleave(nei_messages_sum, 2, dim=0)
|
103
|
+
return self.mlp_out(nei_messages_sum_repeat)[:1, ...]
|
104
|
+
return self.mlp_out(nei_messages_sum)
|
105
|
+
|
106
|
+
|
107
|
+
class TetraDifferencesProduct(nn.Module):
|
108
|
+
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
hidden
|
112
|
+
):
|
113
|
+
super(TetraDifferencesProduct, self).__init__()
|
114
|
+
|
115
|
+
self.mlp_out = nn.Sequential(nn.Linear(hidden, hidden),
|
116
|
+
nn.BatchNorm1d(hidden),
|
117
|
+
nn.ReLU(),
|
118
|
+
nn.Linear(hidden, hidden))
|
119
|
+
self.register_buffer('indices', torch.arange(4))
|
120
|
+
|
121
|
+
def forward(self, x):
|
122
|
+
|
123
|
+
# indices = torch.arange(4).to(x.device)
|
124
|
+
message_tetra_nbs = [
|
125
|
+
x.index_select(dim=1, index=i).squeeze(1)
|
126
|
+
for i in self.indices
|
127
|
+
]
|
128
|
+
message_tetra = torch.ones_like(message_tetra_nbs[0])
|
129
|
+
|
130
|
+
# note: this will zero out reps for chiral centers with multiple carbon neighbors on first pass
|
131
|
+
for i in range(4):
|
132
|
+
for j in range(i + 1, 4):
|
133
|
+
message_tetra = torch.mul(message_tetra, (message_tetra_nbs[i] - message_tetra_nbs[j]))
|
134
|
+
message_tetra = torch.sign(message_tetra) * torch.pow(torch.abs(message_tetra) + 1e-6, 1 / 6)
|
135
|
+
return self.mlp_out(message_tetra)
|
136
|
+
|
137
|
+
|
138
|
+
# def get_tetra_update(
|
139
|
+
# hidden_size,
|
140
|
+
# device,
|
141
|
+
# message,
|
142
|
+
# ):
|
143
|
+
|
144
|
+
# if message == 'tetra_permute':
|
145
|
+
# return TetraPermuter(hidden_size, device)
|
146
|
+
# elif message == 'tetra_permute_concat':
|
147
|
+
# return ConcatTetraPermuter(hidden_size, device)
|
148
|
+
# elif message == 'tetra_pd':
|
149
|
+
# return TetraDifferencesProduct(hidden_size)
|
150
|
+
# else:
|
151
|
+
# raise ValueError("Invalid message type.")
|
152
|
+
|
153
|
+
|
154
|
+
TETRA_UPDATE_DICT = {
|
155
|
+
'tetra_permute': TetraPermuter,
|
156
|
+
'tetra_permute_concat': ConcatTetraPermuter,
|
157
|
+
'tetra_pd': TetraDifferencesProduct
|
158
|
+
}
|
@@ -0,0 +1,188 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Union, Tuple, Optional
|
3
|
+
from torch_geometric.typing import PairTensor, Adj, OptTensor
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear
|
9
|
+
from torch_geometric.nn.conv import MessagePassing
|
10
|
+
from torch_geometric.utils import softmax
|
11
|
+
|
12
|
+
|
13
|
+
class TransformerConv(MessagePassing):
|
14
|
+
r"""The graph transformer operator from the `"Masked Label Prediction:
|
15
|
+
Unified Message Passing Model for Semi-Supervised Classification"
|
16
|
+
<https://arxiv.org/abs/2009.03509>`_ paper
|
17
|
+
.. math::
|
18
|
+
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
|
19
|
+
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
|
20
|
+
where the attention coefficients :math:`\alpha_{i,j}` are computed via
|
21
|
+
multi-head dot product attention:
|
22
|
+
.. math::
|
23
|
+
\alpha_{i,j} = \textrm{softmax} \left(
|
24
|
+
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
|
25
|
+
{\sqrt{d}} \right)
|
26
|
+
Args:
|
27
|
+
in_channels (int or tuple): Size of each input sample. A tuple
|
28
|
+
corresponds to the sizes of source and target dimensionalities.
|
29
|
+
out_channels (int): Size of each output sample.
|
30
|
+
heads (int, optional): Number of multi-head-attentions.
|
31
|
+
(default: :obj:`1`)
|
32
|
+
concat (bool, optional): If set to :obj:`False`, the multi-head
|
33
|
+
attentions are averaged instead of concatenated.
|
34
|
+
(default: :obj:`True`)
|
35
|
+
beta (bool, optional): If set, will combine aggregation and
|
36
|
+
skip information via
|
37
|
+
.. math::
|
38
|
+
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
|
39
|
+
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
|
40
|
+
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
|
41
|
+
with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
|
42
|
+
[ \mathbf{x}_i, \mathbf{m}_i, \mathbf{x}_i - \mathbf{m}_i ])`
|
43
|
+
(default: :obj:`False`)
|
44
|
+
dropout (float, optional): Dropout probability of the normalized
|
45
|
+
attention coefficients which exposes each node to a stochastically
|
46
|
+
sampled neighborhood during training. (default: :obj:`0`)
|
47
|
+
edge_dim (int, optional): Edge feature dimensionality (in case
|
48
|
+
there are any). Edge features are added to the keys after
|
49
|
+
linear transformation, that is, prior to computing the
|
50
|
+
attention dot product. They are also added to final values
|
51
|
+
after the same linear transformation. The model is:
|
52
|
+
.. math::
|
53
|
+
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
|
54
|
+
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
|
55
|
+
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
|
56
|
+
\right),
|
57
|
+
where the attention coefficients :math:`\alpha_{i,j}` are now
|
58
|
+
computed via:
|
59
|
+
.. math::
|
60
|
+
\alpha_{i,j} = \textrm{softmax} \left(
|
61
|
+
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
|
62
|
+
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
|
63
|
+
{\sqrt{d}} \right)
|
64
|
+
(default :obj:`None`)
|
65
|
+
bias (bool, optional): If set to :obj:`False`, the layer will not learn
|
66
|
+
an additive bias. (default: :obj:`True`)
|
67
|
+
root_weight (bool, optional): If set to :obj:`False`, the layer will
|
68
|
+
not add the transformed root node features to the output and the
|
69
|
+
option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
|
70
|
+
**kwargs (optional): Additional arguments of
|
71
|
+
:class:`torch_geometric.nn.conv.MessagePassing`.
|
72
|
+
"""
|
73
|
+
_alpha: OptTensor
|
74
|
+
|
75
|
+
def __init__(self, in_channels: Union[int, Tuple[int,
|
76
|
+
int]], out_channels: int,
|
77
|
+
heads: int = 1, concat: bool = True, beta: bool = False,
|
78
|
+
dropout: float = 0., edge_dim: Optional[int] = None,
|
79
|
+
bias: bool = True, root_weight: bool = True, **kwargs):
|
80
|
+
kwargs.setdefault('aggr', 'add')
|
81
|
+
super(TransformerConv, self).__init__(node_dim=0, **kwargs)
|
82
|
+
|
83
|
+
self.in_channels = in_channels
|
84
|
+
self.out_channels = out_channels
|
85
|
+
self.heads = heads
|
86
|
+
self.beta = beta and root_weight
|
87
|
+
self.root_weight = root_weight
|
88
|
+
self.concat = concat
|
89
|
+
self.dropout = dropout
|
90
|
+
self.edge_dim = edge_dim
|
91
|
+
|
92
|
+
if isinstance(in_channels, int):
|
93
|
+
in_channels = (in_channels, in_channels)
|
94
|
+
|
95
|
+
self.lin_key = Linear(in_channels[0], heads * out_channels)
|
96
|
+
self.lin_query = Linear(in_channels[1], heads * out_channels)
|
97
|
+
self.lin_value = Linear(in_channels[0], heads * out_channels)
|
98
|
+
if edge_dim is not None:
|
99
|
+
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
|
100
|
+
else:
|
101
|
+
self.lin_edge = self.register_parameter('lin_edge', None)
|
102
|
+
|
103
|
+
if concat:
|
104
|
+
self.lin_skip = Linear(in_channels[1], heads * out_channels,
|
105
|
+
bias=bias)
|
106
|
+
if self.beta:
|
107
|
+
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
|
108
|
+
else:
|
109
|
+
self.lin_beta = self.register_parameter('lin_beta', None)
|
110
|
+
else:
|
111
|
+
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
|
112
|
+
if self.beta:
|
113
|
+
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
|
114
|
+
else:
|
115
|
+
self.lin_beta = self.register_parameter('lin_beta', None)
|
116
|
+
|
117
|
+
self.reset_parameters()
|
118
|
+
|
119
|
+
def reset_parameters(self):
|
120
|
+
self.lin_key.reset_parameters()
|
121
|
+
self.lin_query.reset_parameters()
|
122
|
+
self.lin_value.reset_parameters()
|
123
|
+
if self.edge_dim:
|
124
|
+
self.lin_edge.reset_parameters()
|
125
|
+
self.lin_skip.reset_parameters()
|
126
|
+
if self.beta:
|
127
|
+
self.lin_beta.reset_parameters()
|
128
|
+
|
129
|
+
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
|
130
|
+
edge_attr: OptTensor = None):
|
131
|
+
""""""
|
132
|
+
|
133
|
+
if isinstance(x, Tensor):
|
134
|
+
x: PairTensor = (x, x)
|
135
|
+
|
136
|
+
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
|
137
|
+
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
|
138
|
+
|
139
|
+
if self.concat:
|
140
|
+
out = out.view(
|
141
|
+
-1,
|
142
|
+
self.heads * self.out_channels
|
143
|
+
)
|
144
|
+
else:
|
145
|
+
out = out.mean(dim=1)
|
146
|
+
|
147
|
+
if self.root_weight:
|
148
|
+
x_r = self.lin_skip(x[1])
|
149
|
+
if self.lin_beta is not None:
|
150
|
+
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
|
151
|
+
beta = beta.sigmoid()
|
152
|
+
out = beta * x_r + (1 - beta) * out
|
153
|
+
else:
|
154
|
+
out += x_r
|
155
|
+
|
156
|
+
return out
|
157
|
+
|
158
|
+
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor,
|
159
|
+
index: Tensor, ptr: OptTensor,
|
160
|
+
size_i: Optional[int]) -> Tensor:
|
161
|
+
|
162
|
+
query = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
|
163
|
+
key = self.lin_key(x_j).view(-1, self.heads, self.out_channels)
|
164
|
+
|
165
|
+
if self.lin_edge is not None:
|
166
|
+
assert edge_attr is not None
|
167
|
+
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
|
168
|
+
self.out_channels)
|
169
|
+
key += edge_attr
|
170
|
+
|
171
|
+
alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels)
|
172
|
+
alpha = softmax(alpha, index, ptr, size_i)
|
173
|
+
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
174
|
+
|
175
|
+
out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
|
176
|
+
if edge_attr is not None:
|
177
|
+
out += edge_attr
|
178
|
+
|
179
|
+
out *= alpha.view(-1, self.heads, 1)
|
180
|
+
return out
|
181
|
+
|
182
|
+
def __repr__(self):
|
183
|
+
return '{}({}, {}, heads={})'.format(
|
184
|
+
self.__class__.__name__,
|
185
|
+
self.in_channels,
|
186
|
+
self.out_channels,
|
187
|
+
self.heads
|
188
|
+
)
|
File without changes
|
File without changes
|