hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
hdl/metric_loss/loss.py
ADDED
@@ -0,0 +1,79 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
|
6
|
+
from .multi_label import BPMLLLoss
|
7
|
+
|
8
|
+
|
9
|
+
def get_lossfunc(
|
10
|
+
name: str,
|
11
|
+
*args,
|
12
|
+
**kwargs
|
13
|
+
) -> t.Callable:
|
14
|
+
"""Get loss function by name
|
15
|
+
|
16
|
+
Args:
|
17
|
+
name (str): The name of the loss function
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
t.Callable: the loss function
|
21
|
+
"""
|
22
|
+
name = name.lower()
|
23
|
+
if name == 'bce':
|
24
|
+
return nn.BCELoss(*args, **kwargs)
|
25
|
+
elif name == 'ce':
|
26
|
+
return nn.CrossEntropyLoss(*args, **kwargs)
|
27
|
+
elif name == 'mse':
|
28
|
+
return nn.MSELoss(*args, **kwargs)
|
29
|
+
elif name == 'bpmll':
|
30
|
+
return BPMLLLoss(*args, **kwargs)
|
31
|
+
elif name == 'nll':
|
32
|
+
return nn.GaussianNLLLoss(*args, **kwargs)
|
33
|
+
|
34
|
+
|
35
|
+
def mtmc_loss(
|
36
|
+
y_preds: t.Iterable,
|
37
|
+
y_trues: t.Iterable,
|
38
|
+
loss_names: t.Iterable[str] = None,
|
39
|
+
individual: bool = False,
|
40
|
+
task_weights: t.List = None,
|
41
|
+
device=torch.device('cpu'),
|
42
|
+
**kwargs
|
43
|
+
):
|
44
|
+
num_tasks = len(y_preds)
|
45
|
+
if loss_names is None:
|
46
|
+
loss_func = nn.CrossEntropyLoss()
|
47
|
+
loss_funcs = [loss_func] * num_tasks
|
48
|
+
elif isinstance(loss_names, str):
|
49
|
+
loss_func = get_lossfunc(loss_names, **kwargs)
|
50
|
+
loss_funcs = [loss_func] * num_tasks
|
51
|
+
else:
|
52
|
+
loss_funcs = [
|
53
|
+
get_lossfunc(loss_str)
|
54
|
+
for loss_str in loss_names
|
55
|
+
]
|
56
|
+
|
57
|
+
if task_weights is None:
|
58
|
+
task_weights = torch.ones(num_tasks).to(device)
|
59
|
+
else:
|
60
|
+
assert len(task_weights) == num_tasks
|
61
|
+
task_weights = torch.FloatTensor(task_weights).to(device)
|
62
|
+
|
63
|
+
loss_values = [
|
64
|
+
loss_func(y_pred, y_true)
|
65
|
+
for y_pred, y_true, loss_func in zip(
|
66
|
+
y_preds, y_trues, loss_funcs
|
67
|
+
)
|
68
|
+
]
|
69
|
+
|
70
|
+
loss_final = sum([
|
71
|
+
loss_value * task_weight
|
72
|
+
for loss_value, task_weight in zip(loss_values, task_weights)
|
73
|
+
])
|
74
|
+
# loss_final = sum(loss_values) / num_tasks
|
75
|
+
if not individual:
|
76
|
+
return loss_final
|
77
|
+
else:
|
78
|
+
loss_list = [loss_value for loss_value in loss_values]
|
79
|
+
return (loss_final, loss_list)
|
@@ -0,0 +1,178 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Callable, List, Union
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from sklearn.metrics import (
|
7
|
+
auc, mean_absolute_error, mean_squared_error,
|
8
|
+
precision_recall_curve, r2_score,
|
9
|
+
roc_auc_score, accuracy_score, log_loss, matthews_corrcoef,
|
10
|
+
# top_k_accuracy_score
|
11
|
+
)
|
12
|
+
import torch
|
13
|
+
import torch.nn as nn
|
14
|
+
import scipy
|
15
|
+
|
16
|
+
|
17
|
+
def prc_auc(targets: List[int], preds: List[float]) -> float:
|
18
|
+
"""
|
19
|
+
Computes the area under the precision-recall curve.
|
20
|
+
|
21
|
+
:param targets: A list of binary targets.
|
22
|
+
:param preds: A list of prediction probabilities.
|
23
|
+
:return: The computed prc-auc.
|
24
|
+
"""
|
25
|
+
precision, recall, _ = precision_recall_curve(targets, preds)
|
26
|
+
return auc(recall, precision)
|
27
|
+
|
28
|
+
|
29
|
+
def bce(targets: List[int], preds: List[float]) -> float:
|
30
|
+
"""
|
31
|
+
Computes the binary cross entropy loss.
|
32
|
+
|
33
|
+
:param targets: A list of binary targets.
|
34
|
+
:param preds: A list of prediction probabilities.
|
35
|
+
:return: The computed binary cross entropy.
|
36
|
+
"""
|
37
|
+
# Don't use logits because the sigmoid is added in all places except training itself
|
38
|
+
bce_func = nn.BCELoss(reduction='mean')
|
39
|
+
loss = bce_func(target=torch.Tensor(targets), input=torch.Tensor(preds)).item()
|
40
|
+
|
41
|
+
return loss
|
42
|
+
|
43
|
+
|
44
|
+
def rmse(targets: List[float], preds: List[float]) -> float:
|
45
|
+
"""
|
46
|
+
Computes the root mean squared error.
|
47
|
+
|
48
|
+
:param targets: A list of targets.
|
49
|
+
:param preds: A list of predictions.
|
50
|
+
:return: The computed rmse.
|
51
|
+
"""
|
52
|
+
return math.sqrt(mean_squared_error(targets, preds))
|
53
|
+
|
54
|
+
|
55
|
+
def mse(targets: List[float], preds: List[float]) -> float:
|
56
|
+
"""
|
57
|
+
Computes the mean squared error.
|
58
|
+
|
59
|
+
:param targets: A list of targets.
|
60
|
+
:param preds: A list of predictions.
|
61
|
+
:return: The computed mse.
|
62
|
+
"""
|
63
|
+
return mean_squared_error(targets, preds)
|
64
|
+
|
65
|
+
|
66
|
+
def accuracy(targets: List[int], preds: Union[List[float], List[List[float]]], threshold: float = 0.5) -> float:
|
67
|
+
"""
|
68
|
+
Computes the accuracy of a binary prediction task using a given threshold for generating hard predictions.
|
69
|
+
|
70
|
+
Alternatively, computes accuracy for a multiclass prediction task by picking the largest probability.
|
71
|
+
|
72
|
+
:param targets: A list of binary targets.
|
73
|
+
:param preds: A list of prediction probabilities.
|
74
|
+
:param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0.
|
75
|
+
:return: The computed accuracy.
|
76
|
+
"""
|
77
|
+
if type(preds[0]) == list: # multiclass
|
78
|
+
hard_preds = [p.index(max(p)) for p in preds]
|
79
|
+
else:
|
80
|
+
hard_preds = [1 if p > threshold else 0 for p in preds] # binary prediction
|
81
|
+
|
82
|
+
return accuracy_score(targets, hard_preds)
|
83
|
+
|
84
|
+
|
85
|
+
def rsquared(x, y):
|
86
|
+
""" Return R^2 where x and y are array-like."""
|
87
|
+
|
88
|
+
_, _, r_value, _, _ = scipy.stats.linregress(x, y)
|
89
|
+
return r_value ** 2
|
90
|
+
|
91
|
+
|
92
|
+
def mcc(y_true, y_pred):
|
93
|
+
y_true = np.array(y_true).astype(int)
|
94
|
+
# y_true = np.where(y_true == 1, 1, -1).astype(int)
|
95
|
+
y_pred = np.array(y_pred)
|
96
|
+
y_pred = (y_pred >= 0.5).astype(int)
|
97
|
+
|
98
|
+
return matthews_corrcoef(y_true, y_pred)
|
99
|
+
|
100
|
+
|
101
|
+
def topk(y_true, y_pred, k=1):
|
102
|
+
|
103
|
+
y_true = np.array(y_true).astype(int)
|
104
|
+
|
105
|
+
y_pred = np.array(y_pred)
|
106
|
+
|
107
|
+
sorted_pred = np.argsort(y_pred, axis=1, kind='mergesort')[:, ::-1]
|
108
|
+
hits = (y_true == sorted_pred[:, :k].T).any(axis=0)
|
109
|
+
num_hits = np.sum(hits)
|
110
|
+
|
111
|
+
return num_hits / len(y_true)
|
112
|
+
|
113
|
+
|
114
|
+
def get_metric(metric: str) -> Callable[[Union[List[int], List[float]], List[float]], float]:
|
115
|
+
r"""
|
116
|
+
Gets the metric function corresponding to a given metric name.
|
117
|
+
|
118
|
+
Supports:
|
119
|
+
|
120
|
+
* :code:`auc`: Area under the receiver operating characteristic curve
|
121
|
+
* :code:`prc-auc`: Area under the precision recall curve
|
122
|
+
* :code:`rmse`: Root mean squared error
|
123
|
+
* :code:`mse`: Mean squared error
|
124
|
+
* :code:`mae`: Mean absolute error
|
125
|
+
* :code:`r2`: Coefficient of determination R\ :superscript:`2`
|
126
|
+
* :code:`accuracy`: Accuracy (using a threshold to binarize predictions)
|
127
|
+
* :code:`cross_entropy`: Cross entropy
|
128
|
+
* :code:`binary_cross_entropy`: Binary cross entropy
|
129
|
+
|
130
|
+
:param metric: Metric name.
|
131
|
+
:return: A metric function which takes as arguments a list of targets and a list of predictions and returns.
|
132
|
+
"""
|
133
|
+
if metric == 'mcc':
|
134
|
+
return mcc
|
135
|
+
|
136
|
+
if metric == 'rsquared':
|
137
|
+
return rsquared
|
138
|
+
|
139
|
+
if metric == 'auc':
|
140
|
+
return roc_auc_score
|
141
|
+
|
142
|
+
if metric == 'prc-auc':
|
143
|
+
return prc_auc
|
144
|
+
|
145
|
+
if metric == 'rmse':
|
146
|
+
return rmse
|
147
|
+
|
148
|
+
if metric == 'mse':
|
149
|
+
return mse
|
150
|
+
|
151
|
+
if metric == 'mae':
|
152
|
+
return mean_absolute_error
|
153
|
+
|
154
|
+
if metric == 'r2':
|
155
|
+
return r2_score
|
156
|
+
|
157
|
+
if metric == 'acc':
|
158
|
+
return accuracy
|
159
|
+
|
160
|
+
if metric == 'ce':
|
161
|
+
return log_loss
|
162
|
+
|
163
|
+
if metric == 'bce':
|
164
|
+
return bce
|
165
|
+
|
166
|
+
if metric == 'topk':
|
167
|
+
return topk
|
168
|
+
|
169
|
+
if metric == 'top3':
|
170
|
+
return partial(topk, k=3)
|
171
|
+
|
172
|
+
if metric == 'top5':
|
173
|
+
return partial(topk, k=5)
|
174
|
+
|
175
|
+
if metric == 'top10':
|
176
|
+
return partial(topk, k=10)
|
177
|
+
|
178
|
+
raise ValueError(f'Metric "{metric}" not supported.')
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
|
5
|
+
class BPMLLLoss(torch.nn.Module):
|
6
|
+
def __init__(self, reduction: str = 'mean', bias=(1, 1)):
|
7
|
+
super(BPMLLLoss, self).__init__()
|
8
|
+
self.bias = bias
|
9
|
+
self.reduction = reduction
|
10
|
+
assert len(self.bias) == 2 \
|
11
|
+
and all(map(lambda x: isinstance(x, int) and x > 0, bias)), \
|
12
|
+
"bias must be positive integers"
|
13
|
+
|
14
|
+
def forward(self, c: Tensor, y: Tensor) -> Tensor:
|
15
|
+
r"""
|
16
|
+
compute the loss, which has the form:
|
17
|
+
L = \sum_{i=1}^{m} \frac{1}{|Y_i| \cdot |\bar{Y}_i|} \sum_{(k, l) \in Y_i \times \bar{Y}_i} \exp{-c^i_k+c^i_l}
|
18
|
+
:param c: prediction tensor, size: batch_size * n_labels
|
19
|
+
:param y: target tensor, size: batch_size * n_labels
|
20
|
+
:return: size: scalar tensor
|
21
|
+
"""
|
22
|
+
y = y.float()
|
23
|
+
y_bar = -y + 1
|
24
|
+
y_norm = torch.pow(y.sum(dim=(1,)), self.bias[0])
|
25
|
+
y_bar_norm = torch.pow(y_bar.sum(dim=(1,)), self.bias[1])
|
26
|
+
assert torch.all(y_norm != 0) or torch.all(y_bar_norm != 0), \
|
27
|
+
"an instance cannot have none or all the labels"
|
28
|
+
loss = 1 / torch.mul(y_norm, y_bar_norm) \
|
29
|
+
* self.pairwise_sub_exp(y, y_bar, c)
|
30
|
+
|
31
|
+
if self.reduction == 'mean':
|
32
|
+
return torch.mean(loss)
|
33
|
+
elif self.reduction == 'none':
|
34
|
+
return loss
|
35
|
+
|
36
|
+
def pairwise_sub_exp(self, y: Tensor, y_bar: Tensor, c: Tensor) -> Tensor:
|
37
|
+
r"""
|
38
|
+
compute \sum_{(k, l) \in Y_i \times \bar{Y}_i} \exp{-c^i_k+c^i_l}
|
39
|
+
"""
|
40
|
+
truth_matrix = y.unsqueeze(2).float() @ y_bar.unsqueeze(1).float()
|
41
|
+
exp_matrix = torch.exp(c.unsqueeze(1) - c.unsqueeze(2))
|
42
|
+
return (torch.mul(truth_matrix, exp_matrix)).sum(dim=(1, 2))
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
|
5
|
+
class NTXentLoss(torch.nn.Module):
|
6
|
+
|
7
|
+
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
8
|
+
super(NTXentLoss, self).__init__()
|
9
|
+
self.batch_size = batch_size
|
10
|
+
self.temperature = temperature
|
11
|
+
self.device = device
|
12
|
+
self.softmax = torch.nn.Softmax(dim=-1)
|
13
|
+
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
|
14
|
+
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
15
|
+
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
16
|
+
|
17
|
+
def _get_similarity_function(self, use_cosine_similarity):
|
18
|
+
if use_cosine_similarity:
|
19
|
+
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
20
|
+
return self._cosine_simililarity
|
21
|
+
else:
|
22
|
+
return self._dot_simililarity
|
23
|
+
|
24
|
+
def _get_correlated_mask(self):
|
25
|
+
diag = np.eye(2 * self.batch_size)
|
26
|
+
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
|
27
|
+
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
|
28
|
+
mask = torch.from_numpy((diag + l1 + l2))
|
29
|
+
mask = (1 - mask).type(torch.bool)
|
30
|
+
return mask.to(self.device)
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def _dot_simililarity(x, y):
|
34
|
+
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
35
|
+
# x shape: (N, 1, C)
|
36
|
+
# y shape: (1, C, 2N)
|
37
|
+
# v shape: (N, 2N)
|
38
|
+
return v
|
39
|
+
|
40
|
+
def _cosine_simililarity(self, x, y):
|
41
|
+
# x shape: (N, 1, C)
|
42
|
+
# y shape: (1, 2N, C)
|
43
|
+
# v shape: (N, 2N)
|
44
|
+
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
45
|
+
return v
|
46
|
+
|
47
|
+
def forward(self, zis, zjs):
|
48
|
+
representations = torch.cat([zjs, zis], dim=0)
|
49
|
+
|
50
|
+
similarity_matrix = self.similarity_function(representations, representations)
|
51
|
+
|
52
|
+
# filter out the scores from the positive samples
|
53
|
+
l_pos = torch.diag(similarity_matrix, self.batch_size)
|
54
|
+
r_pos = torch.diag(similarity_matrix, -self.batch_size)
|
55
|
+
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
|
56
|
+
|
57
|
+
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
|
58
|
+
|
59
|
+
logits = torch.cat((positives, negatives), dim=1)
|
60
|
+
logits /= self.temperature
|
61
|
+
|
62
|
+
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
|
63
|
+
loss = self.criterion(logits, labels)
|
64
|
+
|
65
|
+
return loss / (2 * self.batch_size)
|
hdl/models/__init__.py
ADDED
File without changes
|
hdl/models/chiral_gnn.py
ADDED
@@ -0,0 +1,176 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
|
6
|
+
from hdl.layers.graph.chiral_graph import (
|
7
|
+
GCNConv,
|
8
|
+
GINEConv,
|
9
|
+
DMPNNConv,
|
10
|
+
# get_tetra_update,
|
11
|
+
)
|
12
|
+
|
13
|
+
from hdl.layers.graph.tetra import (
|
14
|
+
# get_tetra_update,
|
15
|
+
TETRA_UPDATE_DICT
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
class GNN(nn.Module):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
# args,
|
23
|
+
num_node_features: int = 48,
|
24
|
+
num_edge_features: int = 7,
|
25
|
+
depth: int = 15,
|
26
|
+
hidden_size: int = 128,
|
27
|
+
dropout: float = 0.1,
|
28
|
+
gnn_type: str = 'dmpnn',
|
29
|
+
graph_pool: str = 'mean',
|
30
|
+
tetra: bool = True,
|
31
|
+
task: str = 'classification',
|
32
|
+
output_num: int = None,
|
33
|
+
message: str = 'tetra_permute_concat',
|
34
|
+
include_vars: bool = False,
|
35
|
+
):
|
36
|
+
super(GNN, self).__init__()
|
37
|
+
|
38
|
+
self.init_args = {
|
39
|
+
"num_node_features": num_node_features,
|
40
|
+
"num_edge_features": num_edge_features,
|
41
|
+
"depth": depth,
|
42
|
+
"hidden_size": hidden_size,
|
43
|
+
"dropout": dropout,
|
44
|
+
"gnn_type": gnn_type,
|
45
|
+
"graph_pool": graph_pool,
|
46
|
+
"tetra": tetra,
|
47
|
+
"task": task,
|
48
|
+
"message": message,
|
49
|
+
"include_vars": include_vars
|
50
|
+
}
|
51
|
+
|
52
|
+
self.depth = depth
|
53
|
+
self.hidden_size = hidden_size
|
54
|
+
self.dropout = dropout
|
55
|
+
self.gnn_type = gnn_type
|
56
|
+
self.graph_pool = graph_pool
|
57
|
+
self.tetra = tetra
|
58
|
+
self.task = task
|
59
|
+
self.out_dim = output_num
|
60
|
+
self.include_vars = include_vars
|
61
|
+
|
62
|
+
if self.gnn_type == 'dmpnn':
|
63
|
+
self.edge_init = nn.Linear(61, self.hidden_size)
|
64
|
+
self.edge_to_node = DMPNNConv(
|
65
|
+
hidden_size=hidden_size,
|
66
|
+
tetra=tetra,
|
67
|
+
message=message
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
self.node_init = nn.Linear(num_node_features, self.hidden_size)
|
71
|
+
self.edge_init = nn.Linear(13, self.hidden_size)
|
72
|
+
|
73
|
+
# layers
|
74
|
+
self.convs = torch.nn.ModuleList()
|
75
|
+
|
76
|
+
for _ in range(self.depth):
|
77
|
+
if self.gnn_type == 'gin':
|
78
|
+
self.convs.append(GINEConv(
|
79
|
+
hidden_size=hidden_size,
|
80
|
+
tetra=tetra,
|
81
|
+
message=message
|
82
|
+
))
|
83
|
+
elif self.gnn_type == 'gcn':
|
84
|
+
self.convs.append(GCNConv(
|
85
|
+
hidden_size=hidden_size,
|
86
|
+
tetra=tetra,
|
87
|
+
message=message
|
88
|
+
))
|
89
|
+
elif self.gnn_type == 'dmpnn':
|
90
|
+
self.convs.append(DMPNNConv(
|
91
|
+
hidden_size=hidden_size,
|
92
|
+
tetra=tetra,
|
93
|
+
message=message
|
94
|
+
))
|
95
|
+
else:
|
96
|
+
ValueError('Undefined GNN type called {}'.format(self.gnn_type))
|
97
|
+
|
98
|
+
# graph pooling
|
99
|
+
if self.tetra:
|
100
|
+
self.tetra_update = TETRA_UPDATE_DICT[message](hidden_size)
|
101
|
+
# self.tetra_update = get_tetra_update(args)
|
102
|
+
|
103
|
+
if self.graph_pool == "sum":
|
104
|
+
self.pool = global_add_pool
|
105
|
+
elif self.graph_pool == "mean":
|
106
|
+
self.pool = global_mean_pool
|
107
|
+
elif self.graph_pool == "max":
|
108
|
+
self.pool = global_max_pool
|
109
|
+
elif self.graph_pool == "attn":
|
110
|
+
self.pool = GlobalAttention(
|
111
|
+
gate_nn=torch.nn.Sequential(torch.nn.Linear(self.hidden_size, 2 * self.hidden_size),
|
112
|
+
torch.nn.BatchNorm1d(2 * self.hidden_size),
|
113
|
+
torch.nn.ReLU(),
|
114
|
+
torch.nn.Linear(2 * self.hidden_size, 1)))
|
115
|
+
elif self.graph_pool == "set2set":
|
116
|
+
self.pool = Set2Set(self.hidden_size, processing_steps=2)
|
117
|
+
else:
|
118
|
+
raise ValueError("Invalid graph pooling type.")
|
119
|
+
|
120
|
+
# ffn
|
121
|
+
self.mult = 2 if self.graph_pool == "set2set" else 1
|
122
|
+
if self.include_vars:
|
123
|
+
out_dim = 2
|
124
|
+
elif self.out_dim:
|
125
|
+
out_dim = self.out_dim
|
126
|
+
else:
|
127
|
+
out_dim = 1
|
128
|
+
self.ffn = nn.Linear(self.mult * self.hidden_size, out_dim)
|
129
|
+
|
130
|
+
def forward(self, data):
|
131
|
+
x, edge_index, edge_attr, batch, parity_atoms, parity_bond_index = data.x, data.edge_index, data.edge_attr, data.batch, data.parity_atoms, data.parity_bond_index
|
132
|
+
|
133
|
+
if self.gnn_type == 'dmpnn':
|
134
|
+
row, col = edge_index
|
135
|
+
edge_attr = torch.cat([x[row], edge_attr], dim=1)
|
136
|
+
edge_attr = F.relu(self.edge_init(edge_attr))
|
137
|
+
else:
|
138
|
+
x = F.relu(self.node_init(x))
|
139
|
+
edge_attr = F.relu(self.edge_init(edge_attr))
|
140
|
+
|
141
|
+
x_list = [x]
|
142
|
+
edge_attr_list = [edge_attr]
|
143
|
+
|
144
|
+
# convolutions
|
145
|
+
for layer_idx in range(self.depth):
|
146
|
+
|
147
|
+
x_h, edge_attr_h = self.convs[layer_idx](x_list[-1], edge_index, edge_attr_list[-1], parity_atoms, parity_bond_index)
|
148
|
+
h = edge_attr_h if self.gnn_type == 'dmpnn' else x_h
|
149
|
+
|
150
|
+
if layer_idx == self.depth - 1:
|
151
|
+
h = F.dropout(h, self.dropout, training=self.training)
|
152
|
+
else:
|
153
|
+
h = F.dropout(F.relu(h), self.dropout, training=self.training)
|
154
|
+
|
155
|
+
if self.gnn_type == 'dmpnn':
|
156
|
+
h += edge_attr_h
|
157
|
+
edge_attr_list.append(h)
|
158
|
+
else:
|
159
|
+
h += x_h
|
160
|
+
x_list.append(h)
|
161
|
+
|
162
|
+
# dmpnn edge -> node aggregation
|
163
|
+
if self.gnn_type == 'dmpnn':
|
164
|
+
h, _ = self.edge_to_node(x_list[-1], edge_index, h, parity_atoms, parity_bond_index)
|
165
|
+
|
166
|
+
if self.task == 'regression':
|
167
|
+
output = torch.sigmoid(self.ffn(self.pool(h, batch)))
|
168
|
+
elif self.task == 'classification':
|
169
|
+
|
170
|
+
output = torch.sigmoid(self.ffn(self.pool(h, batch)))
|
171
|
+
# mean = output[:, 0]
|
172
|
+
if not self.include_vars:
|
173
|
+
return output
|
174
|
+
else:
|
175
|
+
mean, var = F.softplus(output[:, 1])
|
176
|
+
return mean, var
|