pxmeter 0.1.6__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pxmeter-0.1.6/pxmeter.egg-info → pxmeter-1.0.0}/PKG-INFO +3 -2
- {pxmeter-0.1.6 → pxmeter-1.0.0}/README.md +53 -11
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/calc_metric.py +222 -133
- pxmeter-1.0.0/pxmeter/cli.py +365 -0
- pxmeter-1.0.0/pxmeter/configs/run_config.py +188 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/constants.py +169 -133
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/ccd.py +102 -15
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/parser.py +218 -21
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/struct.py +77 -15
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/utils.py +7 -7
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/writer.py +4 -1
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/eval.py +6 -4
- pxmeter-1.0.0/pxmeter/input_builder/constants.py +16 -0
- pxmeter-1.0.0/pxmeter/input_builder/gen_input.py +381 -0
- pxmeter-1.0.0/pxmeter/input_builder/interactive.py +702 -0
- pxmeter-1.0.0/pxmeter/input_builder/model_inputs/alphafold3.py +366 -0
- pxmeter-1.0.0/pxmeter/input_builder/model_inputs/boltz.py +360 -0
- pxmeter-1.0.0/pxmeter/input_builder/model_inputs/protenix.py +559 -0
- pxmeter-1.0.0/pxmeter/input_builder/seq.py +584 -0
- pxmeter-1.0.0/pxmeter/input_builder/utils/__init__.py +0 -0
- pxmeter-1.0.0/pxmeter/input_builder/utils/unstd_res_mapping.py +225 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/mapping.py +40 -27
- pxmeter-1.0.0/pxmeter/metrics/__init__.py +0 -0
- pxmeter-1.0.0/pxmeter/metrics/dockq.py +523 -0
- pxmeter-1.0.0/pxmeter/metrics/lddt_metrics.py +310 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/metrics/rmsd.py +7 -5
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/metrics/rmsd_metrics.py +35 -32
- pxmeter-1.0.0/pxmeter/metrics/stereochemistry/__init__.py +0 -0
- pxmeter-1.0.0/pxmeter/metrics/stereochemistry/check.py +1706 -0
- pxmeter-1.0.0/pxmeter/metrics/stereochemistry/params.py +2324 -0
- pxmeter-1.0.0/pxmeter/permutation/__init__.py +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/permutation/atom.py +9 -4
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/permutation/chain.py +91 -31
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/permutation/residue.py +30 -8
- pxmeter-1.0.0/pxmeter/utils.py +89 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0/pxmeter.egg-info}/PKG-INFO +3 -2
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/SOURCES.txt +15 -2
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/requires.txt +2 -1
- {pxmeter-0.1.6 → pxmeter-1.0.0}/requirements.txt +3 -2
- {pxmeter-0.1.6 → pxmeter-1.0.0}/setup.py +1 -1
- pxmeter-0.1.6/pxmeter/cli.py +0 -182
- pxmeter-0.1.6/pxmeter/configs/data_config.py +0 -116
- pxmeter-0.1.6/pxmeter/configs/run_config.py +0 -41
- pxmeter-0.1.6/pxmeter/metrics/clashes.py +0 -88
- pxmeter-0.1.6/pxmeter/metrics/lddt_metrics.py +0 -248
- pxmeter-0.1.6/pxmeter/utils.py +0 -38
- {pxmeter-0.1.6 → pxmeter-1.0.0}/LICENSE +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/MANIFEST.in +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/__init__.py +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/configs/__init__.py +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/__init__.py +0 -0
- {pxmeter-0.1.6/pxmeter/metrics → pxmeter-1.0.0/pxmeter/input_builder}/__init__.py +0 -0
- {pxmeter-0.1.6/pxmeter/permutation → pxmeter-1.0.0/pxmeter/input_builder/model_inputs}/__init__.py +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/dependency_links.txt +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/entry_points.txt +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/top_level.txt +0 -0
- {pxmeter-0.1.6 → pxmeter-1.0.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pxmeter
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: PXMeter is a comprehensive toolkit for evaluating the quality of structures generated by biomolecular structure prediction models.
|
|
5
5
|
Author: Bytedance Inc.
|
|
6
6
|
Author-email: ai4s-bio@bytedance.com
|
|
@@ -9,7 +9,6 @@ Platform: manylinux1
|
|
|
9
9
|
Requires-Python: >=3.11
|
|
10
10
|
License-File: LICENSE
|
|
11
11
|
Requires-Dist: biotite>=1.2.0
|
|
12
|
-
Requires-Dist: dockq==2.1.3
|
|
13
12
|
Requires-Dist: gemmi==0.7.0
|
|
14
13
|
Requires-Dist: joblib
|
|
15
14
|
Requires-Dist: ml_collections
|
|
@@ -22,6 +21,8 @@ Requires-Dist: scipy
|
|
|
22
21
|
Requires-Dist: tabulate
|
|
23
22
|
Requires-Dist: tqdm
|
|
24
23
|
Requires-Dist: click
|
|
24
|
+
Requires-Dist: pyarrow
|
|
25
|
+
Requires-Dist: PyYAML
|
|
25
26
|
Dynamic: author
|
|
26
27
|
Dynamic: author-email
|
|
27
28
|
Dynamic: license
|
|
@@ -32,7 +32,7 @@ pip install -r requirements.txt
|
|
|
32
32
|
pip install -e .
|
|
33
33
|
```
|
|
34
34
|
|
|
35
|
-
PXMeter
|
|
35
|
+
PXMeter directly uses the Chemical Component Dictionary (CCD) bundled with Biotite. To update the CCD files:
|
|
36
36
|
|
|
37
37
|
```bash
|
|
38
38
|
pxm ccd update
|
|
@@ -48,12 +48,13 @@ pxm -r examples/7rss.cif -m examples/7rss_protenix_pred.cif -o pxm_output.json
|
|
|
48
48
|
**Key Parameters**:
|
|
49
49
|
- `-r` or `--ref_cif`: Path to reference CIF file
|
|
50
50
|
- `-m` or `--model_cif`: Path to model CIF file
|
|
51
|
-
- `-o` or `--
|
|
51
|
+
- `-o` or `--output_json`: Path to save evaluation results (default: "pxm_output.json")
|
|
52
52
|
- `--ref_model`: Specify model number of reference CIF (default: 1)
|
|
53
53
|
- `--ref_assembly_id`: Specify the assembly ID for the reference CIF (default: None; uses the Asymmetric Unit for evaluation)
|
|
54
|
-
-
|
|
54
|
+
- `--ref_altloc`: Specify the alternative location identifier for the reference CIF (default: "first", uses the first alternative location code for each residue).
|
|
55
55
|
- `--chain_id_to_mol_json`: JSON file defining custom ligands, where keys are chain IDs (label_asym_id) and values are the corresponding ligand SMILES strings.
|
|
56
56
|
- `-l` or `--interested_lig_label_asym_id`: Indicate the `label_asym_id` of ligands for metrics like pocket-aligned RMSD. Multiple ligands should be comma-separated.
|
|
57
|
+
- `-C key.path=value`: Override fields in `pxmeter.configs.run_config.RUN_CONFIG` (repeatable; e.g., `-C metric.lddt.eps=1e-4 -C mapping.mapping_ligand=false`).
|
|
57
58
|
|
|
58
59
|
To access the full list of parameters, use the `--help` option.
|
|
59
60
|
|
|
@@ -80,18 +81,59 @@ For detailed descriptions of additional parameters, use the `help()` function:
|
|
|
80
81
|
help(evaluate)
|
|
81
82
|
```
|
|
82
83
|
|
|
84
|
+
If you need to modify the runtime settings defined in
|
|
85
|
+
`pxmeter.configs.run_config.RUN_CONFIG` (equivalent to using `-C` on the command line),
|
|
86
|
+
you may directly update the values in `RUN_CONFIG` and then pass it into the evaluate() function.
|
|
87
|
+
```python
|
|
88
|
+
from pxmeter.configs.run_config import RUN_CONFIG
|
|
89
|
+
|
|
90
|
+
RUN_CONFIG.mapping.res_id_alignments = False
|
|
91
|
+
metric_result = evaluate(
|
|
92
|
+
...,
|
|
93
|
+
run_config=RUN_CONFIG,
|
|
94
|
+
)
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
For a detailed, step-by-step description of the PXMeter runtime evaluation pipeline (mapping, alignment, and metric computation), please refer to the [PXMeter evaluation pipeline details](docs/pxmeter_eval_details.md).
|
|
98
|
+
|
|
99
|
+
For a comprehensive overview of the runtime configuration options, recommended defaults, and advanced usage examples, see the [PXMeter run configuration guide](docs/run_config_details.md).
|
|
100
|
+
|
|
101
|
+
### Optional: Stereochemistry checks
|
|
102
|
+
|
|
103
|
+
Run stereochemistry checks for a single CIF and export a CSV report:
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
pxm stereocheck -c examples/7rss_protenix_pred.cif -o stereochem_report.csv
|
|
107
|
+
```
|
|
108
|
+
**`pxm stereocheck` Parameters**:
|
|
109
|
+
- `-c` or `--cif` (required): Path to the CIF file
|
|
110
|
+
- `-o` or `--output-csv`: Path to the output CSV report (default: `stereochem_report.csv`)
|
|
111
|
+
|
|
112
|
+
|
|
83
113
|
## 📊 Benchmarking
|
|
84
|
-
Refer to [benchmark/README.md](./benchmark/README.md) for evaluation protocols on:
|
|
85
|
-
- RecentPDB dataset
|
|
86
|
-
- PoseBusters V2
|
|
87
114
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
115
|
+
PXMeter offers a reproducible workflow covering both dataset creation and model evaluation.
|
|
116
|
+
|
|
117
|
+
**Note**: The benchmarking workflow (the `benchmark/` directory) is only available in the source repository and is not shipped with the PyPI package. To run benchmarking, please clone the repository first:
|
|
118
|
+
|
|
119
|
+
```bash
|
|
120
|
+
git clone https://github.com/bytedance/PXMeter.git
|
|
121
|
+
cd PXMeter
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
- The **[Benchmark Documentation](docs/benchmark.md)** explains how to run evaluations on model predictions and how the aggregated metrics are computed.
|
|
125
|
+
- The **[Dataset Pipeline Overview](docs/datapipeline.md)** describes the complete construction of the RecentPDB low-homology dataset,
|
|
126
|
+
including filtering, homology scans, clustering, and subset labeling.
|
|
127
|
+
The pipeline also allows users to **rebuild the evaluation dataset from scratch using any custom time window**.
|
|
128
|
+
This makes the benchmark fully flexible and adaptable to different release periods or ongoing updates from the PDB.
|
|
129
|
+
- For details on the dataset used in our paper, please refer to the **[legacy dataset documentation](docs/legacy_dataset_reference.md)**, which describes the dataset version and evaluation code used at the time of the initial release.
|
|
130
|
+
|
|
131
|
+
## ➡️ Preparing input files
|
|
91
132
|
|
|
92
|
-
|
|
93
|
-
Structure Prediction Benchmarks with PXMeter</a>
|
|
133
|
+
When working with structural inputs—e.g., converting mmCIF, AlpahFold3, Protenix, or Boltz formats—you may find the following utility helpful:
|
|
94
134
|
|
|
135
|
+
[pxm gen-input Usage Guide](docs/gen_input.md).
|
|
136
|
+
— a tool for generating and converting model input files via CLI or Python API.
|
|
95
137
|
|
|
96
138
|
|
|
97
139
|
## 💪 Contributing to PXMeter
|
|
@@ -12,20 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import copy
|
|
16
15
|
import dataclasses
|
|
17
|
-
import gzip
|
|
18
16
|
import json
|
|
19
17
|
import logging
|
|
20
18
|
import tempfile
|
|
21
19
|
from pathlib import Path
|
|
22
|
-
from typing import Any
|
|
20
|
+
from typing import Any, Optional, Union
|
|
23
21
|
|
|
24
|
-
import DockQ.parsers as dockq_parsers
|
|
25
22
|
import numpy as np
|
|
26
23
|
import pandas as pd
|
|
27
24
|
from biotite.structure.io import pdb
|
|
28
|
-
from DockQ.DockQ import run_on_all_native_interfaces
|
|
29
25
|
from ml_collections.config_dict import ConfigDict
|
|
30
26
|
from posebusters import PoseBusters
|
|
31
27
|
from rdkit import Chem
|
|
@@ -34,105 +30,18 @@ from pxmeter.configs.run_config import RUN_CONFIG
|
|
|
34
30
|
from pxmeter.constants import IONS, LIGAND
|
|
35
31
|
from pxmeter.data.ccd import get_ccd_mol_from_chain_atom_array
|
|
36
32
|
from pxmeter.data.struct import Structure
|
|
37
|
-
from pxmeter.metrics.
|
|
33
|
+
from pxmeter.metrics.dockq import compute_dockq
|
|
38
34
|
from pxmeter.metrics.lddt_metrics import LDDT
|
|
39
35
|
from pxmeter.metrics.rmsd_metrics import RMSDMetrics
|
|
40
36
|
|
|
41
37
|
logging.getLogger("posebusters").setLevel(logging.ERROR)
|
|
42
38
|
|
|
43
39
|
|
|
44
|
-
def load_PDB(path, chains=None, small_molecule=False, n_model=0):
|
|
45
|
-
"""
|
|
46
|
-
Modified from DockQ.DockQ.load_PDB to avoid ResourceWarning warnings.
|
|
47
|
-
ResourceWarning: Enable tracemalloc to get the object allocation traceback
|
|
48
|
-
DockQ/DockQ.py:660: ResourceWarning: unclosed file
|
|
49
|
-
"""
|
|
50
|
-
if chains is None:
|
|
51
|
-
chains = []
|
|
52
|
-
try:
|
|
53
|
-
pdb_parser = dockq_parsers.PDBParser(QUIET=True)
|
|
54
|
-
with (
|
|
55
|
-
gzip.open(path, "rt") if path.endswith(".gz") else open(path, "rt")
|
|
56
|
-
) as file_obj:
|
|
57
|
-
model = pdb_parser.get_structure(
|
|
58
|
-
"-",
|
|
59
|
-
file_obj,
|
|
60
|
-
chains=chains,
|
|
61
|
-
parse_hetatms=small_molecule,
|
|
62
|
-
model_number=n_model,
|
|
63
|
-
)
|
|
64
|
-
except Exception:
|
|
65
|
-
pdb_parser = dockq_parsers.MMCIFParser(QUIET=True)
|
|
66
|
-
with (
|
|
67
|
-
gzip.open(path, "rt") if path.endswith(".gz") else open(path, "rt")
|
|
68
|
-
) as file_obj:
|
|
69
|
-
model = pdb_parser.get_structure(
|
|
70
|
-
"-",
|
|
71
|
-
file_obj,
|
|
72
|
-
chains=chains,
|
|
73
|
-
parse_hetatms=small_molecule,
|
|
74
|
-
auth_chains=not small_molecule,
|
|
75
|
-
model_number=n_model,
|
|
76
|
-
)
|
|
77
|
-
model.id = path
|
|
78
|
-
return model
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def compute_dockq(
|
|
82
|
-
ref_struct: Structure,
|
|
83
|
-
model_struct: Structure,
|
|
84
|
-
ref_to_model_chain_map: dict[str, str],
|
|
85
|
-
) -> dict[str, dict[str, Any]]:
|
|
86
|
-
"""
|
|
87
|
-
Computes the DockQ score between a reference structure and a model structure.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
ref_struct (Structure): The reference structure.
|
|
91
|
-
model_struct (Structure): The model structure to be evaluated.
|
|
92
|
-
ref_to_model_chain_map (dict[str, str]): A dictionary mapping reference chain IDs to model chain IDs.
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
dict[str, dict[str, Any]]: A dictionary containing the DockQ score and other related metrics.
|
|
96
|
-
"""
|
|
97
|
-
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
98
|
-
tmp_dir = Path(tmp_dir)
|
|
99
|
-
tmp_ref_cif = tmp_dir / "tmp_ref.cif"
|
|
100
|
-
tmp_model_cif = tmp_dir / "tmp_model.cif"
|
|
101
|
-
|
|
102
|
-
# Calculate DockQ using exclusively valid atoms
|
|
103
|
-
# Use uni_chain_id as label_asym_id
|
|
104
|
-
ref_struct.to_cif(tmp_ref_cif, use_uni_chain_id=True)
|
|
105
|
-
model_struct.to_cif(tmp_model_cif, use_uni_chain_id=True)
|
|
106
|
-
|
|
107
|
-
# small_molecule=False means only polymer is considered
|
|
108
|
-
model = load_PDB(str(tmp_model_cif), small_molecule=False)
|
|
109
|
-
native = load_PDB(str(tmp_ref_cif), small_molecule=False)
|
|
110
|
-
|
|
111
|
-
native_chains = [c.id for c in native]
|
|
112
|
-
model_chains = [c.id for c in model]
|
|
113
|
-
|
|
114
|
-
valid_ref_to_model_chain_map = {}
|
|
115
|
-
for k, v in ref_to_model_chain_map.items():
|
|
116
|
-
if (
|
|
117
|
-
k in ref_struct.uni_chain_id
|
|
118
|
-
and k in native_chains
|
|
119
|
-
and v in model_chains
|
|
120
|
-
):
|
|
121
|
-
# some all UNK structure will not be load by load_PDB(), e.g. chain Q in 7q6i
|
|
122
|
-
valid_ref_to_model_chain_map[k] = v
|
|
123
|
-
assert v in model_struct.uni_chain_id
|
|
124
|
-
|
|
125
|
-
dockq_result_dict, _total_dockq = run_on_all_native_interfaces(
|
|
126
|
-
model, native, chain_map=valid_ref_to_model_chain_map
|
|
127
|
-
)
|
|
128
|
-
return dockq_result_dict
|
|
129
|
-
|
|
130
|
-
|
|
131
40
|
def compute_pb_valid(
|
|
132
41
|
ref_struct: Structure,
|
|
133
42
|
model_struct: Structure,
|
|
134
|
-
ref_lig_label_asym_id: str
|
|
135
|
-
) -> pd.DataFrame
|
|
43
|
+
ref_lig_label_asym_id: Union[str, list[str]],
|
|
44
|
+
) -> Optional[pd.DataFrame]:
|
|
136
45
|
"""
|
|
137
46
|
Compute pose-busting validation metrics for a given reference structure, model structure, and reference features.
|
|
138
47
|
|
|
@@ -152,6 +61,8 @@ def compute_pb_valid(
|
|
|
152
61
|
ref_lig_label_asym_ids = list(ref_lig_label_asym_id)
|
|
153
62
|
|
|
154
63
|
df_list = []
|
|
64
|
+
buster = PoseBusters(config="redock")
|
|
65
|
+
|
|
155
66
|
for lig_label_asym_id in ref_lig_label_asym_ids:
|
|
156
67
|
lig_mask = ref_struct.atom_array.label_asym_id == lig_label_asym_id
|
|
157
68
|
|
|
@@ -159,13 +70,19 @@ def compute_pb_valid(
|
|
|
159
70
|
model_lig_chain_id = model_struct.uni_chain_id[lig_mask][0]
|
|
160
71
|
|
|
161
72
|
ref_lig_atom_array = ref_struct.atom_array[lig_mask]
|
|
162
|
-
model_lig_atom_array =
|
|
73
|
+
model_lig_atom_array = model_struct.atom_array[lig_mask].copy()
|
|
163
74
|
# reset res_name for model ligand atoms by ref Structure
|
|
164
75
|
model_lig_atom_array.res_name = ref_lig_atom_array.res_name
|
|
165
|
-
model_cond_atom_array = model_struct.atom_array[~lig_mask]
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
76
|
+
model_cond_atom_array = model_struct.atom_array[~lig_mask].copy()
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
ref_lig_mol = get_ccd_mol_from_chain_atom_array(ref_lig_atom_array)
|
|
80
|
+
model_lig_mol = get_ccd_mol_from_chain_atom_array(model_lig_atom_array)
|
|
81
|
+
except Exception:
|
|
82
|
+
logging.warning(
|
|
83
|
+
f"Failed to create RDKit molecule for ligand {lig_label_asym_id}. Skipping PoseBusters."
|
|
84
|
+
)
|
|
85
|
+
continue
|
|
169
86
|
|
|
170
87
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
171
88
|
tmp_dir = Path(tmp_dir)
|
|
@@ -182,26 +99,34 @@ def compute_pb_valid(
|
|
|
182
99
|
sdf_writer.close()
|
|
183
100
|
|
|
184
101
|
pdb_file = pdb.PDBFile()
|
|
185
|
-
|
|
186
|
-
# PDB file only support one letter chain_id
|
|
187
|
-
model_cond_atom_array.chain_id =
|
|
188
|
-
i[0] for i in model_cond_atom_array.chain_id
|
|
189
|
-
|
|
102
|
+
|
|
103
|
+
# PDB file only support one letter chain_id, 3 letters res_name, 4 letters atom_name
|
|
104
|
+
model_cond_atom_array.chain_id = np.array(
|
|
105
|
+
[i[0] if len(i) > 0 else " " for i in model_cond_atom_array.chain_id],
|
|
106
|
+
dtype="U1",
|
|
107
|
+
)
|
|
108
|
+
model_cond_atom_array.res_name = np.array(
|
|
109
|
+
[i[:3] for i in model_cond_atom_array.res_name], dtype="U3"
|
|
110
|
+
)
|
|
111
|
+
model_cond_atom_array.atom_name = np.array(
|
|
112
|
+
[i[:4] for i in model_cond_atom_array.atom_name], dtype="U4"
|
|
113
|
+
)
|
|
114
|
+
model_cond_atom_array.bonds = None
|
|
115
|
+
|
|
190
116
|
pdb_file.set_structure(model_cond_atom_array)
|
|
191
117
|
pdb_file.write(model_cond_pdb)
|
|
192
118
|
|
|
193
|
-
buster = PoseBusters(config="redock")
|
|
194
119
|
df = buster.bust(
|
|
195
120
|
mol_pred=model_lig_sdf,
|
|
196
121
|
mol_true=ref_lig_sdf,
|
|
197
122
|
mol_cond=model_cond_pdb,
|
|
198
123
|
full_report=True,
|
|
199
124
|
)
|
|
200
|
-
|
|
201
125
|
# record ligand chain id
|
|
202
126
|
df["ref_lig_chain_id"] = ref_lig_chain_id
|
|
203
127
|
df["model_lig_chain_id"] = model_lig_chain_id
|
|
204
128
|
df_list.append(df)
|
|
129
|
+
|
|
205
130
|
df_cat = pd.concat(df_list)
|
|
206
131
|
return df_cat
|
|
207
132
|
|
|
@@ -232,6 +157,7 @@ class CalcLDDTMetric:
|
|
|
232
157
|
is_nucleotide_threshold=lddt_config.nucleotide_threshold,
|
|
233
158
|
is_not_nucleotide_threshold=lddt_config.non_nucleotide_threshold,
|
|
234
159
|
eps=lddt_config.eps,
|
|
160
|
+
stereochecks=lddt_config.stereochecks,
|
|
235
161
|
)
|
|
236
162
|
|
|
237
163
|
def get_chains_mask(
|
|
@@ -278,7 +204,7 @@ class CalcLDDTMetric:
|
|
|
278
204
|
merged_chain_2_masks = np.array(merged_chain_2_masks)
|
|
279
205
|
return merged_chain_1_masks, merged_chain_2_masks
|
|
280
206
|
|
|
281
|
-
def get_complex_lddt(self) -> float:
|
|
207
|
+
def get_complex_lddt(self, atom_mask: Optional[np.ndarray] = None) -> float:
|
|
282
208
|
"""
|
|
283
209
|
Calculate the LDDT score for a complex.
|
|
284
210
|
|
|
@@ -286,6 +212,9 @@ class CalcLDDTMetric:
|
|
|
286
212
|
and true coordinates of the complex. The LDDT score is a measure of the
|
|
287
213
|
structural similarity between the predicted and true structures.
|
|
288
214
|
|
|
215
|
+
Args:
|
|
216
|
+
atom_mask (np.ndarray): A mask for the atoms to include in the calculation.
|
|
217
|
+
|
|
289
218
|
Returns:
|
|
290
219
|
float: The LDDT score for the complex.
|
|
291
220
|
"""
|
|
@@ -293,11 +222,15 @@ class CalcLDDTMetric:
|
|
|
293
222
|
complex_lddt = self.lddt_calculator.run(
|
|
294
223
|
chain_1_masks=None,
|
|
295
224
|
chain_2_masks=None,
|
|
225
|
+
atom_mask=atom_mask,
|
|
296
226
|
)
|
|
297
227
|
return complex_lddt
|
|
298
228
|
|
|
299
229
|
def get_chain_interface_lddt(
|
|
300
|
-
self,
|
|
230
|
+
self,
|
|
231
|
+
chains: list[str],
|
|
232
|
+
interfaces: list[tuple[str, str]],
|
|
233
|
+
atom_mask: Optional[np.ndarray] = None,
|
|
301
234
|
) -> list[float]:
|
|
302
235
|
"""
|
|
303
236
|
Calculate the LDDT scores for chains and interfaces.
|
|
@@ -305,7 +238,9 @@ class CalcLDDTMetric:
|
|
|
305
238
|
Args:
|
|
306
239
|
chains (list[str]): A list of chain identifiers.
|
|
307
240
|
interfaces (list[tuple[str, str]]): A list of tuples, each containing
|
|
308
|
-
|
|
241
|
+
two chain identifiers representing an interface.
|
|
242
|
+
atom_mask (np.ndarray, optional): A mask for the atoms to include in the calculation.
|
|
243
|
+
Defaults to None.
|
|
309
244
|
|
|
310
245
|
Returns:
|
|
311
246
|
list[float]: A list of LDDT scores for chains and interfaces.
|
|
@@ -317,6 +252,7 @@ class CalcLDDTMetric:
|
|
|
317
252
|
lddt_list = self.lddt_calculator.run(
|
|
318
253
|
chain_1_masks=merged_chain_1_masks,
|
|
319
254
|
chain_2_masks=merged_chain_2_masks,
|
|
255
|
+
atom_mask=atom_mask,
|
|
320
256
|
)
|
|
321
257
|
return lddt_list
|
|
322
258
|
|
|
@@ -343,9 +279,11 @@ class MetricResult:
|
|
|
343
279
|
interface: dict[tuple[str, str], dict[str, Any]]
|
|
344
280
|
|
|
345
281
|
# [ref_chain_id: {metric: value}]
|
|
346
|
-
pb_valid: dict[str, dict[str, Any]]
|
|
282
|
+
pb_valid: Optional[dict[str, dict[str, Any]]] = None
|
|
347
283
|
|
|
348
|
-
ori_model_chain_ids: list[str]
|
|
284
|
+
ori_model_chain_ids: Optional[list[str]] = None
|
|
285
|
+
|
|
286
|
+
update_data: Optional[dict[str, Any]] = None
|
|
349
287
|
|
|
350
288
|
@staticmethod
|
|
351
289
|
def _get_chain_info(ref_struct: Structure) -> dict[str, dict[str, str]]:
|
|
@@ -420,26 +358,31 @@ class MetricResult:
|
|
|
420
358
|
chains: list[str],
|
|
421
359
|
interfaces: list[tuple[str, str]],
|
|
422
360
|
chain_interface_lddt: list[float],
|
|
361
|
+
metric_name: str = "lddt",
|
|
423
362
|
) -> tuple[dict[str, dict[str, float]], dict[tuple[str, str], dict[str, float]]]:
|
|
424
363
|
chain_lddt_dict = {}
|
|
425
364
|
interface_lddt_dict = {}
|
|
426
365
|
num_chains = len(chains)
|
|
427
366
|
for idx, chain_id in enumerate(chains):
|
|
428
|
-
|
|
367
|
+
lddt_value = chain_interface_lddt[idx]
|
|
368
|
+
if np.isnan(lddt_value):
|
|
369
|
+
continue
|
|
370
|
+
chain_lddt_dict[chain_id] = {metric_name: lddt_value}
|
|
429
371
|
|
|
430
372
|
for idx, interface in enumerate(interfaces):
|
|
431
373
|
sorted_interface = tuple(
|
|
432
374
|
sorted(interface)
|
|
433
375
|
) # Sort chains to ensure consistent order
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
376
|
+
lddt_value = chain_interface_lddt[idx + num_chains]
|
|
377
|
+
if np.isnan(lddt_value):
|
|
378
|
+
continue
|
|
379
|
+
interface_lddt_dict[sorted_interface] = {metric_name: lddt_value}
|
|
437
380
|
return chain_lddt_dict, interface_lddt_dict
|
|
438
381
|
|
|
439
382
|
@staticmethod
|
|
440
383
|
def _post_process_dockq(
|
|
441
384
|
dockq_result_dict: dict[str, Any],
|
|
442
|
-
) -> dict[str, float
|
|
385
|
+
) -> dict[str, Union[float, dict[str, float]]]:
|
|
443
386
|
polymer_dockq_metrics = {"F1", "iRMSD", "LRMSD", "fnat", "nat_correct",
|
|
444
387
|
"nat_total", "fnonnat", "nonnat_count", "model_total",
|
|
445
388
|
"clashes", "len1", "len2", "class1", "class2", "is_het",
|
|
@@ -475,8 +418,8 @@ class MetricResult:
|
|
|
475
418
|
|
|
476
419
|
@staticmethod
|
|
477
420
|
def _post_process_pb_valid(
|
|
478
|
-
pb_valid_result_df: pd.DataFrame
|
|
479
|
-
) -> dict[str, dict[str, Any]]
|
|
421
|
+
pb_valid_result_df: Optional[pd.DataFrame],
|
|
422
|
+
) -> Optional[dict[str, dict[str, Any]]]:
|
|
480
423
|
if pb_valid_result_df is None:
|
|
481
424
|
return
|
|
482
425
|
|
|
@@ -505,14 +448,129 @@ class MetricResult:
|
|
|
505
448
|
else:
|
|
506
449
|
tar_dict[key] = value
|
|
507
450
|
|
|
451
|
+
@staticmethod
|
|
452
|
+
def _calc_stereochecks_summary(
|
|
453
|
+
atom_mask: np.ndarray,
|
|
454
|
+
clash_df: pd.DataFrame,
|
|
455
|
+
bad_bond_df: pd.DataFrame,
|
|
456
|
+
bad_angle_df: pd.DataFrame,
|
|
457
|
+
) -> dict[str, int]:
|
|
458
|
+
"""
|
|
459
|
+
ggregate stereochemistry violations within an atom subset.
|
|
460
|
+
|
|
461
|
+
- `clash_atoms`: number of unique atoms involved in clashes (within subset)
|
|
462
|
+
- `bad_bonds`: number of bad bonds (within subset)
|
|
463
|
+
- `bad_angles`: number of bad angles (within subset)
|
|
464
|
+
|
|
465
|
+
The `idx*` columns in DataFrames are indices into the mapped atom arrays.
|
|
466
|
+
"""
|
|
467
|
+
|
|
468
|
+
atom_mask = np.asarray(atom_mask, dtype=bool)
|
|
469
|
+
|
|
470
|
+
clash_atoms = 0
|
|
471
|
+
if clash_df is not None and (not clash_df.empty):
|
|
472
|
+
idx1 = clash_df["idx1"].to_numpy(dtype=np.int64, copy=False)
|
|
473
|
+
idx2 = clash_df["idx2"].to_numpy(dtype=np.int64, copy=False)
|
|
474
|
+
row_mask = atom_mask[idx1] & atom_mask[idx2]
|
|
475
|
+
if np.any(row_mask):
|
|
476
|
+
clash_atoms = int(
|
|
477
|
+
np.unique(np.concatenate([idx1[row_mask], idx2[row_mask]])).size
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
bond_cnt = 0
|
|
481
|
+
if bad_bond_df is not None and (not bad_bond_df.empty):
|
|
482
|
+
idx1 = bad_bond_df["idx1"].to_numpy(dtype=np.int64, copy=False)
|
|
483
|
+
idx2 = bad_bond_df["idx2"].to_numpy(dtype=np.int64, copy=False)
|
|
484
|
+
bond_cnt = int(np.sum(atom_mask[idx1] & atom_mask[idx2]))
|
|
485
|
+
|
|
486
|
+
angle_cnt = 0
|
|
487
|
+
if bad_angle_df is not None and (not bad_angle_df.empty):
|
|
488
|
+
idx_a = bad_angle_df["idx_a"].to_numpy(dtype=np.int64, copy=False)
|
|
489
|
+
idx_b = bad_angle_df["idx_b"].to_numpy(dtype=np.int64, copy=False)
|
|
490
|
+
idx_c = bad_angle_df["idx_c"].to_numpy(dtype=np.int64, copy=False)
|
|
491
|
+
angle_cnt = int(
|
|
492
|
+
np.sum(atom_mask[idx_a] & atom_mask[idx_b] & atom_mask[idx_c])
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
return {
|
|
496
|
+
"clash_atoms": clash_atoms,
|
|
497
|
+
"bad_bonds": bond_cnt,
|
|
498
|
+
"bad_angles": angle_cnt,
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
@classmethod
|
|
502
|
+
def _maybe_add_lddt_stereochecks_summaries(
|
|
503
|
+
cls,
|
|
504
|
+
*,
|
|
505
|
+
lddt_config: ConfigDict,
|
|
506
|
+
lddt_calculator: LDDT,
|
|
507
|
+
ref_struct: Structure,
|
|
508
|
+
chains: list[str],
|
|
509
|
+
interfaces: list[tuple[str, str]],
|
|
510
|
+
complex_result_dict: dict[str, Any],
|
|
511
|
+
chain_result_dict: dict[str, dict[str, Any]],
|
|
512
|
+
interface_result_dict: dict[tuple[str, str], dict[str, Any]],
|
|
513
|
+
) -> None:
|
|
514
|
+
"""Attach stereochemistry violation summaries to output dicts.
|
|
515
|
+
|
|
516
|
+
Only active when `metric.lddt.stereochecks=True` and the underlying
|
|
517
|
+
stereochemistry checker produced violation tables.
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
if not lddt_config.stereochecks:
|
|
521
|
+
return
|
|
522
|
+
|
|
523
|
+
stereo_violation_dfs = getattr(lddt_calculator, "stereo_violation_dfs", None)
|
|
524
|
+
if stereo_violation_dfs is None:
|
|
525
|
+
return
|
|
526
|
+
|
|
527
|
+
clash_df, bad_bond_df, bad_angle_df = stereo_violation_dfs
|
|
528
|
+
n_atoms = len(ref_struct.atom_array)
|
|
529
|
+
|
|
530
|
+
# Complex-level summary
|
|
531
|
+
complex_result_dict["stereochecks"] = cls._calc_stereochecks_summary(
|
|
532
|
+
atom_mask=np.ones(n_atoms, dtype=bool),
|
|
533
|
+
clash_df=clash_df,
|
|
534
|
+
bad_bond_df=bad_bond_df,
|
|
535
|
+
bad_angle_df=bad_angle_df,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# Chain-level summary (keyed by reference chain IDs)
|
|
539
|
+
for chain_id in chains:
|
|
540
|
+
chain_atom_mask = ref_struct.uni_chain_id == chain_id
|
|
541
|
+
chain_result_dict.setdefault(chain_id, {})[
|
|
542
|
+
"stereochecks"
|
|
543
|
+
] = cls._calc_stereochecks_summary(
|
|
544
|
+
atom_mask=chain_atom_mask,
|
|
545
|
+
clash_df=clash_df,
|
|
546
|
+
bad_bond_df=bad_bond_df,
|
|
547
|
+
bad_angle_df=bad_angle_df,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Interface-level summary (keyed by sorted(reference chain IDs))
|
|
551
|
+
for chain_1, chain_2 in interfaces:
|
|
552
|
+
interface_key = tuple(sorted((chain_1, chain_2)))
|
|
553
|
+
interface_atom_mask = (ref_struct.uni_chain_id == chain_1) | (
|
|
554
|
+
ref_struct.uni_chain_id == chain_2
|
|
555
|
+
)
|
|
556
|
+
interface_result_dict.setdefault(interface_key, {})[
|
|
557
|
+
"stereochecks"
|
|
558
|
+
] = cls._calc_stereochecks_summary(
|
|
559
|
+
atom_mask=interface_atom_mask,
|
|
560
|
+
clash_df=clash_df,
|
|
561
|
+
bad_bond_df=bad_bond_df,
|
|
562
|
+
bad_angle_df=bad_angle_df,
|
|
563
|
+
)
|
|
564
|
+
|
|
508
565
|
@classmethod
|
|
509
566
|
def from_struct(
|
|
510
567
|
cls,
|
|
511
568
|
ref_struct: Structure,
|
|
512
569
|
model_struct: Structure,
|
|
513
|
-
ori_model_chain_ids: list[str]
|
|
514
|
-
interested_lig_label_asym_id: str
|
|
570
|
+
ori_model_chain_ids: Optional[list[str]] = None,
|
|
571
|
+
interested_lig_label_asym_id: Optional[Union[str, list[str]]] = None,
|
|
515
572
|
metric_config: ConfigDict = RUN_CONFIG.metric,
|
|
573
|
+
update_data: Optional[dict[str, Any]] = None,
|
|
516
574
|
) -> "MetricResult":
|
|
517
575
|
"""
|
|
518
576
|
Create a MetricResult instance from given structures and features.
|
|
@@ -525,6 +583,8 @@ class MetricResult:
|
|
|
525
583
|
specifying the ligand label asym IDs of interest.
|
|
526
584
|
metric_config (dict[str, Any]): A dictionary containing configuration for
|
|
527
585
|
metrics. Defaults to RUN_CONFIG.metric.
|
|
586
|
+
update_data (dict[str, Any] | None): A dictionary containing additional data to update.
|
|
587
|
+
Defaults to None.
|
|
528
588
|
|
|
529
589
|
Returns:
|
|
530
590
|
MetricResult: An instance of MetricResult containing the calculated metrics.
|
|
@@ -555,16 +615,6 @@ class MetricResult:
|
|
|
555
615
|
meta_info_dict["ref_to_model_chain_mapping"] = chain_map
|
|
556
616
|
meta_info_dict["ref_chain_info"] = cls._get_chain_info(ref_struct)
|
|
557
617
|
|
|
558
|
-
# Calculate clashes
|
|
559
|
-
if metric_config.calc_clashes:
|
|
560
|
-
clashes = check_clashes_by_vdw(
|
|
561
|
-
model_struct.atom_array,
|
|
562
|
-
vdw_scale_factor=metric_config.clashes.vdw_scale_factor,
|
|
563
|
-
)
|
|
564
|
-
complex_result_dict["clashes"] = len(
|
|
565
|
-
{x for a, b in clashes for x in (a, b)}
|
|
566
|
-
)
|
|
567
|
-
|
|
568
618
|
# Calculate RMSD (if ligand and pocket specified in ref_features)
|
|
569
619
|
if metric_config.calc_rmsd and interested_lig_label_asym_id:
|
|
570
620
|
rmsd_metrics = RMSDMetrics(
|
|
@@ -590,8 +640,21 @@ class MetricResult:
|
|
|
590
640
|
model_struct=model_struct,
|
|
591
641
|
lddt_config=metric_config.lddt,
|
|
592
642
|
)
|
|
643
|
+
|
|
644
|
+
cls._maybe_add_lddt_stereochecks_summaries(
|
|
645
|
+
lddt_config=metric_config.lddt,
|
|
646
|
+
lddt_calculator=calc_lddt.lddt_calculator,
|
|
647
|
+
ref_struct=ref_struct,
|
|
648
|
+
chains=chains,
|
|
649
|
+
interfaces=interfaces,
|
|
650
|
+
complex_result_dict=complex_result_dict,
|
|
651
|
+
chain_result_dict=chain_result_dict,
|
|
652
|
+
interface_result_dict=interface_result_dict,
|
|
653
|
+
)
|
|
654
|
+
|
|
593
655
|
complex_lddt = calc_lddt.get_complex_lddt()
|
|
594
|
-
|
|
656
|
+
if not np.isnan(complex_lddt):
|
|
657
|
+
complex_result_dict["lddt"] = complex_lddt
|
|
595
658
|
|
|
596
659
|
chain_interface_lddt = calc_lddt.get_chain_interface_lddt(
|
|
597
660
|
chains, interfaces
|
|
@@ -605,12 +668,34 @@ class MetricResult:
|
|
|
605
668
|
cls._update_src_to_tar_dict(chain_lddt_dict, chain_result_dict)
|
|
606
669
|
cls._update_src_to_tar_dict(interface_lddt_dict, interface_result_dict)
|
|
607
670
|
|
|
671
|
+
if metric_config.lddt.calc_backbone_lddt:
|
|
672
|
+
backbone_mask = ref_struct.get_backbone_atom_masks(only_rep_atom=True)
|
|
673
|
+
complex_bb_lddt = calc_lddt.get_complex_lddt(atom_mask=backbone_mask)
|
|
674
|
+
if not np.isnan(complex_bb_lddt):
|
|
675
|
+
complex_result_dict["bb_lddt"] = complex_bb_lddt
|
|
676
|
+
|
|
677
|
+
# It reuses the chains and interfaces from the previous step
|
|
678
|
+
chain_interface_lddt = calc_lddt.get_chain_interface_lddt(
|
|
679
|
+
chains, interfaces, atom_mask=backbone_mask
|
|
680
|
+
)
|
|
681
|
+
(
|
|
682
|
+
chain_bb_lddt_dict,
|
|
683
|
+
interface_bb_lddt_dict,
|
|
684
|
+
) = cls._post_process_chain_interface_lddt(
|
|
685
|
+
chains, interfaces, chain_interface_lddt, metric_name="bb_lddt"
|
|
686
|
+
)
|
|
687
|
+
cls._update_src_to_tar_dict(chain_bb_lddt_dict, chain_result_dict)
|
|
688
|
+
cls._update_src_to_tar_dict(
|
|
689
|
+
interface_bb_lddt_dict, interface_result_dict
|
|
690
|
+
)
|
|
691
|
+
|
|
608
692
|
# Calculate DockQ
|
|
609
693
|
if metric_config.calc_dockq:
|
|
610
694
|
dockq_result_dict = compute_dockq(
|
|
611
695
|
ref_struct=ref_struct,
|
|
612
696
|
model_struct=model_struct,
|
|
613
697
|
ref_to_model_chain_map=chain_map,
|
|
698
|
+
exclude_hetatms=metric_config.dockq.exclude_hetatms,
|
|
614
699
|
)
|
|
615
700
|
interface_dockq_dict = cls._post_process_dockq(dockq_result_dict)
|
|
616
701
|
cls._update_src_to_tar_dict(interface_dockq_dict, interface_result_dict)
|
|
@@ -635,6 +720,7 @@ class MetricResult:
|
|
|
635
720
|
interface=interface_result_dict,
|
|
636
721
|
pb_valid=chain_pb_valid_dict,
|
|
637
722
|
ori_model_chain_ids=ori_model_chain_ids,
|
|
723
|
+
update_data=update_data,
|
|
638
724
|
)
|
|
639
725
|
|
|
640
726
|
def to_json_dict(self) -> dict[str, Any]:
|
|
@@ -663,7 +749,7 @@ class MetricResult:
|
|
|
663
749
|
json_dict["ori_model_chain_ids"] = self.ori_model_chain_ids
|
|
664
750
|
return json_dict
|
|
665
751
|
|
|
666
|
-
def to_json(self, json_file: Path, update_data: dict
|
|
752
|
+
def to_json(self, json_file: Path, update_data: Optional[dict] = None):
|
|
667
753
|
"""
|
|
668
754
|
Convert the MetricResult instance to a JSON string.
|
|
669
755
|
|
|
@@ -677,5 +763,8 @@ class MetricResult:
|
|
|
677
763
|
if update_data:
|
|
678
764
|
json_dict.update(update_data)
|
|
679
765
|
|
|
766
|
+
if self.update_data is not None:
|
|
767
|
+
json_dict.update(self.update_data)
|
|
768
|
+
|
|
680
769
|
with open(json_file, "w", encoding="utf-8") as f:
|
|
681
770
|
json.dump(json_dict, f, indent=4, ensure_ascii=False)
|