hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
hdl/models/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
|
6
|
+
|
7
|
+
def save_model(
|
8
|
+
model: t.Union[nn.Module, nn.DataParallel],
|
9
|
+
save_dir: str = "./model.ckpt",
|
10
|
+
epoch: int = 0,
|
11
|
+
optimizer: torch.optim.Optimizer = None,
|
12
|
+
loss: float = None,
|
13
|
+
) -> None:
|
14
|
+
if isinstance(model, nn.DataParallel):
|
15
|
+
state_dict = model.module.state_dict()
|
16
|
+
else:
|
17
|
+
state_dict = model.state_dict()
|
18
|
+
if optimizer is None:
|
19
|
+
optim_params = None
|
20
|
+
else:
|
21
|
+
optim_params = optimizer.state_dict()
|
22
|
+
torch.save(
|
23
|
+
{
|
24
|
+
'init_args': model.init_args,
|
25
|
+
'epoch': epoch,
|
26
|
+
'model_state_dict': state_dict,
|
27
|
+
'optimizer_state_dict': optim_params,
|
28
|
+
'loss': loss,
|
29
|
+
},
|
30
|
+
save_dir
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def load_model(
|
35
|
+
save_dir: str,
|
36
|
+
model_name: str = None,
|
37
|
+
model: t.Union[nn.Module, nn.DataParallel] = None,
|
38
|
+
optimizer: torch.optim.Optimizer = None,
|
39
|
+
train: bool = False,
|
40
|
+
) -> t.Tuple[
|
41
|
+
t.Union[nn.Module, nn.DataParallel],
|
42
|
+
torch.optim.Optimizer,
|
43
|
+
int,
|
44
|
+
float
|
45
|
+
]:
|
46
|
+
from .model_dict import MODEL_DICT
|
47
|
+
checkpoint = torch.load(save_dir)
|
48
|
+
if model is None:
|
49
|
+
init_args = checkpoint['init_args']
|
50
|
+
assert model_name is not None
|
51
|
+
model = MODEL_DICT[model_name](**init_args)
|
52
|
+
model.load_state_dict(
|
53
|
+
checkpoint['model_state_dict'],
|
54
|
+
)
|
55
|
+
|
56
|
+
elif isinstance(model, nn.DataParallel):
|
57
|
+
state_dict = checkpoint['model_state_dict']
|
58
|
+
from collections import OrderedDict
|
59
|
+
new_state_dict = OrderedDict()
|
60
|
+
|
61
|
+
for k, v in state_dict.items():
|
62
|
+
if 'module' not in k:
|
63
|
+
k = 'module.' + k
|
64
|
+
else:
|
65
|
+
k = k.replace('features.module.', 'module.features.')
|
66
|
+
new_state_dict[k] = v
|
67
|
+
model.load_state_dict(new_state_dict)
|
68
|
+
else:
|
69
|
+
model.load_state_dict(
|
70
|
+
checkpoint['model_state_dict'],
|
71
|
+
)
|
72
|
+
|
73
|
+
if optimizer is not None:
|
74
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
75
|
+
epoch = checkpoint.get('epoch', 0)
|
76
|
+
loss = checkpoint.get('loss', 0.0)
|
77
|
+
|
78
|
+
if train:
|
79
|
+
model.train()
|
80
|
+
else:
|
81
|
+
model.eval()
|
82
|
+
|
83
|
+
return model, optimizer, epoch, loss
|
hdl/ops/__init__.py
ADDED
File without changes
|
hdl/ops/utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
# import torch
|
4
|
+
from torch import nn
|
5
|
+
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
'get_activation'
|
9
|
+
]
|
10
|
+
|
11
|
+
|
12
|
+
def get_activation(
|
13
|
+
name: str,
|
14
|
+
**kwargs
|
15
|
+
) -> t.Callable:
|
16
|
+
""" Get activation module by name
|
17
|
+
Args:
|
18
|
+
name (str): The name of the activation function (relu, elu, selu)
|
19
|
+
args, kwargs: Other parameters
|
20
|
+
Returns:
|
21
|
+
nn.Module: The activation module
|
22
|
+
"""
|
23
|
+
name = name.lower()
|
24
|
+
if name == 'relu':
|
25
|
+
inplace = kwargs.get('inplace', False)
|
26
|
+
return nn.ReLU(inplace=inplace)
|
27
|
+
elif name == 'elu':
|
28
|
+
alpha = kwargs.get('alpha', 1.)
|
29
|
+
inplace = kwargs.get('inplace', False)
|
30
|
+
return nn.ELU(alpha=alpha, inplace=inplace)
|
31
|
+
elif name == 'selu':
|
32
|
+
inplace = kwargs.get('inplace', False)
|
33
|
+
return nn.SELU(inplace=inplace)
|
34
|
+
elif name == 'softmax':
|
35
|
+
dim = kwargs.get('dim', -1)
|
36
|
+
return nn.Softmax(dim=dim)
|
37
|
+
elif name == 'sigmoid':
|
38
|
+
return nn.Sigmoid()
|
39
|
+
elif name == 'none':
|
40
|
+
return
|
41
|
+
else:
|
42
|
+
raise ValueError('Activation not implemented')
|
hdl/optims/__init__.py
ADDED
File without changes
|
hdl/optims/nadam.py
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
import torch as torch
|
2
|
+
from torch import optim
|
3
|
+
|
4
|
+
|
5
|
+
class Nadam(optim.Adam):
|
6
|
+
"""
|
7
|
+
Adaptive moment with Nesterov gradients.
|
8
|
+
|
9
|
+
http://cs229.stanford.edu/proj2015/054_report.pdf
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
params
|
14
|
+
iterable of parameters to optimize or dicts defining
|
15
|
+
parameter groups
|
16
|
+
lr
|
17
|
+
learning rate (default: 1e-3)
|
18
|
+
betas
|
19
|
+
coefficients used for computing
|
20
|
+
running averages of gradient and its square (default: (0.9, 0.999))
|
21
|
+
eps
|
22
|
+
term added to the denominator to improve
|
23
|
+
numerical stability (default: 1e-8)
|
24
|
+
weight_decay
|
25
|
+
weight decay (L2 penalty) (default: 0)
|
26
|
+
decay
|
27
|
+
a decay scheme for `betas[0]`.
|
28
|
+
Default: :math:`\\beta * (1 - 0.5 * 0.96^{\\frac{t}{250}})`
|
29
|
+
where `t` is the training step.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
|
33
|
+
decay=lambda x, t: x * (1. - .5 * .96 ** (t / 250.))):
|
34
|
+
super().__init__(params, lr, betas, eps, weight_decay)
|
35
|
+
self.decay = decay
|
36
|
+
|
37
|
+
def step(self, closure=None):
|
38
|
+
loss = None
|
39
|
+
if closure is not None:
|
40
|
+
loss = closure()
|
41
|
+
|
42
|
+
for group in self.param_groups:
|
43
|
+
for p in group['params']:
|
44
|
+
if p.grad is None:
|
45
|
+
continue
|
46
|
+
grad = p.grad.data
|
47
|
+
if grad.is_sparse:
|
48
|
+
raise RuntimeError('NAdam does not support sparse gradients, please consider SparseAdam instead')
|
49
|
+
|
50
|
+
state = self.state[p]
|
51
|
+
|
52
|
+
# State initialization
|
53
|
+
if len(state) == 0:
|
54
|
+
state['step'] = 0
|
55
|
+
# Exponential moving average of gradient values
|
56
|
+
state['exp_avg'] = torch.zeros_like(p.data)
|
57
|
+
# Exponential moving average of squared gradient values
|
58
|
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
59
|
+
# Beta1 accumulation
|
60
|
+
state['beta1_cum'] = 1.
|
61
|
+
|
62
|
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
63
|
+
beta1, beta2 = group['betas']
|
64
|
+
|
65
|
+
state['step'] += 1
|
66
|
+
|
67
|
+
if group['weight_decay'] != 0:
|
68
|
+
grad.add_(group['weight_decay'], p.data)
|
69
|
+
|
70
|
+
beta1_t = self.decay(beta1, state['step'])
|
71
|
+
beta1_tp1 = self.decay(beta1, state['step'] + 1.)
|
72
|
+
beta1_cum = state['beta1_cum'] * beta1_t
|
73
|
+
|
74
|
+
g_hat_t = grad / (1. - beta1_cum)
|
75
|
+
exp_avg.mul_(beta1).add_(1. - beta1, grad)
|
76
|
+
m_hat_t = exp_avg / (1. - beta1_cum * beta1_tp1)
|
77
|
+
|
78
|
+
exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
|
79
|
+
v_hat_t = exp_avg_sq / (1. - beta2 ** state['step'])
|
80
|
+
m_bar_t = (1. - beta1) * g_hat_t + beta1_tp1 * m_hat_t
|
81
|
+
|
82
|
+
denom = v_hat_t.sqrt().add_(group['eps'])
|
83
|
+
p.data.addcdiv_(-group['lr'], m_bar_t, denom)
|
84
|
+
state['beta1_cum'] = beta1_cum
|
85
|
+
|
86
|
+
return loss
|
hdl/utils/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,149 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
import cirpy
|
4
|
+
import pubchempy as pcp
|
5
|
+
# from rdkit import Chem
|
6
|
+
import molvs as mv
|
7
|
+
from psycopg import sql
|
8
|
+
|
9
|
+
from hdl.utils.database_tools.connect import connect_by_infofile
|
10
|
+
|
11
|
+
|
12
|
+
def query_from_cir(query_name: str):
|
13
|
+
smiles = None
|
14
|
+
# cas_list = []
|
15
|
+
# name_list = []
|
16
|
+
|
17
|
+
cas_list = cirpy.resolve(query_name, 'cas')
|
18
|
+
if cas_list is None or not cas_list:
|
19
|
+
cas_list = []
|
20
|
+
if isinstance(cas_list, str):
|
21
|
+
cas_list = [cas_list]
|
22
|
+
|
23
|
+
name_list = cirpy.resolve(query_name, 'names')
|
24
|
+
if name_list is None or not name_list:
|
25
|
+
name_list = []
|
26
|
+
if isinstance(name_list, str):
|
27
|
+
name_list = [name_list]
|
28
|
+
|
29
|
+
smiles = cirpy.resolve(query_name, 'smiles')
|
30
|
+
try:
|
31
|
+
smiles = mv.standardize_smiles(smiles)
|
32
|
+
except Exception as e:
|
33
|
+
print(e)
|
34
|
+
|
35
|
+
return smiles, cas_list, name_list
|
36
|
+
|
37
|
+
|
38
|
+
def query_from_pubchem(query_name: str):
|
39
|
+
results = pcp.get_compounds(query_name, 'name')
|
40
|
+
smiles = None
|
41
|
+
name_list = set()
|
42
|
+
cas_list = set()
|
43
|
+
|
44
|
+
if any(results):
|
45
|
+
try:
|
46
|
+
smiles = mv.standardize_smiles(results[0].canonical_smiles)
|
47
|
+
except Exception as e:
|
48
|
+
smiles = results[0].canonical_smiles
|
49
|
+
print(smiles)
|
50
|
+
print(e)
|
51
|
+
for compound in results:
|
52
|
+
name_list.update(set(compound.synonyms))
|
53
|
+
for syn in compound.synonyms:
|
54
|
+
match = re.match('(\d{2,7}-\d\d-\d)', syn)
|
55
|
+
if match:
|
56
|
+
cas_list.add(match.group(1))
|
57
|
+
|
58
|
+
cas_list = list(cas_list)
|
59
|
+
name_list = list(name_list)
|
60
|
+
|
61
|
+
return smiles, cas_list, name_list
|
62
|
+
|
63
|
+
|
64
|
+
def query_a_compound(
|
65
|
+
query_name: str,
|
66
|
+
connect_info: str,
|
67
|
+
by: str = 'name',
|
68
|
+
log_file: str = './err.log'
|
69
|
+
):
|
70
|
+
fei = None
|
71
|
+
found = False
|
72
|
+
|
73
|
+
query_name = query_name.lower()
|
74
|
+
|
75
|
+
by = 'name'
|
76
|
+
table = by + '_maps'
|
77
|
+
# query_name = 'adipic acid'
|
78
|
+
query = sql.SQL(
|
79
|
+
"select fei from {table} where {by} = %s"
|
80
|
+
).format(
|
81
|
+
table=sql.Identifier(table),
|
82
|
+
by=sql.Identifier(by)
|
83
|
+
)
|
84
|
+
conn = connect_by_infofile(connect_info)
|
85
|
+
|
86
|
+
cur = conn.execute(query, [query_name]).fetchone()
|
87
|
+
|
88
|
+
if cur is not None:
|
89
|
+
fei = cur[0]
|
90
|
+
found = True
|
91
|
+
return fei
|
92
|
+
|
93
|
+
if not found:
|
94
|
+
try:
|
95
|
+
smiles, cas_list, name_list = query_from_pubchem(query_name)
|
96
|
+
except Exception as e:
|
97
|
+
print(e)
|
98
|
+
smiles, cas_list, name_list = None, [], []
|
99
|
+
if smiles is not None:
|
100
|
+
found = True
|
101
|
+
else:
|
102
|
+
try:
|
103
|
+
smiles, cas_list, name_list = query_from_cir(query_name)
|
104
|
+
except Exception as e:
|
105
|
+
print(e)
|
106
|
+
smiles, cas_list, name_list = None, [], []
|
107
|
+
if smiles is not None:
|
108
|
+
found = True
|
109
|
+
|
110
|
+
if not found:
|
111
|
+
with open(log_file, 'a') as f:
|
112
|
+
f.write(query_name)
|
113
|
+
f.write('\n')
|
114
|
+
return
|
115
|
+
# raise ValueError('给的啥破玩意儿查都查不着!')
|
116
|
+
else:
|
117
|
+
query_compound = sql.SQL(
|
118
|
+
"select fei from compounds where smiles = %s"
|
119
|
+
)
|
120
|
+
cur = conn.execute(query_compound, [smiles]).fetchone()
|
121
|
+
if cur is not None:
|
122
|
+
fei = cur[0]
|
123
|
+
elif any(cas_list):
|
124
|
+
fei = cas_list[0]
|
125
|
+
insert_compounds_sql = sql.SQL(
|
126
|
+
"INSERT INTO compounds (fei, smiles) VALUES (%s, %s) ON CONFLICT (fei) DO NOTHING"
|
127
|
+
)
|
128
|
+
conn.execute(insert_compounds_sql, [fei, smiles])
|
129
|
+
for cas in cas_list:
|
130
|
+
insert_cas_map_sql = sql.SQL(
|
131
|
+
"INSERT INTO cas_maps (fei, cas) VALUES (%s, %s) ON CONFLICT (cas) DO NOTHING"
|
132
|
+
)
|
133
|
+
try:
|
134
|
+
conn.execute(insert_cas_map_sql, [fei, cas])
|
135
|
+
except Exception as e:
|
136
|
+
print(e)
|
137
|
+
for name in name_list:
|
138
|
+
insert_name_map_sql = sql.SQL(
|
139
|
+
"INSERT INTO name_maps (fei, name) VALUES (%s, %s) ON CONFLICT (name) DO NOTHING"
|
140
|
+
)
|
141
|
+
try:
|
142
|
+
conn.execute(insert_name_map_sql, [fei, name.lower()])
|
143
|
+
except Exception as e:
|
144
|
+
print(e)
|
145
|
+
|
146
|
+
conn.commit()
|
147
|
+
conn.close()
|
148
|
+
|
149
|
+
return fei
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from rdkit import Chem
|
2
|
+
import pandas as pd
|
3
|
+
|
4
|
+
|
5
|
+
def sdf2df(
|
6
|
+
sdf_file,
|
7
|
+
id_col: str = 'Molecule Name',
|
8
|
+
target_col: str = 'Average △G (kcal/mol)'
|
9
|
+
):
|
10
|
+
supp = Chem.SDMolSupplier(sdf_file)
|
11
|
+
mol_dict_list = []
|
12
|
+
for mol in supp:
|
13
|
+
smiles = Chem.MolToSmiles(mol)
|
14
|
+
mol_dict = mol.GetPropsAsDict()
|
15
|
+
mol_dict['smiles'] = smiles
|
16
|
+
mol_dict_list.append(mol_dict)
|
17
|
+
mol_dict['y'] = mol_dict.pop(target_col)
|
18
|
+
mol_dict['name'] = mol_dict.pop(id_col)
|
19
|
+
df = pd.DataFrame(mol_dict_list)
|
20
|
+
return df
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import psycopg
|
2
|
+
|
3
|
+
|
4
|
+
def connect_by_infofile(info_file: str) -> psycopg.Connection:
|
5
|
+
"""Create a postgres connection
|
6
|
+
|
7
|
+
Args:
|
8
|
+
info_file (str):
|
9
|
+
the path of the connection info like
|
10
|
+
host=127.0.0.1 dbname=dbname port=5432 user=postgres password=lala
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
psycopg.Connection:
|
14
|
+
the connection instance should be closed after committing.
|
15
|
+
"""
|
16
|
+
conn = psycopg.connect(
|
17
|
+
open(info_file).readline()
|
18
|
+
)
|
19
|
+
return conn
|
20
|
+
# with psycopg.connect(
|
21
|
+
# open('./conn.info').readline()
|
22
|
+
# ) as conn:
|
23
|
+
# cur = conn.execute('select * from name_maps;')
|
24
|
+
# cur.fetchone()
|
25
|
+
# for record in cur:
|
26
|
+
# print(record)
|
27
|
+
# conn.commit()
|
28
|
+
# conn.close()
|
File without changes
|
@@ -0,0 +1,21 @@
|
|
1
|
+
import subprocess
|
2
|
+
|
3
|
+
|
4
|
+
def get_num_lines(file):
|
5
|
+
num_lines = subprocess.check_output(
|
6
|
+
['wc', '-l', file]
|
7
|
+
).split()[0]
|
8
|
+
return int(num_lines)
|
9
|
+
|
10
|
+
|
11
|
+
def str_from_line(file, line, split=False):
|
12
|
+
smi = subprocess.check_output(
|
13
|
+
# ['sed','-n', f'{str(i+1)}p', file]
|
14
|
+
["sed", f"{str(line + 1)}q;d", file]
|
15
|
+
)
|
16
|
+
if isinstance(smi, bytes):
|
17
|
+
smi = smi.decode().strip()
|
18
|
+
if split:
|
19
|
+
if ' ' or '\t' in smi:
|
20
|
+
smi = smi.split()[0]
|
21
|
+
return smi
|
File without changes
|
@@ -0,0 +1,108 @@
|
|
1
|
+
from typing import List, Union
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
from torch.optim import Optimizer
|
5
|
+
from torch.optim.lr_scheduler import _LRScheduler
|
6
|
+
|
7
|
+
|
8
|
+
class NoamLR(_LRScheduler):
|
9
|
+
"""
|
10
|
+
Noam learning rate scheduler with piecewise linear increase and exponential decay.
|
11
|
+
The learning rate increases linearly from init_lr to max_lr over the course of
|
12
|
+
the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch).
|
13
|
+
Then the learning rate decreases exponentially from max_lr to final_lr over the
|
14
|
+
course of the remaining total_steps - warmup_steps (where total_steps =
|
15
|
+
total_epochs * steps_per_epoch). This is roughly based on the learning rate
|
16
|
+
schedule from Attention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762).
|
17
|
+
"""
|
18
|
+
def __init__(self,
|
19
|
+
optimizer: Optimizer,
|
20
|
+
warmup_epochs: List[Union[float, int]],
|
21
|
+
total_epochs: List[int],
|
22
|
+
steps_per_epoch: int,
|
23
|
+
init_lr: List[float],
|
24
|
+
max_lr: List[float],
|
25
|
+
final_lr: List[float]):
|
26
|
+
"""
|
27
|
+
Initializes the learning rate scheduler.
|
28
|
+
:param optimizer: A PyTorch optimizer.
|
29
|
+
:param warmup_epochs: The number of epochs during which to linearly increase the learning rate.
|
30
|
+
:param total_epochs: The total number of epochs.
|
31
|
+
:param steps_per_epoch: The number of steps (batches) per epoch.
|
32
|
+
:param init_lr: The initial learning rate.
|
33
|
+
:param max_lr: The maximum learning rate (achieved after warmup_epochs).
|
34
|
+
:param final_lr: The final learning rate (achieved after total_epochs).
|
35
|
+
"""
|
36
|
+
assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \
|
37
|
+
len(max_lr) == len(final_lr)
|
38
|
+
|
39
|
+
self.num_lrs = len(optimizer.param_groups)
|
40
|
+
|
41
|
+
self.optimizer = optimizer
|
42
|
+
self.warmup_epochs = np.array(warmup_epochs)
|
43
|
+
self.total_epochs = np.array(total_epochs)
|
44
|
+
self.steps_per_epoch = steps_per_epoch
|
45
|
+
self.init_lr = np.array(init_lr)
|
46
|
+
self.max_lr = np.array(max_lr)
|
47
|
+
self.final_lr = np.array(final_lr)
|
48
|
+
|
49
|
+
self.current_step = 0
|
50
|
+
self.lr = init_lr
|
51
|
+
self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int)
|
52
|
+
self.total_steps = self.total_epochs * self.steps_per_epoch
|
53
|
+
self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps
|
54
|
+
|
55
|
+
self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))
|
56
|
+
|
57
|
+
super(NoamLR, self).__init__(optimizer)
|
58
|
+
|
59
|
+
def get_lr(self) -> List[float]:
|
60
|
+
"""Gets a list of the current learning rates."""
|
61
|
+
return list(self.lr)
|
62
|
+
|
63
|
+
def step(self, current_step: int = None):
|
64
|
+
"""
|
65
|
+
Updates the learning rate by taking a step.
|
66
|
+
:param current_step: Optionally specify what step to set the learning rate to.
|
67
|
+
If None, current_step = self.current_step + 1.
|
68
|
+
"""
|
69
|
+
if current_step is not None:
|
70
|
+
self.current_step = current_step
|
71
|
+
else:
|
72
|
+
self.current_step += 1
|
73
|
+
|
74
|
+
for i in range(self.num_lrs):
|
75
|
+
if self.current_step <= self.warmup_steps[i]:
|
76
|
+
self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i]
|
77
|
+
elif self.current_step <= self.total_steps[i]:
|
78
|
+
self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i]))
|
79
|
+
else: # theoretically this case should never be reached since training should stop at total_steps
|
80
|
+
self.lr[i] = self.final_lr[i]
|
81
|
+
|
82
|
+
self.optimizer.param_groups[i]['lr'] = self.lr[i]
|
83
|
+
|
84
|
+
|
85
|
+
def build_lr_scheduler(
|
86
|
+
optimizer: Optimizer,
|
87
|
+
warmup_epochs: int,
|
88
|
+
n_epochs: int,
|
89
|
+
steps_per_epoch: int,
|
90
|
+
lr: float,
|
91
|
+
) -> _LRScheduler:
|
92
|
+
"""
|
93
|
+
Builds a learning rate scheduler.
|
94
|
+
:param optimizer: The Optimizer whose learning rate will be scheduled.
|
95
|
+
:param args: Arguments.
|
96
|
+
:param train_data_size: The size of the training dataset.
|
97
|
+
:return: An initialized learning rate scheduler.
|
98
|
+
"""
|
99
|
+
# Learning rate scheduler
|
100
|
+
return NoamLR(
|
101
|
+
optimizer=optimizer,
|
102
|
+
warmup_epochs=[warmup_epochs],
|
103
|
+
total_epochs=[n_epochs],
|
104
|
+
steps_per_epoch=steps_per_epoch, # train_data_size // args.batch_size,
|
105
|
+
init_lr=[lr / 10],
|
106
|
+
max_lr=[lr],
|
107
|
+
final_lr=[lr / 10]
|
108
|
+
)
|
@@ -0,0 +1,19 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: hjxdl
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: A collection of functions for Jupyter notebooks
|
5
|
+
Home-page: https://github.com/huluxiaohuowa/hdl
|
6
|
+
Author: Jianxing Hu
|
7
|
+
Author-email: j.hu@pku.edu.cn
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
10
|
+
Classifier: Operating System :: OS Independent
|
11
|
+
Requires-Python: >=3.6
|
12
|
+
Description-Content-Type: text/markdown
|
13
|
+
|
14
|
+
# DL framework by Jianxing
|
15
|
+
|
16
|
+
```bash
|
17
|
+
git clone git@github.com:huluxiaohuowa/hdl.git
|
18
|
+
python setup.py install
|
19
|
+
```
|
@@ -0,0 +1,91 @@
|
|
1
|
+
hdl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
hdl/_version.py,sha256=pMnmqZnpVmaqR5nqHztNWzbbtb1oy5bPN_v7uhOH8K8,411
|
3
|
+
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
|
+
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
hdl/controllers/al/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
hdl/controllers/al/al.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
hdl/controllers/al/dispatcher.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
hdl/controllers/al/feedback.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
hdl/controllers/explain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
hdl/controllers/explain/shapley.py,sha256=6dPc_ICPgvllT8YtFxe9Ds-TVccsXr2M0lP4zKxGHQA,12436
|
12
|
+
hdl/controllers/explain/subgraphx.py,sha256=2lxdmKlveyxWpYf2AdTvjm_C578WH-_z1fqhiHOzYXQ,38202
|
13
|
+
hdl/controllers/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
hdl/controllers/train/rxn_train.py,sha256=G1MTRrcP-g9ohzl7tbGM1tYd9HAdcM-3_wHQtRYg4LE,4876
|
15
|
+
hdl/controllers/train/train.py,sha256=K7UjerhMFks9gQINQtUNXfGBN-XejmJZs0Mqzd7k5rU,1637
|
16
|
+
hdl/controllers/train/train_ginet.py,sha256=KFGIMwCF62HuuFCQcz5Vr394Lf0kkgOT9x_0pvMyVmU,8561
|
17
|
+
hdl/controllers/train/trainer_base.py,sha256=SZpa-4aV9ptKf4wgKVTNqqy4yJnLPOaEUX56LE28u74,4012
|
18
|
+
hdl/controllers/train/trainer_iterative.py,sha256=ydUJWuAvvfjosGBa4MvBeqqoNB6tlB4WWprBOXu9vG4,11803
|
19
|
+
hdl/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
+
hdl/data/to_mols.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
|
+
hdl/data/dataset/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
+
hdl/data/dataset/base_dataset.py,sha256=qPLCCpIV-9dRqXf3MuyjCHt5QAxB2v1lUMtyo3IJ0V8,2664
|
23
|
+
hdl/data/dataset/utils.py,sha256=tAh-a05Ireu6TiY89HHZ3WcLOz_Th8Cz0-lwlslklAM,1071
|
24
|
+
hdl/data/dataset/fp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
|
+
hdl/data/dataset/fp/fp_dataset.py,sha256=ei7M7xRL81hzX4r7ybxupkGE0Oj2dTXxCTgDYEWgtVQ,4033
|
26
|
+
hdl/data/dataset/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
+
hdl/data/dataset/graph/chiral.py,sha256=Tyz9-iBWoKESrYcnK6K0rif3A1SFcNMm9hbLfCmCVUA,1985
|
28
|
+
hdl/data/dataset/graph/gin.py,sha256=To-wvp-u-VXK6w3W3CohqcSCEfMZ7grqZQzr9KiEZQM,8848
|
29
|
+
hdl/data/dataset/graph/molnet.py,sha256=8VsKO3CDXwmq-3tlOFQLVIpxwoftAxI4WS9YtnOcI4Y,12806
|
30
|
+
hdl/data/dataset/loaders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
|
+
hdl/data/dataset/loaders/chiral_graph.py,sha256=efPpN5ovUqUhIPcPnL0zADD-zxaG8HRGD4KPwA_vbS8,2239
|
32
|
+
hdl/data/dataset/loaders/general.py,sha256=eObFlDsIfoiqJkeZHpJ4-cFuFCUfDbc2tfZe1z8qm70,532
|
33
|
+
hdl/data/dataset/loaders/spliter.py,sha256=5-OFKEbASsgaGaaYerNgRXm1WFRwLm_aZ-AZ0LefJ4U,3142
|
34
|
+
hdl/data/dataset/loaders/collate_funcs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
|
+
hdl/data/dataset/loaders/collate_funcs/fp.py,sha256=nch2UjodF9Of7fTc61j7Td1yBhYPhAfVmrA_F3K_LV0,1436
|
36
|
+
hdl/data/dataset/loaders/collate_funcs/rxn.py,sha256=LdwheQfilB9Obc7407nWgpWp7aGFFegTlYBeY20l77Y,860
|
37
|
+
hdl/data/dataset/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
|
+
hdl/data/dataset/samplers/chiral.py,sha256=ZS83kg5e2gdHVGgIuCjCepDwk2SKqWDgJawH37oXy78,463
|
39
|
+
hdl/data/dataset/seq/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
+
hdl/data/dataset/seq/rxn_dataset.py,sha256=jfXFlR3ITAf0KwUfIevzUZHnLBnFYrL69Cc81EMv0x0,1668
|
41
|
+
hdl/features/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
42
|
+
hdl/features/fp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
43
|
+
hdl/features/fp/features_generators.py,sha256=HbyS97i2I2mOcANdJMohs2okA1LlZmkG4ZIIX6Y9fr4,9017
|
44
|
+
hdl/features/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
45
|
+
hdl/features/graph/featurization.py,sha256=QLbj33JsgO-OWarIC2HXQP7eMu8pd-GWmppZQj_tQ_k,10902
|
46
|
+
hdl/features/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
+
hdl/features/utils/utils.py,sha256=aL4UAALblaw1w0AjK7MX8nSj9zwTmrp9CTLwJUX8ZtE,4225
|
48
|
+
hdl/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
|
+
hdl/layers/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
50
|
+
hdl/layers/general/gp.py,sha256=no1P6i2nCa539b0I5S6hd2mC8CeW0Ds726GM0swwlzc,525
|
51
|
+
hdl/layers/general/linear.py,sha256=d8NJwONVpIRr9malj1YGtK3rSgHzTMmSa-_eq46HGdI,18511
|
52
|
+
hdl/layers/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
53
|
+
hdl/layers/graph/chiral_graph.py,sha256=tqYPCUK4S5beUnuGzsCPrUyXEouCTi5CtUv74vZ9tws,8551
|
54
|
+
hdl/layers/graph/gcn.py,sha256=Eg-v672GQ3gBr4Sez3qARD2D95WAhFjC5i0j94IyFTo,397
|
55
|
+
hdl/layers/graph/gin.py,sha256=RkG8bZA_GgV2bcxwi6vCvk29QVVB9rCFchT3YpTULw4,1627
|
56
|
+
hdl/layers/graph/tetra.py,sha256=dKocIszOZh93ZVd0cznZ4AJGDYdoSFN6vcnCwIto6mI,5169
|
57
|
+
hdl/layers/graph/transformer.py,sha256=ZT7OId-0-i5obeoB-SG0XzotA-C-OeYl1CloGltdy48,7959
|
58
|
+
hdl/layers/sequential/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
|
+
hdl/metric_loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
|
+
hdl/metric_loss/loss.py,sha256=s89G3h6A4GVrqj0UXAvcvyVzJPvIfZQ_FjD5f5J6nh4,2024
|
61
|
+
hdl/metric_loss/metric.py,sha256=6fx1XOqtbekohgHZXJvLvnl73zictndCYoynJsy8Q-c,5030
|
62
|
+
hdl/metric_loss/multi_label.py,sha256=XOONQZVfO3EhcpwwO6IeQP2MacW0qDkHcrh0oEoMk5M,1738
|
63
|
+
hdl/metric_loss/nt_xent.py,sha256=CFPFf1mZ64SaAt5Q8DEFYU3AGSgcErCl49vOWIequLM,2476
|
64
|
+
hdl/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
65
|
+
hdl/models/chiral_gnn.py,sha256=NDwSbyefev9SawKJf1DfnBUxOcHxZCFoL13rbXAh2oc,6026
|
66
|
+
hdl/models/fast_transformer.py,sha256=tTBk6svXEdJDycQJtpGJMFudbVy2KJbrplSJUdvwofY,6594
|
67
|
+
hdl/models/ginet.py,sha256=k8_j6uvZ4TJU5eshrZOBBLaIo4YmLUX3WPrvgEtTn3A,5311
|
68
|
+
hdl/models/linear.py,sha256=brVYb1_Nq3IwPZ-i3ykS0JPYSxxmDb7QH1yaPMNDQJY,4523
|
69
|
+
hdl/models/model_dict.py,sha256=uRh13nch7rab3DajnDR_usp_1R0b37WUgNoiqOobk-s,402
|
70
|
+
hdl/models/norm_flows.py,sha256=nROlAaToUOtqi99_BPe6rlPHpUS38YFAtDll1xmGv5U,830
|
71
|
+
hdl/models/optim_dict.py,sha256=xvefZwCGebIXhRLSbcoGa3WVUWW2kEEm0Gsgp8Q9SQw,231
|
72
|
+
hdl/models/rxn.py,sha256=6MEkzjWEgWgl14EkdprvVtycX1q-xlYnXekM1ROKFL0,1557
|
73
|
+
hdl/models/utils.py,sha256=OOdwZ7f5ciaEFNM9zouNpLxj8S8kGKsSIM0msnoJ3bQ,2221
|
74
|
+
hdl/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
75
|
+
hdl/ops/utils.py,sha256=GIe95pJYVqOQQ91ELb1tfogVFzGg2VJdQmIFX8aHeyM,1053
|
76
|
+
hdl/optims/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
77
|
+
hdl/optims/nadam.py,sha256=l9DFQaHwE-uE0a_oSQZcvaqMv7lzYd2x_nE9no4zz64,3046
|
78
|
+
hdl/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
79
|
+
hdl/utils/chemical_tools/__init__.py,sha256=_QRNtVx0ieZZgSEsHndPFKm6XU3WXfRb7GYq-exe1nU,65
|
80
|
+
hdl/utils/chemical_tools/query_info.py,sha256=wyQXwKSY_gBGVUNvYggHpYBtOLAtpYKq3PN5wqDb7Co,4204
|
81
|
+
hdl/utils/chemical_tools/sdf.py,sha256=71PEqU0H885L6IeGHEa6n7ZLZThvMsZOVLuFG2wnoyM,542
|
82
|
+
hdl/utils/database_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
83
|
+
hdl/utils/database_tools/connect.py,sha256=KUnVG-8raifEJ_N0b3c8LkTTIfn9NIyw8LX6qvpA3YU,723
|
84
|
+
hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
85
|
+
hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
|
86
|
+
hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
87
|
+
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
88
|
+
hjxdl-0.0.1.dist-info/METADATA,sha256=w_oFdQOD0f8NHTb_8bVi7ypyO-lCZhHWqAMvGYADljo,525
|
89
|
+
hjxdl-0.0.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
90
|
+
hjxdl-0.0.1.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
91
|
+
hjxdl-0.0.1.dist-info/RECORD,,
|