hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
hdl/models/utils.py ADDED
@@ -0,0 +1,83 @@
1
+ import typing as t
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def save_model(
8
+ model: t.Union[nn.Module, nn.DataParallel],
9
+ save_dir: str = "./model.ckpt",
10
+ epoch: int = 0,
11
+ optimizer: torch.optim.Optimizer = None,
12
+ loss: float = None,
13
+ ) -> None:
14
+ if isinstance(model, nn.DataParallel):
15
+ state_dict = model.module.state_dict()
16
+ else:
17
+ state_dict = model.state_dict()
18
+ if optimizer is None:
19
+ optim_params = None
20
+ else:
21
+ optim_params = optimizer.state_dict()
22
+ torch.save(
23
+ {
24
+ 'init_args': model.init_args,
25
+ 'epoch': epoch,
26
+ 'model_state_dict': state_dict,
27
+ 'optimizer_state_dict': optim_params,
28
+ 'loss': loss,
29
+ },
30
+ save_dir
31
+ )
32
+
33
+
34
+ def load_model(
35
+ save_dir: str,
36
+ model_name: str = None,
37
+ model: t.Union[nn.Module, nn.DataParallel] = None,
38
+ optimizer: torch.optim.Optimizer = None,
39
+ train: bool = False,
40
+ ) -> t.Tuple[
41
+ t.Union[nn.Module, nn.DataParallel],
42
+ torch.optim.Optimizer,
43
+ int,
44
+ float
45
+ ]:
46
+ from .model_dict import MODEL_DICT
47
+ checkpoint = torch.load(save_dir)
48
+ if model is None:
49
+ init_args = checkpoint['init_args']
50
+ assert model_name is not None
51
+ model = MODEL_DICT[model_name](**init_args)
52
+ model.load_state_dict(
53
+ checkpoint['model_state_dict'],
54
+ )
55
+
56
+ elif isinstance(model, nn.DataParallel):
57
+ state_dict = checkpoint['model_state_dict']
58
+ from collections import OrderedDict
59
+ new_state_dict = OrderedDict()
60
+
61
+ for k, v in state_dict.items():
62
+ if 'module' not in k:
63
+ k = 'module.' + k
64
+ else:
65
+ k = k.replace('features.module.', 'module.features.')
66
+ new_state_dict[k] = v
67
+ model.load_state_dict(new_state_dict)
68
+ else:
69
+ model.load_state_dict(
70
+ checkpoint['model_state_dict'],
71
+ )
72
+
73
+ if optimizer is not None:
74
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
75
+ epoch = checkpoint.get('epoch', 0)
76
+ loss = checkpoint.get('loss', 0.0)
77
+
78
+ if train:
79
+ model.train()
80
+ else:
81
+ model.eval()
82
+
83
+ return model, optimizer, epoch, loss
hdl/ops/__init__.py ADDED
File without changes
hdl/ops/utils.py ADDED
@@ -0,0 +1,42 @@
1
+ import typing as t
2
+
3
+ # import torch
4
+ from torch import nn
5
+
6
+
7
+ __all__ = [
8
+ 'get_activation'
9
+ ]
10
+
11
+
12
+ def get_activation(
13
+ name: str,
14
+ **kwargs
15
+ ) -> t.Callable:
16
+ """ Get activation module by name
17
+ Args:
18
+ name (str): The name of the activation function (relu, elu, selu)
19
+ args, kwargs: Other parameters
20
+ Returns:
21
+ nn.Module: The activation module
22
+ """
23
+ name = name.lower()
24
+ if name == 'relu':
25
+ inplace = kwargs.get('inplace', False)
26
+ return nn.ReLU(inplace=inplace)
27
+ elif name == 'elu':
28
+ alpha = kwargs.get('alpha', 1.)
29
+ inplace = kwargs.get('inplace', False)
30
+ return nn.ELU(alpha=alpha, inplace=inplace)
31
+ elif name == 'selu':
32
+ inplace = kwargs.get('inplace', False)
33
+ return nn.SELU(inplace=inplace)
34
+ elif name == 'softmax':
35
+ dim = kwargs.get('dim', -1)
36
+ return nn.Softmax(dim=dim)
37
+ elif name == 'sigmoid':
38
+ return nn.Sigmoid()
39
+ elif name == 'none':
40
+ return
41
+ else:
42
+ raise ValueError('Activation not implemented')
hdl/optims/__init__.py ADDED
File without changes
hdl/optims/nadam.py ADDED
@@ -0,0 +1,86 @@
1
+ import torch as torch
2
+ from torch import optim
3
+
4
+
5
+ class Nadam(optim.Adam):
6
+ """
7
+ Adaptive moment with Nesterov gradients.
8
+
9
+ http://cs229.stanford.edu/proj2015/054_report.pdf
10
+
11
+ Parameters
12
+ ----------
13
+ params
14
+ iterable of parameters to optimize or dicts defining
15
+ parameter groups
16
+ lr
17
+ learning rate (default: 1e-3)
18
+ betas
19
+ coefficients used for computing
20
+ running averages of gradient and its square (default: (0.9, 0.999))
21
+ eps
22
+ term added to the denominator to improve
23
+ numerical stability (default: 1e-8)
24
+ weight_decay
25
+ weight decay (L2 penalty) (default: 0)
26
+ decay
27
+ a decay scheme for `betas[0]`.
28
+ Default: :math:`\\beta * (1 - 0.5 * 0.96^{\\frac{t}{250}})`
29
+ where `t` is the training step.
30
+ """
31
+
32
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
33
+ decay=lambda x, t: x * (1. - .5 * .96 ** (t / 250.))):
34
+ super().__init__(params, lr, betas, eps, weight_decay)
35
+ self.decay = decay
36
+
37
+ def step(self, closure=None):
38
+ loss = None
39
+ if closure is not None:
40
+ loss = closure()
41
+
42
+ for group in self.param_groups:
43
+ for p in group['params']:
44
+ if p.grad is None:
45
+ continue
46
+ grad = p.grad.data
47
+ if grad.is_sparse:
48
+ raise RuntimeError('NAdam does not support sparse gradients, please consider SparseAdam instead')
49
+
50
+ state = self.state[p]
51
+
52
+ # State initialization
53
+ if len(state) == 0:
54
+ state['step'] = 0
55
+ # Exponential moving average of gradient values
56
+ state['exp_avg'] = torch.zeros_like(p.data)
57
+ # Exponential moving average of squared gradient values
58
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
59
+ # Beta1 accumulation
60
+ state['beta1_cum'] = 1.
61
+
62
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
63
+ beta1, beta2 = group['betas']
64
+
65
+ state['step'] += 1
66
+
67
+ if group['weight_decay'] != 0:
68
+ grad.add_(group['weight_decay'], p.data)
69
+
70
+ beta1_t = self.decay(beta1, state['step'])
71
+ beta1_tp1 = self.decay(beta1, state['step'] + 1.)
72
+ beta1_cum = state['beta1_cum'] * beta1_t
73
+
74
+ g_hat_t = grad / (1. - beta1_cum)
75
+ exp_avg.mul_(beta1).add_(1. - beta1, grad)
76
+ m_hat_t = exp_avg / (1. - beta1_cum * beta1_tp1)
77
+
78
+ exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
79
+ v_hat_t = exp_avg_sq / (1. - beta2 ** state['step'])
80
+ m_bar_t = (1. - beta1) * g_hat_t + beta1_tp1 * m_hat_t
81
+
82
+ denom = v_hat_t.sqrt().add_(group['eps'])
83
+ p.data.addcdiv_(-group['lr'], m_bar_t, denom)
84
+ state['beta1_cum'] = beta1_cum
85
+
86
+ return loss
hdl/utils/__init__.py ADDED
File without changes
@@ -0,0 +1,2 @@
1
+ from .sdf import sdf2df
2
+ from .query_info import query_a_compound
@@ -0,0 +1,149 @@
1
+ import re
2
+
3
+ import cirpy
4
+ import pubchempy as pcp
5
+ # from rdkit import Chem
6
+ import molvs as mv
7
+ from psycopg import sql
8
+
9
+ from hdl.utils.database_tools.connect import connect_by_infofile
10
+
11
+
12
+ def query_from_cir(query_name: str):
13
+ smiles = None
14
+ # cas_list = []
15
+ # name_list = []
16
+
17
+ cas_list = cirpy.resolve(query_name, 'cas')
18
+ if cas_list is None or not cas_list:
19
+ cas_list = []
20
+ if isinstance(cas_list, str):
21
+ cas_list = [cas_list]
22
+
23
+ name_list = cirpy.resolve(query_name, 'names')
24
+ if name_list is None or not name_list:
25
+ name_list = []
26
+ if isinstance(name_list, str):
27
+ name_list = [name_list]
28
+
29
+ smiles = cirpy.resolve(query_name, 'smiles')
30
+ try:
31
+ smiles = mv.standardize_smiles(smiles)
32
+ except Exception as e:
33
+ print(e)
34
+
35
+ return smiles, cas_list, name_list
36
+
37
+
38
+ def query_from_pubchem(query_name: str):
39
+ results = pcp.get_compounds(query_name, 'name')
40
+ smiles = None
41
+ name_list = set()
42
+ cas_list = set()
43
+
44
+ if any(results):
45
+ try:
46
+ smiles = mv.standardize_smiles(results[0].canonical_smiles)
47
+ except Exception as e:
48
+ smiles = results[0].canonical_smiles
49
+ print(smiles)
50
+ print(e)
51
+ for compound in results:
52
+ name_list.update(set(compound.synonyms))
53
+ for syn in compound.synonyms:
54
+ match = re.match('(\d{2,7}-\d\d-\d)', syn)
55
+ if match:
56
+ cas_list.add(match.group(1))
57
+
58
+ cas_list = list(cas_list)
59
+ name_list = list(name_list)
60
+
61
+ return smiles, cas_list, name_list
62
+
63
+
64
+ def query_a_compound(
65
+ query_name: str,
66
+ connect_info: str,
67
+ by: str = 'name',
68
+ log_file: str = './err.log'
69
+ ):
70
+ fei = None
71
+ found = False
72
+
73
+ query_name = query_name.lower()
74
+
75
+ by = 'name'
76
+ table = by + '_maps'
77
+ # query_name = 'adipic acid'
78
+ query = sql.SQL(
79
+ "select fei from {table} where {by} = %s"
80
+ ).format(
81
+ table=sql.Identifier(table),
82
+ by=sql.Identifier(by)
83
+ )
84
+ conn = connect_by_infofile(connect_info)
85
+
86
+ cur = conn.execute(query, [query_name]).fetchone()
87
+
88
+ if cur is not None:
89
+ fei = cur[0]
90
+ found = True
91
+ return fei
92
+
93
+ if not found:
94
+ try:
95
+ smiles, cas_list, name_list = query_from_pubchem(query_name)
96
+ except Exception as e:
97
+ print(e)
98
+ smiles, cas_list, name_list = None, [], []
99
+ if smiles is not None:
100
+ found = True
101
+ else:
102
+ try:
103
+ smiles, cas_list, name_list = query_from_cir(query_name)
104
+ except Exception as e:
105
+ print(e)
106
+ smiles, cas_list, name_list = None, [], []
107
+ if smiles is not None:
108
+ found = True
109
+
110
+ if not found:
111
+ with open(log_file, 'a') as f:
112
+ f.write(query_name)
113
+ f.write('\n')
114
+ return
115
+ # raise ValueError('给的啥破玩意儿查都查不着!')
116
+ else:
117
+ query_compound = sql.SQL(
118
+ "select fei from compounds where smiles = %s"
119
+ )
120
+ cur = conn.execute(query_compound, [smiles]).fetchone()
121
+ if cur is not None:
122
+ fei = cur[0]
123
+ elif any(cas_list):
124
+ fei = cas_list[0]
125
+ insert_compounds_sql = sql.SQL(
126
+ "INSERT INTO compounds (fei, smiles) VALUES (%s, %s) ON CONFLICT (fei) DO NOTHING"
127
+ )
128
+ conn.execute(insert_compounds_sql, [fei, smiles])
129
+ for cas in cas_list:
130
+ insert_cas_map_sql = sql.SQL(
131
+ "INSERT INTO cas_maps (fei, cas) VALUES (%s, %s) ON CONFLICT (cas) DO NOTHING"
132
+ )
133
+ try:
134
+ conn.execute(insert_cas_map_sql, [fei, cas])
135
+ except Exception as e:
136
+ print(e)
137
+ for name in name_list:
138
+ insert_name_map_sql = sql.SQL(
139
+ "INSERT INTO name_maps (fei, name) VALUES (%s, %s) ON CONFLICT (name) DO NOTHING"
140
+ )
141
+ try:
142
+ conn.execute(insert_name_map_sql, [fei, name.lower()])
143
+ except Exception as e:
144
+ print(e)
145
+
146
+ conn.commit()
147
+ conn.close()
148
+
149
+ return fei
@@ -0,0 +1,20 @@
1
+ from rdkit import Chem
2
+ import pandas as pd
3
+
4
+
5
+ def sdf2df(
6
+ sdf_file,
7
+ id_col: str = 'Molecule Name',
8
+ target_col: str = 'Average △G (kcal/mol)'
9
+ ):
10
+ supp = Chem.SDMolSupplier(sdf_file)
11
+ mol_dict_list = []
12
+ for mol in supp:
13
+ smiles = Chem.MolToSmiles(mol)
14
+ mol_dict = mol.GetPropsAsDict()
15
+ mol_dict['smiles'] = smiles
16
+ mol_dict_list.append(mol_dict)
17
+ mol_dict['y'] = mol_dict.pop(target_col)
18
+ mol_dict['name'] = mol_dict.pop(id_col)
19
+ df = pd.DataFrame(mol_dict_list)
20
+ return df
File without changes
@@ -0,0 +1,28 @@
1
+ import psycopg
2
+
3
+
4
+ def connect_by_infofile(info_file: str) -> psycopg.Connection:
5
+ """Create a postgres connection
6
+
7
+ Args:
8
+ info_file (str):
9
+ the path of the connection info like
10
+ host=127.0.0.1 dbname=dbname port=5432 user=postgres password=lala
11
+
12
+ Returns:
13
+ psycopg.Connection:
14
+ the connection instance should be closed after committing.
15
+ """
16
+ conn = psycopg.connect(
17
+ open(info_file).readline()
18
+ )
19
+ return conn
20
+ # with psycopg.connect(
21
+ # open('./conn.info').readline()
22
+ # ) as conn:
23
+ # cur = conn.execute('select * from name_maps;')
24
+ # cur.fetchone()
25
+ # for record in cur:
26
+ # print(record)
27
+ # conn.commit()
28
+ # conn.close()
File without changes
@@ -0,0 +1,21 @@
1
+ import subprocess
2
+
3
+
4
+ def get_num_lines(file):
5
+ num_lines = subprocess.check_output(
6
+ ['wc', '-l', file]
7
+ ).split()[0]
8
+ return int(num_lines)
9
+
10
+
11
+ def str_from_line(file, line, split=False):
12
+ smi = subprocess.check_output(
13
+ # ['sed','-n', f'{str(i+1)}p', file]
14
+ ["sed", f"{str(line + 1)}q;d", file]
15
+ )
16
+ if isinstance(smi, bytes):
17
+ smi = smi.decode().strip()
18
+ if split:
19
+ if ' ' or '\t' in smi:
20
+ smi = smi.split()[0]
21
+ return smi
File without changes
@@ -0,0 +1,108 @@
1
+ from typing import List, Union
2
+ import numpy as np
3
+
4
+ from torch.optim import Optimizer
5
+ from torch.optim.lr_scheduler import _LRScheduler
6
+
7
+
8
+ class NoamLR(_LRScheduler):
9
+ """
10
+ Noam learning rate scheduler with piecewise linear increase and exponential decay.
11
+ The learning rate increases linearly from init_lr to max_lr over the course of
12
+ the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch).
13
+ Then the learning rate decreases exponentially from max_lr to final_lr over the
14
+ course of the remaining total_steps - warmup_steps (where total_steps =
15
+ total_epochs * steps_per_epoch). This is roughly based on the learning rate
16
+ schedule from Attention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762).
17
+ """
18
+ def __init__(self,
19
+ optimizer: Optimizer,
20
+ warmup_epochs: List[Union[float, int]],
21
+ total_epochs: List[int],
22
+ steps_per_epoch: int,
23
+ init_lr: List[float],
24
+ max_lr: List[float],
25
+ final_lr: List[float]):
26
+ """
27
+ Initializes the learning rate scheduler.
28
+ :param optimizer: A PyTorch optimizer.
29
+ :param warmup_epochs: The number of epochs during which to linearly increase the learning rate.
30
+ :param total_epochs: The total number of epochs.
31
+ :param steps_per_epoch: The number of steps (batches) per epoch.
32
+ :param init_lr: The initial learning rate.
33
+ :param max_lr: The maximum learning rate (achieved after warmup_epochs).
34
+ :param final_lr: The final learning rate (achieved after total_epochs).
35
+ """
36
+ assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \
37
+ len(max_lr) == len(final_lr)
38
+
39
+ self.num_lrs = len(optimizer.param_groups)
40
+
41
+ self.optimizer = optimizer
42
+ self.warmup_epochs = np.array(warmup_epochs)
43
+ self.total_epochs = np.array(total_epochs)
44
+ self.steps_per_epoch = steps_per_epoch
45
+ self.init_lr = np.array(init_lr)
46
+ self.max_lr = np.array(max_lr)
47
+ self.final_lr = np.array(final_lr)
48
+
49
+ self.current_step = 0
50
+ self.lr = init_lr
51
+ self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int)
52
+ self.total_steps = self.total_epochs * self.steps_per_epoch
53
+ self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps
54
+
55
+ self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))
56
+
57
+ super(NoamLR, self).__init__(optimizer)
58
+
59
+ def get_lr(self) -> List[float]:
60
+ """Gets a list of the current learning rates."""
61
+ return list(self.lr)
62
+
63
+ def step(self, current_step: int = None):
64
+ """
65
+ Updates the learning rate by taking a step.
66
+ :param current_step: Optionally specify what step to set the learning rate to.
67
+ If None, current_step = self.current_step + 1.
68
+ """
69
+ if current_step is not None:
70
+ self.current_step = current_step
71
+ else:
72
+ self.current_step += 1
73
+
74
+ for i in range(self.num_lrs):
75
+ if self.current_step <= self.warmup_steps[i]:
76
+ self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i]
77
+ elif self.current_step <= self.total_steps[i]:
78
+ self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i]))
79
+ else: # theoretically this case should never be reached since training should stop at total_steps
80
+ self.lr[i] = self.final_lr[i]
81
+
82
+ self.optimizer.param_groups[i]['lr'] = self.lr[i]
83
+
84
+
85
+ def build_lr_scheduler(
86
+ optimizer: Optimizer,
87
+ warmup_epochs: int,
88
+ n_epochs: int,
89
+ steps_per_epoch: int,
90
+ lr: float,
91
+ ) -> _LRScheduler:
92
+ """
93
+ Builds a learning rate scheduler.
94
+ :param optimizer: The Optimizer whose learning rate will be scheduled.
95
+ :param args: Arguments.
96
+ :param train_data_size: The size of the training dataset.
97
+ :return: An initialized learning rate scheduler.
98
+ """
99
+ # Learning rate scheduler
100
+ return NoamLR(
101
+ optimizer=optimizer,
102
+ warmup_epochs=[warmup_epochs],
103
+ total_epochs=[n_epochs],
104
+ steps_per_epoch=steps_per_epoch, # train_data_size // args.batch_size,
105
+ init_lr=[lr / 10],
106
+ max_lr=[lr],
107
+ final_lr=[lr / 10]
108
+ )
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.1
2
+ Name: hjxdl
3
+ Version: 0.0.1
4
+ Summary: A collection of functions for Jupyter notebooks
5
+ Home-page: https://github.com/huluxiaohuowa/hdl
6
+ Author: Jianxing Hu
7
+ Author-email: j.hu@pku.edu.cn
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.6
12
+ Description-Content-Type: text/markdown
13
+
14
+ # DL framework by Jianxing
15
+
16
+ ```bash
17
+ git clone git@github.com:huluxiaohuowa/hdl.git
18
+ python setup.py install
19
+ ```
@@ -0,0 +1,91 @@
1
+ hdl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ hdl/_version.py,sha256=pMnmqZnpVmaqR5nqHztNWzbbtb1oy5bPN_v7uhOH8K8,411
3
+ hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
+ hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ hdl/controllers/al/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ hdl/controllers/al/al.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ hdl/controllers/al/dispatcher.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ hdl/controllers/al/feedback.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ hdl/controllers/explain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ hdl/controllers/explain/shapley.py,sha256=6dPc_ICPgvllT8YtFxe9Ds-TVccsXr2M0lP4zKxGHQA,12436
12
+ hdl/controllers/explain/subgraphx.py,sha256=2lxdmKlveyxWpYf2AdTvjm_C578WH-_z1fqhiHOzYXQ,38202
13
+ hdl/controllers/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ hdl/controllers/train/rxn_train.py,sha256=G1MTRrcP-g9ohzl7tbGM1tYd9HAdcM-3_wHQtRYg4LE,4876
15
+ hdl/controllers/train/train.py,sha256=K7UjerhMFks9gQINQtUNXfGBN-XejmJZs0Mqzd7k5rU,1637
16
+ hdl/controllers/train/train_ginet.py,sha256=KFGIMwCF62HuuFCQcz5Vr394Lf0kkgOT9x_0pvMyVmU,8561
17
+ hdl/controllers/train/trainer_base.py,sha256=SZpa-4aV9ptKf4wgKVTNqqy4yJnLPOaEUX56LE28u74,4012
18
+ hdl/controllers/train/trainer_iterative.py,sha256=ydUJWuAvvfjosGBa4MvBeqqoNB6tlB4WWprBOXu9vG4,11803
19
+ hdl/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ hdl/data/to_mols.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ hdl/data/dataset/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ hdl/data/dataset/base_dataset.py,sha256=qPLCCpIV-9dRqXf3MuyjCHt5QAxB2v1lUMtyo3IJ0V8,2664
23
+ hdl/data/dataset/utils.py,sha256=tAh-a05Ireu6TiY89HHZ3WcLOz_Th8Cz0-lwlslklAM,1071
24
+ hdl/data/dataset/fp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ hdl/data/dataset/fp/fp_dataset.py,sha256=ei7M7xRL81hzX4r7ybxupkGE0Oj2dTXxCTgDYEWgtVQ,4033
26
+ hdl/data/dataset/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ hdl/data/dataset/graph/chiral.py,sha256=Tyz9-iBWoKESrYcnK6K0rif3A1SFcNMm9hbLfCmCVUA,1985
28
+ hdl/data/dataset/graph/gin.py,sha256=To-wvp-u-VXK6w3W3CohqcSCEfMZ7grqZQzr9KiEZQM,8848
29
+ hdl/data/dataset/graph/molnet.py,sha256=8VsKO3CDXwmq-3tlOFQLVIpxwoftAxI4WS9YtnOcI4Y,12806
30
+ hdl/data/dataset/loaders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ hdl/data/dataset/loaders/chiral_graph.py,sha256=efPpN5ovUqUhIPcPnL0zADD-zxaG8HRGD4KPwA_vbS8,2239
32
+ hdl/data/dataset/loaders/general.py,sha256=eObFlDsIfoiqJkeZHpJ4-cFuFCUfDbc2tfZe1z8qm70,532
33
+ hdl/data/dataset/loaders/spliter.py,sha256=5-OFKEbASsgaGaaYerNgRXm1WFRwLm_aZ-AZ0LefJ4U,3142
34
+ hdl/data/dataset/loaders/collate_funcs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
+ hdl/data/dataset/loaders/collate_funcs/fp.py,sha256=nch2UjodF9Of7fTc61j7Td1yBhYPhAfVmrA_F3K_LV0,1436
36
+ hdl/data/dataset/loaders/collate_funcs/rxn.py,sha256=LdwheQfilB9Obc7407nWgpWp7aGFFegTlYBeY20l77Y,860
37
+ hdl/data/dataset/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
+ hdl/data/dataset/samplers/chiral.py,sha256=ZS83kg5e2gdHVGgIuCjCepDwk2SKqWDgJawH37oXy78,463
39
+ hdl/data/dataset/seq/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
+ hdl/data/dataset/seq/rxn_dataset.py,sha256=jfXFlR3ITAf0KwUfIevzUZHnLBnFYrL69Cc81EMv0x0,1668
41
+ hdl/features/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ hdl/features/fp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
+ hdl/features/fp/features_generators.py,sha256=HbyS97i2I2mOcANdJMohs2okA1LlZmkG4ZIIX6Y9fr4,9017
44
+ hdl/features/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
+ hdl/features/graph/featurization.py,sha256=QLbj33JsgO-OWarIC2HXQP7eMu8pd-GWmppZQj_tQ_k,10902
46
+ hdl/features/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
+ hdl/features/utils/utils.py,sha256=aL4UAALblaw1w0AjK7MX8nSj9zwTmrp9CTLwJUX8ZtE,4225
48
+ hdl/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
+ hdl/layers/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
+ hdl/layers/general/gp.py,sha256=no1P6i2nCa539b0I5S6hd2mC8CeW0Ds726GM0swwlzc,525
51
+ hdl/layers/general/linear.py,sha256=d8NJwONVpIRr9malj1YGtK3rSgHzTMmSa-_eq46HGdI,18511
52
+ hdl/layers/graph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
+ hdl/layers/graph/chiral_graph.py,sha256=tqYPCUK4S5beUnuGzsCPrUyXEouCTi5CtUv74vZ9tws,8551
54
+ hdl/layers/graph/gcn.py,sha256=Eg-v672GQ3gBr4Sez3qARD2D95WAhFjC5i0j94IyFTo,397
55
+ hdl/layers/graph/gin.py,sha256=RkG8bZA_GgV2bcxwi6vCvk29QVVB9rCFchT3YpTULw4,1627
56
+ hdl/layers/graph/tetra.py,sha256=dKocIszOZh93ZVd0cznZ4AJGDYdoSFN6vcnCwIto6mI,5169
57
+ hdl/layers/graph/transformer.py,sha256=ZT7OId-0-i5obeoB-SG0XzotA-C-OeYl1CloGltdy48,7959
58
+ hdl/layers/sequential/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
+ hdl/metric_loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
+ hdl/metric_loss/loss.py,sha256=s89G3h6A4GVrqj0UXAvcvyVzJPvIfZQ_FjD5f5J6nh4,2024
61
+ hdl/metric_loss/metric.py,sha256=6fx1XOqtbekohgHZXJvLvnl73zictndCYoynJsy8Q-c,5030
62
+ hdl/metric_loss/multi_label.py,sha256=XOONQZVfO3EhcpwwO6IeQP2MacW0qDkHcrh0oEoMk5M,1738
63
+ hdl/metric_loss/nt_xent.py,sha256=CFPFf1mZ64SaAt5Q8DEFYU3AGSgcErCl49vOWIequLM,2476
64
+ hdl/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
65
+ hdl/models/chiral_gnn.py,sha256=NDwSbyefev9SawKJf1DfnBUxOcHxZCFoL13rbXAh2oc,6026
66
+ hdl/models/fast_transformer.py,sha256=tTBk6svXEdJDycQJtpGJMFudbVy2KJbrplSJUdvwofY,6594
67
+ hdl/models/ginet.py,sha256=k8_j6uvZ4TJU5eshrZOBBLaIo4YmLUX3WPrvgEtTn3A,5311
68
+ hdl/models/linear.py,sha256=brVYb1_Nq3IwPZ-i3ykS0JPYSxxmDb7QH1yaPMNDQJY,4523
69
+ hdl/models/model_dict.py,sha256=uRh13nch7rab3DajnDR_usp_1R0b37WUgNoiqOobk-s,402
70
+ hdl/models/norm_flows.py,sha256=nROlAaToUOtqi99_BPe6rlPHpUS38YFAtDll1xmGv5U,830
71
+ hdl/models/optim_dict.py,sha256=xvefZwCGebIXhRLSbcoGa3WVUWW2kEEm0Gsgp8Q9SQw,231
72
+ hdl/models/rxn.py,sha256=6MEkzjWEgWgl14EkdprvVtycX1q-xlYnXekM1ROKFL0,1557
73
+ hdl/models/utils.py,sha256=OOdwZ7f5ciaEFNM9zouNpLxj8S8kGKsSIM0msnoJ3bQ,2221
74
+ hdl/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
+ hdl/ops/utils.py,sha256=GIe95pJYVqOQQ91ELb1tfogVFzGg2VJdQmIFX8aHeyM,1053
76
+ hdl/optims/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
+ hdl/optims/nadam.py,sha256=l9DFQaHwE-uE0a_oSQZcvaqMv7lzYd2x_nE9no4zz64,3046
78
+ hdl/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
79
+ hdl/utils/chemical_tools/__init__.py,sha256=_QRNtVx0ieZZgSEsHndPFKm6XU3WXfRb7GYq-exe1nU,65
80
+ hdl/utils/chemical_tools/query_info.py,sha256=wyQXwKSY_gBGVUNvYggHpYBtOLAtpYKq3PN5wqDb7Co,4204
81
+ hdl/utils/chemical_tools/sdf.py,sha256=71PEqU0H885L6IeGHEa6n7ZLZThvMsZOVLuFG2wnoyM,542
82
+ hdl/utils/database_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
+ hdl/utils/database_tools/connect.py,sha256=KUnVG-8raifEJ_N0b3c8LkTTIfn9NIyw8LX6qvpA3YU,723
84
+ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
85
+ hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
86
+ hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
+ hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
88
+ hjxdl-0.0.1.dist-info/METADATA,sha256=w_oFdQOD0f8NHTb_8bVi7ypyO-lCZhHWqAMvGYADljo,525
89
+ hjxdl-0.0.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
90
+ hjxdl-0.0.1.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
91
+ hjxdl-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (70.1.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+