hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,234 @@
1
+ import torch
2
+ # import torch.nn.functional as F
3
+ from torch import nn, einsum
4
+
5
+ from einops import rearrange, reduce
6
+ from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding
7
+
8
+ # helper functions
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def default(val, d):
16
+ return val if exists(val) else d
17
+
18
+
19
+ # helper classes
20
+ class PreNorm(nn.Module):
21
+ def __init__(self, dim, fn):
22
+ super().__init__()
23
+ self.norm = nn.LayerNorm(dim)
24
+ self.fn = fn
25
+
26
+ def forward(self, x, **kwargs):
27
+ x = self.norm(x)
28
+ return self.fn(x, **kwargs)
29
+
30
+ # blocks
31
+
32
+
33
+ def FeedForward(dim, mult=4):
34
+ return nn.Sequential(
35
+ nn.Linear(dim, dim * mult),
36
+ nn.GELU(),
37
+ nn.Linear(dim * mult, dim)
38
+ )
39
+
40
+
41
+ class FastAttention(nn.Module):
42
+ def __init__(
43
+ self,
44
+ dim,
45
+ *,
46
+ heads=8,
47
+ dim_head=64,
48
+ max_seq_len=None,
49
+ pos_emb=None
50
+ ):
51
+ super().__init__()
52
+ inner_dim = heads * dim_head
53
+ self.heads = heads
54
+ self.scale = dim_head ** -0.5
55
+
56
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
57
+
58
+ # rotary positional embedding
59
+
60
+ assert not (exists(pos_emb) and not exists(max_seq_len)), \
61
+ 'max_seq_len must be passed in if to use rotary positional embeddings'
62
+
63
+ self.pos_emb = pos_emb
64
+ self.max_seq_len = max_seq_len
65
+
66
+ # if using relative positional encoding, make sure to reduce pairs of
67
+ # consecutive feature dimension before doing projection to attention logits
68
+
69
+ kv_attn_proj_divisor = 1 if not exists(pos_emb) else 2
70
+
71
+ # for projecting queries to query attention logits
72
+ self.to_q_attn_logits = nn.Linear(dim_head, 1, bias=False)
73
+ self.to_k_attn_logits = nn.Linear(
74
+ dim_head // kv_attn_proj_divisor,
75
+ 1,
76
+ bias=False
77
+ ) # for projecting keys to key attention logits
78
+
79
+ # final transformation of values to "r" as in the paper
80
+
81
+ self.to_r = nn.Linear(dim_head // kv_attn_proj_divisor, dim_head)
82
+
83
+ self.to_out = nn.Linear(inner_dim, dim)
84
+
85
+ def forward(self, x, mask=None):
86
+ n, device, h, use_rotary_emb = x.shape[1], x.device, self.heads, exists(self.pos_emb)
87
+
88
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
89
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
90
+
91
+ mask_value = -torch.finfo(x.dtype).max
92
+ mask = rearrange(mask, 'b n -> b () n')
93
+
94
+ # if relative positional encoding is needed
95
+
96
+ if use_rotary_emb:
97
+ freqs = self.pos_emb(torch.arange(self.max_seq_len, device=device), cache_key=self.max_seq_len)
98
+ freqs = rearrange(freqs[:n], 'n d -> () () n d')
99
+ q_aggr, k_aggr, v_aggr = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v))
100
+ else:
101
+ q_aggr, k_aggr, v_aggr = q, k, v
102
+
103
+ # calculate query attention logits
104
+
105
+ q_attn_logits = rearrange(self.to_q_attn_logits(q), 'b h n () -> b h n') * self.scale
106
+ q_attn_logits = q_attn_logits.masked_fill(~mask, mask_value)
107
+ q_attn = q_attn_logits.softmax(dim=-1)
108
+
109
+ # calculate global query token
110
+
111
+ global_q = einsum('b h n, b h n d -> b h d', q_attn, q_aggr)
112
+ global_q = rearrange(global_q, 'b h d -> b h () d')
113
+
114
+ # bias keys with global query token
115
+
116
+ k = k * global_q
117
+
118
+ # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
119
+
120
+ if use_rotary_emb:
121
+ k = reduce(k, 'b h n (d r) -> b h n d', 'sum', r=2)
122
+
123
+ # now calculate key attention logits
124
+
125
+ k_attn_logits = rearrange(self.to_k_attn_logits(k), 'b h n () -> b h n') * self.scale
126
+ k_attn_logits = k_attn_logits.masked_fill(~mask, mask_value)
127
+ k_attn = k_attn_logits.softmax(dim=-1)
128
+
129
+ # calculate global key token
130
+
131
+ global_k = einsum('b h n, b h n d -> b h d', k_attn, k_aggr)
132
+ global_k = rearrange(global_k, 'b h d -> b h () d')
133
+
134
+ # bias the values
135
+
136
+ u = v_aggr * global_k
137
+
138
+ # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
139
+
140
+ if use_rotary_emb:
141
+ u = reduce(u, 'b h n (d r) -> b h n d', 'sum', r=2)
142
+
143
+ # transformation step
144
+
145
+ r = self.to_r(u)
146
+
147
+ # paper then says to add the queries as a residual
148
+
149
+ r = r + q
150
+
151
+ # combine heads
152
+
153
+ r = rearrange(r, 'b h n d -> b n (h d)')
154
+ return self.to_out(r)
155
+
156
+
157
+ # main class
158
+ class FastTransformer(nn.Module):
159
+ def __init__(
160
+ self,
161
+ *,
162
+ num_tokens,
163
+ dim,
164
+ depth,
165
+ max_seq_len,
166
+ heads=8,
167
+ dim_head=64,
168
+ ff_mult=4,
169
+ absolute_pos_emb=False
170
+ ):
171
+ super().__init__()
172
+ self.token_emb = nn.Embedding(num_tokens, dim)
173
+
174
+ # positional embeddings
175
+
176
+ self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if absolute_pos_emb else None
177
+
178
+ layer_pos_emb = None
179
+ if not absolute_pos_emb:
180
+ assert (dim_head % 4) == 0, 'dimension of the head must be divisible by 4 to use rotary embeddings'
181
+ layer_pos_emb = RotaryEmbedding(dim_head // 2)
182
+
183
+ # layers
184
+
185
+ self.layers = nn.ModuleList([])
186
+
187
+ for _ in range(depth):
188
+ attn = FastAttention(
189
+ dim,
190
+ dim_head=dim_head,
191
+ heads=heads,
192
+ pos_emb=layer_pos_emb,
193
+ max_seq_len=max_seq_len
194
+ )
195
+ ff = FeedForward(dim, mult=ff_mult)
196
+
197
+ self.layers.append(nn.ModuleList([
198
+ PreNorm(dim, attn),
199
+ PreNorm(dim, ff)
200
+ ]))
201
+
202
+ # weight tie projections across all layers
203
+
204
+ first_block, _ = self.layers[0]
205
+ for block, _ in self.layers[1:]:
206
+ block.fn.to_q_attn_logits = first_block.fn.to_q_attn_logits
207
+ block.fn.to_k_attn_logits = first_block.fn.to_k_attn_logits
208
+
209
+ # to logits
210
+
211
+ self.to_logits = nn.Sequential(
212
+ nn.LayerNorm(dim),
213
+ nn.Linear(dim, num_tokens)
214
+ )
215
+
216
+ def forward(
217
+ self,
218
+ x,
219
+ mask=None
220
+ ):
221
+ n, device = x.shape[1], x.device
222
+ if mask is None:
223
+ mask = torch.ones_like(x).bool().to(device)
224
+ x = self.token_emb(x)
225
+
226
+ if exists(self.abs_pos_emb):
227
+ pos_emb = self.abs_pos_emb(torch.arange(n, device=device))
228
+ x = x + rearrange(pos_emb, 'n d -> () n d')
229
+
230
+ for attn, ff in self.layers:
231
+ x = attn(x, mask=mask) + x
232
+ x = ff(x) + x
233
+
234
+ return self.to_logits(x)
hdl/models/ginet.py ADDED
@@ -0,0 +1,189 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ # from torch_geometric.nn import MessagePassing
6
+ # from torch_geometric.utils import add_self_loops
7
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
8
+
9
+ from hdl.layers.graph.gin import GINEConv
10
+ from hdl.layers.general.linear import (
11
+ # BNReLULinear,
12
+ BNReLULinearBlock,
13
+ )
14
+ from hdl.models.utils import load_model
15
+ from hdl.ops.utils import get_activation
16
+
17
+
18
+ __all__ = [
19
+ "GINet",
20
+ "GINMLPR",
21
+ ]
22
+
23
+
24
+ num_atom_type = 119 # including the extra mask tokens
25
+ num_chirality_tag = 3
26
+
27
+ num_bond_type = 5 # including aromatic and self-loop edge
28
+ num_bond_direction = 3
29
+
30
+
31
+ class GINet(nn.Module):
32
+ """
33
+ Args:
34
+ num_layer (int): the number of GNN layers
35
+ emb_dim (int): dimensionality of embeddings
36
+ max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
37
+ drop_ratio (float): dropout rate
38
+ gnn_type: gin, gcn, graphsage, gat
39
+ Output:
40
+ node representations
41
+ """
42
+ def __init__(
43
+ self,
44
+ num_layer=5,
45
+ emb_dim=300,
46
+ feat_dim=512,
47
+ drop_ratio=0,
48
+ pool='mean'
49
+ ):
50
+ super(GINet, self).__init__()
51
+ self.init_args = {
52
+ 'num_layer': num_layer,
53
+ 'emb_dim': emb_dim,
54
+ 'feat_dim': feat_dim,
55
+ 'drop_ratio': drop_ratio,
56
+ 'pool': pool
57
+ }
58
+ self.num_layer = num_layer
59
+ self.emb_dim = emb_dim
60
+ self.feat_dim = feat_dim
61
+ self.drop_ratio = drop_ratio
62
+
63
+ self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
64
+ self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
65
+ nn.init.xavier_uniform_(self.x_embedding1.weight.data)
66
+ nn.init.xavier_uniform_(self.x_embedding2.weight.data)
67
+
68
+ # List of MLPs
69
+ self.gnns = nn.ModuleList()
70
+ for layer in range(num_layer):
71
+ self.gnns.append(GINEConv(emb_dim))
72
+
73
+ # List of batchnorms
74
+ self.batch_norms = nn.ModuleList()
75
+ for layer in range(num_layer):
76
+ self.batch_norms.append(nn.BatchNorm1d(emb_dim))
77
+
78
+ if pool == 'mean':
79
+ self.pool = global_mean_pool
80
+ elif pool == 'max':
81
+ self.pool = global_max_pool
82
+ elif pool == 'add':
83
+ self.pool = global_add_pool
84
+
85
+ self.feat_lin = nn.Linear(
86
+ self.emb_dim,
87
+ self.feat_dim
88
+ )
89
+
90
+ self.out_lin = nn.Sequential(
91
+ nn.Linear(self.feat_dim, self.feat_dim),
92
+ nn.ReLU(inplace=True),
93
+ nn.Linear(
94
+ self.feat_dim,
95
+ self.feat_dim // 2
96
+ )
97
+ )
98
+
99
+ def forward(self, data):
100
+ x = data.x
101
+ edge_index = data.edge_index
102
+ edge_attr = data.edge_attr
103
+
104
+ h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
105
+
106
+ for layer in range(self.num_layer):
107
+ h = self.gnns[layer](h, edge_index, edge_attr)
108
+ h = self.batch_norms[layer](h)
109
+ if layer == self.num_layer - 1:
110
+ h = F.dropout(h, self.drop_ratio, training=self.training)
111
+ else:
112
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
113
+
114
+ h = self.pool(h, data.batch)
115
+ h = self.feat_lin(h)
116
+ out = self.out_lin(h)
117
+
118
+ return h, out
119
+
120
+
121
+ class GINMLPR(nn.Module):
122
+ def __init__(
123
+ self,
124
+ num_layer=5,
125
+ emb_dim=300,
126
+ feat_dim=512,
127
+ out_dim=1,
128
+ drop_ratio=0,
129
+ pool='mean',
130
+ ckpt_file: str = None,
131
+ num_smiles: int = 1,
132
+ ) -> None:
133
+ super().__init__()
134
+ self.init_args = {
135
+ "num_layer": num_layer,
136
+ "emb_dim": emb_dim,
137
+ "feat_dim": feat_dim,
138
+ "out_dim": out_dim,
139
+ "drop_ratio": drop_ratio,
140
+ "pool": pool,
141
+ "ckpt_file": ckpt_file,
142
+ "num_smiles": num_smiles
143
+ }
144
+ self.gins = nn.ModuleList([])
145
+ for _ in range(num_smiles):
146
+ self.gins.append(
147
+ GINet(
148
+ num_layer=num_layer,
149
+ emb_dim=emb_dim,
150
+ feat_dim=feat_dim,
151
+ drop_ratio=drop_ratio,
152
+ pool=pool,
153
+ )
154
+ )
155
+ self.ckpt_file = ckpt_file
156
+ self.num_smiles = num_smiles
157
+
158
+ self.ffn = BNReLULinearBlock(
159
+ in_features=feat_dim // 2 * num_smiles,
160
+ out_features=out_dim,
161
+ num_layers=num_layer,
162
+ hidden_size=feat_dim // 2
163
+ )
164
+ self.out_act = get_activation('sigmoid')
165
+
166
+ if ckpt_file is not None:
167
+ self.load_ckpt()
168
+
169
+ def load_ckpt(self):
170
+ if self.ckpt_file is not None:
171
+ for i in range(self.num_smiles):
172
+ load_model(
173
+ self.ckpt_file,
174
+ model=self.gins[i]
175
+ )
176
+
177
+ def forward(
178
+ self,
179
+ data
180
+ ):
181
+ out_list = []
182
+ for data_i, gin in zip(data, self.gins):
183
+ out_list.append(gin(data_i[0])[1])
184
+ out = torch.hstack(out_list) # (batch_size, feat_dim//2 * num_smiles)
185
+ out = self.ffn(out)
186
+ out = self.out_act(out)
187
+
188
+ return out
189
+
hdl/models/linear.py ADDED
@@ -0,0 +1,137 @@
1
+ import typing as t
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as nnfunc
6
+ import numpy as np
7
+
8
+ from hdl.layers.general.linear import (
9
+ BNReLULinearBlock,
10
+ BNReLULinear
11
+ )
12
+ # from hdl.ops.utils import get_activation
13
+
14
+
15
+ class MMIterLinear(nn.Module):
16
+ _NAME = 'mumc_linear'
17
+
18
+ def __init__(
19
+ self,
20
+ num_fp_bits: int,
21
+ num_in_feats: int,
22
+ nums_classes: t.List[int] = [3, 3],
23
+ target_names: t.List[str] = None,
24
+ hidden_size: int = 128,
25
+ num_hidden_layers: int = 10,
26
+ activation: str = 'elu',
27
+ out_act: str = 'softmax',
28
+ hard_select: bool = False,
29
+ iterative: bool = True,
30
+ **kwargs,
31
+ ):
32
+ super().__init__()
33
+
34
+ if target_names is None:
35
+ self.target_names = list(range(len(nums_classes)))
36
+ else:
37
+ self.target_names = target_names
38
+
39
+ self.init_args = {
40
+ 'num_fp_bits': num_fp_bits,
41
+ 'num_in_feats': num_in_feats,
42
+ 'nums_classes': nums_classes,
43
+ 'target_names': target_names,
44
+ 'hidden_size': hidden_size,
45
+ 'num_hidden_layers': num_hidden_layers,
46
+ 'activation': activation,
47
+ 'out_act': out_act,
48
+ 'hard_select': hard_select,
49
+ 'iterative': iterative,
50
+ **kwargs
51
+ }
52
+ self.hard_select = hard_select
53
+ self.iterative = iterative
54
+ self._freeze_classifier = [True] * len(target_names)
55
+
56
+ # self.w1 = BNReLULinear(num_fp_bits, num_in_feats, activation)
57
+ self.w1 = nn.Linear(num_fp_bits, num_in_feats)
58
+ # self.w2 = BNReLULinear(num_fp_bits, num_in_feats, activation)
59
+ self.w2 = nn.Linear(num_fp_bits, num_in_feats)
60
+ # self.w3 = BNReLULinear(num_fp_bits, num_in_feats, activation)
61
+ self.w3 = nn.Linear(num_fp_bits, num_in_feats)
62
+
63
+ nums_in_feats = [num_in_feats]
64
+ if iterative:
65
+ nums_in_feats.extend(nums_classes)
66
+ nums_in_feats = np.cumsum(np.array(nums_in_feats, dtype=np.int))[:-1]
67
+ else:
68
+ nums_in_feats = nums_in_feats * len(nums_classes)
69
+
70
+ if isinstance(out_act, str):
71
+ self.out_acts = [out_act] * len(nums_classes)
72
+ else:
73
+ self.out_acts = out_act
74
+
75
+ self.classifiers = nn.ModuleList([
76
+ nn.Sequential(
77
+ BNReLULinearBlock(
78
+ num_in,
79
+ hidden_size,
80
+ num_hidden_layers,
81
+ hidden_size,
82
+ activation,
83
+ **kwargs
84
+ ),
85
+ BNReLULinear(
86
+ hidden_size,
87
+ num_out,
88
+ out_act,
89
+ **kwargs
90
+ )
91
+ )
92
+ for num_in, num_out, out_act in zip(
93
+ nums_in_feats, nums_classes, self.out_acts
94
+ )
95
+ ])
96
+
97
+ @property
98
+ def freeze_classifier(self):
99
+ return self._freeze_classifier
100
+
101
+ @freeze_classifier.setter
102
+ def freeze_classifier(self, freeze: t.List = []):
103
+ self._freeze_classifier = freeze
104
+ self.change_classifier_grad([not f for f in freeze])
105
+
106
+ def change_classifier_grad(self, requires_grads: t.List = []):
107
+ for requires_grad, classifier in zip(requires_grads, self.classifiers):
108
+ for param in classifier.parameters():
109
+ param.requires_grad = requires_grad
110
+
111
+ def forward(self, fps, target_tensors=None, teach=True):
112
+ result_dict = {}
113
+ fp1, fp2, fp3 = fps
114
+ fp1 = self.w1(fp1)
115
+ fp2 = self.w2(fp2)
116
+ fp3 = self.w3(fp3)
117
+ X = fp3 - (fp1 + fp2)
118
+ if target_tensors is None:
119
+ target_tensors = [None] * len(self.target_names)
120
+ for classifier, target_name, target_tensor in zip(
121
+ self.classifiers, self.target_names, target_tensors
122
+ ):
123
+ result = classifier(X)
124
+ result_dict[target_name] = result
125
+ if self.iterative:
126
+ if teach:
127
+ assert target_tensors is not None
128
+ X = torch.cat((X, target_tensor), -1)
129
+ else:
130
+ if not self.hard_select:
131
+ X = torch.cat((X, result), -1)
132
+ else:
133
+ X = torch.cat(
134
+ (X, nnfunc.gumbel_softmax(result, tau=1, hard=True)),
135
+ -1
136
+ )
137
+ return result_dict
@@ -0,0 +1,18 @@
1
+ from hdl.layers.general.linear import (
2
+ MultiTaskMultiClassBlock,
3
+ MuMcHardBlock
4
+ )
5
+ from .linear import MMIterLinear
6
+ from .chiral_gnn import GNN
7
+ from .ginet import GINet
8
+ from .ginet import GINMLPR
9
+
10
+
11
+ MODEL_DICT = {
12
+ 'rxn_trans': MultiTaskMultiClassBlock,
13
+ 'rxn_trans_hard': MuMcHardBlock,
14
+ 'mmiter_linear': MMIterLinear,
15
+ 'chiral_gnn': GNN,
16
+ 'ginet': GINet,
17
+ 'ginmlpr': GINMLPR
18
+ }
@@ -0,0 +1,33 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class NormalizingFlowModel(nn.Module):
6
+
7
+ def __init__(self, prior, flows):
8
+ super().__init__()
9
+ self.prior = prior
10
+ self.flows = nn.ModuleList(flows)
11
+
12
+ def forward(self, x):
13
+ m, _ = x.shape
14
+ log_det = torch.zeros(m)
15
+ for flow in self.flows:
16
+ x, ld = flow.forward(x)
17
+ log_det += ld
18
+ z, prior_logprob = x, self.prior.log_prob(x)
19
+ return z, prior_logprob, log_det
20
+
21
+ def inverse(self, z):
22
+ m, _ = z.shape
23
+ log_det = torch.zeros(m)
24
+ for flow in self.flows[::-1]:
25
+ z, ld = flow.inverse(z)
26
+ log_det += ld
27
+ x = z
28
+ return x, log_det
29
+
30
+ def sample(self, n_samples):
31
+ z = self.prior.sample((n_samples,))
32
+ x, _ = self.inverse(z)
33
+ return x
@@ -0,0 +1,16 @@
1
+ from torch.optim import (
2
+ Adadelta,
3
+ Adam,
4
+ SGD,
5
+ RMSprop,
6
+ )
7
+ from hdl.optims.nadam import Nadam
8
+
9
+
10
+ OPTIM_DICT = {
11
+ 'adam': Adam,
12
+ 'adadelta': Adadelta,
13
+ 'sgd': SGD,
14
+ 'rmsprop': RMSprop,
15
+ 'nadam': Nadam,
16
+ }
hdl/models/rxn.py ADDED
@@ -0,0 +1,63 @@
1
+ import pkg_resources
2
+ from transformers import BertModel
3
+ import torch
4
+ # from torch import nn
5
+
6
+ from hdl.layers.general.linear import (
7
+ MultiTaskMultiClassBlock,
8
+ MuMcHardBlock
9
+ )
10
+ # from hdl.data.seq.rxn import rxn_model
11
+
12
+
13
+ def get_rxn_model(
14
+ model_path: str = None
15
+ ):
16
+ if model_path is None:
17
+ model_path = pkg_resources.resource_filename(
18
+ "rxnfp",
19
+ "models/transformers/bert_ft"
20
+ )
21
+ model = BertModel.from_pretrained(model_path)
22
+ model = model.eval().cpu()
23
+
24
+ return model
25
+
26
+
27
+ def build_rxn_mu(
28
+ nums_classes,
29
+ hard=False,
30
+ hidden_size=128,
31
+ nums_hidden_layers=10,
32
+ encoder=get_rxn_model(),
33
+ # freeze_encoder=True,
34
+ device_id: int = 0,
35
+ **kwargs
36
+ ):
37
+ if not hard:
38
+ model = MultiTaskMultiClassBlock(
39
+ encoder=encoder,
40
+ nums_classes=nums_classes,
41
+ hidden_size=hidden_size,
42
+ num_hidden_layers=nums_hidden_layers,
43
+ # freeze_encoder=freeze_encoder,
44
+ **kwargs
45
+ )
46
+ else:
47
+ model = MuMcHardBlock(
48
+ encoder=encoder,
49
+ nums_classes=nums_classes,
50
+ hidden_size=hidden_size,
51
+ num_hidden_layers=nums_hidden_layers,
52
+ # freeze_encoder=freeze_encoder,
53
+ **kwargs
54
+ )
55
+ device = torch.device(f'cuda:{device_id}') \
56
+ if torch.cuda.is_available() \
57
+ else torch.device('cpu')
58
+
59
+ model = model.to(device)
60
+
61
+ # if torch.cuda.device_count() > 1:
62
+ # model = nn.DataParallel(model)
63
+ return model, device