hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,111 @@
|
|
1
|
+
import csv
|
2
|
+
import os
|
3
|
+
import pickle
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
from rdkit.Chem import PandasTools
|
9
|
+
|
10
|
+
|
11
|
+
SMI_REGEX_PATTERN = \
|
12
|
+
r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.\|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
|
13
|
+
|
14
|
+
|
15
|
+
def save_features(path: str, features: List[np.ndarray]) -> None:
|
16
|
+
"""
|
17
|
+
Saves features to a compressed :code:`.npz` file with array name "features".
|
18
|
+
|
19
|
+
:param path: Path to a :code:`.npz` file where the features will be saved.
|
20
|
+
:param features: A list of 1D numpy arrays containing the features for molecules.
|
21
|
+
"""
|
22
|
+
np.savez_compressed(path, features=features)
|
23
|
+
|
24
|
+
|
25
|
+
def load_features(path: str) -> np.ndarray:
|
26
|
+
"""
|
27
|
+
Loads features saved in a variety of formats.
|
28
|
+
|
29
|
+
Supported formats:
|
30
|
+
|
31
|
+
* :code:`.npz` compressed (assumes features are saved with name "features")
|
32
|
+
* .npy
|
33
|
+
* :code:`.csv` / :code:`.txt` (assumes comma-separated features with a header and with one line per molecule)
|
34
|
+
* :code:`.pkl` / :code:`.pckl` / :code:`.pickle` containing a sparse numpy array
|
35
|
+
|
36
|
+
.. note::
|
37
|
+
|
38
|
+
All formats assume that the SMILES loaded elsewhere in the code are in the same
|
39
|
+
order as the features loaded here.
|
40
|
+
|
41
|
+
:param path: Path to a file containing features.
|
42
|
+
:return: A 2D numpy array of size :code:`(num_molecules, features_size)` containing the features.
|
43
|
+
"""
|
44
|
+
extension = os.path.splitext(path)[1]
|
45
|
+
|
46
|
+
if extension == '.npz':
|
47
|
+
features = np.load(path)['features']
|
48
|
+
elif extension == '.npy':
|
49
|
+
features = np.load(path)
|
50
|
+
elif extension in ['.csv', '.txt']:
|
51
|
+
with open(path) as f:
|
52
|
+
reader = csv.reader(f)
|
53
|
+
next(reader) # skip header
|
54
|
+
features = np.array([[float(value) for value in row] for row in reader])
|
55
|
+
elif extension in ['.pkl', '.pckl', '.pickle']:
|
56
|
+
with open(path, 'rb') as f:
|
57
|
+
features = np.array([np.squeeze(np.array(feat.todense())) for feat in pickle.load(f)])
|
58
|
+
else:
|
59
|
+
raise ValueError(f'Features path extension {extension} not supported.')
|
60
|
+
|
61
|
+
return features
|
62
|
+
|
63
|
+
|
64
|
+
def load_valid_atom_or_bond_features(path: str, smiles: List[str]) -> List[np.ndarray]:
|
65
|
+
"""
|
66
|
+
Loads features saved in a variety of formats.
|
67
|
+
|
68
|
+
Supported formats:
|
69
|
+
|
70
|
+
* :code:`.npz` descriptors are saved as 2D array for each molecule in the order of that in the data.csv
|
71
|
+
* :code:`.pkl` / :code:`.pckl` / :code:`.pickle` containing a pandas dataframe with smiles as index and numpy array of descriptors as columns
|
72
|
+
* :code:'.sdf' containing all mol blocks with descriptors as entries
|
73
|
+
|
74
|
+
:param path: Path to file containing atomwise features.
|
75
|
+
:return: A list of 2D array.
|
76
|
+
"""
|
77
|
+
|
78
|
+
extension = os.path.splitext(path)[1]
|
79
|
+
|
80
|
+
if extension == '.npz':
|
81
|
+
container = np.load(path)
|
82
|
+
features = [container[key] for key in container]
|
83
|
+
|
84
|
+
elif extension in ['.pkl', '.pckl', '.pickle']:
|
85
|
+
features_df = pd.read_pickle(path)
|
86
|
+
if features_df.iloc[0, 0].ndim == 1:
|
87
|
+
features = features_df.apply(lambda x: np.stack(x.tolist(), axis=1), axis=1).tolist()
|
88
|
+
elif features_df.iloc[0, 0].ndim == 2:
|
89
|
+
features = features_df.apply(lambda x: np.concatenate(x.tolist(), axis=1), axis=1).tolist()
|
90
|
+
else:
|
91
|
+
raise ValueError(f'Atom/bond descriptors input {path} format not supported')
|
92
|
+
|
93
|
+
elif extension == '.sdf':
|
94
|
+
features_df = PandasTools.LoadSDF(path).drop(['ID', 'ROMol'], axis=1).set_index('SMILES')
|
95
|
+
|
96
|
+
features_df = features_df[~features_df.index.duplicated()]
|
97
|
+
|
98
|
+
# locate atomic descriptors columns
|
99
|
+
features_df = features_df.iloc[:, features_df.iloc[0, :].apply(lambda x: isinstance(x, str) and ',' in x).to_list()]
|
100
|
+
features_df = features_df.reindex(smiles)
|
101
|
+
if features_df.isnull().any().any():
|
102
|
+
raise ValueError('Invalid custom atomic descriptors file, Nan found in data')
|
103
|
+
|
104
|
+
features_df = features_df.applymap(lambda x: np.array(x.replace('\r', '').replace('\n', '').split(',')).astype(float))
|
105
|
+
|
106
|
+
features = features_df.apply(lambda x: np.stack(x.tolist(), axis=1), axis=1).tolist()
|
107
|
+
|
108
|
+
else:
|
109
|
+
raise ValueError(f'Extension "{extension}" is not supported.')
|
110
|
+
|
111
|
+
return features
|
hdl/layers/__init__.py
ADDED
File without changes
|
File without changes
|
hdl/layers/general/gp.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
# import torch
|
2
|
+
import gpytorch
|
3
|
+
|
4
|
+
|
5
|
+
class ExactGPModel(gpytorch.models.ExactGP):
|
6
|
+
def __init__(self, train_x, train_y, likelihood):
|
7
|
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
8
|
+
self.mean_module = gpytorch.means.ConstantMean()
|
9
|
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
10
|
+
|
11
|
+
def forward(self, x):
|
12
|
+
mean_x = self.mean_module(x)
|
13
|
+
covar_x = self.covar_module(x)
|
14
|
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|