hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,641 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from torch.autograd import Function
|
6
|
+
import torch_scatter
|
7
|
+
from torch.utils import checkpoint as tuc
|
8
|
+
|
9
|
+
from hdl.ops.utils import get_activation
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"WeaveLayer",
|
13
|
+
"DenseNet",
|
14
|
+
"AvgPooling",
|
15
|
+
"SumPooling",
|
16
|
+
"CasualWeave",
|
17
|
+
"DenseLayer"
|
18
|
+
]
|
19
|
+
|
20
|
+
|
21
|
+
def _bn_function_factory(bn_module):
|
22
|
+
def bn_function(*inputs):
|
23
|
+
concated_features = torch.cat(inputs, -1)
|
24
|
+
bottleneck_output = bn_module(concated_features)
|
25
|
+
return bottleneck_output
|
26
|
+
|
27
|
+
return bn_function
|
28
|
+
|
29
|
+
|
30
|
+
class BNReLULinear(nn.Module):
|
31
|
+
"""
|
32
|
+
Linear layer with bn->relu->linear architecture
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
in_features: int,
|
38
|
+
out_features: int,
|
39
|
+
activation: str = 'elu',
|
40
|
+
**kwargs
|
41
|
+
):
|
42
|
+
"""
|
43
|
+
Args:
|
44
|
+
in_features (int):
|
45
|
+
The number of input features
|
46
|
+
out_features (int):
|
47
|
+
The number of output features
|
48
|
+
activation (str):
|
49
|
+
The type of activation unit to use in this module,
|
50
|
+
default to elu
|
51
|
+
"""
|
52
|
+
super(BNReLULinear, self).__init__()
|
53
|
+
self.bn_relu_linear = nn.Sequential(
|
54
|
+
nn.BatchNorm1d(in_features),
|
55
|
+
nn.Linear(
|
56
|
+
in_features,
|
57
|
+
out_features,
|
58
|
+
bias=False
|
59
|
+
),
|
60
|
+
get_activation(
|
61
|
+
activation,
|
62
|
+
inplace=True,
|
63
|
+
**kwargs
|
64
|
+
)
|
65
|
+
)
|
66
|
+
|
67
|
+
def forward(self, x):
|
68
|
+
"""The forward method"""
|
69
|
+
return self.bn_relu_linear(x)
|
70
|
+
|
71
|
+
|
72
|
+
class SelectAdd(Function):
|
73
|
+
"""
|
74
|
+
Implement the memory efficient version of `a + b.index_select(indices)`
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(self,
|
78
|
+
indices: torch.Tensor,
|
79
|
+
indices_a: torch.Tensor = None):
|
80
|
+
"""
|
81
|
+
Initializer
|
82
|
+
Args:
|
83
|
+
indices (torch.Tensor): The indices to select the object `b`
|
84
|
+
indices_a (torch.Tensor or None):
|
85
|
+
The indices to select the object `a`. Default to None
|
86
|
+
"""
|
87
|
+
self._indices = indices
|
88
|
+
self._indices_a = indices_a
|
89
|
+
|
90
|
+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
91
|
+
"""
|
92
|
+
The forward pass
|
93
|
+
Args:
|
94
|
+
a (torch.Tensor)
|
95
|
+
b (torch.Tensor): The input tensors
|
96
|
+
Returns:
|
97
|
+
torch.Tensor:
|
98
|
+
The output tensor
|
99
|
+
"""
|
100
|
+
if self._indices_a is not None:
|
101
|
+
return (a.index_select(dim=0, index=self._indices_a) +
|
102
|
+
b.index_select(dim=0, index=self._indices))
|
103
|
+
else:
|
104
|
+
return a + b.index_select(dim=0, index=self._indices)
|
105
|
+
|
106
|
+
def backward(self, grad_output):
|
107
|
+
# For the input a
|
108
|
+
if self._indices_a is not None:
|
109
|
+
grad_a = torch_scatter.scatter_add(grad_output,
|
110
|
+
index=self._indices_a,
|
111
|
+
dim=0)
|
112
|
+
else:
|
113
|
+
# If a is not index selected, simply clone the gradient
|
114
|
+
grad_a = grad_output.clone()
|
115
|
+
# For the input b, perform a segment sum
|
116
|
+
grad_b = torch_scatter.scatter_add(grad_output,
|
117
|
+
index=self._indices,
|
118
|
+
dim=0)
|
119
|
+
return grad_a, grad_b
|
120
|
+
|
121
|
+
|
122
|
+
class WeaveLayer(nn.Module):
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
num_in_feat: int,
|
126
|
+
num_out_feat: int,
|
127
|
+
activation: str = 'relu',
|
128
|
+
is_first_layer: bool = False
|
129
|
+
):
|
130
|
+
super().__init__()
|
131
|
+
self.num_in_feat = num_in_feat
|
132
|
+
self.num_out_feat = num_out_feat
|
133
|
+
self.activation = activation
|
134
|
+
# Broadcasting node features to edges
|
135
|
+
if is_first_layer:
|
136
|
+
self.broadcast = nn.Linear(self.num_in_feat,
|
137
|
+
self.num_out_feat * 5)
|
138
|
+
else:
|
139
|
+
self.broadcast = BNReLULinear(self.num_in_feat,
|
140
|
+
self.num_out_feat * 5,
|
141
|
+
self.activation)
|
142
|
+
# Gather edge features to node
|
143
|
+
self.gather = nn.Sequential(nn.BatchNorm1d(self.num_out_feat),
|
144
|
+
get_activation(self.activation,
|
145
|
+
inplace=True))
|
146
|
+
|
147
|
+
# Update node features
|
148
|
+
self.update = BNReLULinear(self.num_out_feat * 2,
|
149
|
+
self.num_out_feat,
|
150
|
+
self.activation)
|
151
|
+
|
152
|
+
def forward(
|
153
|
+
self,
|
154
|
+
n_feat: torch.Tensor,
|
155
|
+
adj: torch.Tensor
|
156
|
+
):
|
157
|
+
node_broadcast = self.broadcast(n_feat)
|
158
|
+
(self_features,
|
159
|
+
begin_features_sum,
|
160
|
+
end_features_sum,
|
161
|
+
begin_features_max,
|
162
|
+
end_features_max) = torch.split(node_broadcast,
|
163
|
+
self.num_out_feat,
|
164
|
+
dim=-1)
|
165
|
+
edge_info = adj._indices()
|
166
|
+
begin_ids, end_ids = edge_info[0, :], edge_info[1, :]
|
167
|
+
edge_features_max = SelectAdd(end_ids,
|
168
|
+
begin_ids)(begin_features_max,
|
169
|
+
end_features_max)
|
170
|
+
edge_features_sum = SelectAdd(end_ids,
|
171
|
+
begin_ids)(begin_features_sum,
|
172
|
+
end_features_sum)
|
173
|
+
edge_gathered_sum = self.gather(edge_features_sum)
|
174
|
+
edge_gathered_sum = torch_scatter.scatter_add(edge_gathered_sum,
|
175
|
+
begin_ids,
|
176
|
+
dim=0)
|
177
|
+
min_val = edge_features_max.min()
|
178
|
+
edge_gathered_max = edge_features_max - min_val
|
179
|
+
edge_gathered_max = torch_scatter.scatter_max(edge_gathered_max,
|
180
|
+
begin_ids,
|
181
|
+
dim=0)[0]
|
182
|
+
edge_gathered_max = edge_gathered_max + min_val
|
183
|
+
edge_gathered = torch.cat([edge_gathered_max,
|
184
|
+
edge_gathered_sum],
|
185
|
+
dim=-1)
|
186
|
+
node_update = self.update(edge_gathered)
|
187
|
+
outputs = self_features + node_update
|
188
|
+
return outputs
|
189
|
+
|
190
|
+
|
191
|
+
class CasualWeave(nn.Module):
|
192
|
+
def __init__(
|
193
|
+
self,
|
194
|
+
num_feat: int,
|
195
|
+
hidden_sizes: t.Iterable,
|
196
|
+
activation: str = 'elu'
|
197
|
+
):
|
198
|
+
super().__init__()
|
199
|
+
self.num_feat = num_feat
|
200
|
+
self.hidden_sizes = list(hidden_sizes)
|
201
|
+
self.activation = activation
|
202
|
+
|
203
|
+
layers = []
|
204
|
+
for i, (in_feat, out_feat) in enumerate(
|
205
|
+
zip(
|
206
|
+
[self.num_feat, ] +
|
207
|
+
list(self.hidden_sizes)[:-1], # in_features
|
208
|
+
self.hidden_sizes # out_features
|
209
|
+
)
|
210
|
+
):
|
211
|
+
if i == 0:
|
212
|
+
layers.append(
|
213
|
+
WeaveLayer(
|
214
|
+
in_feat,
|
215
|
+
out_feat,
|
216
|
+
self.activation,
|
217
|
+
True
|
218
|
+
)
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
layers.append(
|
222
|
+
WeaveLayer(
|
223
|
+
in_feat,
|
224
|
+
out_feat,
|
225
|
+
self.activation
|
226
|
+
)
|
227
|
+
)
|
228
|
+
self.layers = nn.ModuleList(layers)
|
229
|
+
|
230
|
+
def forward(
|
231
|
+
self,
|
232
|
+
feat: torch.Tensor,
|
233
|
+
adj: torch.Tensor
|
234
|
+
):
|
235
|
+
feat_out = feat
|
236
|
+
for layer in self.layers:
|
237
|
+
feat_out = layer(
|
238
|
+
feat_out,
|
239
|
+
adj
|
240
|
+
)
|
241
|
+
return feat_out
|
242
|
+
|
243
|
+
|
244
|
+
class DenseLayer(nn.Module):
|
245
|
+
def __init__(
|
246
|
+
self,
|
247
|
+
num_in_feat: int,
|
248
|
+
num_botnec_feat: int,
|
249
|
+
num_out_feat: int,
|
250
|
+
activation: str = 'elu',
|
251
|
+
):
|
252
|
+
super().__init__()
|
253
|
+
self.num_in_feat = num_in_feat
|
254
|
+
self.num_out_feat = num_out_feat
|
255
|
+
self.num_botnec_feat = num_botnec_feat
|
256
|
+
self.activation = activation
|
257
|
+
|
258
|
+
self.bottlenec = BNReLULinear(
|
259
|
+
self.num_in_feat,
|
260
|
+
self.num_botnec_feat,
|
261
|
+
self.activation
|
262
|
+
)
|
263
|
+
|
264
|
+
self.weave = WeaveLayer(
|
265
|
+
self.num_botnec_feat,
|
266
|
+
self.num_out_feat,
|
267
|
+
self.activation
|
268
|
+
)
|
269
|
+
|
270
|
+
def forward(
|
271
|
+
self,
|
272
|
+
ls_feat: t.List[torch.Tensor],
|
273
|
+
adj: torch.Tensor,
|
274
|
+
):
|
275
|
+
bn_fn = _bn_function_factory(self.bottlenec)
|
276
|
+
feat = tuc.checkpoint(bn_fn, *ls_feat)
|
277
|
+
return self.weave(
|
278
|
+
feat,
|
279
|
+
adj
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
class DenseNet(nn.Module):
|
284
|
+
def __init__(
|
285
|
+
self,
|
286
|
+
num_feat: int,
|
287
|
+
casual_hidden_sizes: t.Iterable,
|
288
|
+
num_botnec_feat: int,
|
289
|
+
num_k_feat: int,
|
290
|
+
num_dense_layers: int,
|
291
|
+
num_out_feat: int,
|
292
|
+
activation: str = 'elu'
|
293
|
+
):
|
294
|
+
super().__init__()
|
295
|
+
self.num_feat = num_feat
|
296
|
+
self.num_dense_layers = num_dense_layers
|
297
|
+
self.casual_hidden_sizes = list(casual_hidden_sizes)
|
298
|
+
self.num_out_feat = num_out_feat
|
299
|
+
self.activation = activation
|
300
|
+
self.num_k_feat = num_k_feat
|
301
|
+
self.num_botnec_feat = num_botnec_feat
|
302
|
+
self.casual = CasualWeave(
|
303
|
+
self.num_feat,
|
304
|
+
self.casual_hidden_sizes,
|
305
|
+
self.activation
|
306
|
+
)
|
307
|
+
dense_layers = []
|
308
|
+
for i in range(self.num_dense_layers):
|
309
|
+
dense_layers.append(
|
310
|
+
DenseLayer(
|
311
|
+
self.casual_hidden_sizes[-1] + i * self.num_k_feat,
|
312
|
+
self.num_botnec_feat,
|
313
|
+
self.num_k_feat,
|
314
|
+
self.activation
|
315
|
+
)
|
316
|
+
)
|
317
|
+
self.dense_layers = nn.ModuleList(dense_layers)
|
318
|
+
|
319
|
+
self.output = BNReLULinear(
|
320
|
+
(
|
321
|
+
self.casual_hidden_sizes[-1] +
|
322
|
+
self.num_dense_layers * self.num_k_feat
|
323
|
+
),
|
324
|
+
self.num_out_feat,
|
325
|
+
self.activation
|
326
|
+
)
|
327
|
+
|
328
|
+
def forward(
|
329
|
+
self,
|
330
|
+
feat,
|
331
|
+
adj
|
332
|
+
):
|
333
|
+
feat = self.casual(
|
334
|
+
feat,
|
335
|
+
adj
|
336
|
+
)
|
337
|
+
ls_feat = [feat, ]
|
338
|
+
for dense_layer in self.dense_layers:
|
339
|
+
feat_i = dense_layer(
|
340
|
+
ls_feat,
|
341
|
+
adj
|
342
|
+
)
|
343
|
+
ls_feat.append(feat_i)
|
344
|
+
feat_cat = torch.cat(ls_feat, dim=-1)
|
345
|
+
return self.output(feat_cat)
|
346
|
+
|
347
|
+
|
348
|
+
class _Pooling(nn.Module):
|
349
|
+
def __init__(
|
350
|
+
self,
|
351
|
+
in_features: int,
|
352
|
+
pooling_op: t.Callable = torch_scatter.scatter_mean,
|
353
|
+
activation: str = 'elu'
|
354
|
+
):
|
355
|
+
"""Summary
|
356
|
+
Args:
|
357
|
+
in_features (int): Description
|
358
|
+
pooling_op (t.Callable, optional): Description
|
359
|
+
activation (str, optional): Description
|
360
|
+
"""
|
361
|
+
super(_Pooling, self).__init__()
|
362
|
+
self.bn_relu = nn.Sequential(
|
363
|
+
nn.BatchNorm1d(in_features),
|
364
|
+
get_activation(activation, inplace=True)
|
365
|
+
)
|
366
|
+
self.pooling_op = pooling_op
|
367
|
+
|
368
|
+
def forward(
|
369
|
+
self,
|
370
|
+
x: torch.Tensor,
|
371
|
+
ids: torch.Tensor,
|
372
|
+
num_seg: int = None
|
373
|
+
) -> torch.Tensor:
|
374
|
+
"""
|
375
|
+
Args:
|
376
|
+
x (torch.Tensor): The input tensor, size=[N, in_features]
|
377
|
+
ids (torch.Tensor): A tensor of type `torch.long`, size=[N, ]
|
378
|
+
num_seg (int): The number of segments (graphs)
|
379
|
+
Returns:
|
380
|
+
torch.Tensor: Output tensor with size=[num_seg, in_features]
|
381
|
+
"""
|
382
|
+
|
383
|
+
# performing batch_normalization and activation
|
384
|
+
x_bn = self.bn_relu(x) # size=[N, in_features]
|
385
|
+
|
386
|
+
# performing segment operation
|
387
|
+
x_pooled = self.pooling_op(
|
388
|
+
x_bn,
|
389
|
+
dim=0,
|
390
|
+
index=ids,
|
391
|
+
dim_size=num_seg
|
392
|
+
) # size=[num_seg, in_features]
|
393
|
+
|
394
|
+
return x_pooled
|
395
|
+
|
396
|
+
|
397
|
+
class AvgPooling(_Pooling):
|
398
|
+
"""Average pooling layer for graph"""
|
399
|
+
|
400
|
+
def __init__(
|
401
|
+
self,
|
402
|
+
in_features: int,
|
403
|
+
activation: str = 'elu'
|
404
|
+
):
|
405
|
+
""" Performing graph level average pooling (with bn_relu)
|
406
|
+
Args:
|
407
|
+
in_features (int):
|
408
|
+
The number of input features
|
409
|
+
activation (str):
|
410
|
+
The type of activation function to use, default to elu
|
411
|
+
"""
|
412
|
+
super(AvgPooling, self).__init__(
|
413
|
+
in_features,
|
414
|
+
activation=activation,
|
415
|
+
pooling_op=torch_scatter.scatter_mean
|
416
|
+
)
|
417
|
+
|
418
|
+
|
419
|
+
class SumPooling(_Pooling):
|
420
|
+
"""Sum pooling layer for graph"""
|
421
|
+
|
422
|
+
def __init__(
|
423
|
+
self,
|
424
|
+
in_features: int,
|
425
|
+
activation: str = 'elu'
|
426
|
+
):
|
427
|
+
""" Performing graph level sum pooling (with bn_relu)
|
428
|
+
Args:
|
429
|
+
in_features (int):
|
430
|
+
The number of input features
|
431
|
+
activation (str):
|
432
|
+
The type of activation function to use, default to elu
|
433
|
+
"""
|
434
|
+
super(SumPooling, self).__init__(
|
435
|
+
in_features,
|
436
|
+
activation=activation,
|
437
|
+
pooling_op=torch_scatter.scatter_add
|
438
|
+
)
|
439
|
+
|
440
|
+
|
441
|
+
class BNReLULinearBlock(nn.Module):
|
442
|
+
def __init__(
|
443
|
+
self,
|
444
|
+
in_features: int,
|
445
|
+
out_features: int,
|
446
|
+
num_layers: int,
|
447
|
+
hidden_size: int,
|
448
|
+
activation: str = 'elu',
|
449
|
+
# out_act: str = 'sigmoid',
|
450
|
+
**kwargs
|
451
|
+
):
|
452
|
+
super().__init__()
|
453
|
+
|
454
|
+
input_brl = BNReLULinear(
|
455
|
+
in_features,
|
456
|
+
hidden_size,
|
457
|
+
activation
|
458
|
+
)
|
459
|
+
|
460
|
+
btn_brl = [
|
461
|
+
BNReLULinear(
|
462
|
+
hidden_size,
|
463
|
+
hidden_size,
|
464
|
+
activation
|
465
|
+
)
|
466
|
+
for _ in range(num_layers - 2)
|
467
|
+
]
|
468
|
+
|
469
|
+
output_brl = BNReLULinear(
|
470
|
+
hidden_size,
|
471
|
+
out_features,
|
472
|
+
activation,
|
473
|
+
)
|
474
|
+
# self.out_act = get_activation(out_act, **kwargs)
|
475
|
+
|
476
|
+
self.brl_block = nn.Sequential(
|
477
|
+
input_brl,
|
478
|
+
*btn_brl,
|
479
|
+
output_brl,
|
480
|
+
# self.out_act
|
481
|
+
)
|
482
|
+
|
483
|
+
def forward(self, X):
|
484
|
+
return self.brl_block(X)
|
485
|
+
|
486
|
+
|
487
|
+
class MultiTaskMultiClassBlock(nn.Module):
|
488
|
+
_NAME = 'rxn_trans'
|
489
|
+
|
490
|
+
def __init__(
|
491
|
+
self,
|
492
|
+
encoder: nn.Module = None,
|
493
|
+
nums_classes: t.List[int] = [3, 3],
|
494
|
+
hidden_size: int = 128,
|
495
|
+
num_hidden_layers: int = 10,
|
496
|
+
activation: str = 'elu',
|
497
|
+
out_act: str = 'softmax',
|
498
|
+
**kwargs,
|
499
|
+
):
|
500
|
+
super().__init__()
|
501
|
+
self.init_args = {
|
502
|
+
'encoder': encoder,
|
503
|
+
'nums_classes': nums_classes,
|
504
|
+
'hidden_size': hidden_size,
|
505
|
+
'num_hidden_layers': num_hidden_layers,
|
506
|
+
'activation': activation,
|
507
|
+
'out_act': out_act,
|
508
|
+
**kwargs
|
509
|
+
}
|
510
|
+
if isinstance(out_act, str):
|
511
|
+
self.out_acts = [out_act] * len(nums_classes)
|
512
|
+
else:
|
513
|
+
self.out_acts = out_act
|
514
|
+
self.out_act_funcs = nn.ModuleList(
|
515
|
+
[get_activation(act, **kwargs) for act in self.out_acts]
|
516
|
+
)
|
517
|
+
|
518
|
+
self.encoder = encoder
|
519
|
+
self._freeze_encoder = True
|
520
|
+
self.classifiers = nn.ModuleList([
|
521
|
+
BNReLULinearBlock(
|
522
|
+
256,
|
523
|
+
num_class,
|
524
|
+
num_hidden_layers,
|
525
|
+
hidden_size,
|
526
|
+
activation,
|
527
|
+
# out_action,
|
528
|
+
**kwargs
|
529
|
+
)
|
530
|
+
for num_class in nums_classes
|
531
|
+
])
|
532
|
+
|
533
|
+
@property
|
534
|
+
def freeze_encoder(self):
|
535
|
+
return self._freeze_encoder
|
536
|
+
|
537
|
+
@freeze_encoder.setter
|
538
|
+
def freeze_encoder(self, freeze: bool):
|
539
|
+
self._freeze_encoder = freeze
|
540
|
+
self.change_encoder_grad(not freeze)
|
541
|
+
|
542
|
+
def change_encoder_grad(self, requires_grad: bool):
|
543
|
+
for param in self.encoder.parameters():
|
544
|
+
param.requires_grad = requires_grad
|
545
|
+
|
546
|
+
def forward(self, X):
|
547
|
+
embeddings = self.encoder(*X)[0][:, 0, :]
|
548
|
+
if self.training:
|
549
|
+
outputs = [
|
550
|
+
classifier(embeddings)
|
551
|
+
for classifier in self.classifiers
|
552
|
+
]
|
553
|
+
else:
|
554
|
+
outputs = [
|
555
|
+
act(classifier(embeddings))
|
556
|
+
for classifier, act in zip(self.classifiers, self.out_act_funcs)
|
557
|
+
]
|
558
|
+
|
559
|
+
return outputs
|
560
|
+
|
561
|
+
|
562
|
+
class MuMcHardBlock(nn.Module):
|
563
|
+
_NAME = 'rxn_trans_hard'
|
564
|
+
|
565
|
+
def __init__(
|
566
|
+
self,
|
567
|
+
encoder: nn.Module = None,
|
568
|
+
nums_classes: t.List[int] = [3, 3],
|
569
|
+
hidden_size: int = 128,
|
570
|
+
num_hidden_layers: int = 10,
|
571
|
+
activation: str = 'elu',
|
572
|
+
out_act: str = 'softmax',
|
573
|
+
**kwargs,
|
574
|
+
):
|
575
|
+
super().__init__()
|
576
|
+
self.init_args = {
|
577
|
+
'encoder': encoder,
|
578
|
+
'nums_classes': nums_classes,
|
579
|
+
'hidden_size': hidden_size,
|
580
|
+
'num_hidden_layers': num_hidden_layers,
|
581
|
+
'activation': activation,
|
582
|
+
'out_act': out_act,
|
583
|
+
**kwargs
|
584
|
+
}
|
585
|
+
if isinstance(out_act, str):
|
586
|
+
self.out_acts = [out_act] * len(nums_classes)
|
587
|
+
else:
|
588
|
+
self.out_acts = out_act
|
589
|
+
self.out_act_funcs = nn.ModuleList(
|
590
|
+
[get_activation(act, **kwargs) for act in self.out_acts]
|
591
|
+
)
|
592
|
+
|
593
|
+
self.encoder = encoder
|
594
|
+
self._freeze_encoder = True
|
595
|
+
self.classifier = BNReLULinearBlock(
|
596
|
+
256,
|
597
|
+
hidden_size,
|
598
|
+
num_hidden_layers,
|
599
|
+
hidden_size,
|
600
|
+
activation,
|
601
|
+
# out_action,
|
602
|
+
**kwargs
|
603
|
+
)
|
604
|
+
|
605
|
+
self.out_layers = nn.ModuleList([
|
606
|
+
BNReLULinear(
|
607
|
+
hidden_size,
|
608
|
+
num_classes
|
609
|
+
)
|
610
|
+
for num_classes in nums_classes
|
611
|
+
])
|
612
|
+
|
613
|
+
@property
|
614
|
+
def freeze_encoder(self):
|
615
|
+
return self._freeze_encoder
|
616
|
+
|
617
|
+
@freeze_encoder.setter
|
618
|
+
def freeze_encoder(self, freeze: bool):
|
619
|
+
self._freeze_encoder = freeze
|
620
|
+
self.change_encoder_grad(not freeze)
|
621
|
+
|
622
|
+
def change_encoder_grad(self, requires_grad: bool):
|
623
|
+
for param in self.encoder.parameters():
|
624
|
+
param.requires_grad = requires_grad
|
625
|
+
|
626
|
+
def forward(self, X):
|
627
|
+
embeddings = self.encoder(*X)[0][:, 0, :]
|
628
|
+
embeddings = self.classifier(embeddings)
|
629
|
+
|
630
|
+
if self.training:
|
631
|
+
outputs = [
|
632
|
+
out_layer(embeddings)
|
633
|
+
for out_layer in self.out_layers
|
634
|
+
]
|
635
|
+
else:
|
636
|
+
outputs = [
|
637
|
+
act(out_layer(embeddings))
|
638
|
+
for out_layer, act in zip(self.out_layers, self.out_act_funcs)
|
639
|
+
]
|
640
|
+
|
641
|
+
return outputs
|
File without changes
|