pxmeter 0.1.6__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {pxmeter-0.1.6/pxmeter.egg-info → pxmeter-1.0.0}/PKG-INFO +3 -2
  2. {pxmeter-0.1.6 → pxmeter-1.0.0}/README.md +53 -11
  3. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/calc_metric.py +222 -133
  4. pxmeter-1.0.0/pxmeter/cli.py +365 -0
  5. pxmeter-1.0.0/pxmeter/configs/run_config.py +188 -0
  6. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/constants.py +169 -133
  7. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/ccd.py +102 -15
  8. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/parser.py +218 -21
  9. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/struct.py +77 -15
  10. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/utils.py +7 -7
  11. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/writer.py +4 -1
  12. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/eval.py +6 -4
  13. pxmeter-1.0.0/pxmeter/input_builder/constants.py +16 -0
  14. pxmeter-1.0.0/pxmeter/input_builder/gen_input.py +381 -0
  15. pxmeter-1.0.0/pxmeter/input_builder/interactive.py +702 -0
  16. pxmeter-1.0.0/pxmeter/input_builder/model_inputs/alphafold3.py +366 -0
  17. pxmeter-1.0.0/pxmeter/input_builder/model_inputs/boltz.py +360 -0
  18. pxmeter-1.0.0/pxmeter/input_builder/model_inputs/protenix.py +559 -0
  19. pxmeter-1.0.0/pxmeter/input_builder/seq.py +584 -0
  20. pxmeter-1.0.0/pxmeter/input_builder/utils/__init__.py +0 -0
  21. pxmeter-1.0.0/pxmeter/input_builder/utils/unstd_res_mapping.py +225 -0
  22. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/mapping.py +40 -27
  23. pxmeter-1.0.0/pxmeter/metrics/__init__.py +0 -0
  24. pxmeter-1.0.0/pxmeter/metrics/dockq.py +523 -0
  25. pxmeter-1.0.0/pxmeter/metrics/lddt_metrics.py +310 -0
  26. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/metrics/rmsd.py +7 -5
  27. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/metrics/rmsd_metrics.py +35 -32
  28. pxmeter-1.0.0/pxmeter/metrics/stereochemistry/__init__.py +0 -0
  29. pxmeter-1.0.0/pxmeter/metrics/stereochemistry/check.py +1706 -0
  30. pxmeter-1.0.0/pxmeter/metrics/stereochemistry/params.py +2324 -0
  31. pxmeter-1.0.0/pxmeter/permutation/__init__.py +0 -0
  32. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/permutation/atom.py +9 -4
  33. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/permutation/chain.py +91 -31
  34. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/permutation/residue.py +30 -8
  35. pxmeter-1.0.0/pxmeter/utils.py +89 -0
  36. {pxmeter-0.1.6 → pxmeter-1.0.0/pxmeter.egg-info}/PKG-INFO +3 -2
  37. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/SOURCES.txt +15 -2
  38. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/requires.txt +2 -1
  39. {pxmeter-0.1.6 → pxmeter-1.0.0}/requirements.txt +3 -2
  40. {pxmeter-0.1.6 → pxmeter-1.0.0}/setup.py +1 -1
  41. pxmeter-0.1.6/pxmeter/cli.py +0 -182
  42. pxmeter-0.1.6/pxmeter/configs/data_config.py +0 -116
  43. pxmeter-0.1.6/pxmeter/configs/run_config.py +0 -41
  44. pxmeter-0.1.6/pxmeter/metrics/clashes.py +0 -88
  45. pxmeter-0.1.6/pxmeter/metrics/lddt_metrics.py +0 -248
  46. pxmeter-0.1.6/pxmeter/utils.py +0 -38
  47. {pxmeter-0.1.6 → pxmeter-1.0.0}/LICENSE +0 -0
  48. {pxmeter-0.1.6 → pxmeter-1.0.0}/MANIFEST.in +0 -0
  49. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/__init__.py +0 -0
  50. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/configs/__init__.py +0 -0
  51. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter/data/__init__.py +0 -0
  52. {pxmeter-0.1.6/pxmeter/metrics → pxmeter-1.0.0/pxmeter/input_builder}/__init__.py +0 -0
  53. {pxmeter-0.1.6/pxmeter/permutation → pxmeter-1.0.0/pxmeter/input_builder/model_inputs}/__init__.py +0 -0
  54. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/dependency_links.txt +0 -0
  55. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/entry_points.txt +0 -0
  56. {pxmeter-0.1.6 → pxmeter-1.0.0}/pxmeter.egg-info/top_level.txt +0 -0
  57. {pxmeter-0.1.6 → pxmeter-1.0.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pxmeter
3
- Version: 0.1.6
3
+ Version: 1.0.0
4
4
  Summary: PXMeter is a comprehensive toolkit for evaluating the quality of structures generated by biomolecular structure prediction models.
5
5
  Author: Bytedance Inc.
6
6
  Author-email: ai4s-bio@bytedance.com
@@ -9,7 +9,6 @@ Platform: manylinux1
9
9
  Requires-Python: >=3.11
10
10
  License-File: LICENSE
11
11
  Requires-Dist: biotite>=1.2.0
12
- Requires-Dist: dockq==2.1.3
13
12
  Requires-Dist: gemmi==0.7.0
14
13
  Requires-Dist: joblib
15
14
  Requires-Dist: ml_collections
@@ -22,6 +21,8 @@ Requires-Dist: scipy
22
21
  Requires-Dist: tabulate
23
22
  Requires-Dist: tqdm
24
23
  Requires-Dist: click
24
+ Requires-Dist: pyarrow
25
+ Requires-Dist: PyYAML
25
26
  Dynamic: author
26
27
  Dynamic: author-email
27
28
  Dynamic: license
@@ -32,7 +32,7 @@ pip install -r requirements.txt
32
32
  pip install -e .
33
33
  ```
34
34
 
35
- PXMeter will automatically download the Chemical Component Dictionary (CCD) upon its first run. To update the CCD files:
35
+ PXMeter directly uses the Chemical Component Dictionary (CCD) bundled with Biotite. To update the CCD files:
36
36
 
37
37
  ```bash
38
38
  pxm ccd update
@@ -48,12 +48,13 @@ pxm -r examples/7rss.cif -m examples/7rss_protenix_pred.cif -o pxm_output.json
48
48
  **Key Parameters**:
49
49
  - `-r` or `--ref_cif`: Path to reference CIF file
50
50
  - `-m` or `--model_cif`: Path to model CIF file
51
- - `-o` or `--output`: Path to save evaluation results (default: "pxm_output.json")
51
+ - `-o` or `--output_json`: Path to save evaluation results (default: "pxm_output.json")
52
52
  - `--ref_model`: Specify model number of reference CIF (default: 1)
53
53
  - `--ref_assembly_id`: Specify the assembly ID for the reference CIF (default: None; uses the Asymmetric Unit for evaluation)
54
- - `ref_altloc`: Specify the alternative location identifier for the reference CIF (default: "first", uses the first alternative location code for each residue).
54
+ - `--ref_altloc`: Specify the alternative location identifier for the reference CIF (default: "first", uses the first alternative location code for each residue).
55
55
  - `--chain_id_to_mol_json`: JSON file defining custom ligands, where keys are chain IDs (label_asym_id) and values are the corresponding ligand SMILES strings.
56
56
  - `-l` or `--interested_lig_label_asym_id`: Indicate the `label_asym_id` of ligands for metrics like pocket-aligned RMSD. Multiple ligands should be comma-separated.
57
+ - `-C key.path=value`: Override fields in `pxmeter.configs.run_config.RUN_CONFIG` (repeatable; e.g., `-C metric.lddt.eps=1e-4 -C mapping.mapping_ligand=false`).
57
58
 
58
59
  To access the full list of parameters, use the `--help` option.
59
60
 
@@ -80,18 +81,59 @@ For detailed descriptions of additional parameters, use the `help()` function:
80
81
  help(evaluate)
81
82
  ```
82
83
 
84
+ If you need to modify the runtime settings defined in
85
+ `pxmeter.configs.run_config.RUN_CONFIG` (equivalent to using `-C` on the command line),
86
+ you may directly update the values in `RUN_CONFIG` and then pass it into the evaluate() function.
87
+ ```python
88
+ from pxmeter.configs.run_config import RUN_CONFIG
89
+
90
+ RUN_CONFIG.mapping.res_id_alignments = False
91
+ metric_result = evaluate(
92
+ ...,
93
+ run_config=RUN_CONFIG,
94
+ )
95
+ ```
96
+
97
+ For a detailed, step-by-step description of the PXMeter runtime evaluation pipeline (mapping, alignment, and metric computation), please refer to the [PXMeter evaluation pipeline details](docs/pxmeter_eval_details.md).
98
+
99
+ For a comprehensive overview of the runtime configuration options, recommended defaults, and advanced usage examples, see the [PXMeter run configuration guide](docs/run_config_details.md).
100
+
101
+ ### Optional: Stereochemistry checks
102
+
103
+ Run stereochemistry checks for a single CIF and export a CSV report:
104
+
105
+ ```bash
106
+ pxm stereocheck -c examples/7rss_protenix_pred.cif -o stereochem_report.csv
107
+ ```
108
+ **`pxm stereocheck` Parameters**:
109
+ - `-c` or `--cif` (required): Path to the CIF file
110
+ - `-o` or `--output-csv`: Path to the output CSV report (default: `stereochem_report.csv`)
111
+
112
+
83
113
  ## 📊 Benchmarking
84
- Refer to [benchmark/README.md](./benchmark/README.md) for evaluation protocols on:
85
- - RecentPDB dataset
86
- - PoseBusters V2
87
114
 
88
- The benchmark data is released under the CC0 license.
89
- We include code in the `benchmark` directory that evaluates various models using PXMeter and aggregates their metrics.
90
- This serves as an example of best practices for using the tool. For more details, please refer to our paper:
115
+ PXMeter offers a reproducible workflow covering both dataset creation and model evaluation.
116
+
117
+ **Note**: The benchmarking workflow (the `benchmark/` directory) is only available in the source repository and is not shipped with the PyPI package. To run benchmarking, please clone the repository first:
118
+
119
+ ```bash
120
+ git clone https://github.com/bytedance/PXMeter.git
121
+ cd PXMeter
122
+ ```
123
+
124
+ - The **[Benchmark Documentation](docs/benchmark.md)** explains how to run evaluations on model predictions and how the aggregated metrics are computed.
125
+ - The **[Dataset Pipeline Overview](docs/datapipeline.md)** describes the complete construction of the RecentPDB low-homology dataset,
126
+ including filtering, homology scans, clustering, and subset labeling.
127
+ The pipeline also allows users to **rebuild the evaluation dataset from scratch using any custom time window**.
128
+ This makes the benchmark fully flexible and adaptable to different release periods or ongoing updates from the PDB.
129
+ - For details on the dataset used in our paper, please refer to the **[legacy dataset documentation](docs/legacy_dataset_reference.md)**, which describes the dataset version and evaluation code used at the time of the initial release.
130
+
131
+ ## ➡️ Preparing input files
91
132
 
92
- 📄 <a href="https://www.biorxiv.org/content/10.1101/2025.07.17.664878v1">From Dataset Curation to Unified Evaluation: Revisiting
93
- Structure Prediction Benchmarks with PXMeter</a>
133
+ When working with structural inputs—e.g., converting mmCIF, AlpahFold3, Protenix, or Boltz formats—you may find the following utility helpful:
94
134
 
135
+ [pxm gen-input Usage Guide](docs/gen_input.md).
136
+ — a tool for generating and converting model input files via CLI or Python API.
95
137
 
96
138
 
97
139
  ## 💪 Contributing to PXMeter
@@ -12,20 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import copy
16
15
  import dataclasses
17
- import gzip
18
16
  import json
19
17
  import logging
20
18
  import tempfile
21
19
  from pathlib import Path
22
- from typing import Any
20
+ from typing import Any, Optional, Union
23
21
 
24
- import DockQ.parsers as dockq_parsers
25
22
  import numpy as np
26
23
  import pandas as pd
27
24
  from biotite.structure.io import pdb
28
- from DockQ.DockQ import run_on_all_native_interfaces
29
25
  from ml_collections.config_dict import ConfigDict
30
26
  from posebusters import PoseBusters
31
27
  from rdkit import Chem
@@ -34,105 +30,18 @@ from pxmeter.configs.run_config import RUN_CONFIG
34
30
  from pxmeter.constants import IONS, LIGAND
35
31
  from pxmeter.data.ccd import get_ccd_mol_from_chain_atom_array
36
32
  from pxmeter.data.struct import Structure
37
- from pxmeter.metrics.clashes import check_clashes_by_vdw
33
+ from pxmeter.metrics.dockq import compute_dockq
38
34
  from pxmeter.metrics.lddt_metrics import LDDT
39
35
  from pxmeter.metrics.rmsd_metrics import RMSDMetrics
40
36
 
41
37
  logging.getLogger("posebusters").setLevel(logging.ERROR)
42
38
 
43
39
 
44
- def load_PDB(path, chains=None, small_molecule=False, n_model=0):
45
- """
46
- Modified from DockQ.DockQ.load_PDB to avoid ResourceWarning warnings.
47
- ResourceWarning: Enable tracemalloc to get the object allocation traceback
48
- DockQ/DockQ.py:660: ResourceWarning: unclosed file
49
- """
50
- if chains is None:
51
- chains = []
52
- try:
53
- pdb_parser = dockq_parsers.PDBParser(QUIET=True)
54
- with (
55
- gzip.open(path, "rt") if path.endswith(".gz") else open(path, "rt")
56
- ) as file_obj:
57
- model = pdb_parser.get_structure(
58
- "-",
59
- file_obj,
60
- chains=chains,
61
- parse_hetatms=small_molecule,
62
- model_number=n_model,
63
- )
64
- except Exception:
65
- pdb_parser = dockq_parsers.MMCIFParser(QUIET=True)
66
- with (
67
- gzip.open(path, "rt") if path.endswith(".gz") else open(path, "rt")
68
- ) as file_obj:
69
- model = pdb_parser.get_structure(
70
- "-",
71
- file_obj,
72
- chains=chains,
73
- parse_hetatms=small_molecule,
74
- auth_chains=not small_molecule,
75
- model_number=n_model,
76
- )
77
- model.id = path
78
- return model
79
-
80
-
81
- def compute_dockq(
82
- ref_struct: Structure,
83
- model_struct: Structure,
84
- ref_to_model_chain_map: dict[str, str],
85
- ) -> dict[str, dict[str, Any]]:
86
- """
87
- Computes the DockQ score between a reference structure and a model structure.
88
-
89
- Args:
90
- ref_struct (Structure): The reference structure.
91
- model_struct (Structure): The model structure to be evaluated.
92
- ref_to_model_chain_map (dict[str, str]): A dictionary mapping reference chain IDs to model chain IDs.
93
-
94
- Returns:
95
- dict[str, dict[str, Any]]: A dictionary containing the DockQ score and other related metrics.
96
- """
97
- with tempfile.TemporaryDirectory() as tmp_dir:
98
- tmp_dir = Path(tmp_dir)
99
- tmp_ref_cif = tmp_dir / "tmp_ref.cif"
100
- tmp_model_cif = tmp_dir / "tmp_model.cif"
101
-
102
- # Calculate DockQ using exclusively valid atoms
103
- # Use uni_chain_id as label_asym_id
104
- ref_struct.to_cif(tmp_ref_cif, use_uni_chain_id=True)
105
- model_struct.to_cif(tmp_model_cif, use_uni_chain_id=True)
106
-
107
- # small_molecule=False means only polymer is considered
108
- model = load_PDB(str(tmp_model_cif), small_molecule=False)
109
- native = load_PDB(str(tmp_ref_cif), small_molecule=False)
110
-
111
- native_chains = [c.id for c in native]
112
- model_chains = [c.id for c in model]
113
-
114
- valid_ref_to_model_chain_map = {}
115
- for k, v in ref_to_model_chain_map.items():
116
- if (
117
- k in ref_struct.uni_chain_id
118
- and k in native_chains
119
- and v in model_chains
120
- ):
121
- # some all UNK structure will not be load by load_PDB(), e.g. chain Q in 7q6i
122
- valid_ref_to_model_chain_map[k] = v
123
- assert v in model_struct.uni_chain_id
124
-
125
- dockq_result_dict, _total_dockq = run_on_all_native_interfaces(
126
- model, native, chain_map=valid_ref_to_model_chain_map
127
- )
128
- return dockq_result_dict
129
-
130
-
131
40
  def compute_pb_valid(
132
41
  ref_struct: Structure,
133
42
  model_struct: Structure,
134
- ref_lig_label_asym_id: str | list[str],
135
- ) -> pd.DataFrame | None:
43
+ ref_lig_label_asym_id: Union[str, list[str]],
44
+ ) -> Optional[pd.DataFrame]:
136
45
  """
137
46
  Compute pose-busting validation metrics for a given reference structure, model structure, and reference features.
138
47
 
@@ -152,6 +61,8 @@ def compute_pb_valid(
152
61
  ref_lig_label_asym_ids = list(ref_lig_label_asym_id)
153
62
 
154
63
  df_list = []
64
+ buster = PoseBusters(config="redock")
65
+
155
66
  for lig_label_asym_id in ref_lig_label_asym_ids:
156
67
  lig_mask = ref_struct.atom_array.label_asym_id == lig_label_asym_id
157
68
 
@@ -159,13 +70,19 @@ def compute_pb_valid(
159
70
  model_lig_chain_id = model_struct.uni_chain_id[lig_mask][0]
160
71
 
161
72
  ref_lig_atom_array = ref_struct.atom_array[lig_mask]
162
- model_lig_atom_array = copy.deepcopy(model_struct.atom_array[lig_mask])
73
+ model_lig_atom_array = model_struct.atom_array[lig_mask].copy()
163
74
  # reset res_name for model ligand atoms by ref Structure
164
75
  model_lig_atom_array.res_name = ref_lig_atom_array.res_name
165
- model_cond_atom_array = model_struct.atom_array[~lig_mask]
166
-
167
- ref_lig_mol = get_ccd_mol_from_chain_atom_array(ref_lig_atom_array)
168
- model_lig_mol = get_ccd_mol_from_chain_atom_array(model_lig_atom_array)
76
+ model_cond_atom_array = model_struct.atom_array[~lig_mask].copy()
77
+
78
+ try:
79
+ ref_lig_mol = get_ccd_mol_from_chain_atom_array(ref_lig_atom_array)
80
+ model_lig_mol = get_ccd_mol_from_chain_atom_array(model_lig_atom_array)
81
+ except Exception:
82
+ logging.warning(
83
+ f"Failed to create RDKit molecule for ligand {lig_label_asym_id}. Skipping PoseBusters."
84
+ )
85
+ continue
169
86
 
170
87
  with tempfile.TemporaryDirectory() as tmp_dir:
171
88
  tmp_dir = Path(tmp_dir)
@@ -182,26 +99,34 @@ def compute_pb_valid(
182
99
  sdf_writer.close()
183
100
 
184
101
  pdb_file = pdb.PDBFile()
185
- model_cond_atom_array = copy.deepcopy(model_cond_atom_array)
186
- # PDB file only support one letter chain_id
187
- model_cond_atom_array.chain_id = [
188
- i[0] for i in model_cond_atom_array.chain_id
189
- ]
102
+
103
+ # PDB file only support one letter chain_id, 3 letters res_name, 4 letters atom_name
104
+ model_cond_atom_array.chain_id = np.array(
105
+ [i[0] if len(i) > 0 else " " for i in model_cond_atom_array.chain_id],
106
+ dtype="U1",
107
+ )
108
+ model_cond_atom_array.res_name = np.array(
109
+ [i[:3] for i in model_cond_atom_array.res_name], dtype="U3"
110
+ )
111
+ model_cond_atom_array.atom_name = np.array(
112
+ [i[:4] for i in model_cond_atom_array.atom_name], dtype="U4"
113
+ )
114
+ model_cond_atom_array.bonds = None
115
+
190
116
  pdb_file.set_structure(model_cond_atom_array)
191
117
  pdb_file.write(model_cond_pdb)
192
118
 
193
- buster = PoseBusters(config="redock")
194
119
  df = buster.bust(
195
120
  mol_pred=model_lig_sdf,
196
121
  mol_true=ref_lig_sdf,
197
122
  mol_cond=model_cond_pdb,
198
123
  full_report=True,
199
124
  )
200
-
201
125
  # record ligand chain id
202
126
  df["ref_lig_chain_id"] = ref_lig_chain_id
203
127
  df["model_lig_chain_id"] = model_lig_chain_id
204
128
  df_list.append(df)
129
+
205
130
  df_cat = pd.concat(df_list)
206
131
  return df_cat
207
132
 
@@ -232,6 +157,7 @@ class CalcLDDTMetric:
232
157
  is_nucleotide_threshold=lddt_config.nucleotide_threshold,
233
158
  is_not_nucleotide_threshold=lddt_config.non_nucleotide_threshold,
234
159
  eps=lddt_config.eps,
160
+ stereochecks=lddt_config.stereochecks,
235
161
  )
236
162
 
237
163
  def get_chains_mask(
@@ -278,7 +204,7 @@ class CalcLDDTMetric:
278
204
  merged_chain_2_masks = np.array(merged_chain_2_masks)
279
205
  return merged_chain_1_masks, merged_chain_2_masks
280
206
 
281
- def get_complex_lddt(self) -> float:
207
+ def get_complex_lddt(self, atom_mask: Optional[np.ndarray] = None) -> float:
282
208
  """
283
209
  Calculate the LDDT score for a complex.
284
210
 
@@ -286,6 +212,9 @@ class CalcLDDTMetric:
286
212
  and true coordinates of the complex. The LDDT score is a measure of the
287
213
  structural similarity between the predicted and true structures.
288
214
 
215
+ Args:
216
+ atom_mask (np.ndarray): A mask for the atoms to include in the calculation.
217
+
289
218
  Returns:
290
219
  float: The LDDT score for the complex.
291
220
  """
@@ -293,11 +222,15 @@ class CalcLDDTMetric:
293
222
  complex_lddt = self.lddt_calculator.run(
294
223
  chain_1_masks=None,
295
224
  chain_2_masks=None,
225
+ atom_mask=atom_mask,
296
226
  )
297
227
  return complex_lddt
298
228
 
299
229
  def get_chain_interface_lddt(
300
- self, chains: list[str], interfaces: list[tuple[str, str]]
230
+ self,
231
+ chains: list[str],
232
+ interfaces: list[tuple[str, str]],
233
+ atom_mask: Optional[np.ndarray] = None,
301
234
  ) -> list[float]:
302
235
  """
303
236
  Calculate the LDDT scores for chains and interfaces.
@@ -305,7 +238,9 @@ class CalcLDDTMetric:
305
238
  Args:
306
239
  chains (list[str]): A list of chain identifiers.
307
240
  interfaces (list[tuple[str, str]]): A list of tuples, each containing
308
- two chain identifiers representing an interface.
241
+ two chain identifiers representing an interface.
242
+ atom_mask (np.ndarray, optional): A mask for the atoms to include in the calculation.
243
+ Defaults to None.
309
244
 
310
245
  Returns:
311
246
  list[float]: A list of LDDT scores for chains and interfaces.
@@ -317,6 +252,7 @@ class CalcLDDTMetric:
317
252
  lddt_list = self.lddt_calculator.run(
318
253
  chain_1_masks=merged_chain_1_masks,
319
254
  chain_2_masks=merged_chain_2_masks,
255
+ atom_mask=atom_mask,
320
256
  )
321
257
  return lddt_list
322
258
 
@@ -343,9 +279,11 @@ class MetricResult:
343
279
  interface: dict[tuple[str, str], dict[str, Any]]
344
280
 
345
281
  # [ref_chain_id: {metric: value}]
346
- pb_valid: dict[str, dict[str, Any]] | None = None
282
+ pb_valid: Optional[dict[str, dict[str, Any]]] = None
347
283
 
348
- ori_model_chain_ids: list[str] | None = None
284
+ ori_model_chain_ids: Optional[list[str]] = None
285
+
286
+ update_data: Optional[dict[str, Any]] = None
349
287
 
350
288
  @staticmethod
351
289
  def _get_chain_info(ref_struct: Structure) -> dict[str, dict[str, str]]:
@@ -420,26 +358,31 @@ class MetricResult:
420
358
  chains: list[str],
421
359
  interfaces: list[tuple[str, str]],
422
360
  chain_interface_lddt: list[float],
361
+ metric_name: str = "lddt",
423
362
  ) -> tuple[dict[str, dict[str, float]], dict[tuple[str, str], dict[str, float]]]:
424
363
  chain_lddt_dict = {}
425
364
  interface_lddt_dict = {}
426
365
  num_chains = len(chains)
427
366
  for idx, chain_id in enumerate(chains):
428
- chain_lddt_dict[chain_id] = {"lddt": chain_interface_lddt[idx]}
367
+ lddt_value = chain_interface_lddt[idx]
368
+ if np.isnan(lddt_value):
369
+ continue
370
+ chain_lddt_dict[chain_id] = {metric_name: lddt_value}
429
371
 
430
372
  for idx, interface in enumerate(interfaces):
431
373
  sorted_interface = tuple(
432
374
  sorted(interface)
433
375
  ) # Sort chains to ensure consistent order
434
- interface_lddt_dict[sorted_interface] = {
435
- "lddt": chain_interface_lddt[idx + num_chains]
436
- }
376
+ lddt_value = chain_interface_lddt[idx + num_chains]
377
+ if np.isnan(lddt_value):
378
+ continue
379
+ interface_lddt_dict[sorted_interface] = {metric_name: lddt_value}
437
380
  return chain_lddt_dict, interface_lddt_dict
438
381
 
439
382
  @staticmethod
440
383
  def _post_process_dockq(
441
384
  dockq_result_dict: dict[str, Any],
442
- ) -> dict[str, float | dict[str, float]]:
385
+ ) -> dict[str, Union[float, dict[str, float]]]:
443
386
  polymer_dockq_metrics = {"F1", "iRMSD", "LRMSD", "fnat", "nat_correct",
444
387
  "nat_total", "fnonnat", "nonnat_count", "model_total",
445
388
  "clashes", "len1", "len2", "class1", "class2", "is_het",
@@ -475,8 +418,8 @@ class MetricResult:
475
418
 
476
419
  @staticmethod
477
420
  def _post_process_pb_valid(
478
- pb_valid_result_df: pd.DataFrame | None,
479
- ) -> dict[str, dict[str, Any]] | None:
421
+ pb_valid_result_df: Optional[pd.DataFrame],
422
+ ) -> Optional[dict[str, dict[str, Any]]]:
480
423
  if pb_valid_result_df is None:
481
424
  return
482
425
 
@@ -505,14 +448,129 @@ class MetricResult:
505
448
  else:
506
449
  tar_dict[key] = value
507
450
 
451
+ @staticmethod
452
+ def _calc_stereochecks_summary(
453
+ atom_mask: np.ndarray,
454
+ clash_df: pd.DataFrame,
455
+ bad_bond_df: pd.DataFrame,
456
+ bad_angle_df: pd.DataFrame,
457
+ ) -> dict[str, int]:
458
+ """
459
+ ggregate stereochemistry violations within an atom subset.
460
+
461
+ - `clash_atoms`: number of unique atoms involved in clashes (within subset)
462
+ - `bad_bonds`: number of bad bonds (within subset)
463
+ - `bad_angles`: number of bad angles (within subset)
464
+
465
+ The `idx*` columns in DataFrames are indices into the mapped atom arrays.
466
+ """
467
+
468
+ atom_mask = np.asarray(atom_mask, dtype=bool)
469
+
470
+ clash_atoms = 0
471
+ if clash_df is not None and (not clash_df.empty):
472
+ idx1 = clash_df["idx1"].to_numpy(dtype=np.int64, copy=False)
473
+ idx2 = clash_df["idx2"].to_numpy(dtype=np.int64, copy=False)
474
+ row_mask = atom_mask[idx1] & atom_mask[idx2]
475
+ if np.any(row_mask):
476
+ clash_atoms = int(
477
+ np.unique(np.concatenate([idx1[row_mask], idx2[row_mask]])).size
478
+ )
479
+
480
+ bond_cnt = 0
481
+ if bad_bond_df is not None and (not bad_bond_df.empty):
482
+ idx1 = bad_bond_df["idx1"].to_numpy(dtype=np.int64, copy=False)
483
+ idx2 = bad_bond_df["idx2"].to_numpy(dtype=np.int64, copy=False)
484
+ bond_cnt = int(np.sum(atom_mask[idx1] & atom_mask[idx2]))
485
+
486
+ angle_cnt = 0
487
+ if bad_angle_df is not None and (not bad_angle_df.empty):
488
+ idx_a = bad_angle_df["idx_a"].to_numpy(dtype=np.int64, copy=False)
489
+ idx_b = bad_angle_df["idx_b"].to_numpy(dtype=np.int64, copy=False)
490
+ idx_c = bad_angle_df["idx_c"].to_numpy(dtype=np.int64, copy=False)
491
+ angle_cnt = int(
492
+ np.sum(atom_mask[idx_a] & atom_mask[idx_b] & atom_mask[idx_c])
493
+ )
494
+
495
+ return {
496
+ "clash_atoms": clash_atoms,
497
+ "bad_bonds": bond_cnt,
498
+ "bad_angles": angle_cnt,
499
+ }
500
+
501
+ @classmethod
502
+ def _maybe_add_lddt_stereochecks_summaries(
503
+ cls,
504
+ *,
505
+ lddt_config: ConfigDict,
506
+ lddt_calculator: LDDT,
507
+ ref_struct: Structure,
508
+ chains: list[str],
509
+ interfaces: list[tuple[str, str]],
510
+ complex_result_dict: dict[str, Any],
511
+ chain_result_dict: dict[str, dict[str, Any]],
512
+ interface_result_dict: dict[tuple[str, str], dict[str, Any]],
513
+ ) -> None:
514
+ """Attach stereochemistry violation summaries to output dicts.
515
+
516
+ Only active when `metric.lddt.stereochecks=True` and the underlying
517
+ stereochemistry checker produced violation tables.
518
+ """
519
+
520
+ if not lddt_config.stereochecks:
521
+ return
522
+
523
+ stereo_violation_dfs = getattr(lddt_calculator, "stereo_violation_dfs", None)
524
+ if stereo_violation_dfs is None:
525
+ return
526
+
527
+ clash_df, bad_bond_df, bad_angle_df = stereo_violation_dfs
528
+ n_atoms = len(ref_struct.atom_array)
529
+
530
+ # Complex-level summary
531
+ complex_result_dict["stereochecks"] = cls._calc_stereochecks_summary(
532
+ atom_mask=np.ones(n_atoms, dtype=bool),
533
+ clash_df=clash_df,
534
+ bad_bond_df=bad_bond_df,
535
+ bad_angle_df=bad_angle_df,
536
+ )
537
+
538
+ # Chain-level summary (keyed by reference chain IDs)
539
+ for chain_id in chains:
540
+ chain_atom_mask = ref_struct.uni_chain_id == chain_id
541
+ chain_result_dict.setdefault(chain_id, {})[
542
+ "stereochecks"
543
+ ] = cls._calc_stereochecks_summary(
544
+ atom_mask=chain_atom_mask,
545
+ clash_df=clash_df,
546
+ bad_bond_df=bad_bond_df,
547
+ bad_angle_df=bad_angle_df,
548
+ )
549
+
550
+ # Interface-level summary (keyed by sorted(reference chain IDs))
551
+ for chain_1, chain_2 in interfaces:
552
+ interface_key = tuple(sorted((chain_1, chain_2)))
553
+ interface_atom_mask = (ref_struct.uni_chain_id == chain_1) | (
554
+ ref_struct.uni_chain_id == chain_2
555
+ )
556
+ interface_result_dict.setdefault(interface_key, {})[
557
+ "stereochecks"
558
+ ] = cls._calc_stereochecks_summary(
559
+ atom_mask=interface_atom_mask,
560
+ clash_df=clash_df,
561
+ bad_bond_df=bad_bond_df,
562
+ bad_angle_df=bad_angle_df,
563
+ )
564
+
508
565
  @classmethod
509
566
  def from_struct(
510
567
  cls,
511
568
  ref_struct: Structure,
512
569
  model_struct: Structure,
513
- ori_model_chain_ids: list[str] | None = None,
514
- interested_lig_label_asym_id: str | list[str] | None = None,
570
+ ori_model_chain_ids: Optional[list[str]] = None,
571
+ interested_lig_label_asym_id: Optional[Union[str, list[str]]] = None,
515
572
  metric_config: ConfigDict = RUN_CONFIG.metric,
573
+ update_data: Optional[dict[str, Any]] = None,
516
574
  ) -> "MetricResult":
517
575
  """
518
576
  Create a MetricResult instance from given structures and features.
@@ -525,6 +583,8 @@ class MetricResult:
525
583
  specifying the ligand label asym IDs of interest.
526
584
  metric_config (dict[str, Any]): A dictionary containing configuration for
527
585
  metrics. Defaults to RUN_CONFIG.metric.
586
+ update_data (dict[str, Any] | None): A dictionary containing additional data to update.
587
+ Defaults to None.
528
588
 
529
589
  Returns:
530
590
  MetricResult: An instance of MetricResult containing the calculated metrics.
@@ -555,16 +615,6 @@ class MetricResult:
555
615
  meta_info_dict["ref_to_model_chain_mapping"] = chain_map
556
616
  meta_info_dict["ref_chain_info"] = cls._get_chain_info(ref_struct)
557
617
 
558
- # Calculate clashes
559
- if metric_config.calc_clashes:
560
- clashes = check_clashes_by_vdw(
561
- model_struct.atom_array,
562
- vdw_scale_factor=metric_config.clashes.vdw_scale_factor,
563
- )
564
- complex_result_dict["clashes"] = len(
565
- {x for a, b in clashes for x in (a, b)}
566
- )
567
-
568
618
  # Calculate RMSD (if ligand and pocket specified in ref_features)
569
619
  if metric_config.calc_rmsd and interested_lig_label_asym_id:
570
620
  rmsd_metrics = RMSDMetrics(
@@ -590,8 +640,21 @@ class MetricResult:
590
640
  model_struct=model_struct,
591
641
  lddt_config=metric_config.lddt,
592
642
  )
643
+
644
+ cls._maybe_add_lddt_stereochecks_summaries(
645
+ lddt_config=metric_config.lddt,
646
+ lddt_calculator=calc_lddt.lddt_calculator,
647
+ ref_struct=ref_struct,
648
+ chains=chains,
649
+ interfaces=interfaces,
650
+ complex_result_dict=complex_result_dict,
651
+ chain_result_dict=chain_result_dict,
652
+ interface_result_dict=interface_result_dict,
653
+ )
654
+
593
655
  complex_lddt = calc_lddt.get_complex_lddt()
594
- complex_result_dict["lddt"] = complex_lddt
656
+ if not np.isnan(complex_lddt):
657
+ complex_result_dict["lddt"] = complex_lddt
595
658
 
596
659
  chain_interface_lddt = calc_lddt.get_chain_interface_lddt(
597
660
  chains, interfaces
@@ -605,12 +668,34 @@ class MetricResult:
605
668
  cls._update_src_to_tar_dict(chain_lddt_dict, chain_result_dict)
606
669
  cls._update_src_to_tar_dict(interface_lddt_dict, interface_result_dict)
607
670
 
671
+ if metric_config.lddt.calc_backbone_lddt:
672
+ backbone_mask = ref_struct.get_backbone_atom_masks(only_rep_atom=True)
673
+ complex_bb_lddt = calc_lddt.get_complex_lddt(atom_mask=backbone_mask)
674
+ if not np.isnan(complex_bb_lddt):
675
+ complex_result_dict["bb_lddt"] = complex_bb_lddt
676
+
677
+ # It reuses the chains and interfaces from the previous step
678
+ chain_interface_lddt = calc_lddt.get_chain_interface_lddt(
679
+ chains, interfaces, atom_mask=backbone_mask
680
+ )
681
+ (
682
+ chain_bb_lddt_dict,
683
+ interface_bb_lddt_dict,
684
+ ) = cls._post_process_chain_interface_lddt(
685
+ chains, interfaces, chain_interface_lddt, metric_name="bb_lddt"
686
+ )
687
+ cls._update_src_to_tar_dict(chain_bb_lddt_dict, chain_result_dict)
688
+ cls._update_src_to_tar_dict(
689
+ interface_bb_lddt_dict, interface_result_dict
690
+ )
691
+
608
692
  # Calculate DockQ
609
693
  if metric_config.calc_dockq:
610
694
  dockq_result_dict = compute_dockq(
611
695
  ref_struct=ref_struct,
612
696
  model_struct=model_struct,
613
697
  ref_to_model_chain_map=chain_map,
698
+ exclude_hetatms=metric_config.dockq.exclude_hetatms,
614
699
  )
615
700
  interface_dockq_dict = cls._post_process_dockq(dockq_result_dict)
616
701
  cls._update_src_to_tar_dict(interface_dockq_dict, interface_result_dict)
@@ -635,6 +720,7 @@ class MetricResult:
635
720
  interface=interface_result_dict,
636
721
  pb_valid=chain_pb_valid_dict,
637
722
  ori_model_chain_ids=ori_model_chain_ids,
723
+ update_data=update_data,
638
724
  )
639
725
 
640
726
  def to_json_dict(self) -> dict[str, Any]:
@@ -663,7 +749,7 @@ class MetricResult:
663
749
  json_dict["ori_model_chain_ids"] = self.ori_model_chain_ids
664
750
  return json_dict
665
751
 
666
- def to_json(self, json_file: Path, update_data: dict | None = None):
752
+ def to_json(self, json_file: Path, update_data: Optional[dict] = None):
667
753
  """
668
754
  Convert the MetricResult instance to a JSON string.
669
755
 
@@ -677,5 +763,8 @@ class MetricResult:
677
763
  if update_data:
678
764
  json_dict.update(update_data)
679
765
 
766
+ if self.update_data is not None:
767
+ json_dict.update(self.update_data)
768
+
680
769
  with open(json_file, "w", encoding="utf-8") as f:
681
770
  json.dump(json_dict, f, indent=4, ensure_ascii=False)