hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,111 @@
1
+ import csv
2
+ import os
3
+ import pickle
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from rdkit.Chem import PandasTools
9
+
10
+
11
+ SMI_REGEX_PATTERN = \
12
+ r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.\|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
13
+
14
+
15
+ def save_features(path: str, features: List[np.ndarray]) -> None:
16
+ """
17
+ Saves features to a compressed :code:`.npz` file with array name "features".
18
+
19
+ :param path: Path to a :code:`.npz` file where the features will be saved.
20
+ :param features: A list of 1D numpy arrays containing the features for molecules.
21
+ """
22
+ np.savez_compressed(path, features=features)
23
+
24
+
25
+ def load_features(path: str) -> np.ndarray:
26
+ """
27
+ Loads features saved in a variety of formats.
28
+
29
+ Supported formats:
30
+
31
+ * :code:`.npz` compressed (assumes features are saved with name "features")
32
+ * .npy
33
+ * :code:`.csv` / :code:`.txt` (assumes comma-separated features with a header and with one line per molecule)
34
+ * :code:`.pkl` / :code:`.pckl` / :code:`.pickle` containing a sparse numpy array
35
+
36
+ .. note::
37
+
38
+ All formats assume that the SMILES loaded elsewhere in the code are in the same
39
+ order as the features loaded here.
40
+
41
+ :param path: Path to a file containing features.
42
+ :return: A 2D numpy array of size :code:`(num_molecules, features_size)` containing the features.
43
+ """
44
+ extension = os.path.splitext(path)[1]
45
+
46
+ if extension == '.npz':
47
+ features = np.load(path)['features']
48
+ elif extension == '.npy':
49
+ features = np.load(path)
50
+ elif extension in ['.csv', '.txt']:
51
+ with open(path) as f:
52
+ reader = csv.reader(f)
53
+ next(reader) # skip header
54
+ features = np.array([[float(value) for value in row] for row in reader])
55
+ elif extension in ['.pkl', '.pckl', '.pickle']:
56
+ with open(path, 'rb') as f:
57
+ features = np.array([np.squeeze(np.array(feat.todense())) for feat in pickle.load(f)])
58
+ else:
59
+ raise ValueError(f'Features path extension {extension} not supported.')
60
+
61
+ return features
62
+
63
+
64
+ def load_valid_atom_or_bond_features(path: str, smiles: List[str]) -> List[np.ndarray]:
65
+ """
66
+ Loads features saved in a variety of formats.
67
+
68
+ Supported formats:
69
+
70
+ * :code:`.npz` descriptors are saved as 2D array for each molecule in the order of that in the data.csv
71
+ * :code:`.pkl` / :code:`.pckl` / :code:`.pickle` containing a pandas dataframe with smiles as index and numpy array of descriptors as columns
72
+ * :code:'.sdf' containing all mol blocks with descriptors as entries
73
+
74
+ :param path: Path to file containing atomwise features.
75
+ :return: A list of 2D array.
76
+ """
77
+
78
+ extension = os.path.splitext(path)[1]
79
+
80
+ if extension == '.npz':
81
+ container = np.load(path)
82
+ features = [container[key] for key in container]
83
+
84
+ elif extension in ['.pkl', '.pckl', '.pickle']:
85
+ features_df = pd.read_pickle(path)
86
+ if features_df.iloc[0, 0].ndim == 1:
87
+ features = features_df.apply(lambda x: np.stack(x.tolist(), axis=1), axis=1).tolist()
88
+ elif features_df.iloc[0, 0].ndim == 2:
89
+ features = features_df.apply(lambda x: np.concatenate(x.tolist(), axis=1), axis=1).tolist()
90
+ else:
91
+ raise ValueError(f'Atom/bond descriptors input {path} format not supported')
92
+
93
+ elif extension == '.sdf':
94
+ features_df = PandasTools.LoadSDF(path).drop(['ID', 'ROMol'], axis=1).set_index('SMILES')
95
+
96
+ features_df = features_df[~features_df.index.duplicated()]
97
+
98
+ # locate atomic descriptors columns
99
+ features_df = features_df.iloc[:, features_df.iloc[0, :].apply(lambda x: isinstance(x, str) and ',' in x).to_list()]
100
+ features_df = features_df.reindex(smiles)
101
+ if features_df.isnull().any().any():
102
+ raise ValueError('Invalid custom atomic descriptors file, Nan found in data')
103
+
104
+ features_df = features_df.applymap(lambda x: np.array(x.replace('\r', '').replace('\n', '').split(',')).astype(float))
105
+
106
+ features = features_df.apply(lambda x: np.stack(x.tolist(), axis=1), axis=1).tolist()
107
+
108
+ else:
109
+ raise ValueError(f'Extension "{extension}" is not supported.')
110
+
111
+ return features
hdl/layers/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,14 @@
1
+ # import torch
2
+ import gpytorch
3
+
4
+
5
+ class ExactGPModel(gpytorch.models.ExactGP):
6
+ def __init__(self, train_x, train_y, likelihood):
7
+ super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
8
+ self.mean_module = gpytorch.means.ConstantMean()
9
+ self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
10
+
11
+ def forward(self, x):
12
+ mean_x = self.mean_module(x)
13
+ covar_x = self.covar_module(x)
14
+ return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)