hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,230 @@
1
+ from torch_geometric.nn import MessagePassing
2
+ from torch_geometric.utils import degree
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .tetra import (
9
+ # get_tetra_update,
10
+ TETRA_UPDATE_DICT
11
+ )
12
+
13
+
14
+ class GCNConv(MessagePassing):
15
+ def __init__(
16
+ self,
17
+ # args,
18
+ hidden_size,
19
+ tetra,
20
+ message
21
+ ):
22
+ super(GCNConv, self).__init__(aggr='add')
23
+ self.linear = nn.Linear(hidden_size, hidden_size)
24
+ self.batch_norm = nn.BatchNorm1d(hidden_size)
25
+ self.tetra = tetra # bool
26
+ if self.tetra:
27
+ # self.tetra_update = get_tetra_update(args)
28
+ self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
29
+
30
+ def forward(
31
+ self,
32
+ x,
33
+ edge_index,
34
+ edge_attr,
35
+ parity_atoms
36
+ ):
37
+
38
+ # no edge updates
39
+ x = self.linear(x)
40
+
41
+ # Compute normalization
42
+ row, col = edge_index
43
+ deg = degree(col, x.size(0), dtype=x.dtype) + 1
44
+ deg_inv_sqrt = deg.pow(-0.5)
45
+ norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
46
+ x_new = self.propagate(edge_index, x=x, edge_attr=edge_attr, norm=norm)
47
+
48
+ if self.tetra:
49
+ tetra_ids = parity_atoms.nonzero().squeeze(1)
50
+ if tetra_ids.nelement() != 0:
51
+ x_new[tetra_ids] = self.tetra_message(x, edge_index, edge_attr, tetra_ids, parity_atoms)
52
+ x = x_new + F.relu(x)
53
+
54
+ return self.batch_norm(x), edge_attr
55
+
56
+ def message(self, x_j, edge_attr, norm):
57
+ return norm.view(-1, 1) * F.relu(x_j + edge_attr)
58
+
59
+ def tetra_message(self, x, edge_index, edge_attr, tetra_ids, parity_atoms):
60
+
61
+ row, col = edge_index
62
+ tetra_nei_ids = torch.cat([row[col == i].unsqueeze(0) for i in range(x.size(0)) if i in tetra_ids])
63
+
64
+ # calculate pseudo tetra degree aligned with GCN method
65
+ deg = degree(col, x.size(0), dtype=x.dtype)
66
+ t_deg = deg[tetra_nei_ids]
67
+ t_deg_inv_sqrt = t_deg.pow(-0.5)
68
+ t_norm = 0.5 * t_deg_inv_sqrt.mean(dim=1)
69
+
70
+ # switch entries for -1 rdkit labels
71
+ ccw_mask = parity_atoms[tetra_ids] == -1
72
+ tetra_nei_ids[ccw_mask] = tetra_nei_ids.clone()[ccw_mask][:, [1, 0, 2, 3]]
73
+
74
+ # calculate reps
75
+ edge_ids = torch.cat([tetra_nei_ids.view(1, -1), tetra_ids.repeat_interleave(4).unsqueeze(0)], dim=0)
76
+ # dense_edge_attr = to_dense_adj(edge_index, batch=None, edge_attr=edge_attr).squeeze(0)
77
+ # edge_reps = dense_edge_attr[edge_ids[0], edge_ids[1], :].view(tetra_nei_ids.size(0), 4, -1)
78
+ attr_ids = [torch.where((a == edge_index.t()).all(dim=1))[0] for a in edge_ids.t()]
79
+ edge_reps = edge_attr[attr_ids, :].view(tetra_nei_ids.size(0), 4, -1)
80
+ reps = x[tetra_nei_ids] + edge_reps
81
+
82
+ return t_norm.unsqueeze(-1) * self.tetra_update(reps)
83
+
84
+
85
+ class GINEConv(MessagePassing):
86
+ def __init__(
87
+ self,
88
+ # args,
89
+ hidden_size,
90
+ tetra,
91
+ message
92
+ ):
93
+ super(GINEConv, self).__init__(aggr="add")
94
+ self.eps = nn.Parameter(torch.Tensor([0]))
95
+ self.mlp = nn.Sequential(nn.Linear(hidden_size, 2 * hidden_size),
96
+ nn.BatchNorm1d(2 * hidden_size),
97
+ nn.ReLU(),
98
+ nn.Linear(2 * hidden_size, hidden_size))
99
+ self.batch_norm = nn.BatchNorm1d(hidden_size)
100
+ self.tetra = tetra
101
+ if self.tetra:
102
+ # self.tetra_update = get_tetra_update(args)
103
+ self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
104
+
105
+ def forward(self, x, edge_index, edge_attr, parity_atoms):
106
+ # no edge updates
107
+ x_new = self.propagate(edge_index, x=x, edge_attr=edge_attr)
108
+
109
+ if self.tetra:
110
+ tetra_ids = parity_atoms.nonzero().squeeze(1)
111
+ if tetra_ids.nelement() != 0:
112
+ x_new[tetra_ids] = self.tetra_message(x, edge_index, edge_attr, tetra_ids, parity_atoms)
113
+
114
+ x = self.mlp((1 + self.eps) * x + x_new)
115
+ return self.batch_norm(x), edge_attr
116
+
117
+ def message(self, x_j, edge_attr):
118
+ return F.relu(x_j + edge_attr)
119
+
120
+ def tetra_message(self, x, edge_index, edge_attr, tetra_ids, parity_atoms):
121
+
122
+ row, col = edge_index
123
+ tetra_nei_ids = torch.cat([row[col == i].unsqueeze(0) for i in range(x.size(0)) if i in tetra_ids])
124
+
125
+ # switch entries for -1 rdkit labels
126
+ ccw_mask = parity_atoms[tetra_ids] == -1
127
+ tetra_nei_ids[ccw_mask] = tetra_nei_ids.clone()[ccw_mask][:, [1, 0, 2, 3]]
128
+
129
+ # calculate reps
130
+ edge_ids = torch.cat([tetra_nei_ids.view(1, -1), tetra_ids.repeat_interleave(4).unsqueeze(0)], dim=0)
131
+ # dense_edge_attr = to_dense_adj(edge_index, batch=None, edge_attr=edge_attr).squeeze(0)
132
+ # edge_reps = dense_edge_attr[edge_ids[0], edge_ids[1], :].view(tetra_nei_ids.size(0), 4, -1)
133
+ attr_ids = [torch.where((a == edge_index.t()).all(dim=1))[0] for a in edge_ids.t()]
134
+ edge_reps = edge_attr[attr_ids, :].view(tetra_nei_ids.size(0), 4, -1)
135
+ reps = x[tetra_nei_ids] + edge_reps
136
+
137
+ return self.tetra_update(reps)
138
+
139
+
140
+ class DMPNNConv(MessagePassing):
141
+ def __init__(
142
+ self,
143
+ # args,
144
+ hidden_size,
145
+ tetra,
146
+ message
147
+ ):
148
+ super(DMPNNConv, self).__init__(aggr='add')
149
+ self.lin = nn.Linear(hidden_size, hidden_size)
150
+ self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size),
151
+ nn.BatchNorm1d(hidden_size),
152
+ nn.ReLU())
153
+ self.tetra = tetra
154
+ if self.tetra:
155
+ # self.tetra_update = get_tetra_update(args)
156
+ self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
157
+
158
+ def forward(self, x, edge_index, edge_attr, parity_atoms, parity_bond_index):
159
+ row, col = edge_index
160
+ a_message = self.propagate(edge_index, x=None, edge_attr=edge_attr)
161
+
162
+ if self.tetra:
163
+ tetra_ids = parity_atoms.nonzero().squeeze(1)
164
+ if tetra_ids.nelement() != 0:
165
+ a_message[tetra_ids] = self.tetra_message(x, edge_index, edge_attr, tetra_ids, parity_atoms, parity_bond_index)
166
+
167
+ rev_message = torch.flip(edge_attr.view(edge_attr.size(0) // 2, 2, -1), dims=[1]).view(edge_attr.size(0), -1)
168
+ return a_message, self.mlp(a_message[row] - rev_message)
169
+
170
+ def message(self, x_j, edge_attr):
171
+ return F.relu(self.lin(edge_attr))
172
+
173
+ def tetra_message(self, x, edge_index, edge_attr, tetra_ids, parity_atoms, parity_bond_index):
174
+ edge_reps = edge_attr[parity_bond_index, :].view(parity_bond_index.size(0)//4, 4, -1)
175
+
176
+ return self.tetra_update(edge_reps)
177
+ # print('1')
178
+ row, col = edge_index
179
+
180
+ col_ids = torch.cat(
181
+ [(col == i).nonzero() for i in tetra_ids]
182
+ ).squeeze().unsqueeze(0)
183
+ tetra_nei_ids = row[col_ids].reshape(-1, 4)
184
+
185
+ # tetra_nei_ids = torch.cat([
186
+ # row[col == i].unsqueeze(0)
187
+ # for i in tetra_ids
188
+ # ])
189
+
190
+ # print('2')
191
+ # switch entries for -1 rdkit labels
192
+ ccw_mask = parity_atoms[tetra_ids] == -1
193
+ tetra_nei_ids[ccw_mask] = tetra_nei_ids.clone()[ccw_mask][:, [1, 0, 2, 3]]
194
+
195
+ # calculate reps
196
+ edge_ids = torch.cat([tetra_nei_ids.view(1, -1), tetra_ids.repeat_interleave(4).unsqueeze(0)], dim=0)
197
+ # dense_edge_attr = to_dense_adj(edge_index, batch=None, edge_attr=edge_attr).squeeze(0)
198
+ # edge_reps = dense_edge_attr[edge_ids[0], edge_ids[1], :].view(tetra_nei_ids.size(0), 4, -1)
199
+ # edge_index_T = edge_index.t()
200
+ # edge_ids_T = edge_ids.t()
201
+
202
+ # attr_ids = [
203
+ # torch.where(
204
+ # (a == edge_index_T).all(dim=1)
205
+ # )[0]
206
+ # for a in edge_ids_T
207
+ # ]
208
+ # attr_ids = torch.cat([(edge_index_T == i).nonzero() for i in edge_ids_T])[:, 0].unique()
209
+
210
+ edge_index_T = edge_index.t()
211
+ edge_ids_T = edge_ids.t()
212
+
213
+ c0 = torch.cartesian_prod(
214
+ edge_index_T[:, 0], edge_ids_T[:, 0]
215
+ )
216
+ c1 = torch.cartesian_prod(
217
+ edge_index_T[:, 1], edge_ids_T[:, 1]
218
+ )
219
+ diff = torch.abs(c0[:, 0] - c0[:, 1]) \
220
+ + torch.abs(c1[:, 0] - c1[:, 1])
221
+
222
+ attr_ids = torch.div(
223
+ (diff == 0).nonzero(as_tuple=True)[0],
224
+ edge_ids.size(1),
225
+ rounding_mode='floor'
226
+ )
227
+
228
+ edge_reps = edge_attr[attr_ids, :].view(tetra_nei_ids.size(0), 4, -1)
229
+
230
+ return self.tetra_update(edge_reps)
@@ -0,0 +1,16 @@
1
+ from torch import nn
2
+ from torch_geometric.nn import GCNConv
3
+
4
+
5
+ class GraphConv(nn.Module):
6
+ def __init__(self, num_features, num_out_features):
7
+ # Init parent
8
+ super(GraphConv, self).__init__()
9
+
10
+ # GCN layers
11
+ self.conv = GCNConv(num_features, num_out_features)
12
+
13
+ def forward(self, x, edge_index):
14
+
15
+ hidden = self.conv(x, edge_index)
16
+ return hidden
@@ -0,0 +1,45 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from torch_geometric.nn import MessagePassing
5
+ from torch_geometric.utils import add_self_loops
6
+
7
+ num_atom_type = 119 # including the extra mask tokens
8
+ num_chirality_tag = 3
9
+
10
+ num_bond_type = 5 # including aromatic and self-loop edge
11
+ num_bond_direction = 3
12
+
13
+
14
+ class GINEConv(MessagePassing):
15
+ def __init__(self, emb_dim):
16
+ super(GINEConv, self).__init__()
17
+ self.mlp = nn.Sequential(
18
+ nn.Linear(emb_dim, 2*emb_dim),
19
+ nn.ReLU(),
20
+ nn.Linear(2 * emb_dim, emb_dim)
21
+ )
22
+ self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
23
+ self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
24
+ nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
25
+ nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
26
+
27
+ def forward(self, x, edge_index, edge_attr):
28
+ # add self loops in the edge space
29
+ edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]
30
+
31
+ # add features corresponding to self-loop edges.
32
+ self_loop_attr = torch.zeros(x.size(0), 2)
33
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
34
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
35
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
36
+
37
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
38
+
39
+ return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
40
+
41
+ def message(self, x_j, edge_attr):
42
+ return x_j + edge_attr
43
+
44
+ def update(self, aggr_out):
45
+ return self.mlp(aggr_out)
@@ -0,0 +1,158 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import copy
5
+
6
+
7
+ class TetraPermuter(nn.Module):
8
+
9
+ def __init__(
10
+ self,
11
+ hidden,
12
+ # device
13
+ ):
14
+ super(TetraPermuter, self).__init__()
15
+
16
+ self.W_bs = nn.ModuleList([copy.deepcopy(nn.Linear(hidden, hidden)) for _ in range(4)])
17
+ # self.device = device
18
+ self.drop = nn.Dropout(p=0.2)
19
+ self.reset_parameters()
20
+ self.mlp_out = nn.Sequential(nn.Linear(hidden, hidden),
21
+ nn.BatchNorm1d(hidden),
22
+ nn.ReLU(),
23
+ nn.Linear(hidden, hidden))
24
+
25
+ self.tetra_perms = torch.tensor([[0, 1, 2, 3],
26
+ [0, 2, 3, 1],
27
+ [0, 3, 1, 2],
28
+ [1, 0, 3, 2],
29
+ [1, 2, 0, 3],
30
+ [1, 3, 2, 0],
31
+ [2, 0, 1, 3],
32
+ [2, 1, 3, 0],
33
+ [2, 3, 0, 1],
34
+ [3, 0, 2, 1],
35
+ [3, 1, 0, 2],
36
+ [3, 2, 1, 0]])
37
+
38
+ def reset_parameters(self):
39
+ gain = 0.5
40
+ for W_b in self.W_bs:
41
+ nn.init.xavier_uniform_(W_b.weight, gain=gain)
42
+ gain += 0.5
43
+
44
+ def forward(self, x):
45
+
46
+ nei_messages_list = [self.drop(F.tanh(l(t))) for l, t in zip(self.W_bs, torch.split(x[:, self.tetra_perms, :], 1, dim=-2))]
47
+ nei_messages = torch.sum(self.drop(F.relu(torch.cat(nei_messages_list, dim=-2).sum(dim=-2))), dim=-2)
48
+
49
+ return self.mlp_out(nei_messages / 3.)
50
+
51
+
52
+ class ConcatTetraPermuter(nn.Module):
53
+
54
+ def __init__(
55
+ self,
56
+ hidden,
57
+ # device
58
+ ):
59
+ super(ConcatTetraPermuter, self).__init__()
60
+
61
+ self.W_bs = nn.Linear(hidden * 4, hidden)
62
+ torch.nn.init.xavier_normal_(self.W_bs.weight, gain=1.0)
63
+ self.hidden = hidden
64
+ # self.device = device
65
+ self.drop = nn.Dropout(p=0.2)
66
+ self.mlp_out = nn.Sequential(nn.Linear(hidden, hidden),
67
+ nn.BatchNorm1d(hidden),
68
+ nn.ReLU(),
69
+ nn.Linear(hidden, hidden))
70
+
71
+ tetra_perms = torch.tensor([
72
+ [0, 1, 2, 3],
73
+ [0, 2, 3, 1],
74
+ [0, 3, 1, 2],
75
+ [1, 0, 3, 2],
76
+ [1, 2, 0, 3],
77
+ [1, 3, 2, 0],
78
+ [2, 0, 1, 3],
79
+ [2, 1, 3, 0],
80
+ [2, 3, 0, 1],
81
+ [3, 0, 2, 1],
82
+ [3, 1, 0, 2],
83
+ [3, 2, 1, 0]
84
+ ])
85
+ self.register_buffer('tetra_perms', tetra_perms)
86
+
87
+ def forward(self, x):
88
+
89
+ nei_messages = self.drop(
90
+ F.relu(
91
+ self.W_bs(
92
+ x[
93
+ :,
94
+ self.tetra_perms,
95
+ :
96
+ ].flatten(start_dim=2)
97
+ )
98
+ )
99
+ )
100
+ nei_messages_sum = nei_messages.sum(dim=-2) / 3.
101
+ if nei_messages_sum.size(0) == 1:
102
+ nei_messages_sum_repeat = torch.repeat_interleave(nei_messages_sum, 2, dim=0)
103
+ return self.mlp_out(nei_messages_sum_repeat)[:1, ...]
104
+ return self.mlp_out(nei_messages_sum)
105
+
106
+
107
+ class TetraDifferencesProduct(nn.Module):
108
+
109
+ def __init__(
110
+ self,
111
+ hidden
112
+ ):
113
+ super(TetraDifferencesProduct, self).__init__()
114
+
115
+ self.mlp_out = nn.Sequential(nn.Linear(hidden, hidden),
116
+ nn.BatchNorm1d(hidden),
117
+ nn.ReLU(),
118
+ nn.Linear(hidden, hidden))
119
+ self.register_buffer('indices', torch.arange(4))
120
+
121
+ def forward(self, x):
122
+
123
+ # indices = torch.arange(4).to(x.device)
124
+ message_tetra_nbs = [
125
+ x.index_select(dim=1, index=i).squeeze(1)
126
+ for i in self.indices
127
+ ]
128
+ message_tetra = torch.ones_like(message_tetra_nbs[0])
129
+
130
+ # note: this will zero out reps for chiral centers with multiple carbon neighbors on first pass
131
+ for i in range(4):
132
+ for j in range(i + 1, 4):
133
+ message_tetra = torch.mul(message_tetra, (message_tetra_nbs[i] - message_tetra_nbs[j]))
134
+ message_tetra = torch.sign(message_tetra) * torch.pow(torch.abs(message_tetra) + 1e-6, 1 / 6)
135
+ return self.mlp_out(message_tetra)
136
+
137
+
138
+ # def get_tetra_update(
139
+ # hidden_size,
140
+ # device,
141
+ # message,
142
+ # ):
143
+
144
+ # if message == 'tetra_permute':
145
+ # return TetraPermuter(hidden_size, device)
146
+ # elif message == 'tetra_permute_concat':
147
+ # return ConcatTetraPermuter(hidden_size, device)
148
+ # elif message == 'tetra_pd':
149
+ # return TetraDifferencesProduct(hidden_size)
150
+ # else:
151
+ # raise ValueError("Invalid message type.")
152
+
153
+
154
+ TETRA_UPDATE_DICT = {
155
+ 'tetra_permute': TetraPermuter,
156
+ 'tetra_permute_concat': ConcatTetraPermuter,
157
+ 'tetra_pd': TetraDifferencesProduct
158
+ }
@@ -0,0 +1,188 @@
1
+ import math
2
+ from typing import Union, Tuple, Optional
3
+ from torch_geometric.typing import PairTensor, Adj, OptTensor
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+ from torch.nn import Linear
9
+ from torch_geometric.nn.conv import MessagePassing
10
+ from torch_geometric.utils import softmax
11
+
12
+
13
+ class TransformerConv(MessagePassing):
14
+ r"""The graph transformer operator from the `"Masked Label Prediction:
15
+ Unified Message Passing Model for Semi-Supervised Classification"
16
+ <https://arxiv.org/abs/2009.03509>`_ paper
17
+ .. math::
18
+ \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
19
+ \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
20
+ where the attention coefficients :math:`\alpha_{i,j}` are computed via
21
+ multi-head dot product attention:
22
+ .. math::
23
+ \alpha_{i,j} = \textrm{softmax} \left(
24
+ \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
25
+ {\sqrt{d}} \right)
26
+ Args:
27
+ in_channels (int or tuple): Size of each input sample. A tuple
28
+ corresponds to the sizes of source and target dimensionalities.
29
+ out_channels (int): Size of each output sample.
30
+ heads (int, optional): Number of multi-head-attentions.
31
+ (default: :obj:`1`)
32
+ concat (bool, optional): If set to :obj:`False`, the multi-head
33
+ attentions are averaged instead of concatenated.
34
+ (default: :obj:`True`)
35
+ beta (bool, optional): If set, will combine aggregation and
36
+ skip information via
37
+ .. math::
38
+ \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
39
+ (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
40
+ \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
41
+ with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
42
+ [ \mathbf{x}_i, \mathbf{m}_i, \mathbf{x}_i - \mathbf{m}_i ])`
43
+ (default: :obj:`False`)
44
+ dropout (float, optional): Dropout probability of the normalized
45
+ attention coefficients which exposes each node to a stochastically
46
+ sampled neighborhood during training. (default: :obj:`0`)
47
+ edge_dim (int, optional): Edge feature dimensionality (in case
48
+ there are any). Edge features are added to the keys after
49
+ linear transformation, that is, prior to computing the
50
+ attention dot product. They are also added to final values
51
+ after the same linear transformation. The model is:
52
+ .. math::
53
+ \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
54
+ \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
55
+ \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
56
+ \right),
57
+ where the attention coefficients :math:`\alpha_{i,j}` are now
58
+ computed via:
59
+ .. math::
60
+ \alpha_{i,j} = \textrm{softmax} \left(
61
+ \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
62
+ (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
63
+ {\sqrt{d}} \right)
64
+ (default :obj:`None`)
65
+ bias (bool, optional): If set to :obj:`False`, the layer will not learn
66
+ an additive bias. (default: :obj:`True`)
67
+ root_weight (bool, optional): If set to :obj:`False`, the layer will
68
+ not add the transformed root node features to the output and the
69
+ option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
70
+ **kwargs (optional): Additional arguments of
71
+ :class:`torch_geometric.nn.conv.MessagePassing`.
72
+ """
73
+ _alpha: OptTensor
74
+
75
+ def __init__(self, in_channels: Union[int, Tuple[int,
76
+ int]], out_channels: int,
77
+ heads: int = 1, concat: bool = True, beta: bool = False,
78
+ dropout: float = 0., edge_dim: Optional[int] = None,
79
+ bias: bool = True, root_weight: bool = True, **kwargs):
80
+ kwargs.setdefault('aggr', 'add')
81
+ super(TransformerConv, self).__init__(node_dim=0, **kwargs)
82
+
83
+ self.in_channels = in_channels
84
+ self.out_channels = out_channels
85
+ self.heads = heads
86
+ self.beta = beta and root_weight
87
+ self.root_weight = root_weight
88
+ self.concat = concat
89
+ self.dropout = dropout
90
+ self.edge_dim = edge_dim
91
+
92
+ if isinstance(in_channels, int):
93
+ in_channels = (in_channels, in_channels)
94
+
95
+ self.lin_key = Linear(in_channels[0], heads * out_channels)
96
+ self.lin_query = Linear(in_channels[1], heads * out_channels)
97
+ self.lin_value = Linear(in_channels[0], heads * out_channels)
98
+ if edge_dim is not None:
99
+ self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
100
+ else:
101
+ self.lin_edge = self.register_parameter('lin_edge', None)
102
+
103
+ if concat:
104
+ self.lin_skip = Linear(in_channels[1], heads * out_channels,
105
+ bias=bias)
106
+ if self.beta:
107
+ self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
108
+ else:
109
+ self.lin_beta = self.register_parameter('lin_beta', None)
110
+ else:
111
+ self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
112
+ if self.beta:
113
+ self.lin_beta = Linear(3 * out_channels, 1, bias=False)
114
+ else:
115
+ self.lin_beta = self.register_parameter('lin_beta', None)
116
+
117
+ self.reset_parameters()
118
+
119
+ def reset_parameters(self):
120
+ self.lin_key.reset_parameters()
121
+ self.lin_query.reset_parameters()
122
+ self.lin_value.reset_parameters()
123
+ if self.edge_dim:
124
+ self.lin_edge.reset_parameters()
125
+ self.lin_skip.reset_parameters()
126
+ if self.beta:
127
+ self.lin_beta.reset_parameters()
128
+
129
+ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
130
+ edge_attr: OptTensor = None):
131
+ """"""
132
+
133
+ if isinstance(x, Tensor):
134
+ x: PairTensor = (x, x)
135
+
136
+ # propagate_type: (x: PairTensor, edge_attr: OptTensor)
137
+ out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
138
+
139
+ if self.concat:
140
+ out = out.view(
141
+ -1,
142
+ self.heads * self.out_channels
143
+ )
144
+ else:
145
+ out = out.mean(dim=1)
146
+
147
+ if self.root_weight:
148
+ x_r = self.lin_skip(x[1])
149
+ if self.lin_beta is not None:
150
+ beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
151
+ beta = beta.sigmoid()
152
+ out = beta * x_r + (1 - beta) * out
153
+ else:
154
+ out += x_r
155
+
156
+ return out
157
+
158
+ def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor,
159
+ index: Tensor, ptr: OptTensor,
160
+ size_i: Optional[int]) -> Tensor:
161
+
162
+ query = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
163
+ key = self.lin_key(x_j).view(-1, self.heads, self.out_channels)
164
+
165
+ if self.lin_edge is not None:
166
+ assert edge_attr is not None
167
+ edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
168
+ self.out_channels)
169
+ key += edge_attr
170
+
171
+ alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels)
172
+ alpha = softmax(alpha, index, ptr, size_i)
173
+ alpha = F.dropout(alpha, p=self.dropout, training=self.training)
174
+
175
+ out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
176
+ if edge_attr is not None:
177
+ out += edge_attr
178
+
179
+ out *= alpha.view(-1, self.heads, 1)
180
+ return out
181
+
182
+ def __repr__(self):
183
+ return '{}({}, {}, heads={})'.format(
184
+ self.__class__.__name__,
185
+ self.in_channels,
186
+ self.out_channels,
187
+ self.heads
188
+ )
File without changes
File without changes