hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,234 @@
|
|
1
|
+
import torch
|
2
|
+
# import torch.nn.functional as F
|
3
|
+
from torch import nn, einsum
|
4
|
+
|
5
|
+
from einops import rearrange, reduce
|
6
|
+
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding
|
7
|
+
|
8
|
+
# helper functions
|
9
|
+
|
10
|
+
|
11
|
+
def exists(val):
|
12
|
+
return val is not None
|
13
|
+
|
14
|
+
|
15
|
+
def default(val, d):
|
16
|
+
return val if exists(val) else d
|
17
|
+
|
18
|
+
|
19
|
+
# helper classes
|
20
|
+
class PreNorm(nn.Module):
|
21
|
+
def __init__(self, dim, fn):
|
22
|
+
super().__init__()
|
23
|
+
self.norm = nn.LayerNorm(dim)
|
24
|
+
self.fn = fn
|
25
|
+
|
26
|
+
def forward(self, x, **kwargs):
|
27
|
+
x = self.norm(x)
|
28
|
+
return self.fn(x, **kwargs)
|
29
|
+
|
30
|
+
# blocks
|
31
|
+
|
32
|
+
|
33
|
+
def FeedForward(dim, mult=4):
|
34
|
+
return nn.Sequential(
|
35
|
+
nn.Linear(dim, dim * mult),
|
36
|
+
nn.GELU(),
|
37
|
+
nn.Linear(dim * mult, dim)
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
class FastAttention(nn.Module):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
dim,
|
45
|
+
*,
|
46
|
+
heads=8,
|
47
|
+
dim_head=64,
|
48
|
+
max_seq_len=None,
|
49
|
+
pos_emb=None
|
50
|
+
):
|
51
|
+
super().__init__()
|
52
|
+
inner_dim = heads * dim_head
|
53
|
+
self.heads = heads
|
54
|
+
self.scale = dim_head ** -0.5
|
55
|
+
|
56
|
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
57
|
+
|
58
|
+
# rotary positional embedding
|
59
|
+
|
60
|
+
assert not (exists(pos_emb) and not exists(max_seq_len)), \
|
61
|
+
'max_seq_len must be passed in if to use rotary positional embeddings'
|
62
|
+
|
63
|
+
self.pos_emb = pos_emb
|
64
|
+
self.max_seq_len = max_seq_len
|
65
|
+
|
66
|
+
# if using relative positional encoding, make sure to reduce pairs of
|
67
|
+
# consecutive feature dimension before doing projection to attention logits
|
68
|
+
|
69
|
+
kv_attn_proj_divisor = 1 if not exists(pos_emb) else 2
|
70
|
+
|
71
|
+
# for projecting queries to query attention logits
|
72
|
+
self.to_q_attn_logits = nn.Linear(dim_head, 1, bias=False)
|
73
|
+
self.to_k_attn_logits = nn.Linear(
|
74
|
+
dim_head // kv_attn_proj_divisor,
|
75
|
+
1,
|
76
|
+
bias=False
|
77
|
+
) # for projecting keys to key attention logits
|
78
|
+
|
79
|
+
# final transformation of values to "r" as in the paper
|
80
|
+
|
81
|
+
self.to_r = nn.Linear(dim_head // kv_attn_proj_divisor, dim_head)
|
82
|
+
|
83
|
+
self.to_out = nn.Linear(inner_dim, dim)
|
84
|
+
|
85
|
+
def forward(self, x, mask=None):
|
86
|
+
n, device, h, use_rotary_emb = x.shape[1], x.device, self.heads, exists(self.pos_emb)
|
87
|
+
|
88
|
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
89
|
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
90
|
+
|
91
|
+
mask_value = -torch.finfo(x.dtype).max
|
92
|
+
mask = rearrange(mask, 'b n -> b () n')
|
93
|
+
|
94
|
+
# if relative positional encoding is needed
|
95
|
+
|
96
|
+
if use_rotary_emb:
|
97
|
+
freqs = self.pos_emb(torch.arange(self.max_seq_len, device=device), cache_key=self.max_seq_len)
|
98
|
+
freqs = rearrange(freqs[:n], 'n d -> () () n d')
|
99
|
+
q_aggr, k_aggr, v_aggr = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v))
|
100
|
+
else:
|
101
|
+
q_aggr, k_aggr, v_aggr = q, k, v
|
102
|
+
|
103
|
+
# calculate query attention logits
|
104
|
+
|
105
|
+
q_attn_logits = rearrange(self.to_q_attn_logits(q), 'b h n () -> b h n') * self.scale
|
106
|
+
q_attn_logits = q_attn_logits.masked_fill(~mask, mask_value)
|
107
|
+
q_attn = q_attn_logits.softmax(dim=-1)
|
108
|
+
|
109
|
+
# calculate global query token
|
110
|
+
|
111
|
+
global_q = einsum('b h n, b h n d -> b h d', q_attn, q_aggr)
|
112
|
+
global_q = rearrange(global_q, 'b h d -> b h () d')
|
113
|
+
|
114
|
+
# bias keys with global query token
|
115
|
+
|
116
|
+
k = k * global_q
|
117
|
+
|
118
|
+
# if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
|
119
|
+
|
120
|
+
if use_rotary_emb:
|
121
|
+
k = reduce(k, 'b h n (d r) -> b h n d', 'sum', r=2)
|
122
|
+
|
123
|
+
# now calculate key attention logits
|
124
|
+
|
125
|
+
k_attn_logits = rearrange(self.to_k_attn_logits(k), 'b h n () -> b h n') * self.scale
|
126
|
+
k_attn_logits = k_attn_logits.masked_fill(~mask, mask_value)
|
127
|
+
k_attn = k_attn_logits.softmax(dim=-1)
|
128
|
+
|
129
|
+
# calculate global key token
|
130
|
+
|
131
|
+
global_k = einsum('b h n, b h n d -> b h d', k_attn, k_aggr)
|
132
|
+
global_k = rearrange(global_k, 'b h d -> b h () d')
|
133
|
+
|
134
|
+
# bias the values
|
135
|
+
|
136
|
+
u = v_aggr * global_k
|
137
|
+
|
138
|
+
# if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
|
139
|
+
|
140
|
+
if use_rotary_emb:
|
141
|
+
u = reduce(u, 'b h n (d r) -> b h n d', 'sum', r=2)
|
142
|
+
|
143
|
+
# transformation step
|
144
|
+
|
145
|
+
r = self.to_r(u)
|
146
|
+
|
147
|
+
# paper then says to add the queries as a residual
|
148
|
+
|
149
|
+
r = r + q
|
150
|
+
|
151
|
+
# combine heads
|
152
|
+
|
153
|
+
r = rearrange(r, 'b h n d -> b n (h d)')
|
154
|
+
return self.to_out(r)
|
155
|
+
|
156
|
+
|
157
|
+
# main class
|
158
|
+
class FastTransformer(nn.Module):
|
159
|
+
def __init__(
|
160
|
+
self,
|
161
|
+
*,
|
162
|
+
num_tokens,
|
163
|
+
dim,
|
164
|
+
depth,
|
165
|
+
max_seq_len,
|
166
|
+
heads=8,
|
167
|
+
dim_head=64,
|
168
|
+
ff_mult=4,
|
169
|
+
absolute_pos_emb=False
|
170
|
+
):
|
171
|
+
super().__init__()
|
172
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
173
|
+
|
174
|
+
# positional embeddings
|
175
|
+
|
176
|
+
self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if absolute_pos_emb else None
|
177
|
+
|
178
|
+
layer_pos_emb = None
|
179
|
+
if not absolute_pos_emb:
|
180
|
+
assert (dim_head % 4) == 0, 'dimension of the head must be divisible by 4 to use rotary embeddings'
|
181
|
+
layer_pos_emb = RotaryEmbedding(dim_head // 2)
|
182
|
+
|
183
|
+
# layers
|
184
|
+
|
185
|
+
self.layers = nn.ModuleList([])
|
186
|
+
|
187
|
+
for _ in range(depth):
|
188
|
+
attn = FastAttention(
|
189
|
+
dim,
|
190
|
+
dim_head=dim_head,
|
191
|
+
heads=heads,
|
192
|
+
pos_emb=layer_pos_emb,
|
193
|
+
max_seq_len=max_seq_len
|
194
|
+
)
|
195
|
+
ff = FeedForward(dim, mult=ff_mult)
|
196
|
+
|
197
|
+
self.layers.append(nn.ModuleList([
|
198
|
+
PreNorm(dim, attn),
|
199
|
+
PreNorm(dim, ff)
|
200
|
+
]))
|
201
|
+
|
202
|
+
# weight tie projections across all layers
|
203
|
+
|
204
|
+
first_block, _ = self.layers[0]
|
205
|
+
for block, _ in self.layers[1:]:
|
206
|
+
block.fn.to_q_attn_logits = first_block.fn.to_q_attn_logits
|
207
|
+
block.fn.to_k_attn_logits = first_block.fn.to_k_attn_logits
|
208
|
+
|
209
|
+
# to logits
|
210
|
+
|
211
|
+
self.to_logits = nn.Sequential(
|
212
|
+
nn.LayerNorm(dim),
|
213
|
+
nn.Linear(dim, num_tokens)
|
214
|
+
)
|
215
|
+
|
216
|
+
def forward(
|
217
|
+
self,
|
218
|
+
x,
|
219
|
+
mask=None
|
220
|
+
):
|
221
|
+
n, device = x.shape[1], x.device
|
222
|
+
if mask is None:
|
223
|
+
mask = torch.ones_like(x).bool().to(device)
|
224
|
+
x = self.token_emb(x)
|
225
|
+
|
226
|
+
if exists(self.abs_pos_emb):
|
227
|
+
pos_emb = self.abs_pos_emb(torch.arange(n, device=device))
|
228
|
+
x = x + rearrange(pos_emb, 'n d -> () n d')
|
229
|
+
|
230
|
+
for attn, ff in self.layers:
|
231
|
+
x = attn(x, mask=mask) + x
|
232
|
+
x = ff(x) + x
|
233
|
+
|
234
|
+
return self.to_logits(x)
|
hdl/models/ginet.py
ADDED
@@ -0,0 +1,189 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
# from torch_geometric.nn import MessagePassing
|
6
|
+
# from torch_geometric.utils import add_self_loops
|
7
|
+
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
|
8
|
+
|
9
|
+
from hdl.layers.graph.gin import GINEConv
|
10
|
+
from hdl.layers.general.linear import (
|
11
|
+
# BNReLULinear,
|
12
|
+
BNReLULinearBlock,
|
13
|
+
)
|
14
|
+
from hdl.models.utils import load_model
|
15
|
+
from hdl.ops.utils import get_activation
|
16
|
+
|
17
|
+
|
18
|
+
__all__ = [
|
19
|
+
"GINet",
|
20
|
+
"GINMLPR",
|
21
|
+
]
|
22
|
+
|
23
|
+
|
24
|
+
num_atom_type = 119 # including the extra mask tokens
|
25
|
+
num_chirality_tag = 3
|
26
|
+
|
27
|
+
num_bond_type = 5 # including aromatic and self-loop edge
|
28
|
+
num_bond_direction = 3
|
29
|
+
|
30
|
+
|
31
|
+
class GINet(nn.Module):
|
32
|
+
"""
|
33
|
+
Args:
|
34
|
+
num_layer (int): the number of GNN layers
|
35
|
+
emb_dim (int): dimensionality of embeddings
|
36
|
+
max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
|
37
|
+
drop_ratio (float): dropout rate
|
38
|
+
gnn_type: gin, gcn, graphsage, gat
|
39
|
+
Output:
|
40
|
+
node representations
|
41
|
+
"""
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
num_layer=5,
|
45
|
+
emb_dim=300,
|
46
|
+
feat_dim=512,
|
47
|
+
drop_ratio=0,
|
48
|
+
pool='mean'
|
49
|
+
):
|
50
|
+
super(GINet, self).__init__()
|
51
|
+
self.init_args = {
|
52
|
+
'num_layer': num_layer,
|
53
|
+
'emb_dim': emb_dim,
|
54
|
+
'feat_dim': feat_dim,
|
55
|
+
'drop_ratio': drop_ratio,
|
56
|
+
'pool': pool
|
57
|
+
}
|
58
|
+
self.num_layer = num_layer
|
59
|
+
self.emb_dim = emb_dim
|
60
|
+
self.feat_dim = feat_dim
|
61
|
+
self.drop_ratio = drop_ratio
|
62
|
+
|
63
|
+
self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
|
64
|
+
self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
|
65
|
+
nn.init.xavier_uniform_(self.x_embedding1.weight.data)
|
66
|
+
nn.init.xavier_uniform_(self.x_embedding2.weight.data)
|
67
|
+
|
68
|
+
# List of MLPs
|
69
|
+
self.gnns = nn.ModuleList()
|
70
|
+
for layer in range(num_layer):
|
71
|
+
self.gnns.append(GINEConv(emb_dim))
|
72
|
+
|
73
|
+
# List of batchnorms
|
74
|
+
self.batch_norms = nn.ModuleList()
|
75
|
+
for layer in range(num_layer):
|
76
|
+
self.batch_norms.append(nn.BatchNorm1d(emb_dim))
|
77
|
+
|
78
|
+
if pool == 'mean':
|
79
|
+
self.pool = global_mean_pool
|
80
|
+
elif pool == 'max':
|
81
|
+
self.pool = global_max_pool
|
82
|
+
elif pool == 'add':
|
83
|
+
self.pool = global_add_pool
|
84
|
+
|
85
|
+
self.feat_lin = nn.Linear(
|
86
|
+
self.emb_dim,
|
87
|
+
self.feat_dim
|
88
|
+
)
|
89
|
+
|
90
|
+
self.out_lin = nn.Sequential(
|
91
|
+
nn.Linear(self.feat_dim, self.feat_dim),
|
92
|
+
nn.ReLU(inplace=True),
|
93
|
+
nn.Linear(
|
94
|
+
self.feat_dim,
|
95
|
+
self.feat_dim // 2
|
96
|
+
)
|
97
|
+
)
|
98
|
+
|
99
|
+
def forward(self, data):
|
100
|
+
x = data.x
|
101
|
+
edge_index = data.edge_index
|
102
|
+
edge_attr = data.edge_attr
|
103
|
+
|
104
|
+
h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
|
105
|
+
|
106
|
+
for layer in range(self.num_layer):
|
107
|
+
h = self.gnns[layer](h, edge_index, edge_attr)
|
108
|
+
h = self.batch_norms[layer](h)
|
109
|
+
if layer == self.num_layer - 1:
|
110
|
+
h = F.dropout(h, self.drop_ratio, training=self.training)
|
111
|
+
else:
|
112
|
+
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
|
113
|
+
|
114
|
+
h = self.pool(h, data.batch)
|
115
|
+
h = self.feat_lin(h)
|
116
|
+
out = self.out_lin(h)
|
117
|
+
|
118
|
+
return h, out
|
119
|
+
|
120
|
+
|
121
|
+
class GINMLPR(nn.Module):
|
122
|
+
def __init__(
|
123
|
+
self,
|
124
|
+
num_layer=5,
|
125
|
+
emb_dim=300,
|
126
|
+
feat_dim=512,
|
127
|
+
out_dim=1,
|
128
|
+
drop_ratio=0,
|
129
|
+
pool='mean',
|
130
|
+
ckpt_file: str = None,
|
131
|
+
num_smiles: int = 1,
|
132
|
+
) -> None:
|
133
|
+
super().__init__()
|
134
|
+
self.init_args = {
|
135
|
+
"num_layer": num_layer,
|
136
|
+
"emb_dim": emb_dim,
|
137
|
+
"feat_dim": feat_dim,
|
138
|
+
"out_dim": out_dim,
|
139
|
+
"drop_ratio": drop_ratio,
|
140
|
+
"pool": pool,
|
141
|
+
"ckpt_file": ckpt_file,
|
142
|
+
"num_smiles": num_smiles
|
143
|
+
}
|
144
|
+
self.gins = nn.ModuleList([])
|
145
|
+
for _ in range(num_smiles):
|
146
|
+
self.gins.append(
|
147
|
+
GINet(
|
148
|
+
num_layer=num_layer,
|
149
|
+
emb_dim=emb_dim,
|
150
|
+
feat_dim=feat_dim,
|
151
|
+
drop_ratio=drop_ratio,
|
152
|
+
pool=pool,
|
153
|
+
)
|
154
|
+
)
|
155
|
+
self.ckpt_file = ckpt_file
|
156
|
+
self.num_smiles = num_smiles
|
157
|
+
|
158
|
+
self.ffn = BNReLULinearBlock(
|
159
|
+
in_features=feat_dim // 2 * num_smiles,
|
160
|
+
out_features=out_dim,
|
161
|
+
num_layers=num_layer,
|
162
|
+
hidden_size=feat_dim // 2
|
163
|
+
)
|
164
|
+
self.out_act = get_activation('sigmoid')
|
165
|
+
|
166
|
+
if ckpt_file is not None:
|
167
|
+
self.load_ckpt()
|
168
|
+
|
169
|
+
def load_ckpt(self):
|
170
|
+
if self.ckpt_file is not None:
|
171
|
+
for i in range(self.num_smiles):
|
172
|
+
load_model(
|
173
|
+
self.ckpt_file,
|
174
|
+
model=self.gins[i]
|
175
|
+
)
|
176
|
+
|
177
|
+
def forward(
|
178
|
+
self,
|
179
|
+
data
|
180
|
+
):
|
181
|
+
out_list = []
|
182
|
+
for data_i, gin in zip(data, self.gins):
|
183
|
+
out_list.append(gin(data_i[0])[1])
|
184
|
+
out = torch.hstack(out_list) # (batch_size, feat_dim//2 * num_smiles)
|
185
|
+
out = self.ffn(out)
|
186
|
+
out = self.out_act(out)
|
187
|
+
|
188
|
+
return out
|
189
|
+
|
hdl/models/linear.py
ADDED
@@ -0,0 +1,137 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from torch.nn import functional as nnfunc
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from hdl.layers.general.linear import (
|
9
|
+
BNReLULinearBlock,
|
10
|
+
BNReLULinear
|
11
|
+
)
|
12
|
+
# from hdl.ops.utils import get_activation
|
13
|
+
|
14
|
+
|
15
|
+
class MMIterLinear(nn.Module):
|
16
|
+
_NAME = 'mumc_linear'
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
num_fp_bits: int,
|
21
|
+
num_in_feats: int,
|
22
|
+
nums_classes: t.List[int] = [3, 3],
|
23
|
+
target_names: t.List[str] = None,
|
24
|
+
hidden_size: int = 128,
|
25
|
+
num_hidden_layers: int = 10,
|
26
|
+
activation: str = 'elu',
|
27
|
+
out_act: str = 'softmax',
|
28
|
+
hard_select: bool = False,
|
29
|
+
iterative: bool = True,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
if target_names is None:
|
35
|
+
self.target_names = list(range(len(nums_classes)))
|
36
|
+
else:
|
37
|
+
self.target_names = target_names
|
38
|
+
|
39
|
+
self.init_args = {
|
40
|
+
'num_fp_bits': num_fp_bits,
|
41
|
+
'num_in_feats': num_in_feats,
|
42
|
+
'nums_classes': nums_classes,
|
43
|
+
'target_names': target_names,
|
44
|
+
'hidden_size': hidden_size,
|
45
|
+
'num_hidden_layers': num_hidden_layers,
|
46
|
+
'activation': activation,
|
47
|
+
'out_act': out_act,
|
48
|
+
'hard_select': hard_select,
|
49
|
+
'iterative': iterative,
|
50
|
+
**kwargs
|
51
|
+
}
|
52
|
+
self.hard_select = hard_select
|
53
|
+
self.iterative = iterative
|
54
|
+
self._freeze_classifier = [True] * len(target_names)
|
55
|
+
|
56
|
+
# self.w1 = BNReLULinear(num_fp_bits, num_in_feats, activation)
|
57
|
+
self.w1 = nn.Linear(num_fp_bits, num_in_feats)
|
58
|
+
# self.w2 = BNReLULinear(num_fp_bits, num_in_feats, activation)
|
59
|
+
self.w2 = nn.Linear(num_fp_bits, num_in_feats)
|
60
|
+
# self.w3 = BNReLULinear(num_fp_bits, num_in_feats, activation)
|
61
|
+
self.w3 = nn.Linear(num_fp_bits, num_in_feats)
|
62
|
+
|
63
|
+
nums_in_feats = [num_in_feats]
|
64
|
+
if iterative:
|
65
|
+
nums_in_feats.extend(nums_classes)
|
66
|
+
nums_in_feats = np.cumsum(np.array(nums_in_feats, dtype=np.int))[:-1]
|
67
|
+
else:
|
68
|
+
nums_in_feats = nums_in_feats * len(nums_classes)
|
69
|
+
|
70
|
+
if isinstance(out_act, str):
|
71
|
+
self.out_acts = [out_act] * len(nums_classes)
|
72
|
+
else:
|
73
|
+
self.out_acts = out_act
|
74
|
+
|
75
|
+
self.classifiers = nn.ModuleList([
|
76
|
+
nn.Sequential(
|
77
|
+
BNReLULinearBlock(
|
78
|
+
num_in,
|
79
|
+
hidden_size,
|
80
|
+
num_hidden_layers,
|
81
|
+
hidden_size,
|
82
|
+
activation,
|
83
|
+
**kwargs
|
84
|
+
),
|
85
|
+
BNReLULinear(
|
86
|
+
hidden_size,
|
87
|
+
num_out,
|
88
|
+
out_act,
|
89
|
+
**kwargs
|
90
|
+
)
|
91
|
+
)
|
92
|
+
for num_in, num_out, out_act in zip(
|
93
|
+
nums_in_feats, nums_classes, self.out_acts
|
94
|
+
)
|
95
|
+
])
|
96
|
+
|
97
|
+
@property
|
98
|
+
def freeze_classifier(self):
|
99
|
+
return self._freeze_classifier
|
100
|
+
|
101
|
+
@freeze_classifier.setter
|
102
|
+
def freeze_classifier(self, freeze: t.List = []):
|
103
|
+
self._freeze_classifier = freeze
|
104
|
+
self.change_classifier_grad([not f for f in freeze])
|
105
|
+
|
106
|
+
def change_classifier_grad(self, requires_grads: t.List = []):
|
107
|
+
for requires_grad, classifier in zip(requires_grads, self.classifiers):
|
108
|
+
for param in classifier.parameters():
|
109
|
+
param.requires_grad = requires_grad
|
110
|
+
|
111
|
+
def forward(self, fps, target_tensors=None, teach=True):
|
112
|
+
result_dict = {}
|
113
|
+
fp1, fp2, fp3 = fps
|
114
|
+
fp1 = self.w1(fp1)
|
115
|
+
fp2 = self.w2(fp2)
|
116
|
+
fp3 = self.w3(fp3)
|
117
|
+
X = fp3 - (fp1 + fp2)
|
118
|
+
if target_tensors is None:
|
119
|
+
target_tensors = [None] * len(self.target_names)
|
120
|
+
for classifier, target_name, target_tensor in zip(
|
121
|
+
self.classifiers, self.target_names, target_tensors
|
122
|
+
):
|
123
|
+
result = classifier(X)
|
124
|
+
result_dict[target_name] = result
|
125
|
+
if self.iterative:
|
126
|
+
if teach:
|
127
|
+
assert target_tensors is not None
|
128
|
+
X = torch.cat((X, target_tensor), -1)
|
129
|
+
else:
|
130
|
+
if not self.hard_select:
|
131
|
+
X = torch.cat((X, result), -1)
|
132
|
+
else:
|
133
|
+
X = torch.cat(
|
134
|
+
(X, nnfunc.gumbel_softmax(result, tau=1, hard=True)),
|
135
|
+
-1
|
136
|
+
)
|
137
|
+
return result_dict
|
hdl/models/model_dict.py
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
from hdl.layers.general.linear import (
|
2
|
+
MultiTaskMultiClassBlock,
|
3
|
+
MuMcHardBlock
|
4
|
+
)
|
5
|
+
from .linear import MMIterLinear
|
6
|
+
from .chiral_gnn import GNN
|
7
|
+
from .ginet import GINet
|
8
|
+
from .ginet import GINMLPR
|
9
|
+
|
10
|
+
|
11
|
+
MODEL_DICT = {
|
12
|
+
'rxn_trans': MultiTaskMultiClassBlock,
|
13
|
+
'rxn_trans_hard': MuMcHardBlock,
|
14
|
+
'mmiter_linear': MMIterLinear,
|
15
|
+
'chiral_gnn': GNN,
|
16
|
+
'ginet': GINet,
|
17
|
+
'ginmlpr': GINMLPR
|
18
|
+
}
|
hdl/models/norm_flows.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
|
5
|
+
class NormalizingFlowModel(nn.Module):
|
6
|
+
|
7
|
+
def __init__(self, prior, flows):
|
8
|
+
super().__init__()
|
9
|
+
self.prior = prior
|
10
|
+
self.flows = nn.ModuleList(flows)
|
11
|
+
|
12
|
+
def forward(self, x):
|
13
|
+
m, _ = x.shape
|
14
|
+
log_det = torch.zeros(m)
|
15
|
+
for flow in self.flows:
|
16
|
+
x, ld = flow.forward(x)
|
17
|
+
log_det += ld
|
18
|
+
z, prior_logprob = x, self.prior.log_prob(x)
|
19
|
+
return z, prior_logprob, log_det
|
20
|
+
|
21
|
+
def inverse(self, z):
|
22
|
+
m, _ = z.shape
|
23
|
+
log_det = torch.zeros(m)
|
24
|
+
for flow in self.flows[::-1]:
|
25
|
+
z, ld = flow.inverse(z)
|
26
|
+
log_det += ld
|
27
|
+
x = z
|
28
|
+
return x, log_det
|
29
|
+
|
30
|
+
def sample(self, n_samples):
|
31
|
+
z = self.prior.sample((n_samples,))
|
32
|
+
x, _ = self.inverse(z)
|
33
|
+
return x
|
hdl/models/optim_dict.py
ADDED
hdl/models/rxn.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
import pkg_resources
|
2
|
+
from transformers import BertModel
|
3
|
+
import torch
|
4
|
+
# from torch import nn
|
5
|
+
|
6
|
+
from hdl.layers.general.linear import (
|
7
|
+
MultiTaskMultiClassBlock,
|
8
|
+
MuMcHardBlock
|
9
|
+
)
|
10
|
+
# from hdl.data.seq.rxn import rxn_model
|
11
|
+
|
12
|
+
|
13
|
+
def get_rxn_model(
|
14
|
+
model_path: str = None
|
15
|
+
):
|
16
|
+
if model_path is None:
|
17
|
+
model_path = pkg_resources.resource_filename(
|
18
|
+
"rxnfp",
|
19
|
+
"models/transformers/bert_ft"
|
20
|
+
)
|
21
|
+
model = BertModel.from_pretrained(model_path)
|
22
|
+
model = model.eval().cpu()
|
23
|
+
|
24
|
+
return model
|
25
|
+
|
26
|
+
|
27
|
+
def build_rxn_mu(
|
28
|
+
nums_classes,
|
29
|
+
hard=False,
|
30
|
+
hidden_size=128,
|
31
|
+
nums_hidden_layers=10,
|
32
|
+
encoder=get_rxn_model(),
|
33
|
+
# freeze_encoder=True,
|
34
|
+
device_id: int = 0,
|
35
|
+
**kwargs
|
36
|
+
):
|
37
|
+
if not hard:
|
38
|
+
model = MultiTaskMultiClassBlock(
|
39
|
+
encoder=encoder,
|
40
|
+
nums_classes=nums_classes,
|
41
|
+
hidden_size=hidden_size,
|
42
|
+
num_hidden_layers=nums_hidden_layers,
|
43
|
+
# freeze_encoder=freeze_encoder,
|
44
|
+
**kwargs
|
45
|
+
)
|
46
|
+
else:
|
47
|
+
model = MuMcHardBlock(
|
48
|
+
encoder=encoder,
|
49
|
+
nums_classes=nums_classes,
|
50
|
+
hidden_size=hidden_size,
|
51
|
+
num_hidden_layers=nums_hidden_layers,
|
52
|
+
# freeze_encoder=freeze_encoder,
|
53
|
+
**kwargs
|
54
|
+
)
|
55
|
+
device = torch.device(f'cuda:{device_id}') \
|
56
|
+
if torch.cuda.is_available() \
|
57
|
+
else torch.device('cpu')
|
58
|
+
|
59
|
+
model = model.to(device)
|
60
|
+
|
61
|
+
# if torch.cuda.device_count() > 1:
|
62
|
+
# model = nn.DataParallel(model)
|
63
|
+
return model, device
|