tyche-tools 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tyche_tools-0.1.0/LICENSE +21 -0
- tyche_tools-0.1.0/PKG-INFO +17 -0
- tyche_tools-0.1.0/pyproject.toml +27 -0
- tyche_tools-0.1.0/setup.cfg +4 -0
- tyche_tools-0.1.0/tyche_tools/__init__.py +9 -0
- tyche_tools-0.1.0/tyche_tools/_features.py +175 -0
- tyche_tools-0.1.0/tyche_tools/_network.py +250 -0
- tyche_tools-0.1.0/tyche_tools/_utils.py +51 -0
- tyche_tools-0.1.0/tyche_tools/median.py +241 -0
- tyche_tools-0.1.0/tyche_tools/optimize.py +709 -0
- tyche_tools-0.1.0/tyche_tools/subspace.py +292 -0
- tyche_tools-0.1.0/tyche_tools.egg-info/PKG-INFO +17 -0
- tyche_tools-0.1.0/tyche_tools.egg-info/SOURCES.txt +14 -0
- tyche_tools-0.1.0/tyche_tools.egg-info/dependency_links.txt +1 -0
- tyche_tools-0.1.0/tyche_tools.egg-info/requires.txt +11 -0
- tyche_tools-0.1.0/tyche_tools.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Robert Pollice
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tyche-tools
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Extended cheminformatics toolkit: median molecules, chemical subspace enumeration, and evolutionary molecular optimization.
|
|
5
|
+
Project-URL: Repository, https://git.lwp.rug.nl/pollice-research-group/artificial-design/tyche
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: tyche-core
|
|
9
|
+
Requires-Dist: selfies<=2.1.2,>=2.0.0
|
|
10
|
+
Requires-Dist: rdkit
|
|
11
|
+
Requires-Dist: numpy
|
|
12
|
+
Provides-Extra: nn
|
|
13
|
+
Requires-Dist: torch; extra == "nn"
|
|
14
|
+
Provides-Extra: all
|
|
15
|
+
Requires-Dist: torch; extra == "all"
|
|
16
|
+
Requires-Dist: pyyaml; extra == "all"
|
|
17
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tyche-tools"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Extended cheminformatics toolkit: median molecules, chemical subspace enumeration, and evolutionary molecular optimization."
|
|
9
|
+
requires-python = ">=3.8"
|
|
10
|
+
license-files = ["LICENSE"]
|
|
11
|
+
dependencies = [
|
|
12
|
+
"tyche-core",
|
|
13
|
+
"selfies>=2.0.0,<=2.1.2",
|
|
14
|
+
"rdkit",
|
|
15
|
+
"numpy",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
nn = ["torch"]
|
|
20
|
+
all = ["torch", "pyyaml"]
|
|
21
|
+
|
|
22
|
+
[project.urls]
|
|
23
|
+
Repository = "https://git.lwp.rug.nl/pollice-research-group/artificial-design/tyche"
|
|
24
|
+
|
|
25
|
+
[tool.setuptools.packages.find]
|
|
26
|
+
where = ["."]
|
|
27
|
+
include = ["tyche_tools*"]
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Molecular feature extraction for the optimizer's neural network classifier.
|
|
2
|
+
|
|
3
|
+
Computes a 51-dimensional property vector for a given SMILES string. Features
|
|
4
|
+
include atom-count ratios, RDKit descriptors, bond type ratios, and ring
|
|
5
|
+
statistics.
|
|
6
|
+
"""
|
|
7
|
+
import inspect
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from rdkit import Chem, RDLogger
|
|
12
|
+
from rdkit.Chem import Descriptors
|
|
13
|
+
|
|
14
|
+
RDLogger.DisableLog('rdApp.*')
|
|
15
|
+
|
|
16
|
+
_DESCRIPTOR_NAMES = [
|
|
17
|
+
"RingCount", "HallKierAlpha", "BalabanJ",
|
|
18
|
+
"NumAliphaticCarbocycles", "NumAliphaticHeterocycles", "NumAliphaticRings",
|
|
19
|
+
"NumAromaticCarbocycles", "NumAromaticHeterocycles", "NumAromaticRings",
|
|
20
|
+
"NumHAcceptors", "NumHDonors", "NumHeteroatoms",
|
|
21
|
+
"NumRadicalElectrons", "NumSaturatedCarbocycles", "NumSaturatedHeterocycles",
|
|
22
|
+
"NumSaturatedRings", "NumValenceElectrons",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
_ROTATABLE_BOND_SMARTS = Chem.MolFromSmarts('*-&!@*')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_rot_bonds_posn(mol):
|
|
29
|
+
"""Return atom-index pairs for all rotatable bonds in mol."""
|
|
30
|
+
return mol.GetSubstructMatches(_ROTATABLE_BOND_SMARTS)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_bond_indices(mol, rot):
|
|
34
|
+
"""Convert rotatable bond atom pairs to bond indices."""
|
|
35
|
+
return [mol.GetBondBetweenAtoms(r[0], r[1]).GetIdx() for r in rot]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _obtain_rings(smi):
|
|
39
|
+
"""Return a list of ring SMILES fragments from the input molecule.
|
|
40
|
+
|
|
41
|
+
Fragments the molecule on rotatable bonds and retains the pieces that
|
|
42
|
+
contain ring closures. Returns ``(None, None)`` for molecules with no
|
|
43
|
+
rotatable bonds (e.g. purely cyclic structures).
|
|
44
|
+
"""
|
|
45
|
+
mol = Chem.MolFromSmiles(smi)
|
|
46
|
+
rot = _get_rot_bonds_posn(mol)
|
|
47
|
+
if len(rot) == 0:
|
|
48
|
+
return None, None
|
|
49
|
+
bond_idx = _get_bond_indices(mol, rot)
|
|
50
|
+
new_mol = Chem.FragmentOnBonds(mol, bond_idx, addDummies=False)
|
|
51
|
+
new_smi = Chem.MolToSmiles(new_mol)
|
|
52
|
+
return [s for s in new_smi.split('.') if '1' in s and Chem.MolFromSmiles(s) is not None]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _count_atoms(mol, atomic_num):
|
|
56
|
+
"""Count atoms of a given atomic number in mol."""
|
|
57
|
+
pat = Chem.MolFromSmarts(f'[#{atomic_num}]')
|
|
58
|
+
return len(mol.GetSubstructMatches(pat))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _get_num_bond_types(mol):
|
|
62
|
+
"""Return [single, double, triple, aromatic] bond counts as fractions of total."""
|
|
63
|
+
from rdkit.Chem import rdchem
|
|
64
|
+
counts = {
|
|
65
|
+
rdchem.BondType.SINGLE: 0,
|
|
66
|
+
rdchem.BondType.DOUBLE: 0,
|
|
67
|
+
rdchem.BondType.TRIPLE: 0,
|
|
68
|
+
rdchem.BondType.AROMATIC: 0,
|
|
69
|
+
}
|
|
70
|
+
total = 0
|
|
71
|
+
for bond in mol.GetBonds():
|
|
72
|
+
total += 1
|
|
73
|
+
bt = bond.GetBondType()
|
|
74
|
+
if bt in counts:
|
|
75
|
+
counts[bt] += 1
|
|
76
|
+
if total == 0:
|
|
77
|
+
return [0.0, 0.0, 0.0, 0.0]
|
|
78
|
+
return [counts[t] / total for t in [
|
|
79
|
+
rdchem.BondType.SINGLE, rdchem.BondType.DOUBLE,
|
|
80
|
+
rdchem.BondType.TRIPLE, rdchem.BondType.AROMATIC,
|
|
81
|
+
]]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _count_conseq_double(mol):
|
|
85
|
+
"""Count consecutive double bonds in mol."""
|
|
86
|
+
from rdkit.Chem import rdchem
|
|
87
|
+
prev = None
|
|
88
|
+
count = 0
|
|
89
|
+
for bond in mol.GetBonds():
|
|
90
|
+
curr = bond.GetBondType()
|
|
91
|
+
if prev == curr == rdchem.BondType.DOUBLE:
|
|
92
|
+
count += 1
|
|
93
|
+
prev = curr
|
|
94
|
+
return count
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _size_ring_counter(ring_ls):
|
|
98
|
+
"""Return a 19-element vector: [consecutive doubles in rings, ring counts by size 3–20].
|
|
99
|
+
|
|
100
|
+
Returns all zeros when ``ring_ls`` is ``(None, None)`` (no rotatable bonds).
|
|
101
|
+
"""
|
|
102
|
+
if ring_ls == (None, None):
|
|
103
|
+
return [0] * 19
|
|
104
|
+
ring_mols = [Chem.MolFromSmiles(s) for s in ring_ls]
|
|
105
|
+
conseq = sum(_count_conseq_double(m) for m in ring_mols)
|
|
106
|
+
size_counts = [
|
|
107
|
+
sum(1 for m in ring_mols if m.GetNumAtoms() == sz)
|
|
108
|
+
for sz in range(3, 21)
|
|
109
|
+
]
|
|
110
|
+
return [conseq] + size_counts
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_mol_info(smi):
|
|
114
|
+
"""Compute a 51-dimensional molecular feature vector for classifier training.
|
|
115
|
+
|
|
116
|
+
Features (in order):
|
|
117
|
+
- 8 atom-count ratios relative to carbon (atoms, H, N, S, O, Cl, Br, F)
|
|
118
|
+
- 17 RDKit descriptor values
|
|
119
|
+
- 4 bond-type fractions (single, double, triple, aromatic)
|
|
120
|
+
- 2 ring summary features (ring count, triple bonds in rings)
|
|
121
|
+
- 19 ring-size histogram features
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
smi : str
|
|
126
|
+
Valid SMILES string.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
numpy.ndarray of shape (51,)
|
|
131
|
+
"""
|
|
132
|
+
mol = Chem.MolFromSmiles(smi)
|
|
133
|
+
|
|
134
|
+
num_atoms = mol.GetNumAtoms()
|
|
135
|
+
num_hydro = Chem.AddHs(mol).GetNumAtoms() - num_atoms
|
|
136
|
+
num_carbon = _count_atoms(mol, 6) or 0.0001 # avoid division by zero
|
|
137
|
+
|
|
138
|
+
basic_props = [
|
|
139
|
+
num_atoms / num_carbon,
|
|
140
|
+
num_hydro / num_carbon,
|
|
141
|
+
_count_atoms(mol, 7) / num_carbon, # N
|
|
142
|
+
_count_atoms(mol, 16) / num_carbon, # S
|
|
143
|
+
_count_atoms(mol, 8) / num_carbon, # O
|
|
144
|
+
_count_atoms(mol, 17) / num_carbon, # Cl
|
|
145
|
+
_count_atoms(mol, 35) / num_carbon, # Br
|
|
146
|
+
_count_atoms(mol, 9) / num_carbon, # F
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
# 17 RDKit descriptors — selected by name from Descriptors module
|
|
150
|
+
calc_props = OrderedDict(inspect.getmembers(Descriptors, inspect.isfunction))
|
|
151
|
+
for key in list(calc_props.keys()):
|
|
152
|
+
if key.startswith('_') or key not in _DESCRIPTOR_NAMES:
|
|
153
|
+
del calc_props[key]
|
|
154
|
+
rdkit_features = []
|
|
155
|
+
for key, fn in calc_props.items():
|
|
156
|
+
try:
|
|
157
|
+
rdkit_features.append(fn(mol))
|
|
158
|
+
except Exception:
|
|
159
|
+
rdkit_features.append(0.0)
|
|
160
|
+
|
|
161
|
+
bond_info = _get_num_bond_types(mol)
|
|
162
|
+
|
|
163
|
+
ring_ls = _obtain_rings(smi)
|
|
164
|
+
num_triple_in_rings = 0
|
|
165
|
+
if ring_ls and ring_ls != (None, None) and len(ring_ls) > 0:
|
|
166
|
+
for item in ring_ls:
|
|
167
|
+
num_triple_in_rings += item.count('#')
|
|
168
|
+
bond_info.append(len(ring_ls))
|
|
169
|
+
else:
|
|
170
|
+
bond_info.append(0)
|
|
171
|
+
bond_info.append(num_triple_in_rings)
|
|
172
|
+
bond_info += _size_ring_counter(ring_ls)
|
|
173
|
+
bond_info.append(_count_conseq_double(mol))
|
|
174
|
+
|
|
175
|
+
return np.array(rdkit_features + basic_props + bond_info)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Neural network classifier for guided molecular exploration.
|
|
2
|
+
|
|
3
|
+
Trains a small MLP on previously evaluated (SMILES, fitness) pairs and uses it
|
|
4
|
+
to predict which newly generated molecules are likely to score highly. This
|
|
5
|
+
biases the exploration phase toward promising chemical space without requiring
|
|
6
|
+
a full fitness evaluation for every candidate.
|
|
7
|
+
|
|
8
|
+
Requires PyTorch. If PyTorch is not installed, import of this module will
|
|
9
|
+
raise an ImportError; the optimizer falls back to random sampling in that case.
|
|
10
|
+
"""
|
|
11
|
+
import copy
|
|
12
|
+
import multiprocessing
|
|
13
|
+
from typing import List
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
19
|
+
|
|
20
|
+
from tyche_tools._features import get_mol_info
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ── Feature extraction ─────────────────────────────────────────────────────────
|
|
24
|
+
|
|
25
|
+
def _get_mol_feature(smi: str) -> np.ndarray:
|
|
26
|
+
return np.array(get_mol_info(smi))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def obtain_features(smi_list: List[str], num_workers: int = 1) -> np.ndarray:
|
|
30
|
+
"""Compute the 51-dim feature matrix for a list of SMILES.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
smi_list : list of str
|
|
35
|
+
num_workers : int
|
|
36
|
+
Number of parallel worker processes.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
numpy.ndarray of shape (N, 51)
|
|
41
|
+
"""
|
|
42
|
+
if num_workers == 1:
|
|
43
|
+
return np.array([_get_mol_feature(s) for s in smi_list])
|
|
44
|
+
with multiprocessing.Pool(num_workers) as pool:
|
|
45
|
+
return np.array(pool.map(_get_mol_feature, smi_list))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ── Model architecture ─────────────────────────────────────────────────────────
|
|
49
|
+
|
|
50
|
+
class MLP(nn.Module):
|
|
51
|
+
"""Multi-layer perceptron with sigmoid activations throughout.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
h_sizes : list of int
|
|
56
|
+
Hidden layer widths.
|
|
57
|
+
n_input : int
|
|
58
|
+
Input dimensionality (51 for the default molecular features).
|
|
59
|
+
n_output : int
|
|
60
|
+
Output dimensionality (1 for binary classification).
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, h_sizes: List[int], n_input: int, n_output: int):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.hidden = nn.ModuleList([nn.Linear(n_input, h_sizes[0])])
|
|
66
|
+
for i in range(len(h_sizes) - 1):
|
|
67
|
+
self.hidden.append(nn.Linear(h_sizes[i], h_sizes[i + 1]))
|
|
68
|
+
self.predict = nn.Linear(h_sizes[-1], n_output)
|
|
69
|
+
|
|
70
|
+
def forward(self, x):
|
|
71
|
+
for layer in self.hidden:
|
|
72
|
+
x = torch.sigmoid(layer(x))
|
|
73
|
+
return torch.sigmoid(self.predict(x))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# ── Training utilities ─────────────────────────────────────────────────────────
|
|
77
|
+
|
|
78
|
+
class _EarlyStopping:
|
|
79
|
+
"""Monitor validation loss and restore best weights when improvement stalls."""
|
|
80
|
+
|
|
81
|
+
def __init__(self, patience: int = 500, min_delta: float = 1e-7):
|
|
82
|
+
self.patience = patience
|
|
83
|
+
self.min_delta = min_delta
|
|
84
|
+
self.best_val = np.inf
|
|
85
|
+
self.best_weights = None
|
|
86
|
+
self.best_epoch = 0
|
|
87
|
+
self.checkpoint = 0
|
|
88
|
+
|
|
89
|
+
def step(self, net, epoch: int, val_loss: float) -> bool:
|
|
90
|
+
"""Update state. Returns True when training should stop."""
|
|
91
|
+
if val_loss + self.min_delta < self.best_val:
|
|
92
|
+
self.best_val = val_loss
|
|
93
|
+
self.best_weights = copy.deepcopy(net.state_dict())
|
|
94
|
+
self.best_epoch = epoch
|
|
95
|
+
self.checkpoint = 0
|
|
96
|
+
else:
|
|
97
|
+
self.checkpoint += 1
|
|
98
|
+
return self.checkpoint > self.patience
|
|
99
|
+
|
|
100
|
+
def restore_best(self, net) -> nn.Module:
|
|
101
|
+
print(f' Early stopping at epoch {self.best_epoch}, val loss {self.best_val:.6f}')
|
|
102
|
+
net.load_state_dict(self.best_weights)
|
|
103
|
+
return net
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _get_device(use_gpu: bool) -> str:
|
|
107
|
+
if use_gpu and torch.cuda.is_available():
|
|
108
|
+
return 'cuda'
|
|
109
|
+
if use_gpu:
|
|
110
|
+
print('No GPU available, defaulting to CPU.')
|
|
111
|
+
return 'cpu'
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _train_valid_split(data_x, data_y, train_ratio=0.8, seed=30624700):
|
|
115
|
+
"""Deterministic 80/20 train-validation split."""
|
|
116
|
+
n = data_x.shape[0]
|
|
117
|
+
train_n = int(np.floor(n * train_ratio))
|
|
118
|
+
idx = np.random.RandomState(seed=seed).permutation(n)
|
|
119
|
+
return (
|
|
120
|
+
data_x[idx[:train_n]], data_y[idx[:train_n]],
|
|
121
|
+
data_x[idx[train_n:]], data_y[idx[train_n:]],
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _do_training(data_x, data_y, net, optimizer, loss_fn, steps=20000, batch_size=1024, device='cpu'):
|
|
126
|
+
"""Train net for up to steps epochs with early stopping on validation loss."""
|
|
127
|
+
train_x, train_y, valid_x, valid_y = _train_valid_split(data_x, data_y)
|
|
128
|
+
train_x = torch.tensor(train_x, device=device, dtype=torch.float)
|
|
129
|
+
train_y = torch.tensor(train_y, device=device, dtype=torch.float)
|
|
130
|
+
valid_x = torch.tensor(valid_x, device=device, dtype=torch.float)
|
|
131
|
+
valid_y = torch.tensor(valid_y, device=device, dtype=torch.float)
|
|
132
|
+
|
|
133
|
+
loader = DataLoader(TensorDataset(train_x, train_y), batch_size=batch_size, shuffle=True)
|
|
134
|
+
valid_loader = DataLoader(TensorDataset(valid_x, valid_y), batch_size=batch_size)
|
|
135
|
+
early_stop = _EarlyStopping(patience=500, min_delta=1e-7)
|
|
136
|
+
|
|
137
|
+
net.train()
|
|
138
|
+
for epoch in range(steps):
|
|
139
|
+
for x, y in loader:
|
|
140
|
+
pred = net(x)
|
|
141
|
+
loss = loss_fn(pred, y)
|
|
142
|
+
optimizer.zero_grad()
|
|
143
|
+
loss.backward()
|
|
144
|
+
optimizer.step()
|
|
145
|
+
|
|
146
|
+
val_loss = 0.0
|
|
147
|
+
net.eval()
|
|
148
|
+
with torch.no_grad():
|
|
149
|
+
for x, y in valid_loader:
|
|
150
|
+
val_loss += loss_fn(net(x), y).item()
|
|
151
|
+
val_loss /= len(valid_loader)
|
|
152
|
+
net.train()
|
|
153
|
+
|
|
154
|
+
if epoch % 1000 == 0:
|
|
155
|
+
print(f' Epoch {epoch}: train loss {loss.item():.6f}, val loss {val_loss:.6f}')
|
|
156
|
+
|
|
157
|
+
if early_stop.step(net, epoch, val_loss):
|
|
158
|
+
net = early_stop.restore_best(net)
|
|
159
|
+
break
|
|
160
|
+
|
|
161
|
+
return net
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# ── Public API ─────────────────────────────────────────────────────────────────
|
|
165
|
+
|
|
166
|
+
def create_and_train_network(
|
|
167
|
+
smi_list: List[str],
|
|
168
|
+
targets: List[float],
|
|
169
|
+
n_hidden: List[int] = None,
|
|
170
|
+
use_gpu: bool = True,
|
|
171
|
+
num_workers: int = 1,
|
|
172
|
+
) -> MLP:
|
|
173
|
+
"""Featurize SMILES, build a binary MLP, and train it.
|
|
174
|
+
|
|
175
|
+
Labels are 1 for molecules at or above the 80th fitness percentile, 0 otherwise.
|
|
176
|
+
The trained model predicts which unseen molecules are likely to score highly.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
smi_list : list of str
|
|
181
|
+
SMILES of all previously evaluated molecules.
|
|
182
|
+
targets : list of float
|
|
183
|
+
Fitness values corresponding to each SMILES in smi_list.
|
|
184
|
+
n_hidden : list of int, default [100, 10]
|
|
185
|
+
Hidden layer widths of the MLP.
|
|
186
|
+
use_gpu : bool
|
|
187
|
+
Use CUDA if available.
|
|
188
|
+
num_workers : int
|
|
189
|
+
Parallel workers for feature extraction.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
MLP
|
|
194
|
+
Trained PyTorch model.
|
|
195
|
+
"""
|
|
196
|
+
if n_hidden is None:
|
|
197
|
+
n_hidden = [100, 10]
|
|
198
|
+
|
|
199
|
+
dataset_x = obtain_features(smi_list, num_workers=num_workers)
|
|
200
|
+
threshold = np.percentile(targets, 80)
|
|
201
|
+
dataset_y = np.expand_dims(
|
|
202
|
+
[1.0 if t >= threshold else 0.0 for t in targets], axis=-1
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
device = _get_device(use_gpu)
|
|
206
|
+
net = MLP(n_hidden, dataset_x.shape[-1], dataset_y.shape[-1]).to(device)
|
|
207
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
|
|
208
|
+
loss_fn = nn.BCELoss()
|
|
209
|
+
|
|
210
|
+
net = _do_training(
|
|
211
|
+
dataset_x, dataset_y, net, optimizer, loss_fn,
|
|
212
|
+
steps=20000, batch_size=1024, device=device,
|
|
213
|
+
)
|
|
214
|
+
return net
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def obtain_model_pred(
|
|
218
|
+
smi_list: List[str],
|
|
219
|
+
net: MLP,
|
|
220
|
+
use_gpu: bool = True,
|
|
221
|
+
num_workers: int = 1,
|
|
222
|
+
batch_size: int = 1024,
|
|
223
|
+
) -> np.ndarray:
|
|
224
|
+
"""Return classifier predictions for a list of SMILES.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
smi_list : list of str
|
|
229
|
+
net : MLP
|
|
230
|
+
Trained model from ``create_and_train_network``.
|
|
231
|
+
use_gpu : bool
|
|
232
|
+
num_workers : int
|
|
233
|
+
batch_size : int
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
numpy.ndarray of shape (N, 1)
|
|
238
|
+
Predicted probability of belonging to the high-fitness class.
|
|
239
|
+
"""
|
|
240
|
+
device = _get_device(use_gpu)
|
|
241
|
+
data_x = obtain_features(smi_list, num_workers=num_workers)
|
|
242
|
+
data_x = torch.tensor(data_x, device=device, dtype=torch.float)
|
|
243
|
+
loader = DataLoader(TensorDataset(data_x), batch_size=batch_size)
|
|
244
|
+
|
|
245
|
+
net.eval()
|
|
246
|
+
predictions = []
|
|
247
|
+
with torch.no_grad():
|
|
248
|
+
for (x,) in loader:
|
|
249
|
+
predictions.append(net(x).detach().cpu().numpy())
|
|
250
|
+
return np.concatenate(predictions, axis=0)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from rdkit import RDLogger
|
|
2
|
+
from rdkit.Chem import MolFromSmiles as smi2mol
|
|
3
|
+
from rdkit.Chem import MolToSmiles as mol2smi
|
|
4
|
+
|
|
5
|
+
RDLogger.DisableLog('rdApp.*')
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_selfie_chars(selfie):
|
|
9
|
+
"""Split a SELFIES string into a list of its tokens.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
selfie : str
|
|
14
|
+
A valid SELFIES string.
|
|
15
|
+
|
|
16
|
+
Returns
|
|
17
|
+
-------
|
|
18
|
+
list of str
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
>>> get_selfie_chars('[C][=C][C][=C][C][=C][Ring1][Branch1_1]')
|
|
23
|
+
['[C]', '[=C]', '[C]', '[=C]', '[C]', '[=C]', '[Ring1]', '[Branch1_1]']
|
|
24
|
+
"""
|
|
25
|
+
chars = []
|
|
26
|
+
while selfie:
|
|
27
|
+
chars.append(selfie[selfie.find('['):selfie.find(']') + 1])
|
|
28
|
+
selfie = selfie[selfie.find(']') + 1:]
|
|
29
|
+
return chars
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def sanitize_smiles(smi):
|
|
33
|
+
"""Return a canonical SMILES representation of the input string.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
smi : str
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
mol : rdkit.Chem.rdchem.Mol or None
|
|
42
|
+
smi_canon : str or None
|
|
43
|
+
Canonical, non-isomeric SMILES string.
|
|
44
|
+
success : bool
|
|
45
|
+
"""
|
|
46
|
+
try:
|
|
47
|
+
mol = smi2mol(smi, sanitize=True)
|
|
48
|
+
smi_canon = mol2smi(mol, isomericSmiles=False, canonical=True)
|
|
49
|
+
return mol, smi_canon, True
|
|
50
|
+
except Exception:
|
|
51
|
+
return None, None, False
|