chemtsv3 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chemtsv3-0.1.2/PKG-INFO +296 -0
- chemtsv3-0.1.2/README.md +278 -0
- chemtsv3-0.1.2/chemtsv3/__init__.py +0 -0
- chemtsv3-0.1.2/chemtsv3/cli/__init__.py +0 -0
- chemtsv3-0.1.2/chemtsv3/cli/generation.py +66 -0
- chemtsv3-0.1.2/chemtsv3/cli/model_training.py +47 -0
- chemtsv3-0.1.2/chemtsv3/data/__init__.py +0 -0
- chemtsv3-0.1.2/chemtsv3/data/filtering_substruct_oota_cho.csv +125 -0
- chemtsv3-0.1.2/chemtsv3/data/gbgm/p1.p +0 -0
- chemtsv3-0.1.2/chemtsv3/data/gbgm/p_ring.p +0 -0
- chemtsv3-0.1.2/chemtsv3/data/gbgm/r_s1.p +0 -0
- chemtsv3-0.1.2/chemtsv3/data/gbgm/rs_make_ring.p +0 -0
- chemtsv3-0.1.2/chemtsv3/data/gbgm/rs_ring.p +0 -0
- chemtsv3-0.1.2/chemtsv3/data/j_score/SA_scores.txt +10000 -0
- chemtsv3-0.1.2/chemtsv3/data/j_score/cycle_scores.txt +10000 -0
- chemtsv3-0.1.2/chemtsv3/data/j_score/logP_values.txt +10000 -0
- chemtsv3-0.1.2/chemtsv3/data/pubchem_filter/atoms_dict.txt +1132 -0
- chemtsv3-0.1.2/chemtsv3/data/pubchem_filter/bonds_dict.txt +560 -0
- chemtsv3-0.1.2/chemtsv3/data/pubchem_filter/metadata.py +260 -0
- chemtsv3-0.1.2/chemtsv3/filter/__init__.py +35 -0
- chemtsv3-0.1.2/chemtsv3/filter/aromatic_ring_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/atom_count_filter.py +10 -0
- chemtsv3-0.1.2/chemtsv3/filter/attachment_points_filter.py +11 -0
- chemtsv3-0.1.2/chemtsv3/filter/base.py +100 -0
- chemtsv3-0.1.2/chemtsv3/filter/catalog_filter.py +41 -0
- chemtsv3-0.1.2/chemtsv3/filter/charge_filter.py +14 -0
- chemtsv3-0.1.2/chemtsv3/filter/connectivity_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/hba_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/hbd_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/heavy_atom_count_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/known_list_filter.py +46 -0
- chemtsv3-0.1.2/chemtsv3/filter/lipinski_filter.py +41 -0
- chemtsv3-0.1.2/chemtsv3/filter/log_p_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/pains_filter.py +22 -0
- chemtsv3-0.1.2/chemtsv3/filter/pubchem_filter.py +277 -0
- chemtsv3-0.1.2/chemtsv3/filter/radical_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/ring_bond_filter.py +18 -0
- chemtsv3-0.1.2/chemtsv3/filter/ring_size_filter.py +22 -0
- chemtsv3-0.1.2/chemtsv3/filter/roc_filter.py +56 -0
- chemtsv3-0.1.2/chemtsv3/filter/rotatable_bonds_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/sa_score_filter.py +14 -0
- chemtsv3-0.1.2/chemtsv3/filter/substructure_filter.py +54 -0
- chemtsv3-0.1.2/chemtsv3/filter/tpsa_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/validity_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/filter/weight_filter.py +13 -0
- chemtsv3-0.1.2/chemtsv3/generator/__init__.py +4 -0
- chemtsv3-0.1.2/chemtsv3/generator/base.py +589 -0
- chemtsv3-0.1.2/chemtsv3/generator/heapq_generator.py +35 -0
- chemtsv3-0.1.2/chemtsv3/generator/mcts.py +223 -0
- chemtsv3-0.1.2/chemtsv3/generator/random_generator.py +14 -0
- chemtsv3-0.1.2/chemtsv3/language/__init__.py +14 -0
- chemtsv3-0.1.2/chemtsv3/language/base.py +198 -0
- chemtsv3-0.1.2/chemtsv3/language/fasta.py +59 -0
- chemtsv3-0.1.2/chemtsv3/language/helm.py +110 -0
- chemtsv3-0.1.2/chemtsv3/language/selfies.py +21 -0
- chemtsv3-0.1.2/chemtsv3/language/smiles.py +24 -0
- chemtsv3-0.1.2/chemtsv3/language/tokenizer.py +46 -0
- chemtsv3-0.1.2/chemtsv3/node/__init__.py +10 -0
- chemtsv3-0.1.2/chemtsv3/node/base.py +189 -0
- chemtsv3-0.1.2/chemtsv3/node/selfies_string_node.py +19 -0
- chemtsv3-0.1.2/chemtsv3/node/sentence_node.py +75 -0
- chemtsv3-0.1.2/chemtsv3/node/string_node.py +106 -0
- chemtsv3-0.1.2/chemtsv3/policy/__init__.py +9 -0
- chemtsv3-0.1.2/chemtsv3/policy/base.py +109 -0
- chemtsv3-0.1.2/chemtsv3/policy/puct.py +19 -0
- chemtsv3-0.1.2/chemtsv3/policy/puct_with_predictor.py +233 -0
- chemtsv3-0.1.2/chemtsv3/policy/uct.py +67 -0
- chemtsv3-0.1.2/chemtsv3/reward/__init__.py +10 -0
- chemtsv3-0.1.2/chemtsv3/reward/base.py +122 -0
- chemtsv3-0.1.2/chemtsv3/reward/j_score_reward.py +48 -0
- chemtsv3-0.1.2/chemtsv3/reward/log_p_reward.py +15 -0
- chemtsv3-0.1.2/chemtsv3/reward/similarity_reward.py +16 -0
- chemtsv3-0.1.2/chemtsv3/transition/__init__.py +25 -0
- chemtsv3-0.1.2/chemtsv3/transition/base.py +252 -0
- chemtsv3-0.1.2/chemtsv3/transition/biot5.py +18 -0
- chemtsv3-0.1.2/chemtsv3/transition/chat_gpt.py +35 -0
- chemtsv3-0.1.2/chemtsv3/transition/chat_gpt_with_memory.py +60 -0
- chemtsv3-0.1.2/chemtsv3/transition/gbga.py +221 -0
- chemtsv3-0.1.2/chemtsv3/transition/gbgm.py +171 -0
- chemtsv3-0.1.2/chemtsv3/transition/gpt2.py +229 -0
- chemtsv3-0.1.2/chemtsv3/transition/prot_gpt2.py +42 -0
- chemtsv3-0.1.2/chemtsv3/transition/rnn.py +303 -0
- chemtsv3-0.1.2/chemtsv3/transition/rnn_based_mutation.py +41 -0
- chemtsv3-0.1.2/chemtsv3/transition/smirks.py +225 -0
- chemtsv3-0.1.2/chemtsv3/utils/__init__.py +19 -0
- chemtsv3-0.1.2/chemtsv3/utils/file_utils.py +249 -0
- chemtsv3-0.1.2/chemtsv3/utils/helm_utils.py +416 -0
- chemtsv3-0.1.2/chemtsv3/utils/logging_utils.py +69 -0
- chemtsv3-0.1.2/chemtsv3/utils/math_utils.py +194 -0
- chemtsv3-0.1.2/chemtsv3/utils/mol_utils.py +119 -0
- chemtsv3-0.1.2/chemtsv3/utils/plot_utils.py +86 -0
- chemtsv3-0.1.2/chemtsv3/utils/third_party/fpscores.pkl.gz +0 -0
- chemtsv3-0.1.2/chemtsv3/utils/third_party/sascorer.py +192 -0
- chemtsv3-0.1.2/chemtsv3/utils/yaml_utils.py +275 -0
- chemtsv3-0.1.2/pyproject.toml +24 -0
chemtsv3-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: chemtsv3
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: ChemTSv3:
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Satoru Fujii
|
|
7
|
+
Author-email: fujii.sat.rk@yokohama-cu.ac.jp
|
|
8
|
+
Requires-Python: >=3.11
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
15
|
+
Project-URL: Repository, https://github.com/molecule-generator-collection/ChemTSv3
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
## ChemTSv3
|
|
19
|
+
A unified tree search framework for molecular generation.
|
|
20
|
+
- **Node is modular**: Supports any molecular representation (e.g., SMILES, SELFIES, FASTA, or HELM) in either string or tensor format.
|
|
21
|
+
- **Transition is modular**: Allows any molecular transformation strategy, including graph-based editing, sequence generation with RNN or GPT-2, sequence mutation, or LLM-guided modification.
|
|
22
|
+
- **Filter is modular**: Enables flexible constraints such as structural alerts, scaffold preservation, or physicochemical property filters.
|
|
23
|
+
- **Reward is modular**: Anything can be optimized, including QSAR predictions or simulation results, for both single- and multi-objective tasks.
|
|
24
|
+
|
|
25
|
+
## Setup
|
|
26
|
+
|
|
27
|
+
<details>
|
|
28
|
+
<summary><b>Minimal installation (Mac, Linux)</b></summary><br>
|
|
29
|
+
|
|
30
|
+
### Available classes
|
|
31
|
+
- **Transition**: `GBGATransition`, `GPT2Transition`, `RNNBasedMutation`, `RNNTransition`, `SMIRKSTransition`
|
|
32
|
+
- **Reward**: `GFPReward`, `SimilarityReward`, `JScoreReward`, `LogPReward`
|
|
33
|
+
- **Policy**: `UCT`, `PUCT`
|
|
34
|
+
- The corresponding Node classes and all implemented Filter classes are also available in this environment.
|
|
35
|
+
|
|
36
|
+
### Setup steps
|
|
37
|
+
|
|
38
|
+
1. Clone the repository
|
|
39
|
+
2. Install uv: https://docs.astral.sh/uv/getting-started/installation/
|
|
40
|
+
3. Restart the shell
|
|
41
|
+
4. Move to the repository root (e.g., cd molgen)
|
|
42
|
+
5. Run the following commands:
|
|
43
|
+
```bash
|
|
44
|
+
uv venv --python 3.11.11
|
|
45
|
+
source .venv/bin/activate
|
|
46
|
+
uv pip install chemtsv3 numpy==1.26.4 pandas==2.3.3 matplotlib==3.10.7 rdkit==2023.09.6 ipykernel==6.30.0 transformers==4.43.4 torch==2.5.1 --torch-backend=auto
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
To activate the virtual environment, run the following command from the repository root (this process can also be automated through VS Code settings):
|
|
50
|
+
```bash
|
|
51
|
+
source .venv/bin/activate
|
|
52
|
+
```
|
|
53
|
+
To deactivate the virtual environment, run:
|
|
54
|
+
```bash
|
|
55
|
+
deactivate
|
|
56
|
+
```
|
|
57
|
+
</details>
|
|
58
|
+
|
|
59
|
+
<details>
|
|
60
|
+
<summary><b>Minimal installation (Windows)</b></summary><br>
|
|
61
|
+
|
|
62
|
+
### Available classes
|
|
63
|
+
- **Transition**: `GBGATransition`, `GPT2Transition`, `RNNBasedMutation`, `RNNTransition`, `SMIRKSTransition`
|
|
64
|
+
- **Reward**: `GFPReward`, `SimilarityReward`, `JScoreReward`, `LogPReward`
|
|
65
|
+
- **Policy**: `UCT`, `PUCT`
|
|
66
|
+
- The corresponding Node classes and all implemented Filter classes are also available in this environment.
|
|
67
|
+
|
|
68
|
+
### Setup steps
|
|
69
|
+
|
|
70
|
+
1. Clone the repository
|
|
71
|
+
2. Install uv: https://docs.astral.sh/uv/getting-started/installation/
|
|
72
|
+
3. Restart the shell (and VSCode if used)
|
|
73
|
+
4. Move to the repository root (e.g., cd molgen)
|
|
74
|
+
5. Run the following commands:
|
|
75
|
+
```bash
|
|
76
|
+
uv venv --python 3.11.11
|
|
77
|
+
.venv\Scripts\activate
|
|
78
|
+
uv pip install chemtsv3 numpy==1.26.4 pandas==2.3.3 matplotlib==3.10.7 rdkit==2023.09.6 ipykernel==6.30.0 transformers==4.43.4 torch==2.5.1 --torch-backend=auto
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
To activate the virtual environment, run the following command from the repository root (this process can also be automated through VS Code settings):
|
|
82
|
+
```bash
|
|
83
|
+
.venv\Scripts\activate
|
|
84
|
+
```
|
|
85
|
+
To deactivate the virtual environment, run:
|
|
86
|
+
```bash
|
|
87
|
+
deactivate
|
|
88
|
+
```
|
|
89
|
+
</details>
|
|
90
|
+
|
|
91
|
+
<details>
|
|
92
|
+
<summary><b>Full installation (Mac, Linux)</b></summary><br>
|
|
93
|
+
|
|
94
|
+
### Available classes
|
|
95
|
+
- **Transition**: `BioT5Transition`, `ChatGPTTransition`, `ChatGPTTransitionWithMemory`, `GBGATransition`, `GPT2Transition`, `RNNBasedMutation`, `RNNTransition`, `SMIRKSTransition`
|
|
96
|
+
- **Reward**: `DScoreReward`, `DyRAMOReward`, `GFPReward`, `SimilarityReward`, `JScoreReward`, `LogPReward`, `TDCReward`
|
|
97
|
+
- The corresponding Node classes, along with all implemented Filter and Policy classes, are also available in this environment.
|
|
98
|
+
- `ChatGPTTransition` and `ChatGPTTransitionWithMemory` requires openai api key to use.
|
|
99
|
+
|
|
100
|
+
### Setup steps
|
|
101
|
+
1. Clone the repository
|
|
102
|
+
2. Install uv: https://docs.astral.sh/uv/getting-started/installation/
|
|
103
|
+
3. Restart the shell
|
|
104
|
+
4. Move to the repository root (e.g., cd molgen)
|
|
105
|
+
5. Run the following commands:
|
|
106
|
+
```bash
|
|
107
|
+
uv venv --python 3.11.11
|
|
108
|
+
source .venv/bin/activate
|
|
109
|
+
uv pip install chemtsv3 pytdc==1.1.14 numpy==1.26.4 pandas==2.3.3 matplotlib==3.10.7 rdkit==2023.09.6 selfies==2.2.0 ipykernel==6.30.0 transformers==4.43.4 setuptools==78.1.1 lightgbm==4.6.0 openai==2.6.0 torch==2.5.1 --torch-backend=auto
|
|
110
|
+
```
|
|
111
|
+
To activate the virtual environment, run the following command from the repository root (this process can also be automated through VS Code settings):
|
|
112
|
+
```bash
|
|
113
|
+
source .venv/bin/activate
|
|
114
|
+
```
|
|
115
|
+
To deactivate the virtual environment, run:
|
|
116
|
+
```bash
|
|
117
|
+
deactivate
|
|
118
|
+
```
|
|
119
|
+
</details>
|
|
120
|
+
|
|
121
|
+
<details>
|
|
122
|
+
<summary><b>Optional dependencies</b></summary><br>
|
|
123
|
+
|
|
124
|
+
The full installation includes the following optional packages:
|
|
125
|
+
|
|
126
|
+
|Package|Required for|Tested version|
|
|
127
|
+
|---|---|---|
|
|
128
|
+
|`lightgbm`|`DScoreReward`, `DyRAMOReward`, `PUCTWithPredictor`|3.3.5, 4.6.0|
|
|
129
|
+
|`selfies`|`SELFIESStringNode`|2.2.0|
|
|
130
|
+
|`openai`|`ChatGPT2Transition`, `ChatGPT2TransitionWithMemory`|2.6.0|
|
|
131
|
+
|`pytdc`|`TDCReward`|1.1.14|
|
|
132
|
+
|
|
133
|
+
</details>
|
|
134
|
+
|
|
135
|
+
<details>
|
|
136
|
+
<summary><b>Troubleshooting</b></summary><br>
|
|
137
|
+
|
|
138
|
+
### CUDA not available
|
|
139
|
+
In some cases (for example, when setting up environments on a control node), it may be necessary to reinstall torch with a different backend to enable CUDA support. However, since major implemented classes (including `RNNTransition`) are likely to run faster on the CPU, this is not strictly required. After reinstalling torch, you may also need to downgrade numpy to version 1.26.4 if it was upgraded during the process.
|
|
140
|
+
</details>
|
|
141
|
+
|
|
142
|
+
</details>
|
|
143
|
+
|
|
144
|
+
## Generation via CLI
|
|
145
|
+
See `config/mcts/example.yaml` for an example YAML configuration.
|
|
146
|
+
```bash
|
|
147
|
+
# Simple generation
|
|
148
|
+
chemtsv3 -c config/mcts/example.yaml
|
|
149
|
+
# Chain generation
|
|
150
|
+
chemtsv3 -c config/mcts/example_chain_1.yaml
|
|
151
|
+
# Load a checkpoint and continue the generation
|
|
152
|
+
chemtsv3 -l generation_results/~~~ --max_generations 100 --time_limit 60
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
## Notebooks
|
|
156
|
+
- **Tutorials**: `sandbox/tutorial/***.ipynb`
|
|
157
|
+
- **Generation via notebook**: `sandbox/generation.ipynb`
|
|
158
|
+
|
|
159
|
+
## Options
|
|
160
|
+
See `config/mcts/example.yaml` for an example and advanced options. More examples (settings used in the paper) can be found in `config/mcts/egfr_de_novo/` and `config/mcts/egfr_lead_opt/`.
|
|
161
|
+
|
|
162
|
+
All options for each component (class) are defined as arguments in the `__init__()` method of the corresponding class.
|
|
163
|
+
|
|
164
|
+
<details>
|
|
165
|
+
<summary><b>Nodes and Transitions</b></summary><br>
|
|
166
|
+
|
|
167
|
+
**For general usage:**
|
|
168
|
+
|Node class|Transition class|Description|
|
|
169
|
+
|---|---|---|
|
|
170
|
+
|`MolSentenceNode`|`RNNTransition`|For de novo generation. Uses the RNN (GRU / LSTM) model specified by `model_dir`.|
|
|
171
|
+
|`MolSentenceNode`|`GPT2Transition`|For de novo generation. Uses the Transformer (GPT-2) model specified by `model_dir`.|
|
|
172
|
+
|`CanonicalSMILESStringNode`|`GBGATransition`|For lead optimization. Uses [GB-GA mutation rules](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc05372c).|
|
|
173
|
+
|`CanonicalSMILESStringNode`|`SMIRKSTransition`|For lead optimization. Uses the specified SMIRKS rules (e.g. MMP-based ones).|
|
|
174
|
+
|`SMILESStringNode`|`ChatGPTTransition`|For lead optimization. Uses the specified prompt(s) as input to the GPT model specified by `model` (e.g., `"gpt-4o-mini"`). Requires an OpenAI API key.|
|
|
175
|
+
|
|
176
|
+
**For research purposes (did not perform well in our testing):**
|
|
177
|
+
|Node class|Transition class|Description|
|
|
178
|
+
|---|---|---|
|
|
179
|
+
|`CanonicalSMILESStringNode`|`GBGMTransition`|For de novo generation. Uses [GB-GM rules](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc05372c). Rollouts iteratively apply transitions until the molecule size reaches a sampled value determined by `size_mean` and `size_std`.|
|
|
180
|
+
|`FASTAStringNode`|`ProtGPT2Transition`|For de novo protein generation. Uses the [ProtGPT2 model](https://www.nature.com/articles/s41467-022-32007-7).|
|
|
181
|
+
|`SELFIESStringNode`|`BioT5Transition`|For lead optimization. Uses the specified prompt(s) as input to the [BioT5 text2mol model](https://github.com/QizhiPei/BioT5).|
|
|
182
|
+
|`SMILESStringNode`|`ChatGPTTransitionWithMemory`|For lead optimization. Unlike `ChatGPTTransition`, retains conversation history and feedback reward calculation results to the model.|
|
|
183
|
+
|
|
184
|
+
</details>
|
|
185
|
+
|
|
186
|
+
<details>
|
|
187
|
+
<summary><b>Policies</b></summary><br>
|
|
188
|
+
|
|
189
|
+
- `UCT`: Does not use transition probabilities. Performed better with `RNNTransition` in our testing.
|
|
190
|
+
- `PUCT`: Incorporates transition probabilities (follows the modification introduced in [AlphaGo Zero](https://www.nature.com/articles/nature24270)). Performed better with `GBGATransition` in our testing.
|
|
191
|
+
- `PUCTWithPredictor`: Trains an optimistic predictor of leaf-node evaluations using the generation history, and uses its output as the score for unvisited nodes when the model’s performance (measured by the normalized pinball loss) exceeds a specified threshold. This option adds a few seconds of overhead per generation (depending on the number of child nodes per transition and the computational cost of each prediction), and is recommended only when the reward calculations are expensive. Inherits all the arguments of `UCT` and `PUCT`. For non-molecular nodes, a function that returns a feature vector must be defined (see `policy/puct_with_predictor.py` for details.)
|
|
192
|
+
|
|
193
|
+
</details>
|
|
194
|
+
|
|
195
|
+
<details>
|
|
196
|
+
<summary><b>Basic options</b></summary><br>
|
|
197
|
+
|
|
198
|
+
|Class|Option|Default|Description|
|
|
199
|
+
|---|---|---|---|
|
|
200
|
+
|-|`max_generations`|-|Stops generation after producing the specified number of molecules.|
|
|
201
|
+
|-|`time_limit`|-|Stops generation once the time limit (in seconds) is reached.|
|
|
202
|
+
|-|`root`|`""`|Key (string) for the root node (e.g. SMILES of the starting molecule for `SMILESStringNode`). Multiple roots can be specified by list input. If not specified, an empty string `""` will be used as the root node's key.|
|
|
203
|
+
|`MCTS`|`n_eval_width`|∞|By default (= ∞), evaluates all new leaf nodes after each transition. Setting `n_eval_width = 1` often improves sample efficiency and can be beneficial when reward computation is expensive.|
|
|
204
|
+
|`MCTS`|`filter_reward`|0|Substitutes the reward with this value when nodes are filtered. Use a list to specify different reward values for each filtering step. Set to `"ignore"` to skip reward assignment (in this case, other penalty types for filtered nodes, such as `failed_parent_reward`, needs to be set).|
|
|
205
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`c`|0.3|A larger value prioritizes exploration over exploitation. Recommended range: [0.01, 1]|
|
|
206
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`best_rate`|0|A value between 0 and 1. The exploitation term is calculated as: `best_rate` * {best reward} + (1 - `best_rate`) * {average reward}. For better sample efficiency, it might be better to set this value to around 0.5 for de novo generations, and around 0.9 for lead optimizations.|
|
|
207
|
+
|
|
208
|
+
</details>
|
|
209
|
+
|
|
210
|
+
<details>
|
|
211
|
+
<summary><b>Advanced options</b></summary><br>
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
For other options and further details, please refer to each class’s `__init__()` method.
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
|Class|Option|Default|Description|
|
|
218
|
+
|---|---|---|---|
|
|
219
|
+
|-|`seed`|-|The seed value for `random`, `np.random` and `torch`.|
|
|
220
|
+
|-|`device`|-|Torch device specification (e.g., "cpu", "cuda", "cuda:0"). For `RNNTransition`, using the CPU tends to be faster.|
|
|
221
|
+
|-|`debug`|False|If True, debug logging are enabled.|
|
|
222
|
+
|-|`silent`|False|If True, console logging are disabled.|
|
|
223
|
+
|-|`next_yaml_path`|False|If a path to the YAML config for the next generator is set, the generated molecules will be passed for chain generation.|
|
|
224
|
+
|-|`n_keys_to_pass`|3|Number of top-k generated molecules (keys) to be used as root nodes for the next generator.|
|
|
225
|
+
|`MCTS`|`save_on_completion`|False|If True, saves a checkpoint upon completion of the generation.|
|
|
226
|
+
|`MCTS`|`n_eval_iters`|1|The number of child node evaluations. This value should not be > 1 unless the evaluations are undeterministic (e.g. involve rollouts).|
|
|
227
|
+
|`MCTS`|`n_tries`|1|The number of attempts to obtain an unfiltered node in a single evaluation. This value should not be >1 unless the evaluations are undeterministic (e.g. involve rollouts).|
|
|
228
|
+
|`MCTS`|`allow_eval_overlaps`|False|Whether to allow overlap nodes when sampling eval candidates (recommended: False)|
|
|
229
|
+
|`MCTS`|`reward_cutoff`|None|Child nodes are removed if their reward is lower than this value. This applies only to nodes for which `has_reward() = True` (i.e., complete molecules). |
|
|
230
|
+
|`MCTS`|`reward_cutoff_warmups`|None|If specified, reward_cutoff will be inactive until `reward_cutoff_warmups` generations.|
|
|
231
|
+
|`MCTS`|`cut_failed_child`|False|If True, child nodes will be removed when {`n_eval_iters` * `n_tries`} evals are filtered.|
|
|
232
|
+
|`MCTS`|`failed_parent_reward`|`"ignore"`|Backpropagate this value when {`n_eval_width` * `n_eval_iters` * `n_tries`} evals are filtered from the node.|
|
|
233
|
+
|`MCTS`|`terminal_reward`|`"ignore"`|If a float value is set, that value is backpropagated when a leaf node reaches a terminal state. If set to `"ignore"`, no value is backpropagated.|
|
|
234
|
+
|`MCTS`|`cut_terminal`|True|If True, terminal nodes are pruned from the search tree and will not be visited more than once.|
|
|
235
|
+
|`MCTS`|`avoid_duplicates`|True|If True, duplicate nodes won't be added to the search tree. Should be True if the transition forms a cyclic graph. Unneeded if the tree structure of the transition graph is guranteed, and can be set to False to reduce memory usage.|
|
|
236
|
+
|`MCTS`|`discard_unneeded_states`|True|If True, discards node variables that are no longer needed after expansion. Set this to False when using custom classes that utilize these values.|
|
|
237
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`pw_c`, `pw_alpha`, `pw_beta`|None, 0, 0|If `pw_c` is set, the number of available child nodes is limited to `pw_c` * ({visit count} ** `pw_alpha`) + `pw_beta`.|
|
|
238
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`max_prior`|None (0)|A lower bound for the best reward. If the actual best reward is lower than this value, this value is used instead.|
|
|
239
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`epsilon`|0|The probability of randomly selecting a child node while descending the search tree.|
|
|
240
|
+
|`PUCTWithPredictor`|`alpha`|0.9|Quantile level for the predictor, representing the target percentile of the response variable to be estimated and used.|
|
|
241
|
+
|`PUCTWithPredictor`|`score_threshold`|0.6|If the recent prediction score (1 - {pinball loss} / {baseline pinball loss}) is better than this threshold, the model will be used afterwards.|
|
|
242
|
+
|`MolSentenceNode`, `MolStringNode`|`use_canonical_smiles_as_key`|False|Whether to convert generated molecules to canonical SMILES when generating keys. If False, the same molecule may be counted multiple times.|
|
|
243
|
+
|`RNNTransition`, `GPT2Transition`|`top_p`|0.995|Nucleus sampling threshold in (0, 1]; keeps the smallest probability mass ≥ `top_p`.|
|
|
244
|
+
|`RNNTransition`, `GPT2Transition`|`temperature`|1|Logit temperature > 0 applied **before** `top_p`; values < 1.0 sharp, > 1.0 smooth.|
|
|
245
|
+
|`RNNTransition`|`sharpness`|1| Probability distribution sharpness > 0 applied **after** `top_p`; values < 1.0 smooth, > 1.0 sharp.|
|
|
246
|
+
|`RNNTransition`|`disable_top_p_on_rollout`|False|If True, `top_p` won't be applied for rollouts.|
|
|
247
|
+
|`SMIRKSTransition`|`limit`|None|If the number of generated SMILES exceeded this value, stops applying further SMIRKS patterns. The order of SMIRKS patterns are shuffled with weights before applying transition if this option is enabled.|
|
|
248
|
+
|
|
249
|
+
</details>
|
|
250
|
+
|
|
251
|
+
<details>
|
|
252
|
+
<summary><b>Filters</b></summary><br>
|
|
253
|
+
|
|
254
|
+
**Sanity**
|
|
255
|
+
- `ValidityFilter`: Excludes invalid molecule objects. Since other filters and rewards typically assume validity and do not recheck it, usually this filter should be applied first in molecular generation.
|
|
256
|
+
- `RadicalFilter`: Excludes molecules whose number of radical electrons is not 0.
|
|
257
|
+
- `ConnectivityFilter`: Excludes molecules whose number of disconnected fragments is not 1.
|
|
258
|
+
|
|
259
|
+
**Topological**
|
|
260
|
+
- `SubstructureFilter`: Excludes molecules that **do not** contain the specified (list of) substructure(s) by `smiles` or `smarts` arguments. If `preserve` is set to False, excludes molecules that **do** contain the specified (list of) substructure(s) instead. By specifying appropriate SMARTS patterns, it is possible to control where substitutions or structural modifications (i.e., adding a substituent or arm) are allowed to occur.
|
|
261
|
+
- `AromaticRingFilter`: Excludes molecules whose number of aromatic rings falls outside the range [`min`, `max`]. (Default: [1, ∞))
|
|
262
|
+
- `HeavyAtomCountFilter`: Excludes molecules whose number of heavy atoms falls outside the range [`min`, `max`]. (Default: [0, 45])
|
|
263
|
+
- `MaxRingSizeFilter`: Excludes molecules whose largest ring size falls outside the range [`min`, `max`]. (Default: [0, 6])
|
|
264
|
+
- `MinRingSizeFilter`: Excludes molecules whose smallest ring size falls outside the range [`min`, `max`]. (Default: (-∞, ∞))
|
|
265
|
+
- `RingBondFilter`: Excludes molecules containing ring allenes (`[R]=[R]=[R]`) or double bonds in small rings (`[r3,r4]=[r3,r4]`).
|
|
266
|
+
- `RotatableBondsFilter`: Excludes molecules whose number of rotatable bonds falls outside the range [`min`, `max`]. (Default: [0, 10])
|
|
267
|
+
|
|
268
|
+
**Structural alert**
|
|
269
|
+
- `ROCFilter`: Excludes molecules that contain structural alerts defined by Ohta and Cho.
|
|
270
|
+
- `CatalogFilter`: Excludes molecules that contain structural alerts in the specified list of [rdkit.Chem.FilterCatalogParams.FilterCatalogs](https://www.rdkit.org/docs/source/rdkit.Chem.rdfiltercatalog.html#rdkit.Chem.rdfiltercatalog.FilterCatalogParams.FilterCatalogs). (e.g. `catalogs = ["PAINS_A", "PAINS_B", "PAINS_C", "NIH", "BRENK"]`)
|
|
271
|
+
|
|
272
|
+
**Drug-likeness**
|
|
273
|
+
- `PubChemFilter`: Excludes molecules based on the frequency of occurrence of molecular patterns in the PubChem database. Reported in [Ma et al.](https://doi.org/10.1021/acs.jcim.1c00679).
|
|
274
|
+
- `LipinskiFilter`: Excludes molecules based on Lipinski’s Rule of Five. Set `rule_of` to 3 to apply the Rule of Three instead.
|
|
275
|
+
- `SAScoreFilter`: Excludes molecules whose synthetic accessibility score (SA Score) falls outside the range [`min`, `max`]. (Default: [1, 3.5])
|
|
276
|
+
|
|
277
|
+
**Physicochemical**
|
|
278
|
+
- `ChargeFilter`: Excludes molecules whose formal charge is not 0.
|
|
279
|
+
- `HBAFilter`: Excludes molecules whose number of hydrogen bond acceptors falls outside the range [`min`, `max`]. (Default: [0, 10])
|
|
280
|
+
- `HBDFilter`: Excludes molecules whose number of hydrogen bond donors falls outside the range [`min`, `max`]. (Default: [0, 5])
|
|
281
|
+
- `LogPFilter`: Excludes molecules whose LogP value falls outside the range [`min`, `max`]. (Default: (-∞, 5])
|
|
282
|
+
- `TPSAFilter`: Excludes molecules whose topological polar surface area (TPSA) falls outside the range [`min`, `max`]. (Default: [0, 140])
|
|
283
|
+
- `WeightFilter`: Excludes molecules whose molecular weight falls outside the range [`min`, `max`]. (Default: [0, 500])
|
|
284
|
+
|
|
285
|
+
**Misc**
|
|
286
|
+
- `KnownListFilter`: Excludes molecules that are contained in the key column of the input CSV file(s), and overrides their reward with the corresponding value from the reward column (unless applied for the transition). (CSV files from generation results can be used directly.)
|
|
287
|
+
|
|
288
|
+
Filters can also be specified using `filters` argument of transitions that inherit from `TemplateTransition` (e.g. `GBGATransition`, `SMIRKSTransition`, `ChatGPTTransition`) to directly exclude molecules from child nodes.
|
|
289
|
+
|
|
290
|
+
</details>
|
|
291
|
+
|
|
292
|
+
## Model training
|
|
293
|
+
- **RNN (GRU) training** (example): `chemtsv3-train -c config/training/train_rnn_smiles.yaml`
|
|
294
|
+
- **Transformer (GPT-2) training** (example): `chemtsv3-train -c config/training/train_gpt2.yaml`
|
|
295
|
+
Change `dataset_path` in YAML to train on an arbitrary dataset (1 sentence per line).
|
|
296
|
+
|
chemtsv3-0.1.2/README.md
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
## ChemTSv3
|
|
2
|
+
A unified tree search framework for molecular generation.
|
|
3
|
+
- **Node is modular**: Supports any molecular representation (e.g., SMILES, SELFIES, FASTA, or HELM) in either string or tensor format.
|
|
4
|
+
- **Transition is modular**: Allows any molecular transformation strategy, including graph-based editing, sequence generation with RNN or GPT-2, sequence mutation, or LLM-guided modification.
|
|
5
|
+
- **Filter is modular**: Enables flexible constraints such as structural alerts, scaffold preservation, or physicochemical property filters.
|
|
6
|
+
- **Reward is modular**: Anything can be optimized, including QSAR predictions or simulation results, for both single- and multi-objective tasks.
|
|
7
|
+
|
|
8
|
+
## Setup
|
|
9
|
+
|
|
10
|
+
<details>
|
|
11
|
+
<summary><b>Minimal installation (Mac, Linux)</b></summary><br>
|
|
12
|
+
|
|
13
|
+
### Available classes
|
|
14
|
+
- **Transition**: `GBGATransition`, `GPT2Transition`, `RNNBasedMutation`, `RNNTransition`, `SMIRKSTransition`
|
|
15
|
+
- **Reward**: `GFPReward`, `SimilarityReward`, `JScoreReward`, `LogPReward`
|
|
16
|
+
- **Policy**: `UCT`, `PUCT`
|
|
17
|
+
- The corresponding Node classes and all implemented Filter classes are also available in this environment.
|
|
18
|
+
|
|
19
|
+
### Setup steps
|
|
20
|
+
|
|
21
|
+
1. Clone the repository
|
|
22
|
+
2. Install uv: https://docs.astral.sh/uv/getting-started/installation/
|
|
23
|
+
3. Restart the shell
|
|
24
|
+
4. Move to the repository root (e.g., cd molgen)
|
|
25
|
+
5. Run the following commands:
|
|
26
|
+
```bash
|
|
27
|
+
uv venv --python 3.11.11
|
|
28
|
+
source .venv/bin/activate
|
|
29
|
+
uv pip install chemtsv3 numpy==1.26.4 pandas==2.3.3 matplotlib==3.10.7 rdkit==2023.09.6 ipykernel==6.30.0 transformers==4.43.4 torch==2.5.1 --torch-backend=auto
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
To activate the virtual environment, run the following command from the repository root (this process can also be automated through VS Code settings):
|
|
33
|
+
```bash
|
|
34
|
+
source .venv/bin/activate
|
|
35
|
+
```
|
|
36
|
+
To deactivate the virtual environment, run:
|
|
37
|
+
```bash
|
|
38
|
+
deactivate
|
|
39
|
+
```
|
|
40
|
+
</details>
|
|
41
|
+
|
|
42
|
+
<details>
|
|
43
|
+
<summary><b>Minimal installation (Windows)</b></summary><br>
|
|
44
|
+
|
|
45
|
+
### Available classes
|
|
46
|
+
- **Transition**: `GBGATransition`, `GPT2Transition`, `RNNBasedMutation`, `RNNTransition`, `SMIRKSTransition`
|
|
47
|
+
- **Reward**: `GFPReward`, `SimilarityReward`, `JScoreReward`, `LogPReward`
|
|
48
|
+
- **Policy**: `UCT`, `PUCT`
|
|
49
|
+
- The corresponding Node classes and all implemented Filter classes are also available in this environment.
|
|
50
|
+
|
|
51
|
+
### Setup steps
|
|
52
|
+
|
|
53
|
+
1. Clone the repository
|
|
54
|
+
2. Install uv: https://docs.astral.sh/uv/getting-started/installation/
|
|
55
|
+
3. Restart the shell (and VSCode if used)
|
|
56
|
+
4. Move to the repository root (e.g., cd molgen)
|
|
57
|
+
5. Run the following commands:
|
|
58
|
+
```bash
|
|
59
|
+
uv venv --python 3.11.11
|
|
60
|
+
.venv\Scripts\activate
|
|
61
|
+
uv pip install chemtsv3 numpy==1.26.4 pandas==2.3.3 matplotlib==3.10.7 rdkit==2023.09.6 ipykernel==6.30.0 transformers==4.43.4 torch==2.5.1 --torch-backend=auto
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
To activate the virtual environment, run the following command from the repository root (this process can also be automated through VS Code settings):
|
|
65
|
+
```bash
|
|
66
|
+
.venv\Scripts\activate
|
|
67
|
+
```
|
|
68
|
+
To deactivate the virtual environment, run:
|
|
69
|
+
```bash
|
|
70
|
+
deactivate
|
|
71
|
+
```
|
|
72
|
+
</details>
|
|
73
|
+
|
|
74
|
+
<details>
|
|
75
|
+
<summary><b>Full installation (Mac, Linux)</b></summary><br>
|
|
76
|
+
|
|
77
|
+
### Available classes
|
|
78
|
+
- **Transition**: `BioT5Transition`, `ChatGPTTransition`, `ChatGPTTransitionWithMemory`, `GBGATransition`, `GPT2Transition`, `RNNBasedMutation`, `RNNTransition`, `SMIRKSTransition`
|
|
79
|
+
- **Reward**: `DScoreReward`, `DyRAMOReward`, `GFPReward`, `SimilarityReward`, `JScoreReward`, `LogPReward`, `TDCReward`
|
|
80
|
+
- The corresponding Node classes, along with all implemented Filter and Policy classes, are also available in this environment.
|
|
81
|
+
- `ChatGPTTransition` and `ChatGPTTransitionWithMemory` requires openai api key to use.
|
|
82
|
+
|
|
83
|
+
### Setup steps
|
|
84
|
+
1. Clone the repository
|
|
85
|
+
2. Install uv: https://docs.astral.sh/uv/getting-started/installation/
|
|
86
|
+
3. Restart the shell
|
|
87
|
+
4. Move to the repository root (e.g., cd molgen)
|
|
88
|
+
5. Run the following commands:
|
|
89
|
+
```bash
|
|
90
|
+
uv venv --python 3.11.11
|
|
91
|
+
source .venv/bin/activate
|
|
92
|
+
uv pip install chemtsv3 pytdc==1.1.14 numpy==1.26.4 pandas==2.3.3 matplotlib==3.10.7 rdkit==2023.09.6 selfies==2.2.0 ipykernel==6.30.0 transformers==4.43.4 setuptools==78.1.1 lightgbm==4.6.0 openai==2.6.0 torch==2.5.1 --torch-backend=auto
|
|
93
|
+
```
|
|
94
|
+
To activate the virtual environment, run the following command from the repository root (this process can also be automated through VS Code settings):
|
|
95
|
+
```bash
|
|
96
|
+
source .venv/bin/activate
|
|
97
|
+
```
|
|
98
|
+
To deactivate the virtual environment, run:
|
|
99
|
+
```bash
|
|
100
|
+
deactivate
|
|
101
|
+
```
|
|
102
|
+
</details>
|
|
103
|
+
|
|
104
|
+
<details>
|
|
105
|
+
<summary><b>Optional dependencies</b></summary><br>
|
|
106
|
+
|
|
107
|
+
The full installation includes the following optional packages:
|
|
108
|
+
|
|
109
|
+
|Package|Required for|Tested version|
|
|
110
|
+
|---|---|---|
|
|
111
|
+
|`lightgbm`|`DScoreReward`, `DyRAMOReward`, `PUCTWithPredictor`|3.3.5, 4.6.0|
|
|
112
|
+
|`selfies`|`SELFIESStringNode`|2.2.0|
|
|
113
|
+
|`openai`|`ChatGPT2Transition`, `ChatGPT2TransitionWithMemory`|2.6.0|
|
|
114
|
+
|`pytdc`|`TDCReward`|1.1.14|
|
|
115
|
+
|
|
116
|
+
</details>
|
|
117
|
+
|
|
118
|
+
<details>
|
|
119
|
+
<summary><b>Troubleshooting</b></summary><br>
|
|
120
|
+
|
|
121
|
+
### CUDA not available
|
|
122
|
+
In some cases (for example, when setting up environments on a control node), it may be necessary to reinstall torch with a different backend to enable CUDA support. However, since major implemented classes (including `RNNTransition`) are likely to run faster on the CPU, this is not strictly required. After reinstalling torch, you may also need to downgrade numpy to version 1.26.4 if it was upgraded during the process.
|
|
123
|
+
</details>
|
|
124
|
+
|
|
125
|
+
</details>
|
|
126
|
+
|
|
127
|
+
## Generation via CLI
|
|
128
|
+
See `config/mcts/example.yaml` for an example YAML configuration.
|
|
129
|
+
```bash
|
|
130
|
+
# Simple generation
|
|
131
|
+
chemtsv3 -c config/mcts/example.yaml
|
|
132
|
+
# Chain generation
|
|
133
|
+
chemtsv3 -c config/mcts/example_chain_1.yaml
|
|
134
|
+
# Load a checkpoint and continue the generation
|
|
135
|
+
chemtsv3 -l generation_results/~~~ --max_generations 100 --time_limit 60
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
## Notebooks
|
|
139
|
+
- **Tutorials**: `sandbox/tutorial/***.ipynb`
|
|
140
|
+
- **Generation via notebook**: `sandbox/generation.ipynb`
|
|
141
|
+
|
|
142
|
+
## Options
|
|
143
|
+
See `config/mcts/example.yaml` for an example and advanced options. More examples (settings used in the paper) can be found in `config/mcts/egfr_de_novo/` and `config/mcts/egfr_lead_opt/`.
|
|
144
|
+
|
|
145
|
+
All options for each component (class) are defined as arguments in the `__init__()` method of the corresponding class.
|
|
146
|
+
|
|
147
|
+
<details>
|
|
148
|
+
<summary><b>Nodes and Transitions</b></summary><br>
|
|
149
|
+
|
|
150
|
+
**For general usage:**
|
|
151
|
+
|Node class|Transition class|Description|
|
|
152
|
+
|---|---|---|
|
|
153
|
+
|`MolSentenceNode`|`RNNTransition`|For de novo generation. Uses the RNN (GRU / LSTM) model specified by `model_dir`.|
|
|
154
|
+
|`MolSentenceNode`|`GPT2Transition`|For de novo generation. Uses the Transformer (GPT-2) model specified by `model_dir`.|
|
|
155
|
+
|`CanonicalSMILESStringNode`|`GBGATransition`|For lead optimization. Uses [GB-GA mutation rules](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc05372c).|
|
|
156
|
+
|`CanonicalSMILESStringNode`|`SMIRKSTransition`|For lead optimization. Uses the specified SMIRKS rules (e.g. MMP-based ones).|
|
|
157
|
+
|`SMILESStringNode`|`ChatGPTTransition`|For lead optimization. Uses the specified prompt(s) as input to the GPT model specified by `model` (e.g., `"gpt-4o-mini"`). Requires an OpenAI API key.|
|
|
158
|
+
|
|
159
|
+
**For research purposes (did not perform well in our testing):**
|
|
160
|
+
|Node class|Transition class|Description|
|
|
161
|
+
|---|---|---|
|
|
162
|
+
|`CanonicalSMILESStringNode`|`GBGMTransition`|For de novo generation. Uses [GB-GM rules](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc05372c). Rollouts iteratively apply transitions until the molecule size reaches a sampled value determined by `size_mean` and `size_std`.|
|
|
163
|
+
|`FASTAStringNode`|`ProtGPT2Transition`|For de novo protein generation. Uses the [ProtGPT2 model](https://www.nature.com/articles/s41467-022-32007-7).|
|
|
164
|
+
|`SELFIESStringNode`|`BioT5Transition`|For lead optimization. Uses the specified prompt(s) as input to the [BioT5 text2mol model](https://github.com/QizhiPei/BioT5).|
|
|
165
|
+
|`SMILESStringNode`|`ChatGPTTransitionWithMemory`|For lead optimization. Unlike `ChatGPTTransition`, retains conversation history and feedback reward calculation results to the model.|
|
|
166
|
+
|
|
167
|
+
</details>
|
|
168
|
+
|
|
169
|
+
<details>
|
|
170
|
+
<summary><b>Policies</b></summary><br>
|
|
171
|
+
|
|
172
|
+
- `UCT`: Does not use transition probabilities. Performed better with `RNNTransition` in our testing.
|
|
173
|
+
- `PUCT`: Incorporates transition probabilities (follows the modification introduced in [AlphaGo Zero](https://www.nature.com/articles/nature24270)). Performed better with `GBGATransition` in our testing.
|
|
174
|
+
- `PUCTWithPredictor`: Trains an optimistic predictor of leaf-node evaluations using the generation history, and uses its output as the score for unvisited nodes when the model’s performance (measured by the normalized pinball loss) exceeds a specified threshold. This option adds a few seconds of overhead per generation (depending on the number of child nodes per transition and the computational cost of each prediction), and is recommended only when the reward calculations are expensive. Inherits all the arguments of `UCT` and `PUCT`. For non-molecular nodes, a function that returns a feature vector must be defined (see `policy/puct_with_predictor.py` for details.)
|
|
175
|
+
|
|
176
|
+
</details>
|
|
177
|
+
|
|
178
|
+
<details>
|
|
179
|
+
<summary><b>Basic options</b></summary><br>
|
|
180
|
+
|
|
181
|
+
|Class|Option|Default|Description|
|
|
182
|
+
|---|---|---|---|
|
|
183
|
+
|-|`max_generations`|-|Stops generation after producing the specified number of molecules.|
|
|
184
|
+
|-|`time_limit`|-|Stops generation once the time limit (in seconds) is reached.|
|
|
185
|
+
|-|`root`|`""`|Key (string) for the root node (e.g. SMILES of the starting molecule for `SMILESStringNode`). Multiple roots can be specified by list input. If not specified, an empty string `""` will be used as the root node's key.|
|
|
186
|
+
|`MCTS`|`n_eval_width`|∞|By default (= ∞), evaluates all new leaf nodes after each transition. Setting `n_eval_width = 1` often improves sample efficiency and can be beneficial when reward computation is expensive.|
|
|
187
|
+
|`MCTS`|`filter_reward`|0|Substitutes the reward with this value when nodes are filtered. Use a list to specify different reward values for each filtering step. Set to `"ignore"` to skip reward assignment (in this case, other penalty types for filtered nodes, such as `failed_parent_reward`, needs to be set).|
|
|
188
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`c`|0.3|A larger value prioritizes exploration over exploitation. Recommended range: [0.01, 1]|
|
|
189
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`best_rate`|0|A value between 0 and 1. The exploitation term is calculated as: `best_rate` * {best reward} + (1 - `best_rate`) * {average reward}. For better sample efficiency, it might be better to set this value to around 0.5 for de novo generations, and around 0.9 for lead optimizations.|
|
|
190
|
+
|
|
191
|
+
</details>
|
|
192
|
+
|
|
193
|
+
<details>
|
|
194
|
+
<summary><b>Advanced options</b></summary><br>
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
For other options and further details, please refer to each class’s `__init__()` method.
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
|Class|Option|Default|Description|
|
|
201
|
+
|---|---|---|---|
|
|
202
|
+
|-|`seed`|-|The seed value for `random`, `np.random` and `torch`.|
|
|
203
|
+
|-|`device`|-|Torch device specification (e.g., "cpu", "cuda", "cuda:0"). For `RNNTransition`, using the CPU tends to be faster.|
|
|
204
|
+
|-|`debug`|False|If True, debug logging are enabled.|
|
|
205
|
+
|-|`silent`|False|If True, console logging are disabled.|
|
|
206
|
+
|-|`next_yaml_path`|False|If a path to the YAML config for the next generator is set, the generated molecules will be passed for chain generation.|
|
|
207
|
+
|-|`n_keys_to_pass`|3|Number of top-k generated molecules (keys) to be used as root nodes for the next generator.|
|
|
208
|
+
|`MCTS`|`save_on_completion`|False|If True, saves a checkpoint upon completion of the generation.|
|
|
209
|
+
|`MCTS`|`n_eval_iters`|1|The number of child node evaluations. This value should not be > 1 unless the evaluations are undeterministic (e.g. involve rollouts).|
|
|
210
|
+
|`MCTS`|`n_tries`|1|The number of attempts to obtain an unfiltered node in a single evaluation. This value should not be >1 unless the evaluations are undeterministic (e.g. involve rollouts).|
|
|
211
|
+
|`MCTS`|`allow_eval_overlaps`|False|Whether to allow overlap nodes when sampling eval candidates (recommended: False)|
|
|
212
|
+
|`MCTS`|`reward_cutoff`|None|Child nodes are removed if their reward is lower than this value. This applies only to nodes for which `has_reward() = True` (i.e., complete molecules). |
|
|
213
|
+
|`MCTS`|`reward_cutoff_warmups`|None|If specified, reward_cutoff will be inactive until `reward_cutoff_warmups` generations.|
|
|
214
|
+
|`MCTS`|`cut_failed_child`|False|If True, child nodes will be removed when {`n_eval_iters` * `n_tries`} evals are filtered.|
|
|
215
|
+
|`MCTS`|`failed_parent_reward`|`"ignore"`|Backpropagate this value when {`n_eval_width` * `n_eval_iters` * `n_tries`} evals are filtered from the node.|
|
|
216
|
+
|`MCTS`|`terminal_reward`|`"ignore"`|If a float value is set, that value is backpropagated when a leaf node reaches a terminal state. If set to `"ignore"`, no value is backpropagated.|
|
|
217
|
+
|`MCTS`|`cut_terminal`|True|If True, terminal nodes are pruned from the search tree and will not be visited more than once.|
|
|
218
|
+
|`MCTS`|`avoid_duplicates`|True|If True, duplicate nodes won't be added to the search tree. Should be True if the transition forms a cyclic graph. Unneeded if the tree structure of the transition graph is guranteed, and can be set to False to reduce memory usage.|
|
|
219
|
+
|`MCTS`|`discard_unneeded_states`|True|If True, discards node variables that are no longer needed after expansion. Set this to False when using custom classes that utilize these values.|
|
|
220
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`pw_c`, `pw_alpha`, `pw_beta`|None, 0, 0|If `pw_c` is set, the number of available child nodes is limited to `pw_c` * ({visit count} ** `pw_alpha`) + `pw_beta`.|
|
|
221
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`max_prior`|None (0)|A lower bound for the best reward. If the actual best reward is lower than this value, this value is used instead.|
|
|
222
|
+
|`UCT`, `PUCT`, `PUCTWithPredictor`|`epsilon`|0|The probability of randomly selecting a child node while descending the search tree.|
|
|
223
|
+
|`PUCTWithPredictor`|`alpha`|0.9|Quantile level for the predictor, representing the target percentile of the response variable to be estimated and used.|
|
|
224
|
+
|`PUCTWithPredictor`|`score_threshold`|0.6|If the recent prediction score (1 - {pinball loss} / {baseline pinball loss}) is better than this threshold, the model will be used afterwards.|
|
|
225
|
+
|`MolSentenceNode`, `MolStringNode`|`use_canonical_smiles_as_key`|False|Whether to convert generated molecules to canonical SMILES when generating keys. If False, the same molecule may be counted multiple times.|
|
|
226
|
+
|`RNNTransition`, `GPT2Transition`|`top_p`|0.995|Nucleus sampling threshold in (0, 1]; keeps the smallest probability mass ≥ `top_p`.|
|
|
227
|
+
|`RNNTransition`, `GPT2Transition`|`temperature`|1|Logit temperature > 0 applied **before** `top_p`; values < 1.0 sharp, > 1.0 smooth.|
|
|
228
|
+
|`RNNTransition`|`sharpness`|1| Probability distribution sharpness > 0 applied **after** `top_p`; values < 1.0 smooth, > 1.0 sharp.|
|
|
229
|
+
|`RNNTransition`|`disable_top_p_on_rollout`|False|If True, `top_p` won't be applied for rollouts.|
|
|
230
|
+
|`SMIRKSTransition`|`limit`|None|If the number of generated SMILES exceeded this value, stops applying further SMIRKS patterns. The order of SMIRKS patterns are shuffled with weights before applying transition if this option is enabled.|
|
|
231
|
+
|
|
232
|
+
</details>
|
|
233
|
+
|
|
234
|
+
<details>
|
|
235
|
+
<summary><b>Filters</b></summary><br>
|
|
236
|
+
|
|
237
|
+
**Sanity**
|
|
238
|
+
- `ValidityFilter`: Excludes invalid molecule objects. Since other filters and rewards typically assume validity and do not recheck it, usually this filter should be applied first in molecular generation.
|
|
239
|
+
- `RadicalFilter`: Excludes molecules whose number of radical electrons is not 0.
|
|
240
|
+
- `ConnectivityFilter`: Excludes molecules whose number of disconnected fragments is not 1.
|
|
241
|
+
|
|
242
|
+
**Topological**
|
|
243
|
+
- `SubstructureFilter`: Excludes molecules that **do not** contain the specified (list of) substructure(s) by `smiles` or `smarts` arguments. If `preserve` is set to False, excludes molecules that **do** contain the specified (list of) substructure(s) instead. By specifying appropriate SMARTS patterns, it is possible to control where substitutions or structural modifications (i.e., adding a substituent or arm) are allowed to occur.
|
|
244
|
+
- `AromaticRingFilter`: Excludes molecules whose number of aromatic rings falls outside the range [`min`, `max`]. (Default: [1, ∞))
|
|
245
|
+
- `HeavyAtomCountFilter`: Excludes molecules whose number of heavy atoms falls outside the range [`min`, `max`]. (Default: [0, 45])
|
|
246
|
+
- `MaxRingSizeFilter`: Excludes molecules whose largest ring size falls outside the range [`min`, `max`]. (Default: [0, 6])
|
|
247
|
+
- `MinRingSizeFilter`: Excludes molecules whose smallest ring size falls outside the range [`min`, `max`]. (Default: (-∞, ∞))
|
|
248
|
+
- `RingBondFilter`: Excludes molecules containing ring allenes (`[R]=[R]=[R]`) or double bonds in small rings (`[r3,r4]=[r3,r4]`).
|
|
249
|
+
- `RotatableBondsFilter`: Excludes molecules whose number of rotatable bonds falls outside the range [`min`, `max`]. (Default: [0, 10])
|
|
250
|
+
|
|
251
|
+
**Structural alert**
|
|
252
|
+
- `ROCFilter`: Excludes molecules that contain structural alerts defined by Ohta and Cho.
|
|
253
|
+
- `CatalogFilter`: Excludes molecules that contain structural alerts in the specified list of [rdkit.Chem.FilterCatalogParams.FilterCatalogs](https://www.rdkit.org/docs/source/rdkit.Chem.rdfiltercatalog.html#rdkit.Chem.rdfiltercatalog.FilterCatalogParams.FilterCatalogs). (e.g. `catalogs = ["PAINS_A", "PAINS_B", "PAINS_C", "NIH", "BRENK"]`)
|
|
254
|
+
|
|
255
|
+
**Drug-likeness**
|
|
256
|
+
- `PubChemFilter`: Excludes molecules based on the frequency of occurrence of molecular patterns in the PubChem database. Reported in [Ma et al.](https://doi.org/10.1021/acs.jcim.1c00679).
|
|
257
|
+
- `LipinskiFilter`: Excludes molecules based on Lipinski’s Rule of Five. Set `rule_of` to 3 to apply the Rule of Three instead.
|
|
258
|
+
- `SAScoreFilter`: Excludes molecules whose synthetic accessibility score (SA Score) falls outside the range [`min`, `max`]. (Default: [1, 3.5])
|
|
259
|
+
|
|
260
|
+
**Physicochemical**
|
|
261
|
+
- `ChargeFilter`: Excludes molecules whose formal charge is not 0.
|
|
262
|
+
- `HBAFilter`: Excludes molecules whose number of hydrogen bond acceptors falls outside the range [`min`, `max`]. (Default: [0, 10])
|
|
263
|
+
- `HBDFilter`: Excludes molecules whose number of hydrogen bond donors falls outside the range [`min`, `max`]. (Default: [0, 5])
|
|
264
|
+
- `LogPFilter`: Excludes molecules whose LogP value falls outside the range [`min`, `max`]. (Default: (-∞, 5])
|
|
265
|
+
- `TPSAFilter`: Excludes molecules whose topological polar surface area (TPSA) falls outside the range [`min`, `max`]. (Default: [0, 140])
|
|
266
|
+
- `WeightFilter`: Excludes molecules whose molecular weight falls outside the range [`min`, `max`]. (Default: [0, 500])
|
|
267
|
+
|
|
268
|
+
**Misc**
|
|
269
|
+
- `KnownListFilter`: Excludes molecules that are contained in the key column of the input CSV file(s), and overrides their reward with the corresponding value from the reward column (unless applied for the transition). (CSV files from generation results can be used directly.)
|
|
270
|
+
|
|
271
|
+
Filters can also be specified using `filters` argument of transitions that inherit from `TemplateTransition` (e.g. `GBGATransition`, `SMIRKSTransition`, `ChatGPTTransition`) to directly exclude molecules from child nodes.
|
|
272
|
+
|
|
273
|
+
</details>
|
|
274
|
+
|
|
275
|
+
## Model training
|
|
276
|
+
- **RNN (GRU) training** (example): `chemtsv3-train -c config/training/train_rnn_smiles.yaml`
|
|
277
|
+
- **Transformer (GPT-2) training** (example): `chemtsv3-train -c config/training/train_gpt2.yaml`
|
|
278
|
+
Change `dataset_path` in YAML to train on an arbitrary dataset (1 sentence per line).
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Example (RNN): chemtsv3 -c config/mcts/example.yaml
|
|
2
|
+
# Example (Chain): chemtsv3 -c config/mcts/example_chain_1.yaml
|
|
3
|
+
# Example (Load): chemtsv3 -l generation_results/~~~ --max_generations 100
|
|
4
|
+
|
|
5
|
+
# Path setup / Imports
|
|
6
|
+
import faulthandler
|
|
7
|
+
# import sys
|
|
8
|
+
# import os
|
|
9
|
+
# repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
|
|
10
|
+
# if repo_root not in sys.path:
|
|
11
|
+
# sys.path.insert(0, repo_root)
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
from chemtsv3.generator import Generator
|
|
15
|
+
from chemtsv3.utils import conf_from_yaml, generator_from_conf
|
|
16
|
+
|
|
17
|
+
def main():
|
|
18
|
+
parser = argparse.ArgumentParser()
|
|
19
|
+
parser.add_argument("-c", "--yaml_path", type=str, help="Path to the config file (.yaml)")
|
|
20
|
+
parser.add_argument("-l", "--load_dir", type=str, help="Path to the save directory (contains config.yaml and save.gtr)")
|
|
21
|
+
|
|
22
|
+
parser.add_argument("--max_generations", type=int, help="Only used when loading the generator from the save.")
|
|
23
|
+
parser.add_argument("-t", "--time_limit", type=int, help="Only used when loading the generator from the save.")
|
|
24
|
+
|
|
25
|
+
args = parser.parse_args()
|
|
26
|
+
|
|
27
|
+
yaml_path = args.yaml_path
|
|
28
|
+
load_dir = args.load_dir
|
|
29
|
+
|
|
30
|
+
if yaml_path is None and load_dir is None:
|
|
31
|
+
raise ValueError("Specify either 'yaml_path' (-c) or 'load_dir' (-l).")
|
|
32
|
+
elif yaml_path is not None and load_dir is None:
|
|
33
|
+
conf = conf_from_yaml(yaml_path)
|
|
34
|
+
generator = generator_from_conf(conf)
|
|
35
|
+
while(yaml_path):
|
|
36
|
+
generator.generate(time_limit=conf.get("time_limit"), max_generations=conf.get("max_generations"))
|
|
37
|
+
if not "next_yaml_path" in conf:
|
|
38
|
+
yaml_path = None
|
|
39
|
+
plot_args = conf.get("plot_args", {})
|
|
40
|
+
if not "save_only" in plot_args:
|
|
41
|
+
plot_args["save_only"] = True
|
|
42
|
+
generator.plot(**plot_args)
|
|
43
|
+
generator.analyze()
|
|
44
|
+
else:
|
|
45
|
+
n_top_keys_to_pass=conf.get("n_keys_to_pass", 3)
|
|
46
|
+
yaml_path = conf["next_yaml_path"]
|
|
47
|
+
conf = conf_from_yaml(yaml_path)
|
|
48
|
+
new_generator = generator_from_conf(conf, predecessor=generator, n_top_keys_to_pass=n_top_keys_to_pass)
|
|
49
|
+
generator = new_generator
|
|
50
|
+
|
|
51
|
+
elif yaml_path is None and load_dir is not None:
|
|
52
|
+
generator = Generator.load_dir(load_dir)
|
|
53
|
+
max_generations = args.max_generations
|
|
54
|
+
time_limit = args.time_limit
|
|
55
|
+
generator.generate(max_generations=max_generations, time_limit=time_limit)
|
|
56
|
+
generator.analyze()
|
|
57
|
+
plot_args = generator.yaml_copy.get("plot_args", {})
|
|
58
|
+
if not "save_only" in plot_args:
|
|
59
|
+
plot_args["save_only"] = True
|
|
60
|
+
generator.plot(**plot_args)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError("Specify one of 'yaml_path' (-c) or 'load_dir' (-l), not both.")
|
|
63
|
+
|
|
64
|
+
if __name__ == "__main__":
|
|
65
|
+
faulthandler.enable()
|
|
66
|
+
main()
|