tyche-tools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ from tyche_tools.median import get_median_mols
2
+ from tyche_tools.subspace import get_local_chemical_subspace
3
+ from tyche_tools.optimize import optimize_molecules
4
+
5
+ __all__ = [
6
+ "get_median_mols",
7
+ "get_local_chemical_subspace",
8
+ "optimize_molecules",
9
+ ]
@@ -0,0 +1,175 @@
1
+ """Molecular feature extraction for the optimizer's neural network classifier.
2
+
3
+ Computes a 51-dimensional property vector for a given SMILES string. Features
4
+ include atom-count ratios, RDKit descriptors, bond type ratios, and ring
5
+ statistics.
6
+ """
7
+ import inspect
8
+ from collections import OrderedDict
9
+
10
+ import numpy as np
11
+ from rdkit import Chem, RDLogger
12
+ from rdkit.Chem import Descriptors
13
+
14
+ RDLogger.DisableLog('rdApp.*')
15
+
16
+ _DESCRIPTOR_NAMES = [
17
+ "RingCount", "HallKierAlpha", "BalabanJ",
18
+ "NumAliphaticCarbocycles", "NumAliphaticHeterocycles", "NumAliphaticRings",
19
+ "NumAromaticCarbocycles", "NumAromaticHeterocycles", "NumAromaticRings",
20
+ "NumHAcceptors", "NumHDonors", "NumHeteroatoms",
21
+ "NumRadicalElectrons", "NumSaturatedCarbocycles", "NumSaturatedHeterocycles",
22
+ "NumSaturatedRings", "NumValenceElectrons",
23
+ ]
24
+
25
+ _ROTATABLE_BOND_SMARTS = Chem.MolFromSmarts('*-&!@*')
26
+
27
+
28
+ def _get_rot_bonds_posn(mol):
29
+ """Return atom-index pairs for all rotatable bonds in mol."""
30
+ return mol.GetSubstructMatches(_ROTATABLE_BOND_SMARTS)
31
+
32
+
33
+ def _get_bond_indices(mol, rot):
34
+ """Convert rotatable bond atom pairs to bond indices."""
35
+ return [mol.GetBondBetweenAtoms(r[0], r[1]).GetIdx() for r in rot]
36
+
37
+
38
+ def _obtain_rings(smi):
39
+ """Return a list of ring SMILES fragments from the input molecule.
40
+
41
+ Fragments the molecule on rotatable bonds and retains the pieces that
42
+ contain ring closures. Returns ``(None, None)`` for molecules with no
43
+ rotatable bonds (e.g. purely cyclic structures).
44
+ """
45
+ mol = Chem.MolFromSmiles(smi)
46
+ rot = _get_rot_bonds_posn(mol)
47
+ if len(rot) == 0:
48
+ return None, None
49
+ bond_idx = _get_bond_indices(mol, rot)
50
+ new_mol = Chem.FragmentOnBonds(mol, bond_idx, addDummies=False)
51
+ new_smi = Chem.MolToSmiles(new_mol)
52
+ return [s for s in new_smi.split('.') if '1' in s and Chem.MolFromSmiles(s) is not None]
53
+
54
+
55
+ def _count_atoms(mol, atomic_num):
56
+ """Count atoms of a given atomic number in mol."""
57
+ pat = Chem.MolFromSmarts(f'[#{atomic_num}]')
58
+ return len(mol.GetSubstructMatches(pat))
59
+
60
+
61
+ def _get_num_bond_types(mol):
62
+ """Return [single, double, triple, aromatic] bond counts as fractions of total."""
63
+ from rdkit.Chem import rdchem
64
+ counts = {
65
+ rdchem.BondType.SINGLE: 0,
66
+ rdchem.BondType.DOUBLE: 0,
67
+ rdchem.BondType.TRIPLE: 0,
68
+ rdchem.BondType.AROMATIC: 0,
69
+ }
70
+ total = 0
71
+ for bond in mol.GetBonds():
72
+ total += 1
73
+ bt = bond.GetBondType()
74
+ if bt in counts:
75
+ counts[bt] += 1
76
+ if total == 0:
77
+ return [0.0, 0.0, 0.0, 0.0]
78
+ return [counts[t] / total for t in [
79
+ rdchem.BondType.SINGLE, rdchem.BondType.DOUBLE,
80
+ rdchem.BondType.TRIPLE, rdchem.BondType.AROMATIC,
81
+ ]]
82
+
83
+
84
+ def _count_conseq_double(mol):
85
+ """Count consecutive double bonds in mol."""
86
+ from rdkit.Chem import rdchem
87
+ prev = None
88
+ count = 0
89
+ for bond in mol.GetBonds():
90
+ curr = bond.GetBondType()
91
+ if prev == curr == rdchem.BondType.DOUBLE:
92
+ count += 1
93
+ prev = curr
94
+ return count
95
+
96
+
97
+ def _size_ring_counter(ring_ls):
98
+ """Return a 19-element vector: [consecutive doubles in rings, ring counts by size 3–20].
99
+
100
+ Returns all zeros when ``ring_ls`` is ``(None, None)`` (no rotatable bonds).
101
+ """
102
+ if ring_ls == (None, None):
103
+ return [0] * 19
104
+ ring_mols = [Chem.MolFromSmiles(s) for s in ring_ls]
105
+ conseq = sum(_count_conseq_double(m) for m in ring_mols)
106
+ size_counts = [
107
+ sum(1 for m in ring_mols if m.GetNumAtoms() == sz)
108
+ for sz in range(3, 21)
109
+ ]
110
+ return [conseq] + size_counts
111
+
112
+
113
+ def get_mol_info(smi):
114
+ """Compute a 51-dimensional molecular feature vector for classifier training.
115
+
116
+ Features (in order):
117
+ - 8 atom-count ratios relative to carbon (atoms, H, N, S, O, Cl, Br, F)
118
+ - 17 RDKit descriptor values
119
+ - 4 bond-type fractions (single, double, triple, aromatic)
120
+ - 2 ring summary features (ring count, triple bonds in rings)
121
+ - 19 ring-size histogram features
122
+
123
+ Parameters
124
+ ----------
125
+ smi : str
126
+ Valid SMILES string.
127
+
128
+ Returns
129
+ -------
130
+ numpy.ndarray of shape (51,)
131
+ """
132
+ mol = Chem.MolFromSmiles(smi)
133
+
134
+ num_atoms = mol.GetNumAtoms()
135
+ num_hydro = Chem.AddHs(mol).GetNumAtoms() - num_atoms
136
+ num_carbon = _count_atoms(mol, 6) or 0.0001 # avoid division by zero
137
+
138
+ basic_props = [
139
+ num_atoms / num_carbon,
140
+ num_hydro / num_carbon,
141
+ _count_atoms(mol, 7) / num_carbon, # N
142
+ _count_atoms(mol, 16) / num_carbon, # S
143
+ _count_atoms(mol, 8) / num_carbon, # O
144
+ _count_atoms(mol, 17) / num_carbon, # Cl
145
+ _count_atoms(mol, 35) / num_carbon, # Br
146
+ _count_atoms(mol, 9) / num_carbon, # F
147
+ ]
148
+
149
+ # 17 RDKit descriptors — selected by name from Descriptors module
150
+ calc_props = OrderedDict(inspect.getmembers(Descriptors, inspect.isfunction))
151
+ for key in list(calc_props.keys()):
152
+ if key.startswith('_') or key not in _DESCRIPTOR_NAMES:
153
+ del calc_props[key]
154
+ rdkit_features = []
155
+ for key, fn in calc_props.items():
156
+ try:
157
+ rdkit_features.append(fn(mol))
158
+ except Exception:
159
+ rdkit_features.append(0.0)
160
+
161
+ bond_info = _get_num_bond_types(mol)
162
+
163
+ ring_ls = _obtain_rings(smi)
164
+ num_triple_in_rings = 0
165
+ if ring_ls and ring_ls != (None, None) and len(ring_ls) > 0:
166
+ for item in ring_ls:
167
+ num_triple_in_rings += item.count('#')
168
+ bond_info.append(len(ring_ls))
169
+ else:
170
+ bond_info.append(0)
171
+ bond_info.append(num_triple_in_rings)
172
+ bond_info += _size_ring_counter(ring_ls)
173
+ bond_info.append(_count_conseq_double(mol))
174
+
175
+ return np.array(rdkit_features + basic_props + bond_info)
@@ -0,0 +1,250 @@
1
+ """Neural network classifier for guided molecular exploration.
2
+
3
+ Trains a small MLP on previously evaluated (SMILES, fitness) pairs and uses it
4
+ to predict which newly generated molecules are likely to score highly. This
5
+ biases the exploration phase toward promising chemical space without requiring
6
+ a full fitness evaluation for every candidate.
7
+
8
+ Requires PyTorch. If PyTorch is not installed, import of this module will
9
+ raise an ImportError; the optimizer falls back to random sampling in that case.
10
+ """
11
+ import copy
12
+ import multiprocessing
13
+ from typing import List
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from tyche_tools._features import get_mol_info
21
+
22
+
23
+ # ── Feature extraction ─────────────────────────────────────────────────────────
24
+
25
+ def _get_mol_feature(smi: str) -> np.ndarray:
26
+ return np.array(get_mol_info(smi))
27
+
28
+
29
+ def obtain_features(smi_list: List[str], num_workers: int = 1) -> np.ndarray:
30
+ """Compute the 51-dim feature matrix for a list of SMILES.
31
+
32
+ Parameters
33
+ ----------
34
+ smi_list : list of str
35
+ num_workers : int
36
+ Number of parallel worker processes.
37
+
38
+ Returns
39
+ -------
40
+ numpy.ndarray of shape (N, 51)
41
+ """
42
+ if num_workers == 1:
43
+ return np.array([_get_mol_feature(s) for s in smi_list])
44
+ with multiprocessing.Pool(num_workers) as pool:
45
+ return np.array(pool.map(_get_mol_feature, smi_list))
46
+
47
+
48
+ # ── Model architecture ─────────────────────────────────────────────────────────
49
+
50
+ class MLP(nn.Module):
51
+ """Multi-layer perceptron with sigmoid activations throughout.
52
+
53
+ Parameters
54
+ ----------
55
+ h_sizes : list of int
56
+ Hidden layer widths.
57
+ n_input : int
58
+ Input dimensionality (51 for the default molecular features).
59
+ n_output : int
60
+ Output dimensionality (1 for binary classification).
61
+ """
62
+
63
+ def __init__(self, h_sizes: List[int], n_input: int, n_output: int):
64
+ super().__init__()
65
+ self.hidden = nn.ModuleList([nn.Linear(n_input, h_sizes[0])])
66
+ for i in range(len(h_sizes) - 1):
67
+ self.hidden.append(nn.Linear(h_sizes[i], h_sizes[i + 1]))
68
+ self.predict = nn.Linear(h_sizes[-1], n_output)
69
+
70
+ def forward(self, x):
71
+ for layer in self.hidden:
72
+ x = torch.sigmoid(layer(x))
73
+ return torch.sigmoid(self.predict(x))
74
+
75
+
76
+ # ── Training utilities ─────────────────────────────────────────────────────────
77
+
78
+ class _EarlyStopping:
79
+ """Monitor validation loss and restore best weights when improvement stalls."""
80
+
81
+ def __init__(self, patience: int = 500, min_delta: float = 1e-7):
82
+ self.patience = patience
83
+ self.min_delta = min_delta
84
+ self.best_val = np.inf
85
+ self.best_weights = None
86
+ self.best_epoch = 0
87
+ self.checkpoint = 0
88
+
89
+ def step(self, net, epoch: int, val_loss: float) -> bool:
90
+ """Update state. Returns True when training should stop."""
91
+ if val_loss + self.min_delta < self.best_val:
92
+ self.best_val = val_loss
93
+ self.best_weights = copy.deepcopy(net.state_dict())
94
+ self.best_epoch = epoch
95
+ self.checkpoint = 0
96
+ else:
97
+ self.checkpoint += 1
98
+ return self.checkpoint > self.patience
99
+
100
+ def restore_best(self, net) -> nn.Module:
101
+ print(f' Early stopping at epoch {self.best_epoch}, val loss {self.best_val:.6f}')
102
+ net.load_state_dict(self.best_weights)
103
+ return net
104
+
105
+
106
+ def _get_device(use_gpu: bool) -> str:
107
+ if use_gpu and torch.cuda.is_available():
108
+ return 'cuda'
109
+ if use_gpu:
110
+ print('No GPU available, defaulting to CPU.')
111
+ return 'cpu'
112
+
113
+
114
+ def _train_valid_split(data_x, data_y, train_ratio=0.8, seed=30624700):
115
+ """Deterministic 80/20 train-validation split."""
116
+ n = data_x.shape[0]
117
+ train_n = int(np.floor(n * train_ratio))
118
+ idx = np.random.RandomState(seed=seed).permutation(n)
119
+ return (
120
+ data_x[idx[:train_n]], data_y[idx[:train_n]],
121
+ data_x[idx[train_n:]], data_y[idx[train_n:]],
122
+ )
123
+
124
+
125
+ def _do_training(data_x, data_y, net, optimizer, loss_fn, steps=20000, batch_size=1024, device='cpu'):
126
+ """Train net for up to steps epochs with early stopping on validation loss."""
127
+ train_x, train_y, valid_x, valid_y = _train_valid_split(data_x, data_y)
128
+ train_x = torch.tensor(train_x, device=device, dtype=torch.float)
129
+ train_y = torch.tensor(train_y, device=device, dtype=torch.float)
130
+ valid_x = torch.tensor(valid_x, device=device, dtype=torch.float)
131
+ valid_y = torch.tensor(valid_y, device=device, dtype=torch.float)
132
+
133
+ loader = DataLoader(TensorDataset(train_x, train_y), batch_size=batch_size, shuffle=True)
134
+ valid_loader = DataLoader(TensorDataset(valid_x, valid_y), batch_size=batch_size)
135
+ early_stop = _EarlyStopping(patience=500, min_delta=1e-7)
136
+
137
+ net.train()
138
+ for epoch in range(steps):
139
+ for x, y in loader:
140
+ pred = net(x)
141
+ loss = loss_fn(pred, y)
142
+ optimizer.zero_grad()
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ val_loss = 0.0
147
+ net.eval()
148
+ with torch.no_grad():
149
+ for x, y in valid_loader:
150
+ val_loss += loss_fn(net(x), y).item()
151
+ val_loss /= len(valid_loader)
152
+ net.train()
153
+
154
+ if epoch % 1000 == 0:
155
+ print(f' Epoch {epoch}: train loss {loss.item():.6f}, val loss {val_loss:.6f}')
156
+
157
+ if early_stop.step(net, epoch, val_loss):
158
+ net = early_stop.restore_best(net)
159
+ break
160
+
161
+ return net
162
+
163
+
164
+ # ── Public API ─────────────────────────────────────────────────────────────────
165
+
166
+ def create_and_train_network(
167
+ smi_list: List[str],
168
+ targets: List[float],
169
+ n_hidden: List[int] = None,
170
+ use_gpu: bool = True,
171
+ num_workers: int = 1,
172
+ ) -> MLP:
173
+ """Featurize SMILES, build a binary MLP, and train it.
174
+
175
+ Labels are 1 for molecules at or above the 80th fitness percentile, 0 otherwise.
176
+ The trained model predicts which unseen molecules are likely to score highly.
177
+
178
+ Parameters
179
+ ----------
180
+ smi_list : list of str
181
+ SMILES of all previously evaluated molecules.
182
+ targets : list of float
183
+ Fitness values corresponding to each SMILES in smi_list.
184
+ n_hidden : list of int, default [100, 10]
185
+ Hidden layer widths of the MLP.
186
+ use_gpu : bool
187
+ Use CUDA if available.
188
+ num_workers : int
189
+ Parallel workers for feature extraction.
190
+
191
+ Returns
192
+ -------
193
+ MLP
194
+ Trained PyTorch model.
195
+ """
196
+ if n_hidden is None:
197
+ n_hidden = [100, 10]
198
+
199
+ dataset_x = obtain_features(smi_list, num_workers=num_workers)
200
+ threshold = np.percentile(targets, 80)
201
+ dataset_y = np.expand_dims(
202
+ [1.0 if t >= threshold else 0.0 for t in targets], axis=-1
203
+ )
204
+
205
+ device = _get_device(use_gpu)
206
+ net = MLP(n_hidden, dataset_x.shape[-1], dataset_y.shape[-1]).to(device)
207
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
208
+ loss_fn = nn.BCELoss()
209
+
210
+ net = _do_training(
211
+ dataset_x, dataset_y, net, optimizer, loss_fn,
212
+ steps=20000, batch_size=1024, device=device,
213
+ )
214
+ return net
215
+
216
+
217
+ def obtain_model_pred(
218
+ smi_list: List[str],
219
+ net: MLP,
220
+ use_gpu: bool = True,
221
+ num_workers: int = 1,
222
+ batch_size: int = 1024,
223
+ ) -> np.ndarray:
224
+ """Return classifier predictions for a list of SMILES.
225
+
226
+ Parameters
227
+ ----------
228
+ smi_list : list of str
229
+ net : MLP
230
+ Trained model from ``create_and_train_network``.
231
+ use_gpu : bool
232
+ num_workers : int
233
+ batch_size : int
234
+
235
+ Returns
236
+ -------
237
+ numpy.ndarray of shape (N, 1)
238
+ Predicted probability of belonging to the high-fitness class.
239
+ """
240
+ device = _get_device(use_gpu)
241
+ data_x = obtain_features(smi_list, num_workers=num_workers)
242
+ data_x = torch.tensor(data_x, device=device, dtype=torch.float)
243
+ loader = DataLoader(TensorDataset(data_x), batch_size=batch_size)
244
+
245
+ net.eval()
246
+ predictions = []
247
+ with torch.no_grad():
248
+ for (x,) in loader:
249
+ predictions.append(net(x).detach().cpu().numpy())
250
+ return np.concatenate(predictions, axis=0)
tyche_tools/_utils.py ADDED
@@ -0,0 +1,51 @@
1
+ from rdkit import RDLogger
2
+ from rdkit.Chem import MolFromSmiles as smi2mol
3
+ from rdkit.Chem import MolToSmiles as mol2smi
4
+
5
+ RDLogger.DisableLog('rdApp.*')
6
+
7
+
8
+ def get_selfie_chars(selfie):
9
+ """Split a SELFIES string into a list of its tokens.
10
+
11
+ Parameters
12
+ ----------
13
+ selfie : str
14
+ A valid SELFIES string.
15
+
16
+ Returns
17
+ -------
18
+ list of str
19
+
20
+ Examples
21
+ --------
22
+ >>> get_selfie_chars('[C][=C][C][=C][C][=C][Ring1][Branch1_1]')
23
+ ['[C]', '[=C]', '[C]', '[=C]', '[C]', '[=C]', '[Ring1]', '[Branch1_1]']
24
+ """
25
+ chars = []
26
+ while selfie:
27
+ chars.append(selfie[selfie.find('['):selfie.find(']') + 1])
28
+ selfie = selfie[selfie.find(']') + 1:]
29
+ return chars
30
+
31
+
32
+ def sanitize_smiles(smi):
33
+ """Return a canonical SMILES representation of the input string.
34
+
35
+ Parameters
36
+ ----------
37
+ smi : str
38
+
39
+ Returns
40
+ -------
41
+ mol : rdkit.Chem.rdchem.Mol or None
42
+ smi_canon : str or None
43
+ Canonical, non-isomeric SMILES string.
44
+ success : bool
45
+ """
46
+ try:
47
+ mol = smi2mol(smi, sanitize=True)
48
+ smi_canon = mol2smi(mol, isomericSmiles=False, canonical=True)
49
+ return mol, smi_canon, True
50
+ except Exception:
51
+ return None, None, False