hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,641 @@
1
+ import typing as t
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ import torch_scatter
7
+ from torch.utils import checkpoint as tuc
8
+
9
+ from hdl.ops.utils import get_activation
10
+
11
+ __all__ = [
12
+ "WeaveLayer",
13
+ "DenseNet",
14
+ "AvgPooling",
15
+ "SumPooling",
16
+ "CasualWeave",
17
+ "DenseLayer"
18
+ ]
19
+
20
+
21
+ def _bn_function_factory(bn_module):
22
+ def bn_function(*inputs):
23
+ concated_features = torch.cat(inputs, -1)
24
+ bottleneck_output = bn_module(concated_features)
25
+ return bottleneck_output
26
+
27
+ return bn_function
28
+
29
+
30
+ class BNReLULinear(nn.Module):
31
+ """
32
+ Linear layer with bn->relu->linear architecture
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ in_features: int,
38
+ out_features: int,
39
+ activation: str = 'elu',
40
+ **kwargs
41
+ ):
42
+ """
43
+ Args:
44
+ in_features (int):
45
+ The number of input features
46
+ out_features (int):
47
+ The number of output features
48
+ activation (str):
49
+ The type of activation unit to use in this module,
50
+ default to elu
51
+ """
52
+ super(BNReLULinear, self).__init__()
53
+ self.bn_relu_linear = nn.Sequential(
54
+ nn.BatchNorm1d(in_features),
55
+ nn.Linear(
56
+ in_features,
57
+ out_features,
58
+ bias=False
59
+ ),
60
+ get_activation(
61
+ activation,
62
+ inplace=True,
63
+ **kwargs
64
+ )
65
+ )
66
+
67
+ def forward(self, x):
68
+ """The forward method"""
69
+ return self.bn_relu_linear(x)
70
+
71
+
72
+ class SelectAdd(Function):
73
+ """
74
+ Implement the memory efficient version of `a + b.index_select(indices)`
75
+ """
76
+
77
+ def __init__(self,
78
+ indices: torch.Tensor,
79
+ indices_a: torch.Tensor = None):
80
+ """
81
+ Initializer
82
+ Args:
83
+ indices (torch.Tensor): The indices to select the object `b`
84
+ indices_a (torch.Tensor or None):
85
+ The indices to select the object `a`. Default to None
86
+ """
87
+ self._indices = indices
88
+ self._indices_a = indices_a
89
+
90
+ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
91
+ """
92
+ The forward pass
93
+ Args:
94
+ a (torch.Tensor)
95
+ b (torch.Tensor): The input tensors
96
+ Returns:
97
+ torch.Tensor:
98
+ The output tensor
99
+ """
100
+ if self._indices_a is not None:
101
+ return (a.index_select(dim=0, index=self._indices_a) +
102
+ b.index_select(dim=0, index=self._indices))
103
+ else:
104
+ return a + b.index_select(dim=0, index=self._indices)
105
+
106
+ def backward(self, grad_output):
107
+ # For the input a
108
+ if self._indices_a is not None:
109
+ grad_a = torch_scatter.scatter_add(grad_output,
110
+ index=self._indices_a,
111
+ dim=0)
112
+ else:
113
+ # If a is not index selected, simply clone the gradient
114
+ grad_a = grad_output.clone()
115
+ # For the input b, perform a segment sum
116
+ grad_b = torch_scatter.scatter_add(grad_output,
117
+ index=self._indices,
118
+ dim=0)
119
+ return grad_a, grad_b
120
+
121
+
122
+ class WeaveLayer(nn.Module):
123
+ def __init__(
124
+ self,
125
+ num_in_feat: int,
126
+ num_out_feat: int,
127
+ activation: str = 'relu',
128
+ is_first_layer: bool = False
129
+ ):
130
+ super().__init__()
131
+ self.num_in_feat = num_in_feat
132
+ self.num_out_feat = num_out_feat
133
+ self.activation = activation
134
+ # Broadcasting node features to edges
135
+ if is_first_layer:
136
+ self.broadcast = nn.Linear(self.num_in_feat,
137
+ self.num_out_feat * 5)
138
+ else:
139
+ self.broadcast = BNReLULinear(self.num_in_feat,
140
+ self.num_out_feat * 5,
141
+ self.activation)
142
+ # Gather edge features to node
143
+ self.gather = nn.Sequential(nn.BatchNorm1d(self.num_out_feat),
144
+ get_activation(self.activation,
145
+ inplace=True))
146
+
147
+ # Update node features
148
+ self.update = BNReLULinear(self.num_out_feat * 2,
149
+ self.num_out_feat,
150
+ self.activation)
151
+
152
+ def forward(
153
+ self,
154
+ n_feat: torch.Tensor,
155
+ adj: torch.Tensor
156
+ ):
157
+ node_broadcast = self.broadcast(n_feat)
158
+ (self_features,
159
+ begin_features_sum,
160
+ end_features_sum,
161
+ begin_features_max,
162
+ end_features_max) = torch.split(node_broadcast,
163
+ self.num_out_feat,
164
+ dim=-1)
165
+ edge_info = adj._indices()
166
+ begin_ids, end_ids = edge_info[0, :], edge_info[1, :]
167
+ edge_features_max = SelectAdd(end_ids,
168
+ begin_ids)(begin_features_max,
169
+ end_features_max)
170
+ edge_features_sum = SelectAdd(end_ids,
171
+ begin_ids)(begin_features_sum,
172
+ end_features_sum)
173
+ edge_gathered_sum = self.gather(edge_features_sum)
174
+ edge_gathered_sum = torch_scatter.scatter_add(edge_gathered_sum,
175
+ begin_ids,
176
+ dim=0)
177
+ min_val = edge_features_max.min()
178
+ edge_gathered_max = edge_features_max - min_val
179
+ edge_gathered_max = torch_scatter.scatter_max(edge_gathered_max,
180
+ begin_ids,
181
+ dim=0)[0]
182
+ edge_gathered_max = edge_gathered_max + min_val
183
+ edge_gathered = torch.cat([edge_gathered_max,
184
+ edge_gathered_sum],
185
+ dim=-1)
186
+ node_update = self.update(edge_gathered)
187
+ outputs = self_features + node_update
188
+ return outputs
189
+
190
+
191
+ class CasualWeave(nn.Module):
192
+ def __init__(
193
+ self,
194
+ num_feat: int,
195
+ hidden_sizes: t.Iterable,
196
+ activation: str = 'elu'
197
+ ):
198
+ super().__init__()
199
+ self.num_feat = num_feat
200
+ self.hidden_sizes = list(hidden_sizes)
201
+ self.activation = activation
202
+
203
+ layers = []
204
+ for i, (in_feat, out_feat) in enumerate(
205
+ zip(
206
+ [self.num_feat, ] +
207
+ list(self.hidden_sizes)[:-1], # in_features
208
+ self.hidden_sizes # out_features
209
+ )
210
+ ):
211
+ if i == 0:
212
+ layers.append(
213
+ WeaveLayer(
214
+ in_feat,
215
+ out_feat,
216
+ self.activation,
217
+ True
218
+ )
219
+ )
220
+ else:
221
+ layers.append(
222
+ WeaveLayer(
223
+ in_feat,
224
+ out_feat,
225
+ self.activation
226
+ )
227
+ )
228
+ self.layers = nn.ModuleList(layers)
229
+
230
+ def forward(
231
+ self,
232
+ feat: torch.Tensor,
233
+ adj: torch.Tensor
234
+ ):
235
+ feat_out = feat
236
+ for layer in self.layers:
237
+ feat_out = layer(
238
+ feat_out,
239
+ adj
240
+ )
241
+ return feat_out
242
+
243
+
244
+ class DenseLayer(nn.Module):
245
+ def __init__(
246
+ self,
247
+ num_in_feat: int,
248
+ num_botnec_feat: int,
249
+ num_out_feat: int,
250
+ activation: str = 'elu',
251
+ ):
252
+ super().__init__()
253
+ self.num_in_feat = num_in_feat
254
+ self.num_out_feat = num_out_feat
255
+ self.num_botnec_feat = num_botnec_feat
256
+ self.activation = activation
257
+
258
+ self.bottlenec = BNReLULinear(
259
+ self.num_in_feat,
260
+ self.num_botnec_feat,
261
+ self.activation
262
+ )
263
+
264
+ self.weave = WeaveLayer(
265
+ self.num_botnec_feat,
266
+ self.num_out_feat,
267
+ self.activation
268
+ )
269
+
270
+ def forward(
271
+ self,
272
+ ls_feat: t.List[torch.Tensor],
273
+ adj: torch.Tensor,
274
+ ):
275
+ bn_fn = _bn_function_factory(self.bottlenec)
276
+ feat = tuc.checkpoint(bn_fn, *ls_feat)
277
+ return self.weave(
278
+ feat,
279
+ adj
280
+ )
281
+
282
+
283
+ class DenseNet(nn.Module):
284
+ def __init__(
285
+ self,
286
+ num_feat: int,
287
+ casual_hidden_sizes: t.Iterable,
288
+ num_botnec_feat: int,
289
+ num_k_feat: int,
290
+ num_dense_layers: int,
291
+ num_out_feat: int,
292
+ activation: str = 'elu'
293
+ ):
294
+ super().__init__()
295
+ self.num_feat = num_feat
296
+ self.num_dense_layers = num_dense_layers
297
+ self.casual_hidden_sizes = list(casual_hidden_sizes)
298
+ self.num_out_feat = num_out_feat
299
+ self.activation = activation
300
+ self.num_k_feat = num_k_feat
301
+ self.num_botnec_feat = num_botnec_feat
302
+ self.casual = CasualWeave(
303
+ self.num_feat,
304
+ self.casual_hidden_sizes,
305
+ self.activation
306
+ )
307
+ dense_layers = []
308
+ for i in range(self.num_dense_layers):
309
+ dense_layers.append(
310
+ DenseLayer(
311
+ self.casual_hidden_sizes[-1] + i * self.num_k_feat,
312
+ self.num_botnec_feat,
313
+ self.num_k_feat,
314
+ self.activation
315
+ )
316
+ )
317
+ self.dense_layers = nn.ModuleList(dense_layers)
318
+
319
+ self.output = BNReLULinear(
320
+ (
321
+ self.casual_hidden_sizes[-1] +
322
+ self.num_dense_layers * self.num_k_feat
323
+ ),
324
+ self.num_out_feat,
325
+ self.activation
326
+ )
327
+
328
+ def forward(
329
+ self,
330
+ feat,
331
+ adj
332
+ ):
333
+ feat = self.casual(
334
+ feat,
335
+ adj
336
+ )
337
+ ls_feat = [feat, ]
338
+ for dense_layer in self.dense_layers:
339
+ feat_i = dense_layer(
340
+ ls_feat,
341
+ adj
342
+ )
343
+ ls_feat.append(feat_i)
344
+ feat_cat = torch.cat(ls_feat, dim=-1)
345
+ return self.output(feat_cat)
346
+
347
+
348
+ class _Pooling(nn.Module):
349
+ def __init__(
350
+ self,
351
+ in_features: int,
352
+ pooling_op: t.Callable = torch_scatter.scatter_mean,
353
+ activation: str = 'elu'
354
+ ):
355
+ """Summary
356
+ Args:
357
+ in_features (int): Description
358
+ pooling_op (t.Callable, optional): Description
359
+ activation (str, optional): Description
360
+ """
361
+ super(_Pooling, self).__init__()
362
+ self.bn_relu = nn.Sequential(
363
+ nn.BatchNorm1d(in_features),
364
+ get_activation(activation, inplace=True)
365
+ )
366
+ self.pooling_op = pooling_op
367
+
368
+ def forward(
369
+ self,
370
+ x: torch.Tensor,
371
+ ids: torch.Tensor,
372
+ num_seg: int = None
373
+ ) -> torch.Tensor:
374
+ """
375
+ Args:
376
+ x (torch.Tensor): The input tensor, size=[N, in_features]
377
+ ids (torch.Tensor): A tensor of type `torch.long`, size=[N, ]
378
+ num_seg (int): The number of segments (graphs)
379
+ Returns:
380
+ torch.Tensor: Output tensor with size=[num_seg, in_features]
381
+ """
382
+
383
+ # performing batch_normalization and activation
384
+ x_bn = self.bn_relu(x) # size=[N, in_features]
385
+
386
+ # performing segment operation
387
+ x_pooled = self.pooling_op(
388
+ x_bn,
389
+ dim=0,
390
+ index=ids,
391
+ dim_size=num_seg
392
+ ) # size=[num_seg, in_features]
393
+
394
+ return x_pooled
395
+
396
+
397
+ class AvgPooling(_Pooling):
398
+ """Average pooling layer for graph"""
399
+
400
+ def __init__(
401
+ self,
402
+ in_features: int,
403
+ activation: str = 'elu'
404
+ ):
405
+ """ Performing graph level average pooling (with bn_relu)
406
+ Args:
407
+ in_features (int):
408
+ The number of input features
409
+ activation (str):
410
+ The type of activation function to use, default to elu
411
+ """
412
+ super(AvgPooling, self).__init__(
413
+ in_features,
414
+ activation=activation,
415
+ pooling_op=torch_scatter.scatter_mean
416
+ )
417
+
418
+
419
+ class SumPooling(_Pooling):
420
+ """Sum pooling layer for graph"""
421
+
422
+ def __init__(
423
+ self,
424
+ in_features: int,
425
+ activation: str = 'elu'
426
+ ):
427
+ """ Performing graph level sum pooling (with bn_relu)
428
+ Args:
429
+ in_features (int):
430
+ The number of input features
431
+ activation (str):
432
+ The type of activation function to use, default to elu
433
+ """
434
+ super(SumPooling, self).__init__(
435
+ in_features,
436
+ activation=activation,
437
+ pooling_op=torch_scatter.scatter_add
438
+ )
439
+
440
+
441
+ class BNReLULinearBlock(nn.Module):
442
+ def __init__(
443
+ self,
444
+ in_features: int,
445
+ out_features: int,
446
+ num_layers: int,
447
+ hidden_size: int,
448
+ activation: str = 'elu',
449
+ # out_act: str = 'sigmoid',
450
+ **kwargs
451
+ ):
452
+ super().__init__()
453
+
454
+ input_brl = BNReLULinear(
455
+ in_features,
456
+ hidden_size,
457
+ activation
458
+ )
459
+
460
+ btn_brl = [
461
+ BNReLULinear(
462
+ hidden_size,
463
+ hidden_size,
464
+ activation
465
+ )
466
+ for _ in range(num_layers - 2)
467
+ ]
468
+
469
+ output_brl = BNReLULinear(
470
+ hidden_size,
471
+ out_features,
472
+ activation,
473
+ )
474
+ # self.out_act = get_activation(out_act, **kwargs)
475
+
476
+ self.brl_block = nn.Sequential(
477
+ input_brl,
478
+ *btn_brl,
479
+ output_brl,
480
+ # self.out_act
481
+ )
482
+
483
+ def forward(self, X):
484
+ return self.brl_block(X)
485
+
486
+
487
+ class MultiTaskMultiClassBlock(nn.Module):
488
+ _NAME = 'rxn_trans'
489
+
490
+ def __init__(
491
+ self,
492
+ encoder: nn.Module = None,
493
+ nums_classes: t.List[int] = [3, 3],
494
+ hidden_size: int = 128,
495
+ num_hidden_layers: int = 10,
496
+ activation: str = 'elu',
497
+ out_act: str = 'softmax',
498
+ **kwargs,
499
+ ):
500
+ super().__init__()
501
+ self.init_args = {
502
+ 'encoder': encoder,
503
+ 'nums_classes': nums_classes,
504
+ 'hidden_size': hidden_size,
505
+ 'num_hidden_layers': num_hidden_layers,
506
+ 'activation': activation,
507
+ 'out_act': out_act,
508
+ **kwargs
509
+ }
510
+ if isinstance(out_act, str):
511
+ self.out_acts = [out_act] * len(nums_classes)
512
+ else:
513
+ self.out_acts = out_act
514
+ self.out_act_funcs = nn.ModuleList(
515
+ [get_activation(act, **kwargs) for act in self.out_acts]
516
+ )
517
+
518
+ self.encoder = encoder
519
+ self._freeze_encoder = True
520
+ self.classifiers = nn.ModuleList([
521
+ BNReLULinearBlock(
522
+ 256,
523
+ num_class,
524
+ num_hidden_layers,
525
+ hidden_size,
526
+ activation,
527
+ # out_action,
528
+ **kwargs
529
+ )
530
+ for num_class in nums_classes
531
+ ])
532
+
533
+ @property
534
+ def freeze_encoder(self):
535
+ return self._freeze_encoder
536
+
537
+ @freeze_encoder.setter
538
+ def freeze_encoder(self, freeze: bool):
539
+ self._freeze_encoder = freeze
540
+ self.change_encoder_grad(not freeze)
541
+
542
+ def change_encoder_grad(self, requires_grad: bool):
543
+ for param in self.encoder.parameters():
544
+ param.requires_grad = requires_grad
545
+
546
+ def forward(self, X):
547
+ embeddings = self.encoder(*X)[0][:, 0, :]
548
+ if self.training:
549
+ outputs = [
550
+ classifier(embeddings)
551
+ for classifier in self.classifiers
552
+ ]
553
+ else:
554
+ outputs = [
555
+ act(classifier(embeddings))
556
+ for classifier, act in zip(self.classifiers, self.out_act_funcs)
557
+ ]
558
+
559
+ return outputs
560
+
561
+
562
+ class MuMcHardBlock(nn.Module):
563
+ _NAME = 'rxn_trans_hard'
564
+
565
+ def __init__(
566
+ self,
567
+ encoder: nn.Module = None,
568
+ nums_classes: t.List[int] = [3, 3],
569
+ hidden_size: int = 128,
570
+ num_hidden_layers: int = 10,
571
+ activation: str = 'elu',
572
+ out_act: str = 'softmax',
573
+ **kwargs,
574
+ ):
575
+ super().__init__()
576
+ self.init_args = {
577
+ 'encoder': encoder,
578
+ 'nums_classes': nums_classes,
579
+ 'hidden_size': hidden_size,
580
+ 'num_hidden_layers': num_hidden_layers,
581
+ 'activation': activation,
582
+ 'out_act': out_act,
583
+ **kwargs
584
+ }
585
+ if isinstance(out_act, str):
586
+ self.out_acts = [out_act] * len(nums_classes)
587
+ else:
588
+ self.out_acts = out_act
589
+ self.out_act_funcs = nn.ModuleList(
590
+ [get_activation(act, **kwargs) for act in self.out_acts]
591
+ )
592
+
593
+ self.encoder = encoder
594
+ self._freeze_encoder = True
595
+ self.classifier = BNReLULinearBlock(
596
+ 256,
597
+ hidden_size,
598
+ num_hidden_layers,
599
+ hidden_size,
600
+ activation,
601
+ # out_action,
602
+ **kwargs
603
+ )
604
+
605
+ self.out_layers = nn.ModuleList([
606
+ BNReLULinear(
607
+ hidden_size,
608
+ num_classes
609
+ )
610
+ for num_classes in nums_classes
611
+ ])
612
+
613
+ @property
614
+ def freeze_encoder(self):
615
+ return self._freeze_encoder
616
+
617
+ @freeze_encoder.setter
618
+ def freeze_encoder(self, freeze: bool):
619
+ self._freeze_encoder = freeze
620
+ self.change_encoder_grad(not freeze)
621
+
622
+ def change_encoder_grad(self, requires_grad: bool):
623
+ for param in self.encoder.parameters():
624
+ param.requires_grad = requires_grad
625
+
626
+ def forward(self, X):
627
+ embeddings = self.encoder(*X)[0][:, 0, :]
628
+ embeddings = self.classifier(embeddings)
629
+
630
+ if self.training:
631
+ outputs = [
632
+ out_layer(embeddings)
633
+ for out_layer in self.out_layers
634
+ ]
635
+ else:
636
+ outputs = [
637
+ act(out_layer(embeddings))
638
+ for out_layer, act in zip(self.out_layers, self.out_act_funcs)
639
+ ]
640
+
641
+ return outputs
File without changes