gtsr 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gtsr-0.0.1/GTsRunner.py +5 -0
- gtsr-0.0.1/LICENSE +21 -0
- gtsr-0.0.1/PKG-INFO +129 -0
- gtsr-0.0.1/README.md +105 -0
- gtsr-0.0.1/ckpt/__init__.py +1 -0
- gtsr-0.0.1/ckpt/all_best.pth +0 -0
- gtsr-0.0.1/ckpt/free_best.pth +0 -0
- gtsr-0.0.1/ckpt/stability_best.pkl +0 -0
- gtsr-0.0.1/gtsr/__init__.py +5 -0
- gtsr-0.0.1/gtsr/runner.py +317 -0
- gtsr-0.0.1/gtsr.egg-info/PKG-INFO +129 -0
- gtsr-0.0.1/gtsr.egg-info/SOURCES.txt +20 -0
- gtsr-0.0.1/gtsr.egg-info/dependency_links.txt +1 -0
- gtsr-0.0.1/gtsr.egg-info/requires.txt +8 -0
- gtsr-0.0.1/gtsr.egg-info/top_level.txt +2 -0
- gtsr-0.0.1/setup.cfg +4 -0
- gtsr-0.0.1/setup.py +51 -0
- gtsr-0.0.1/src/GCN.py +141 -0
- gtsr-0.0.1/src/__init__.py +2 -0
- gtsr-0.0.1/src/cif_utils.py +368 -0
- gtsr-0.0.1/src/data.py +149 -0
- gtsr-0.0.1/src/utils.py +58 -0
gtsr-0.0.1/GTsRunner.py
ADDED
gtsr-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 coollkr
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
gtsr-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: gtsr
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Graph neural network tool for solvent removal from MOF structures
|
|
5
|
+
Author: Xiao-Yan Li Group
|
|
6
|
+
License: MIT
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
13
|
+
Requires-Python: >=3.9
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: ase>=3.19
|
|
17
|
+
Requires-Dist: numpy>=1.21
|
|
18
|
+
Requires-Dist: pymatgen>=2018.6.11
|
|
19
|
+
Requires-Dist: scikit-learn>=1.0
|
|
20
|
+
Requires-Dist: torch>=1.12
|
|
21
|
+
Requires-Dist: molSimplify==1.8.0
|
|
22
|
+
Requires-Dist: rdkit
|
|
23
|
+
Requires-Dist: networkx
|
|
24
|
+
|
|
25
|
+
# GTsR
|
|
26
|
+
|
|
27
|
+
<div align="center">
|
|
28
|
+
<img src="https://raw.githubusercontent.com/Xiao-Yan-Li-group/GTsR/main/webapp/imgs/gtsr_logo.png" alt="GTsR logo" width="500"/>
|
|
29
|
+
</div>
|
|
30
|
+
|
|
31
|
+
[](https://python.org/downloads)
|
|
32
|
+
[](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
|
|
33
|
+
|
|
34
|
+
**GTsR (GNN Tool for Solvent Removal)** is a tool for solvent identification, solvent removal, and activation-stability prediction in metal-organic frameworks (MOFs).
|
|
35
|
+
|
|
36
|
+
GTsR uses graph neural networks to classify atoms in CIF structures and generate solvent-free framework CIF files. It also provides a random forest model that predicts the activation stability of cleaned MOFs using structural, pore, and RAC descriptors.
|
|
37
|
+
|
|
38
|
+
## Models
|
|
39
|
+
|
|
40
|
+
| `checkpoint` | Model file | Purpose |
|
|
41
|
+
| --- | --- | --- |
|
|
42
|
+
| `free` (default) | `ckpt/free_best.pth` | Remove free solvent |
|
|
43
|
+
| `all` | `ckpt/all_best.pth` | Remove all solvent |
|
|
44
|
+
| `stability` | `ckpt/stability_best.pkl` | Predict activation stability |
|
|
45
|
+
|
|
46
|
+
The `free` and `all` checkpoints are atom-level GNN classifiers. The `stability` checkpoint is a random forest model bundled with its missing-value imputer.
|
|
47
|
+
|
|
48
|
+
## Installation
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
git clone https://github.com/coollkr/GTsR.git
|
|
52
|
+
cd GTsR
|
|
53
|
+
pip install -e .
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Usage
|
|
57
|
+
|
|
58
|
+
### Solvent Removal
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
from gtsr import GTsRunner
|
|
62
|
+
|
|
63
|
+
runner = GTsRunner(checkpoint="free") ### for free solvent removal
|
|
64
|
+
runner = GTsRunner(checkpoint="all") ### for all solvent removal
|
|
65
|
+
runner = GTsRunner(checkpoint="path/to/ckpt.pth", device="cpu") #### use your model
|
|
66
|
+
result = runner.clean(
|
|
67
|
+
cif="input.cif",
|
|
68
|
+
output="prediction",
|
|
69
|
+
threshold=0.5,
|
|
70
|
+
)
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
#### `clean()` Result
|
|
74
|
+
|
|
75
|
+
`clean()` returns a dictionary containing the following fields:
|
|
76
|
+
|
|
77
|
+
| Field | Description |
|
|
78
|
+
| --- | --- |
|
|
79
|
+
| `input` | Absolute path to the input CIF |
|
|
80
|
+
| `output` | Output directory |
|
|
81
|
+
| `framework` | Path to the cleaned framework CIF |
|
|
82
|
+
| `solvent` | Path to the solvent CIF, or `None` if no file was generated |
|
|
83
|
+
| `checkpoint` | Path to the checkpoint used for prediction |
|
|
84
|
+
| `task` | Task name stored in the checkpoint |
|
|
85
|
+
| `threshold` | Atom-classification threshold |
|
|
86
|
+
| `num_atoms` | Total number of atoms |
|
|
87
|
+
| `num_framework_atoms` | Number of framework atoms |
|
|
88
|
+
| `num_solvent_atoms` | Number of solvent atoms |
|
|
89
|
+
| `probabilities` | Solvent probability for each atom |
|
|
90
|
+
| `labels` | Predicted class label for each atom |
|
|
91
|
+
| `solvent_smiles` | SMILES strings of identified solvents |
|
|
92
|
+
|
|
93
|
+
### Predict Activation Stability
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
from gtsr import GTsRunner
|
|
97
|
+
|
|
98
|
+
runner = GTsRunner(checkpoint="stability")
|
|
99
|
+
score = runner.stability(cif="cleaned_framework.cif")
|
|
100
|
+
|
|
101
|
+
if score == 1:
|
|
102
|
+
print("The cleaned structure is stable.")
|
|
103
|
+
else:
|
|
104
|
+
print("The cleaned structure is not stable.")
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
## Web Interface
|
|
108
|
+
|
|
109
|
+
[Host on Streamlit](https://xiao-yan-li-group.streamlit.app/GTsR)
|
|
110
|
+
or in your location
|
|
111
|
+
```bash
|
|
112
|
+
streamlit run webapp/Home.py
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
## Citation
|
|
116
|
+
|
|
117
|
+
Update the following entry when the associated publication becomes available:
|
|
118
|
+
|
|
119
|
+
```bibtex
|
|
120
|
+
@article{gtsr-xyl-group,
|
|
121
|
+
title = {GTSR: A GNN Based Tool for Solvent Removal from MOF with Stability Check},
|
|
122
|
+
author = {Liang, Kairui and Zhao, Guobin and Li, Xiao-Yan},
|
|
123
|
+
year = {2026}
|
|
124
|
+
}
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
The repository's [`LICENSE`](LICENSE) file currently uses the MIT License.
|
gtsr-0.0.1/README.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# GTsR
|
|
2
|
+
|
|
3
|
+
<div align="center">
|
|
4
|
+
<img src="https://raw.githubusercontent.com/Xiao-Yan-Li-group/GTsR/main/webapp/imgs/gtsr_logo.png" alt="GTsR logo" width="500"/>
|
|
5
|
+
</div>
|
|
6
|
+
|
|
7
|
+
[](https://python.org/downloads)
|
|
8
|
+
[](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
|
|
9
|
+
|
|
10
|
+
**GTsR (GNN Tool for Solvent Removal)** is a tool for solvent identification, solvent removal, and activation-stability prediction in metal-organic frameworks (MOFs).
|
|
11
|
+
|
|
12
|
+
GTsR uses graph neural networks to classify atoms in CIF structures and generate solvent-free framework CIF files. It also provides a random forest model that predicts the activation stability of cleaned MOFs using structural, pore, and RAC descriptors.
|
|
13
|
+
|
|
14
|
+
## Models
|
|
15
|
+
|
|
16
|
+
| `checkpoint` | Model file | Purpose |
|
|
17
|
+
| --- | --- | --- |
|
|
18
|
+
| `free` (default) | `ckpt/free_best.pth` | Remove free solvent |
|
|
19
|
+
| `all` | `ckpt/all_best.pth` | Remove all solvent |
|
|
20
|
+
| `stability` | `ckpt/stability_best.pkl` | Predict activation stability |
|
|
21
|
+
|
|
22
|
+
The `free` and `all` checkpoints are atom-level GNN classifiers. The `stability` checkpoint is a random forest model bundled with its missing-value imputer.
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
git clone https://github.com/coollkr/GTsR.git
|
|
28
|
+
cd GTsR
|
|
29
|
+
pip install -e .
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Usage
|
|
33
|
+
|
|
34
|
+
### Solvent Removal
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
from gtsr import GTsRunner
|
|
38
|
+
|
|
39
|
+
runner = GTsRunner(checkpoint="free") ### for free solvent removal
|
|
40
|
+
runner = GTsRunner(checkpoint="all") ### for all solvent removal
|
|
41
|
+
runner = GTsRunner(checkpoint="path/to/ckpt.pth", device="cpu") #### use your model
|
|
42
|
+
result = runner.clean(
|
|
43
|
+
cif="input.cif",
|
|
44
|
+
output="prediction",
|
|
45
|
+
threshold=0.5,
|
|
46
|
+
)
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
#### `clean()` Result
|
|
50
|
+
|
|
51
|
+
`clean()` returns a dictionary containing the following fields:
|
|
52
|
+
|
|
53
|
+
| Field | Description |
|
|
54
|
+
| --- | --- |
|
|
55
|
+
| `input` | Absolute path to the input CIF |
|
|
56
|
+
| `output` | Output directory |
|
|
57
|
+
| `framework` | Path to the cleaned framework CIF |
|
|
58
|
+
| `solvent` | Path to the solvent CIF, or `None` if no file was generated |
|
|
59
|
+
| `checkpoint` | Path to the checkpoint used for prediction |
|
|
60
|
+
| `task` | Task name stored in the checkpoint |
|
|
61
|
+
| `threshold` | Atom-classification threshold |
|
|
62
|
+
| `num_atoms` | Total number of atoms |
|
|
63
|
+
| `num_framework_atoms` | Number of framework atoms |
|
|
64
|
+
| `num_solvent_atoms` | Number of solvent atoms |
|
|
65
|
+
| `probabilities` | Solvent probability for each atom |
|
|
66
|
+
| `labels` | Predicted class label for each atom |
|
|
67
|
+
| `solvent_smiles` | SMILES strings of identified solvents |
|
|
68
|
+
|
|
69
|
+
### Predict Activation Stability
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from gtsr import GTsRunner
|
|
73
|
+
|
|
74
|
+
runner = GTsRunner(checkpoint="stability")
|
|
75
|
+
score = runner.stability(cif="cleaned_framework.cif")
|
|
76
|
+
|
|
77
|
+
if score == 1:
|
|
78
|
+
print("The cleaned structure is stable.")
|
|
79
|
+
else:
|
|
80
|
+
print("The cleaned structure is not stable.")
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## Web Interface
|
|
84
|
+
|
|
85
|
+
[Host on Streamlit](https://xiao-yan-li-group.streamlit.app/GTsR)
|
|
86
|
+
or in your location
|
|
87
|
+
```bash
|
|
88
|
+
streamlit run webapp/Home.py
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## Citation
|
|
92
|
+
|
|
93
|
+
Update the following entry when the associated publication becomes available:
|
|
94
|
+
|
|
95
|
+
```bibtex
|
|
96
|
+
@article{gtsr-xyl-group,
|
|
97
|
+
title = {GTSR: A GNN Based Tool for Solvent Removal from MOF with Stability Check},
|
|
98
|
+
author = {Liang, Kairui and Zhao, Guobin and Li, Xiao-Yan},
|
|
99
|
+
year = {2026}
|
|
100
|
+
}
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## License
|
|
104
|
+
|
|
105
|
+
The repository's [`LICENSE`](LICENSE) file currently uses the MIT License.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Bundled GTsR model checkpoints."""
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pickle
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from .src.GCN import SolventAtomClassifier
|
|
12
|
+
from .src.cif_utils import (
|
|
13
|
+
PoreDiameter,
|
|
14
|
+
PoreVolume,
|
|
15
|
+
RACs,
|
|
16
|
+
cif2graph,
|
|
17
|
+
cif2pos,
|
|
18
|
+
flatten_cell,
|
|
19
|
+
flatten_rac,
|
|
20
|
+
get_cell,
|
|
21
|
+
get_sol_smi,
|
|
22
|
+
label2cif,
|
|
23
|
+
n_atom,
|
|
24
|
+
convert2pymatgen
|
|
25
|
+
)
|
|
26
|
+
from .src.data import GaussianDistance
|
|
27
|
+
from .src.utils import load_checkpoint
|
|
28
|
+
except ImportError:
|
|
29
|
+
from src.GCN import SolventAtomClassifier
|
|
30
|
+
from src.cif_utils import (
|
|
31
|
+
PoreDiameter,
|
|
32
|
+
PoreVolume,
|
|
33
|
+
RACs,
|
|
34
|
+
cif2graph,
|
|
35
|
+
cif2pos,
|
|
36
|
+
flatten_cell,
|
|
37
|
+
flatten_rac,
|
|
38
|
+
get_cell,
|
|
39
|
+
get_sol_smi,
|
|
40
|
+
label2cif,
|
|
41
|
+
n_atom,
|
|
42
|
+
convert2pymatgen
|
|
43
|
+
)
|
|
44
|
+
from src.data import GaussianDistance
|
|
45
|
+
from src.utils import load_checkpoint
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
PACKAGE_DIR = Path(__file__).resolve().parent
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _bundled_model(filename: str) -> Path:
|
|
52
|
+
candidates = (
|
|
53
|
+
PACKAGE_DIR / "ckpt" / filename,
|
|
54
|
+
PACKAGE_DIR.parent / "ckpt" / filename,
|
|
55
|
+
)
|
|
56
|
+
return next((path for path in candidates if path.is_file()), candidates[0])
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _bundled_checkpoint(name: str) -> Path:
|
|
60
|
+
return _bundled_model(f"{name}_best.pth")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
CHECKPOINTS = {
|
|
64
|
+
"free": _bundled_checkpoint("free"),
|
|
65
|
+
"all": _bundled_checkpoint("all"),
|
|
66
|
+
}
|
|
67
|
+
DEFAULT_CHECKPOINT = CHECKPOINTS["free"]
|
|
68
|
+
STABILITY_MODEL = _bundled_model("stability_best.pkl")
|
|
69
|
+
RAC_FEATURE_NAMES = tuple(
|
|
70
|
+
f"{prefix}-{property_name}-{depth}"
|
|
71
|
+
for prefix, property_names in (
|
|
72
|
+
("f-sbu", ("chi", "Z", "I", "T", "S")),
|
|
73
|
+
("mc", ("chi", "Z", "I", "T", "S")),
|
|
74
|
+
("D_mc", ("chi", "Z", "I", "T", "S")),
|
|
75
|
+
("f-link", ("chi", "Z", "I", "T", "S")),
|
|
76
|
+
("lc", ("chi", "Z", "I", "T", "S", "alpha")),
|
|
77
|
+
("D_lc", ("chi", "Z", "I", "T", "S", "alpha")),
|
|
78
|
+
("func", ("chi", "Z", "I", "T", "S", "alpha")),
|
|
79
|
+
("D_func", ("chi", "Z", "I", "T", "S", "alpha")),
|
|
80
|
+
)
|
|
81
|
+
for property_name in property_names
|
|
82
|
+
for depth in range(4)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class GTsRunner:
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
checkpoint: str | Path = "",
|
|
91
|
+
device: str | torch.device | None = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
checkpoint_name = str(checkpoint).strip().lower()
|
|
94
|
+
self.device = self._resolve_device(device)
|
|
95
|
+
self.stability_model = None
|
|
96
|
+
self.stability_imputer = None
|
|
97
|
+
|
|
98
|
+
if checkpoint_name == "stability":
|
|
99
|
+
self.checkpoint_path = self._resolve_stability_model()
|
|
100
|
+
self._load_stability_model()
|
|
101
|
+
self.checkpoint = None
|
|
102
|
+
self.task = "stability"
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
self.checkpoint_path = self._resolve_checkpoint(checkpoint)
|
|
106
|
+
self.checkpoint = load_checkpoint(self.checkpoint_path, device=self.device)
|
|
107
|
+
|
|
108
|
+
model_config = self.checkpoint.get("model_config")
|
|
109
|
+
if not isinstance(model_config, dict):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Checkpoint does not contain a valid model_config: {self.checkpoint_path}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.model = SolventAtomClassifier(**model_config).to(self.device)
|
|
115
|
+
self.model.load_state_dict(self.checkpoint["state_dict"])
|
|
116
|
+
self.model.eval()
|
|
117
|
+
|
|
118
|
+
self.radius = float(self.checkpoint.get("radius", 8.0))
|
|
119
|
+
self.dmin = float(self.checkpoint.get("dmin", 0.0))
|
|
120
|
+
self.step = float(self.checkpoint.get("step", 0.2))
|
|
121
|
+
self.default_threshold = float(self.checkpoint.get("threshold", 0.5))
|
|
122
|
+
self.task = str(self.checkpoint.get("task", "unknown"))
|
|
123
|
+
self.max_atomic_number = 118
|
|
124
|
+
self.gdf = GaussianDistance(
|
|
125
|
+
dmin=self.dmin,
|
|
126
|
+
dmax=self.radius,
|
|
127
|
+
step=self.step,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def clean(
|
|
131
|
+
self,
|
|
132
|
+
cif: str | Path = "",
|
|
133
|
+
output: str | Path = "",
|
|
134
|
+
threshold: float | None = None,
|
|
135
|
+
) -> dict[str, Any]:
|
|
136
|
+
|
|
137
|
+
convert2pymatgen(cif)
|
|
138
|
+
|
|
139
|
+
if self.task == "stability":
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"clean() requires a GNN checkpoint; initialize GTsRunner with "
|
|
142
|
+
"checkpoint='free' or checkpoint='all'"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
cif_path = self._resolve_cif(cif)
|
|
146
|
+
output_dir = self._resolve_output(cif_path, output)
|
|
147
|
+
cutoff = self.default_threshold if threshold is None else float(threshold)
|
|
148
|
+
if not 0.0 <= cutoff <= 1.0:
|
|
149
|
+
raise ValueError(f"threshold must be between 0 and 1, got {cutoff}")
|
|
150
|
+
|
|
151
|
+
tensors = self._build_tensors(cif_path)
|
|
152
|
+
with torch.inference_mode():
|
|
153
|
+
probabilities = torch.sigmoid(self.model(*tensors)).cpu().numpy()
|
|
154
|
+
|
|
155
|
+
labels = (probabilities >= cutoff).astype(np.int64)
|
|
156
|
+
label2cif(cif_path, labels, str(output_dir))
|
|
157
|
+
|
|
158
|
+
stem = cif_path.stem
|
|
159
|
+
framework_path = output_dir / f"{stem}_gtsr.cif"
|
|
160
|
+
solvent_path = output_dir / f"{stem}_sol.cif"
|
|
161
|
+
try:
|
|
162
|
+
sol_smis = get_sol_smi(solvent_path)
|
|
163
|
+
except:
|
|
164
|
+
sol_smis = None
|
|
165
|
+
return {
|
|
166
|
+
"input": str(cif_path),
|
|
167
|
+
"output": str(output_dir),
|
|
168
|
+
"framework": str(framework_path),
|
|
169
|
+
"solvent": str(solvent_path) if solvent_path.exists() else None,
|
|
170
|
+
"checkpoint": str(self.checkpoint_path),
|
|
171
|
+
"task": self.task,
|
|
172
|
+
"threshold": cutoff,
|
|
173
|
+
"num_atoms": int(labels.size),
|
|
174
|
+
"num_framework_atoms": int((labels == 0).sum()),
|
|
175
|
+
"num_solvent_atoms": int((labels == 1).sum()),
|
|
176
|
+
"probabilities": probabilities.tolist(),
|
|
177
|
+
"labels": labels.tolist(),
|
|
178
|
+
"solvent_smiles": sol_smis
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def _build_tensors(self, cif_path: Path) -> tuple[torch.Tensor, ...]:
|
|
182
|
+
graph = cif2graph(cif_path, radius=self.radius)
|
|
183
|
+
positions = np.asarray(cif2pos(cif_path), dtype=np.float32)
|
|
184
|
+
numbers = np.asarray(graph["numbers"], dtype=np.int64)
|
|
185
|
+
|
|
186
|
+
if numbers.size == 0:
|
|
187
|
+
raise ValueError(f"CIF contains no atoms: {cif_path}")
|
|
188
|
+
if numbers.min() < 1 or numbers.max() > self.max_atomic_number:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"CIF contains an unsupported atomic number; supported range is "
|
|
191
|
+
f"1-{self.max_atomic_number}"
|
|
192
|
+
)
|
|
193
|
+
if len(positions) != len(numbers):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Position/atom mismatch in {cif_path}: "
|
|
196
|
+
f"{len(positions)} positions for {len(numbers)} atoms"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
atom_features = np.eye(
|
|
200
|
+
self.max_atomic_number + 1,
|
|
201
|
+
dtype=np.float32,
|
|
202
|
+
)[numbers]
|
|
203
|
+
atom_features = np.concatenate([atom_features, positions], axis=1)
|
|
204
|
+
|
|
205
|
+
distances = np.asarray(graph["dij"], dtype=np.float32)
|
|
206
|
+
neighbor_features = self.gdf.expand(distances).astype(np.float32)
|
|
207
|
+
index1 = np.asarray(graph["index1"], dtype=np.int64)
|
|
208
|
+
index2 = np.asarray(graph["index2"], dtype=np.int64)
|
|
209
|
+
atom_index = np.zeros(len(numbers), dtype=np.int64)
|
|
210
|
+
|
|
211
|
+
tensors = (
|
|
212
|
+
torch.from_numpy(atom_features),
|
|
213
|
+
torch.from_numpy(neighbor_features),
|
|
214
|
+
torch.from_numpy(index1),
|
|
215
|
+
torch.from_numpy(index2),
|
|
216
|
+
torch.from_numpy(atom_index),
|
|
217
|
+
)
|
|
218
|
+
return tuple(tensor.to(self.device) for tensor in tensors)
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _resolve_device(device: str | torch.device | None) -> torch.device:
|
|
222
|
+
if device is None or str(device).lower() == "auto":
|
|
223
|
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
224
|
+
resolved = torch.device(device)
|
|
225
|
+
if resolved.type == "cuda" and not torch.cuda.is_available():
|
|
226
|
+
raise RuntimeError("CUDA was requested but is not available")
|
|
227
|
+
return resolved
|
|
228
|
+
|
|
229
|
+
@staticmethod
|
|
230
|
+
def _resolve_checkpoint(checkpoint: str | Path) -> Path:
|
|
231
|
+
checkpoint_name = str(checkpoint).strip().lower()
|
|
232
|
+
if not checkpoint_name:
|
|
233
|
+
path = DEFAULT_CHECKPOINT
|
|
234
|
+
elif checkpoint_name in CHECKPOINTS:
|
|
235
|
+
path = CHECKPOINTS[checkpoint_name]
|
|
236
|
+
else:
|
|
237
|
+
path = Path(checkpoint).expanduser()
|
|
238
|
+
path = path.resolve()
|
|
239
|
+
if not path.is_file():
|
|
240
|
+
raise FileNotFoundError(f"Checkpoint not found: {path}")
|
|
241
|
+
return path
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _resolve_stability_model() -> Path:
|
|
245
|
+
path = STABILITY_MODEL.resolve()
|
|
246
|
+
if not path.is_file():
|
|
247
|
+
raise FileNotFoundError(f"Stability model not found: {path}")
|
|
248
|
+
return path
|
|
249
|
+
|
|
250
|
+
def _load_stability_model(self) -> None:
|
|
251
|
+
model_path = self._resolve_stability_model()
|
|
252
|
+
with model_path.open("rb") as model_file:
|
|
253
|
+
saved_model = pickle.load(model_file)
|
|
254
|
+
|
|
255
|
+
if isinstance(saved_model, dict):
|
|
256
|
+
self.stability_model = saved_model["model"]
|
|
257
|
+
self.stability_imputer = saved_model.get("imputer")
|
|
258
|
+
else:
|
|
259
|
+
self.stability_model = saved_model
|
|
260
|
+
self.stability_imputer = None
|
|
261
|
+
|
|
262
|
+
self._make_stability_model_compatible()
|
|
263
|
+
|
|
264
|
+
def _make_stability_model_compatible(self) -> None:
|
|
265
|
+
"""Fill attributes absent from models saved by older scikit-learn versions."""
|
|
266
|
+
estimators = [self.stability_model]
|
|
267
|
+
estimators.extend(getattr(self.stability_model, "estimators_", []))
|
|
268
|
+
for estimator in estimators:
|
|
269
|
+
if estimator is not None and not hasattr(estimator, "monotonic_cst"):
|
|
270
|
+
estimator.monotonic_cst = None
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _resolve_cif(cif: str | Path) -> Path:
|
|
274
|
+
if not cif:
|
|
275
|
+
raise ValueError("cif must be a path to an input CIF file")
|
|
276
|
+
path = Path(cif).expanduser().resolve()
|
|
277
|
+
if not path.is_file():
|
|
278
|
+
raise FileNotFoundError(f"CIF not found: {path}")
|
|
279
|
+
return path
|
|
280
|
+
|
|
281
|
+
@staticmethod
|
|
282
|
+
def _resolve_output(cif_path: Path, output: str | Path) -> Path:
|
|
283
|
+
path = (
|
|
284
|
+
Path(output).expanduser()
|
|
285
|
+
if output
|
|
286
|
+
else cif_path.parent / f"{cif_path.stem}_gtsr"
|
|
287
|
+
)
|
|
288
|
+
path = path.resolve()
|
|
289
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
290
|
+
return path
|
|
291
|
+
|
|
292
|
+
def stability(self, cif: str | Path):
|
|
293
|
+
cif_path = self._resolve_cif(cif)
|
|
294
|
+
cif_filename = str(cif_path)
|
|
295
|
+
cell = flatten_cell(get_cell(cif_filename))
|
|
296
|
+
pore_diameter = PoreDiameter(cif_filename)
|
|
297
|
+
pore_volume = PoreVolume(cif_filename)
|
|
298
|
+
rac = flatten_rac(RACs(cif_filename))
|
|
299
|
+
|
|
300
|
+
features = [
|
|
301
|
+
n_atom(cif_filename),
|
|
302
|
+
*cell.values(),
|
|
303
|
+
pore_diameter["Di"],
|
|
304
|
+
pore_diameter["Df"],
|
|
305
|
+
pore_diameter["Dif"],
|
|
306
|
+
pore_volume["Density"],
|
|
307
|
+
pore_volume["VF"],
|
|
308
|
+
*(rac.get(name, np.nan) for name in RAC_FEATURE_NAMES),
|
|
309
|
+
]
|
|
310
|
+
feature_batch = np.asarray([features], dtype=np.float64)
|
|
311
|
+
|
|
312
|
+
if self.stability_model is None:
|
|
313
|
+
self._load_stability_model()
|
|
314
|
+
if self.stability_imputer is not None:
|
|
315
|
+
feature_batch = self.stability_imputer.transform(feature_batch)
|
|
316
|
+
|
|
317
|
+
return self.stability_model.predict(feature_batch)[0]
|