gtsr 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ """Backward-compatible source-tree import for GTsRunner."""
2
+
3
+ from gtsr.runner import GTsRunner
4
+
5
+ __all__ = ["GTsRunner"]
gtsr-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 coollkr
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
gtsr-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,129 @@
1
+ Metadata-Version: 2.1
2
+ Name: gtsr
3
+ Version: 0.0.1
4
+ Summary: Graph neural network tool for solvent removal from MOF structures
5
+ Author: Xiao-Yan Li Group
6
+ License: MIT
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.9
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Requires-Dist: ase>=3.19
17
+ Requires-Dist: numpy>=1.21
18
+ Requires-Dist: pymatgen>=2018.6.11
19
+ Requires-Dist: scikit-learn>=1.0
20
+ Requires-Dist: torch>=1.12
21
+ Requires-Dist: molSimplify==1.8.0
22
+ Requires-Dist: rdkit
23
+ Requires-Dist: networkx
24
+
25
+ # GTsR
26
+
27
+ <div align="center">
28
+ <img src="https://raw.githubusercontent.com/Xiao-Yan-Li-group/GTsR/main/webapp/imgs/gtsr_logo.png" alt="GTsR logo" width="500"/>
29
+ </div>
30
+
31
+ [![Requires Python 3.10](https://img.shields.io/badge/Python-3.9-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
32
+ [![MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
33
+
34
+ **GTsR (GNN Tool for Solvent Removal)** is a tool for solvent identification, solvent removal, and activation-stability prediction in metal-organic frameworks (MOFs).
35
+
36
+ GTsR uses graph neural networks to classify atoms in CIF structures and generate solvent-free framework CIF files. It also provides a random forest model that predicts the activation stability of cleaned MOFs using structural, pore, and RAC descriptors.
37
+
38
+ ## Models
39
+
40
+ | `checkpoint` | Model file | Purpose |
41
+ | --- | --- | --- |
42
+ | `free` (default) | `ckpt/free_best.pth` | Remove free solvent |
43
+ | `all` | `ckpt/all_best.pth` | Remove all solvent |
44
+ | `stability` | `ckpt/stability_best.pkl` | Predict activation stability |
45
+
46
+ The `free` and `all` checkpoints are atom-level GNN classifiers. The `stability` checkpoint is a random forest model bundled with its missing-value imputer.
47
+
48
+ ## Installation
49
+
50
+ ```bash
51
+ git clone https://github.com/coollkr/GTsR.git
52
+ cd GTsR
53
+ pip install -e .
54
+ ```
55
+
56
+ ## Usage
57
+
58
+ ### Solvent Removal
59
+
60
+ ```python
61
+ from gtsr import GTsRunner
62
+
63
+ runner = GTsRunner(checkpoint="free") ### for free solvent removal
64
+ runner = GTsRunner(checkpoint="all") ### for all solvent removal
65
+ runner = GTsRunner(checkpoint="path/to/ckpt.pth", device="cpu") #### use your model
66
+ result = runner.clean(
67
+ cif="input.cif",
68
+ output="prediction",
69
+ threshold=0.5,
70
+ )
71
+ ```
72
+
73
+ #### `clean()` Result
74
+
75
+ `clean()` returns a dictionary containing the following fields:
76
+
77
+ | Field | Description |
78
+ | --- | --- |
79
+ | `input` | Absolute path to the input CIF |
80
+ | `output` | Output directory |
81
+ | `framework` | Path to the cleaned framework CIF |
82
+ | `solvent` | Path to the solvent CIF, or `None` if no file was generated |
83
+ | `checkpoint` | Path to the checkpoint used for prediction |
84
+ | `task` | Task name stored in the checkpoint |
85
+ | `threshold` | Atom-classification threshold |
86
+ | `num_atoms` | Total number of atoms |
87
+ | `num_framework_atoms` | Number of framework atoms |
88
+ | `num_solvent_atoms` | Number of solvent atoms |
89
+ | `probabilities` | Solvent probability for each atom |
90
+ | `labels` | Predicted class label for each atom |
91
+ | `solvent_smiles` | SMILES strings of identified solvents |
92
+
93
+ ### Predict Activation Stability
94
+
95
+ ```python
96
+ from gtsr import GTsRunner
97
+
98
+ runner = GTsRunner(checkpoint="stability")
99
+ score = runner.stability(cif="cleaned_framework.cif")
100
+
101
+ if score == 1:
102
+ print("The cleaned structure is stable.")
103
+ else:
104
+ print("The cleaned structure is not stable.")
105
+ ```
106
+
107
+ ## Web Interface
108
+
109
+ [Host on Streamlit](https://xiao-yan-li-group.streamlit.app/GTsR)
110
+ or in your location
111
+ ```bash
112
+ streamlit run webapp/Home.py
113
+ ```
114
+
115
+ ## Citation
116
+
117
+ Update the following entry when the associated publication becomes available:
118
+
119
+ ```bibtex
120
+ @article{gtsr-xyl-group,
121
+ title = {GTSR: A GNN Based Tool for Solvent Removal from MOF with Stability Check},
122
+ author = {Liang, Kairui and Zhao, Guobin and Li, Xiao-Yan},
123
+ year = {2026}
124
+ }
125
+ ```
126
+
127
+ ## License
128
+
129
+ The repository's [`LICENSE`](LICENSE) file currently uses the MIT License.
gtsr-0.0.1/README.md ADDED
@@ -0,0 +1,105 @@
1
+ # GTsR
2
+
3
+ <div align="center">
4
+ <img src="https://raw.githubusercontent.com/Xiao-Yan-Li-group/GTsR/main/webapp/imgs/gtsr_logo.png" alt="GTsR logo" width="500"/>
5
+ </div>
6
+
7
+ [![Requires Python 3.10](https://img.shields.io/badge/Python-3.9-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
8
+ [![MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
9
+
10
+ **GTsR (GNN Tool for Solvent Removal)** is a tool for solvent identification, solvent removal, and activation-stability prediction in metal-organic frameworks (MOFs).
11
+
12
+ GTsR uses graph neural networks to classify atoms in CIF structures and generate solvent-free framework CIF files. It also provides a random forest model that predicts the activation stability of cleaned MOFs using structural, pore, and RAC descriptors.
13
+
14
+ ## Models
15
+
16
+ | `checkpoint` | Model file | Purpose |
17
+ | --- | --- | --- |
18
+ | `free` (default) | `ckpt/free_best.pth` | Remove free solvent |
19
+ | `all` | `ckpt/all_best.pth` | Remove all solvent |
20
+ | `stability` | `ckpt/stability_best.pkl` | Predict activation stability |
21
+
22
+ The `free` and `all` checkpoints are atom-level GNN classifiers. The `stability` checkpoint is a random forest model bundled with its missing-value imputer.
23
+
24
+ ## Installation
25
+
26
+ ```bash
27
+ git clone https://github.com/coollkr/GTsR.git
28
+ cd GTsR
29
+ pip install -e .
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ ### Solvent Removal
35
+
36
+ ```python
37
+ from gtsr import GTsRunner
38
+
39
+ runner = GTsRunner(checkpoint="free") ### for free solvent removal
40
+ runner = GTsRunner(checkpoint="all") ### for all solvent removal
41
+ runner = GTsRunner(checkpoint="path/to/ckpt.pth", device="cpu") #### use your model
42
+ result = runner.clean(
43
+ cif="input.cif",
44
+ output="prediction",
45
+ threshold=0.5,
46
+ )
47
+ ```
48
+
49
+ #### `clean()` Result
50
+
51
+ `clean()` returns a dictionary containing the following fields:
52
+
53
+ | Field | Description |
54
+ | --- | --- |
55
+ | `input` | Absolute path to the input CIF |
56
+ | `output` | Output directory |
57
+ | `framework` | Path to the cleaned framework CIF |
58
+ | `solvent` | Path to the solvent CIF, or `None` if no file was generated |
59
+ | `checkpoint` | Path to the checkpoint used for prediction |
60
+ | `task` | Task name stored in the checkpoint |
61
+ | `threshold` | Atom-classification threshold |
62
+ | `num_atoms` | Total number of atoms |
63
+ | `num_framework_atoms` | Number of framework atoms |
64
+ | `num_solvent_atoms` | Number of solvent atoms |
65
+ | `probabilities` | Solvent probability for each atom |
66
+ | `labels` | Predicted class label for each atom |
67
+ | `solvent_smiles` | SMILES strings of identified solvents |
68
+
69
+ ### Predict Activation Stability
70
+
71
+ ```python
72
+ from gtsr import GTsRunner
73
+
74
+ runner = GTsRunner(checkpoint="stability")
75
+ score = runner.stability(cif="cleaned_framework.cif")
76
+
77
+ if score == 1:
78
+ print("The cleaned structure is stable.")
79
+ else:
80
+ print("The cleaned structure is not stable.")
81
+ ```
82
+
83
+ ## Web Interface
84
+
85
+ [Host on Streamlit](https://xiao-yan-li-group.streamlit.app/GTsR)
86
+ or in your location
87
+ ```bash
88
+ streamlit run webapp/Home.py
89
+ ```
90
+
91
+ ## Citation
92
+
93
+ Update the following entry when the associated publication becomes available:
94
+
95
+ ```bibtex
96
+ @article{gtsr-xyl-group,
97
+ title = {GTSR: A GNN Based Tool for Solvent Removal from MOF with Stability Check},
98
+ author = {Liang, Kairui and Zhao, Guobin and Li, Xiao-Yan},
99
+ year = {2026}
100
+ }
101
+ ```
102
+
103
+ ## License
104
+
105
+ The repository's [`LICENSE`](LICENSE) file currently uses the MIT License.
@@ -0,0 +1 @@
1
+ """Bundled GTsR model checkpoints."""
Binary file
Binary file
Binary file
@@ -0,0 +1,5 @@
1
+ """GTsR solvent-removal prediction API."""
2
+
3
+ from .runner import GTsRunner
4
+
5
+ __all__ = ["GTsRunner"]
@@ -0,0 +1,317 @@
1
+ from __future__ import annotations
2
+
3
+ import pickle
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ from .src.GCN import SolventAtomClassifier
12
+ from .src.cif_utils import (
13
+ PoreDiameter,
14
+ PoreVolume,
15
+ RACs,
16
+ cif2graph,
17
+ cif2pos,
18
+ flatten_cell,
19
+ flatten_rac,
20
+ get_cell,
21
+ get_sol_smi,
22
+ label2cif,
23
+ n_atom,
24
+ convert2pymatgen
25
+ )
26
+ from .src.data import GaussianDistance
27
+ from .src.utils import load_checkpoint
28
+ except ImportError:
29
+ from src.GCN import SolventAtomClassifier
30
+ from src.cif_utils import (
31
+ PoreDiameter,
32
+ PoreVolume,
33
+ RACs,
34
+ cif2graph,
35
+ cif2pos,
36
+ flatten_cell,
37
+ flatten_rac,
38
+ get_cell,
39
+ get_sol_smi,
40
+ label2cif,
41
+ n_atom,
42
+ convert2pymatgen
43
+ )
44
+ from src.data import GaussianDistance
45
+ from src.utils import load_checkpoint
46
+
47
+
48
+ PACKAGE_DIR = Path(__file__).resolve().parent
49
+
50
+
51
+ def _bundled_model(filename: str) -> Path:
52
+ candidates = (
53
+ PACKAGE_DIR / "ckpt" / filename,
54
+ PACKAGE_DIR.parent / "ckpt" / filename,
55
+ )
56
+ return next((path for path in candidates if path.is_file()), candidates[0])
57
+
58
+
59
+ def _bundled_checkpoint(name: str) -> Path:
60
+ return _bundled_model(f"{name}_best.pth")
61
+
62
+
63
+ CHECKPOINTS = {
64
+ "free": _bundled_checkpoint("free"),
65
+ "all": _bundled_checkpoint("all"),
66
+ }
67
+ DEFAULT_CHECKPOINT = CHECKPOINTS["free"]
68
+ STABILITY_MODEL = _bundled_model("stability_best.pkl")
69
+ RAC_FEATURE_NAMES = tuple(
70
+ f"{prefix}-{property_name}-{depth}"
71
+ for prefix, property_names in (
72
+ ("f-sbu", ("chi", "Z", "I", "T", "S")),
73
+ ("mc", ("chi", "Z", "I", "T", "S")),
74
+ ("D_mc", ("chi", "Z", "I", "T", "S")),
75
+ ("f-link", ("chi", "Z", "I", "T", "S")),
76
+ ("lc", ("chi", "Z", "I", "T", "S", "alpha")),
77
+ ("D_lc", ("chi", "Z", "I", "T", "S", "alpha")),
78
+ ("func", ("chi", "Z", "I", "T", "S", "alpha")),
79
+ ("D_func", ("chi", "Z", "I", "T", "S", "alpha")),
80
+ )
81
+ for property_name in property_names
82
+ for depth in range(4)
83
+ )
84
+
85
+
86
+ class GTsRunner:
87
+
88
+ def __init__(
89
+ self,
90
+ checkpoint: str | Path = "",
91
+ device: str | torch.device | None = None,
92
+ ) -> None:
93
+ checkpoint_name = str(checkpoint).strip().lower()
94
+ self.device = self._resolve_device(device)
95
+ self.stability_model = None
96
+ self.stability_imputer = None
97
+
98
+ if checkpoint_name == "stability":
99
+ self.checkpoint_path = self._resolve_stability_model()
100
+ self._load_stability_model()
101
+ self.checkpoint = None
102
+ self.task = "stability"
103
+ return
104
+
105
+ self.checkpoint_path = self._resolve_checkpoint(checkpoint)
106
+ self.checkpoint = load_checkpoint(self.checkpoint_path, device=self.device)
107
+
108
+ model_config = self.checkpoint.get("model_config")
109
+ if not isinstance(model_config, dict):
110
+ raise ValueError(
111
+ f"Checkpoint does not contain a valid model_config: {self.checkpoint_path}"
112
+ )
113
+
114
+ self.model = SolventAtomClassifier(**model_config).to(self.device)
115
+ self.model.load_state_dict(self.checkpoint["state_dict"])
116
+ self.model.eval()
117
+
118
+ self.radius = float(self.checkpoint.get("radius", 8.0))
119
+ self.dmin = float(self.checkpoint.get("dmin", 0.0))
120
+ self.step = float(self.checkpoint.get("step", 0.2))
121
+ self.default_threshold = float(self.checkpoint.get("threshold", 0.5))
122
+ self.task = str(self.checkpoint.get("task", "unknown"))
123
+ self.max_atomic_number = 118
124
+ self.gdf = GaussianDistance(
125
+ dmin=self.dmin,
126
+ dmax=self.radius,
127
+ step=self.step,
128
+ )
129
+
130
+ def clean(
131
+ self,
132
+ cif: str | Path = "",
133
+ output: str | Path = "",
134
+ threshold: float | None = None,
135
+ ) -> dict[str, Any]:
136
+
137
+ convert2pymatgen(cif)
138
+
139
+ if self.task == "stability":
140
+ raise RuntimeError(
141
+ "clean() requires a GNN checkpoint; initialize GTsRunner with "
142
+ "checkpoint='free' or checkpoint='all'"
143
+ )
144
+
145
+ cif_path = self._resolve_cif(cif)
146
+ output_dir = self._resolve_output(cif_path, output)
147
+ cutoff = self.default_threshold if threshold is None else float(threshold)
148
+ if not 0.0 <= cutoff <= 1.0:
149
+ raise ValueError(f"threshold must be between 0 and 1, got {cutoff}")
150
+
151
+ tensors = self._build_tensors(cif_path)
152
+ with torch.inference_mode():
153
+ probabilities = torch.sigmoid(self.model(*tensors)).cpu().numpy()
154
+
155
+ labels = (probabilities >= cutoff).astype(np.int64)
156
+ label2cif(cif_path, labels, str(output_dir))
157
+
158
+ stem = cif_path.stem
159
+ framework_path = output_dir / f"{stem}_gtsr.cif"
160
+ solvent_path = output_dir / f"{stem}_sol.cif"
161
+ try:
162
+ sol_smis = get_sol_smi(solvent_path)
163
+ except:
164
+ sol_smis = None
165
+ return {
166
+ "input": str(cif_path),
167
+ "output": str(output_dir),
168
+ "framework": str(framework_path),
169
+ "solvent": str(solvent_path) if solvent_path.exists() else None,
170
+ "checkpoint": str(self.checkpoint_path),
171
+ "task": self.task,
172
+ "threshold": cutoff,
173
+ "num_atoms": int(labels.size),
174
+ "num_framework_atoms": int((labels == 0).sum()),
175
+ "num_solvent_atoms": int((labels == 1).sum()),
176
+ "probabilities": probabilities.tolist(),
177
+ "labels": labels.tolist(),
178
+ "solvent_smiles": sol_smis
179
+ }
180
+
181
+ def _build_tensors(self, cif_path: Path) -> tuple[torch.Tensor, ...]:
182
+ graph = cif2graph(cif_path, radius=self.radius)
183
+ positions = np.asarray(cif2pos(cif_path), dtype=np.float32)
184
+ numbers = np.asarray(graph["numbers"], dtype=np.int64)
185
+
186
+ if numbers.size == 0:
187
+ raise ValueError(f"CIF contains no atoms: {cif_path}")
188
+ if numbers.min() < 1 or numbers.max() > self.max_atomic_number:
189
+ raise ValueError(
190
+ f"CIF contains an unsupported atomic number; supported range is "
191
+ f"1-{self.max_atomic_number}"
192
+ )
193
+ if len(positions) != len(numbers):
194
+ raise ValueError(
195
+ f"Position/atom mismatch in {cif_path}: "
196
+ f"{len(positions)} positions for {len(numbers)} atoms"
197
+ )
198
+
199
+ atom_features = np.eye(
200
+ self.max_atomic_number + 1,
201
+ dtype=np.float32,
202
+ )[numbers]
203
+ atom_features = np.concatenate([atom_features, positions], axis=1)
204
+
205
+ distances = np.asarray(graph["dij"], dtype=np.float32)
206
+ neighbor_features = self.gdf.expand(distances).astype(np.float32)
207
+ index1 = np.asarray(graph["index1"], dtype=np.int64)
208
+ index2 = np.asarray(graph["index2"], dtype=np.int64)
209
+ atom_index = np.zeros(len(numbers), dtype=np.int64)
210
+
211
+ tensors = (
212
+ torch.from_numpy(atom_features),
213
+ torch.from_numpy(neighbor_features),
214
+ torch.from_numpy(index1),
215
+ torch.from_numpy(index2),
216
+ torch.from_numpy(atom_index),
217
+ )
218
+ return tuple(tensor.to(self.device) for tensor in tensors)
219
+
220
+ @staticmethod
221
+ def _resolve_device(device: str | torch.device | None) -> torch.device:
222
+ if device is None or str(device).lower() == "auto":
223
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
224
+ resolved = torch.device(device)
225
+ if resolved.type == "cuda" and not torch.cuda.is_available():
226
+ raise RuntimeError("CUDA was requested but is not available")
227
+ return resolved
228
+
229
+ @staticmethod
230
+ def _resolve_checkpoint(checkpoint: str | Path) -> Path:
231
+ checkpoint_name = str(checkpoint).strip().lower()
232
+ if not checkpoint_name:
233
+ path = DEFAULT_CHECKPOINT
234
+ elif checkpoint_name in CHECKPOINTS:
235
+ path = CHECKPOINTS[checkpoint_name]
236
+ else:
237
+ path = Path(checkpoint).expanduser()
238
+ path = path.resolve()
239
+ if not path.is_file():
240
+ raise FileNotFoundError(f"Checkpoint not found: {path}")
241
+ return path
242
+
243
+ @staticmethod
244
+ def _resolve_stability_model() -> Path:
245
+ path = STABILITY_MODEL.resolve()
246
+ if not path.is_file():
247
+ raise FileNotFoundError(f"Stability model not found: {path}")
248
+ return path
249
+
250
+ def _load_stability_model(self) -> None:
251
+ model_path = self._resolve_stability_model()
252
+ with model_path.open("rb") as model_file:
253
+ saved_model = pickle.load(model_file)
254
+
255
+ if isinstance(saved_model, dict):
256
+ self.stability_model = saved_model["model"]
257
+ self.stability_imputer = saved_model.get("imputer")
258
+ else:
259
+ self.stability_model = saved_model
260
+ self.stability_imputer = None
261
+
262
+ self._make_stability_model_compatible()
263
+
264
+ def _make_stability_model_compatible(self) -> None:
265
+ """Fill attributes absent from models saved by older scikit-learn versions."""
266
+ estimators = [self.stability_model]
267
+ estimators.extend(getattr(self.stability_model, "estimators_", []))
268
+ for estimator in estimators:
269
+ if estimator is not None and not hasattr(estimator, "monotonic_cst"):
270
+ estimator.monotonic_cst = None
271
+
272
+ @staticmethod
273
+ def _resolve_cif(cif: str | Path) -> Path:
274
+ if not cif:
275
+ raise ValueError("cif must be a path to an input CIF file")
276
+ path = Path(cif).expanduser().resolve()
277
+ if not path.is_file():
278
+ raise FileNotFoundError(f"CIF not found: {path}")
279
+ return path
280
+
281
+ @staticmethod
282
+ def _resolve_output(cif_path: Path, output: str | Path) -> Path:
283
+ path = (
284
+ Path(output).expanduser()
285
+ if output
286
+ else cif_path.parent / f"{cif_path.stem}_gtsr"
287
+ )
288
+ path = path.resolve()
289
+ path.mkdir(parents=True, exist_ok=True)
290
+ return path
291
+
292
+ def stability(self, cif: str | Path):
293
+ cif_path = self._resolve_cif(cif)
294
+ cif_filename = str(cif_path)
295
+ cell = flatten_cell(get_cell(cif_filename))
296
+ pore_diameter = PoreDiameter(cif_filename)
297
+ pore_volume = PoreVolume(cif_filename)
298
+ rac = flatten_rac(RACs(cif_filename))
299
+
300
+ features = [
301
+ n_atom(cif_filename),
302
+ *cell.values(),
303
+ pore_diameter["Di"],
304
+ pore_diameter["Df"],
305
+ pore_diameter["Dif"],
306
+ pore_volume["Density"],
307
+ pore_volume["VF"],
308
+ *(rac.get(name, np.nan) for name in RAC_FEATURE_NAMES),
309
+ ]
310
+ feature_batch = np.asarray([features], dtype=np.float64)
311
+
312
+ if self.stability_model is None:
313
+ self._load_stability_model()
314
+ if self.stability_imputer is not None:
315
+ feature_batch = self.stability_imputer.transform(feature_batch)
316
+
317
+ return self.stability_model.predict(feature_batch)[0]