hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,79 @@
1
+ import typing as t
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .multi_label import BPMLLLoss
7
+
8
+
9
+ def get_lossfunc(
10
+ name: str,
11
+ *args,
12
+ **kwargs
13
+ ) -> t.Callable:
14
+ """Get loss function by name
15
+
16
+ Args:
17
+ name (str): The name of the loss function
18
+
19
+ Returns:
20
+ t.Callable: the loss function
21
+ """
22
+ name = name.lower()
23
+ if name == 'bce':
24
+ return nn.BCELoss(*args, **kwargs)
25
+ elif name == 'ce':
26
+ return nn.CrossEntropyLoss(*args, **kwargs)
27
+ elif name == 'mse':
28
+ return nn.MSELoss(*args, **kwargs)
29
+ elif name == 'bpmll':
30
+ return BPMLLLoss(*args, **kwargs)
31
+ elif name == 'nll':
32
+ return nn.GaussianNLLLoss(*args, **kwargs)
33
+
34
+
35
+ def mtmc_loss(
36
+ y_preds: t.Iterable,
37
+ y_trues: t.Iterable,
38
+ loss_names: t.Iterable[str] = None,
39
+ individual: bool = False,
40
+ task_weights: t.List = None,
41
+ device=torch.device('cpu'),
42
+ **kwargs
43
+ ):
44
+ num_tasks = len(y_preds)
45
+ if loss_names is None:
46
+ loss_func = nn.CrossEntropyLoss()
47
+ loss_funcs = [loss_func] * num_tasks
48
+ elif isinstance(loss_names, str):
49
+ loss_func = get_lossfunc(loss_names, **kwargs)
50
+ loss_funcs = [loss_func] * num_tasks
51
+ else:
52
+ loss_funcs = [
53
+ get_lossfunc(loss_str)
54
+ for loss_str in loss_names
55
+ ]
56
+
57
+ if task_weights is None:
58
+ task_weights = torch.ones(num_tasks).to(device)
59
+ else:
60
+ assert len(task_weights) == num_tasks
61
+ task_weights = torch.FloatTensor(task_weights).to(device)
62
+
63
+ loss_values = [
64
+ loss_func(y_pred, y_true)
65
+ for y_pred, y_true, loss_func in zip(
66
+ y_preds, y_trues, loss_funcs
67
+ )
68
+ ]
69
+
70
+ loss_final = sum([
71
+ loss_value * task_weight
72
+ for loss_value, task_weight in zip(loss_values, task_weights)
73
+ ])
74
+ # loss_final = sum(loss_values) / num_tasks
75
+ if not individual:
76
+ return loss_final
77
+ else:
78
+ loss_list = [loss_value for loss_value in loss_values]
79
+ return (loss_final, loss_list)
@@ -0,0 +1,178 @@
1
+ import math
2
+ from typing import Callable, List, Union
3
+ from functools import partial
4
+
5
+ import numpy as np
6
+ from sklearn.metrics import (
7
+ auc, mean_absolute_error, mean_squared_error,
8
+ precision_recall_curve, r2_score,
9
+ roc_auc_score, accuracy_score, log_loss, matthews_corrcoef,
10
+ # top_k_accuracy_score
11
+ )
12
+ import torch
13
+ import torch.nn as nn
14
+ import scipy
15
+
16
+
17
+ def prc_auc(targets: List[int], preds: List[float]) -> float:
18
+ """
19
+ Computes the area under the precision-recall curve.
20
+
21
+ :param targets: A list of binary targets.
22
+ :param preds: A list of prediction probabilities.
23
+ :return: The computed prc-auc.
24
+ """
25
+ precision, recall, _ = precision_recall_curve(targets, preds)
26
+ return auc(recall, precision)
27
+
28
+
29
+ def bce(targets: List[int], preds: List[float]) -> float:
30
+ """
31
+ Computes the binary cross entropy loss.
32
+
33
+ :param targets: A list of binary targets.
34
+ :param preds: A list of prediction probabilities.
35
+ :return: The computed binary cross entropy.
36
+ """
37
+ # Don't use logits because the sigmoid is added in all places except training itself
38
+ bce_func = nn.BCELoss(reduction='mean')
39
+ loss = bce_func(target=torch.Tensor(targets), input=torch.Tensor(preds)).item()
40
+
41
+ return loss
42
+
43
+
44
+ def rmse(targets: List[float], preds: List[float]) -> float:
45
+ """
46
+ Computes the root mean squared error.
47
+
48
+ :param targets: A list of targets.
49
+ :param preds: A list of predictions.
50
+ :return: The computed rmse.
51
+ """
52
+ return math.sqrt(mean_squared_error(targets, preds))
53
+
54
+
55
+ def mse(targets: List[float], preds: List[float]) -> float:
56
+ """
57
+ Computes the mean squared error.
58
+
59
+ :param targets: A list of targets.
60
+ :param preds: A list of predictions.
61
+ :return: The computed mse.
62
+ """
63
+ return mean_squared_error(targets, preds)
64
+
65
+
66
+ def accuracy(targets: List[int], preds: Union[List[float], List[List[float]]], threshold: float = 0.5) -> float:
67
+ """
68
+ Computes the accuracy of a binary prediction task using a given threshold for generating hard predictions.
69
+
70
+ Alternatively, computes accuracy for a multiclass prediction task by picking the largest probability.
71
+
72
+ :param targets: A list of binary targets.
73
+ :param preds: A list of prediction probabilities.
74
+ :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0.
75
+ :return: The computed accuracy.
76
+ """
77
+ if type(preds[0]) == list: # multiclass
78
+ hard_preds = [p.index(max(p)) for p in preds]
79
+ else:
80
+ hard_preds = [1 if p > threshold else 0 for p in preds] # binary prediction
81
+
82
+ return accuracy_score(targets, hard_preds)
83
+
84
+
85
+ def rsquared(x, y):
86
+ """ Return R^2 where x and y are array-like."""
87
+
88
+ _, _, r_value, _, _ = scipy.stats.linregress(x, y)
89
+ return r_value ** 2
90
+
91
+
92
+ def mcc(y_true, y_pred):
93
+ y_true = np.array(y_true).astype(int)
94
+ # y_true = np.where(y_true == 1, 1, -1).astype(int)
95
+ y_pred = np.array(y_pred)
96
+ y_pred = (y_pred >= 0.5).astype(int)
97
+
98
+ return matthews_corrcoef(y_true, y_pred)
99
+
100
+
101
+ def topk(y_true, y_pred, k=1):
102
+
103
+ y_true = np.array(y_true).astype(int)
104
+
105
+ y_pred = np.array(y_pred)
106
+
107
+ sorted_pred = np.argsort(y_pred, axis=1, kind='mergesort')[:, ::-1]
108
+ hits = (y_true == sorted_pred[:, :k].T).any(axis=0)
109
+ num_hits = np.sum(hits)
110
+
111
+ return num_hits / len(y_true)
112
+
113
+
114
+ def get_metric(metric: str) -> Callable[[Union[List[int], List[float]], List[float]], float]:
115
+ r"""
116
+ Gets the metric function corresponding to a given metric name.
117
+
118
+ Supports:
119
+
120
+ * :code:`auc`: Area under the receiver operating characteristic curve
121
+ * :code:`prc-auc`: Area under the precision recall curve
122
+ * :code:`rmse`: Root mean squared error
123
+ * :code:`mse`: Mean squared error
124
+ * :code:`mae`: Mean absolute error
125
+ * :code:`r2`: Coefficient of determination R\ :superscript:`2`
126
+ * :code:`accuracy`: Accuracy (using a threshold to binarize predictions)
127
+ * :code:`cross_entropy`: Cross entropy
128
+ * :code:`binary_cross_entropy`: Binary cross entropy
129
+
130
+ :param metric: Metric name.
131
+ :return: A metric function which takes as arguments a list of targets and a list of predictions and returns.
132
+ """
133
+ if metric == 'mcc':
134
+ return mcc
135
+
136
+ if metric == 'rsquared':
137
+ return rsquared
138
+
139
+ if metric == 'auc':
140
+ return roc_auc_score
141
+
142
+ if metric == 'prc-auc':
143
+ return prc_auc
144
+
145
+ if metric == 'rmse':
146
+ return rmse
147
+
148
+ if metric == 'mse':
149
+ return mse
150
+
151
+ if metric == 'mae':
152
+ return mean_absolute_error
153
+
154
+ if metric == 'r2':
155
+ return r2_score
156
+
157
+ if metric == 'acc':
158
+ return accuracy
159
+
160
+ if metric == 'ce':
161
+ return log_loss
162
+
163
+ if metric == 'bce':
164
+ return bce
165
+
166
+ if metric == 'topk':
167
+ return topk
168
+
169
+ if metric == 'top3':
170
+ return partial(topk, k=3)
171
+
172
+ if metric == 'top5':
173
+ return partial(topk, k=5)
174
+
175
+ if metric == 'top10':
176
+ return partial(topk, k=10)
177
+
178
+ raise ValueError(f'Metric "{metric}" not supported.')
@@ -0,0 +1,42 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ class BPMLLLoss(torch.nn.Module):
6
+ def __init__(self, reduction: str = 'mean', bias=(1, 1)):
7
+ super(BPMLLLoss, self).__init__()
8
+ self.bias = bias
9
+ self.reduction = reduction
10
+ assert len(self.bias) == 2 \
11
+ and all(map(lambda x: isinstance(x, int) and x > 0, bias)), \
12
+ "bias must be positive integers"
13
+
14
+ def forward(self, c: Tensor, y: Tensor) -> Tensor:
15
+ r"""
16
+ compute the loss, which has the form:
17
+ L = \sum_{i=1}^{m} \frac{1}{|Y_i| \cdot |\bar{Y}_i|} \sum_{(k, l) \in Y_i \times \bar{Y}_i} \exp{-c^i_k+c^i_l}
18
+ :param c: prediction tensor, size: batch_size * n_labels
19
+ :param y: target tensor, size: batch_size * n_labels
20
+ :return: size: scalar tensor
21
+ """
22
+ y = y.float()
23
+ y_bar = -y + 1
24
+ y_norm = torch.pow(y.sum(dim=(1,)), self.bias[0])
25
+ y_bar_norm = torch.pow(y_bar.sum(dim=(1,)), self.bias[1])
26
+ assert torch.all(y_norm != 0) or torch.all(y_bar_norm != 0), \
27
+ "an instance cannot have none or all the labels"
28
+ loss = 1 / torch.mul(y_norm, y_bar_norm) \
29
+ * self.pairwise_sub_exp(y, y_bar, c)
30
+
31
+ if self.reduction == 'mean':
32
+ return torch.mean(loss)
33
+ elif self.reduction == 'none':
34
+ return loss
35
+
36
+ def pairwise_sub_exp(self, y: Tensor, y_bar: Tensor, c: Tensor) -> Tensor:
37
+ r"""
38
+ compute \sum_{(k, l) \in Y_i \times \bar{Y}_i} \exp{-c^i_k+c^i_l}
39
+ """
40
+ truth_matrix = y.unsqueeze(2).float() @ y_bar.unsqueeze(1).float()
41
+ exp_matrix = torch.exp(c.unsqueeze(1) - c.unsqueeze(2))
42
+ return (torch.mul(truth_matrix, exp_matrix)).sum(dim=(1, 2))
@@ -0,0 +1,65 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class NTXentLoss(torch.nn.Module):
6
+
7
+ def __init__(self, device, batch_size, temperature, use_cosine_similarity):
8
+ super(NTXentLoss, self).__init__()
9
+ self.batch_size = batch_size
10
+ self.temperature = temperature
11
+ self.device = device
12
+ self.softmax = torch.nn.Softmax(dim=-1)
13
+ self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
14
+ self.similarity_function = self._get_similarity_function(use_cosine_similarity)
15
+ self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
16
+
17
+ def _get_similarity_function(self, use_cosine_similarity):
18
+ if use_cosine_similarity:
19
+ self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
20
+ return self._cosine_simililarity
21
+ else:
22
+ return self._dot_simililarity
23
+
24
+ def _get_correlated_mask(self):
25
+ diag = np.eye(2 * self.batch_size)
26
+ l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
27
+ l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
28
+ mask = torch.from_numpy((diag + l1 + l2))
29
+ mask = (1 - mask).type(torch.bool)
30
+ return mask.to(self.device)
31
+
32
+ @staticmethod
33
+ def _dot_simililarity(x, y):
34
+ v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
35
+ # x shape: (N, 1, C)
36
+ # y shape: (1, C, 2N)
37
+ # v shape: (N, 2N)
38
+ return v
39
+
40
+ def _cosine_simililarity(self, x, y):
41
+ # x shape: (N, 1, C)
42
+ # y shape: (1, 2N, C)
43
+ # v shape: (N, 2N)
44
+ v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
45
+ return v
46
+
47
+ def forward(self, zis, zjs):
48
+ representations = torch.cat([zjs, zis], dim=0)
49
+
50
+ similarity_matrix = self.similarity_function(representations, representations)
51
+
52
+ # filter out the scores from the positive samples
53
+ l_pos = torch.diag(similarity_matrix, self.batch_size)
54
+ r_pos = torch.diag(similarity_matrix, -self.batch_size)
55
+ positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
56
+
57
+ negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
58
+
59
+ logits = torch.cat((positives, negatives), dim=1)
60
+ logits /= self.temperature
61
+
62
+ labels = torch.zeros(2 * self.batch_size).to(self.device).long()
63
+ loss = self.criterion(logits, labels)
64
+
65
+ return loss / (2 * self.batch_size)
hdl/models/__init__.py ADDED
File without changes
@@ -0,0 +1,176 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
6
+ from hdl.layers.graph.chiral_graph import (
7
+ GCNConv,
8
+ GINEConv,
9
+ DMPNNConv,
10
+ # get_tetra_update,
11
+ )
12
+
13
+ from hdl.layers.graph.tetra import (
14
+ # get_tetra_update,
15
+ TETRA_UPDATE_DICT
16
+ )
17
+
18
+
19
+ class GNN(nn.Module):
20
+ def __init__(
21
+ self,
22
+ # args,
23
+ num_node_features: int = 48,
24
+ num_edge_features: int = 7,
25
+ depth: int = 15,
26
+ hidden_size: int = 128,
27
+ dropout: float = 0.1,
28
+ gnn_type: str = 'dmpnn',
29
+ graph_pool: str = 'mean',
30
+ tetra: bool = True,
31
+ task: str = 'classification',
32
+ output_num: int = None,
33
+ message: str = 'tetra_permute_concat',
34
+ include_vars: bool = False,
35
+ ):
36
+ super(GNN, self).__init__()
37
+
38
+ self.init_args = {
39
+ "num_node_features": num_node_features,
40
+ "num_edge_features": num_edge_features,
41
+ "depth": depth,
42
+ "hidden_size": hidden_size,
43
+ "dropout": dropout,
44
+ "gnn_type": gnn_type,
45
+ "graph_pool": graph_pool,
46
+ "tetra": tetra,
47
+ "task": task,
48
+ "message": message,
49
+ "include_vars": include_vars
50
+ }
51
+
52
+ self.depth = depth
53
+ self.hidden_size = hidden_size
54
+ self.dropout = dropout
55
+ self.gnn_type = gnn_type
56
+ self.graph_pool = graph_pool
57
+ self.tetra = tetra
58
+ self.task = task
59
+ self.out_dim = output_num
60
+ self.include_vars = include_vars
61
+
62
+ if self.gnn_type == 'dmpnn':
63
+ self.edge_init = nn.Linear(61, self.hidden_size)
64
+ self.edge_to_node = DMPNNConv(
65
+ hidden_size=hidden_size,
66
+ tetra=tetra,
67
+ message=message
68
+ )
69
+ else:
70
+ self.node_init = nn.Linear(num_node_features, self.hidden_size)
71
+ self.edge_init = nn.Linear(13, self.hidden_size)
72
+
73
+ # layers
74
+ self.convs = torch.nn.ModuleList()
75
+
76
+ for _ in range(self.depth):
77
+ if self.gnn_type == 'gin':
78
+ self.convs.append(GINEConv(
79
+ hidden_size=hidden_size,
80
+ tetra=tetra,
81
+ message=message
82
+ ))
83
+ elif self.gnn_type == 'gcn':
84
+ self.convs.append(GCNConv(
85
+ hidden_size=hidden_size,
86
+ tetra=tetra,
87
+ message=message
88
+ ))
89
+ elif self.gnn_type == 'dmpnn':
90
+ self.convs.append(DMPNNConv(
91
+ hidden_size=hidden_size,
92
+ tetra=tetra,
93
+ message=message
94
+ ))
95
+ else:
96
+ ValueError('Undefined GNN type called {}'.format(self.gnn_type))
97
+
98
+ # graph pooling
99
+ if self.tetra:
100
+ self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
101
+ # self.tetra_update = get_tetra_update(args)
102
+
103
+ if self.graph_pool == "sum":
104
+ self.pool = global_add_pool
105
+ elif self.graph_pool == "mean":
106
+ self.pool = global_mean_pool
107
+ elif self.graph_pool == "max":
108
+ self.pool = global_max_pool
109
+ elif self.graph_pool == "attn":
110
+ self.pool = GlobalAttention(
111
+ gate_nn=torch.nn.Sequential(torch.nn.Linear(self.hidden_size, 2 * self.hidden_size),
112
+ torch.nn.BatchNorm1d(2 * self.hidden_size),
113
+ torch.nn.ReLU(),
114
+ torch.nn.Linear(2 * self.hidden_size, 1)))
115
+ elif self.graph_pool == "set2set":
116
+ self.pool = Set2Set(self.hidden_size, processing_steps=2)
117
+ else:
118
+ raise ValueError("Invalid graph pooling type.")
119
+
120
+ # ffn
121
+ self.mult = 2 if self.graph_pool == "set2set" else 1
122
+ if self.include_vars:
123
+ out_dim = 2
124
+ elif self.out_dim:
125
+ out_dim = self.out_dim
126
+ else:
127
+ out_dim = 1
128
+ self.ffn = nn.Linear(self.mult * self.hidden_size, out_dim)
129
+
130
+ def forward(self, data):
131
+ x, edge_index, edge_attr, batch, parity_atoms, parity_bond_index = data.x, data.edge_index, data.edge_attr, data.batch, data.parity_atoms, data.parity_bond_index
132
+
133
+ if self.gnn_type == 'dmpnn':
134
+ row, col = edge_index
135
+ edge_attr = torch.cat([x[row], edge_attr], dim=1)
136
+ edge_attr = F.relu(self.edge_init(edge_attr))
137
+ else:
138
+ x = F.relu(self.node_init(x))
139
+ edge_attr = F.relu(self.edge_init(edge_attr))
140
+
141
+ x_list = [x]
142
+ edge_attr_list = [edge_attr]
143
+
144
+ # convolutions
145
+ for layer_idx in range(self.depth):
146
+
147
+ x_h, edge_attr_h = self.convs[layer_idx](x_list[-1], edge_index, edge_attr_list[-1], parity_atoms, parity_bond_index)
148
+ h = edge_attr_h if self.gnn_type == 'dmpnn' else x_h
149
+
150
+ if layer_idx == self.depth - 1:
151
+ h = F.dropout(h, self.dropout, training=self.training)
152
+ else:
153
+ h = F.dropout(F.relu(h), self.dropout, training=self.training)
154
+
155
+ if self.gnn_type == 'dmpnn':
156
+ h += edge_attr_h
157
+ edge_attr_list.append(h)
158
+ else:
159
+ h += x_h
160
+ x_list.append(h)
161
+
162
+ # dmpnn edge -> node aggregation
163
+ if self.gnn_type == 'dmpnn':
164
+ h, _ = self.edge_to_node(x_list[-1], edge_index, h, parity_atoms, parity_bond_index)
165
+
166
+ if self.task == 'regression':
167
+ output = torch.sigmoid(self.ffn(self.pool(h, batch)))
168
+ elif self.task == 'classification':
169
+
170
+ output = torch.sigmoid(self.ffn(self.pool(h, batch)))
171
+ # mean = output[:, 0]
172
+ if not self.include_vars:
173
+ return output
174
+ else:
175
+ mean, var = F.softplus(output[:, 1])
176
+ return mean, var