boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,37 @@
|
|
1
|
+
from boltz.data.types import Record
|
2
|
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
3
|
+
|
4
|
+
|
5
|
+
class MaxResiduesFilter(DynamicFilter):
|
6
|
+
"""A filter that filters structures based on their size."""
|
7
|
+
|
8
|
+
def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
|
9
|
+
"""Initialize the filter.
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
min_chains : int
|
14
|
+
The minimum number of chains allowed.
|
15
|
+
max_chains : int
|
16
|
+
The maximum number of chains allowed.
|
17
|
+
|
18
|
+
"""
|
19
|
+
self.min_residues = min_residues
|
20
|
+
self.max_residues = max_residues
|
21
|
+
|
22
|
+
def filter(self, record: Record) -> bool:
|
23
|
+
"""Filter structures based on their resolution.
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
record : Record
|
28
|
+
The record to filter.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
bool
|
33
|
+
Whether the record should be filtered.
|
34
|
+
|
35
|
+
"""
|
36
|
+
num_residues = sum(chain.num_residues for chain in record.chains)
|
37
|
+
return num_residues <= self.max_residues and num_residues >= self.min_residues
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from boltz.data.types import Record
|
2
|
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
3
|
+
|
4
|
+
|
5
|
+
class ResolutionFilter(DynamicFilter):
|
6
|
+
"""A filter that filters complexes based on their resolution."""
|
7
|
+
|
8
|
+
def __init__(self, resolution: float = 9.0) -> None:
|
9
|
+
"""Initialize the filter.
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
resolution : float, optional
|
14
|
+
The maximum allowed resolution.
|
15
|
+
|
16
|
+
"""
|
17
|
+
self.resolution = resolution
|
18
|
+
|
19
|
+
def filter(self, record: Record) -> bool:
|
20
|
+
"""Filter complexes based on their resolution.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
record : Record
|
25
|
+
The record to filter.
|
26
|
+
|
27
|
+
Returns
|
28
|
+
-------
|
29
|
+
bool
|
30
|
+
Whether the record should be filtered.
|
31
|
+
|
32
|
+
"""
|
33
|
+
structure = record.structure
|
34
|
+
return structure.resolution <= self.resolution
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from boltz.data.types import Record
|
2
|
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
3
|
+
|
4
|
+
|
5
|
+
class SizeFilter(DynamicFilter):
|
6
|
+
"""A filter that filters structures based on their size."""
|
7
|
+
|
8
|
+
def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
|
9
|
+
"""Initialize the filter.
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
min_chains : int
|
14
|
+
The minimum number of chains allowed.
|
15
|
+
max_chains : int
|
16
|
+
The maximum number of chains allowed.
|
17
|
+
|
18
|
+
"""
|
19
|
+
self.min_chains = min_chains
|
20
|
+
self.max_chains = max_chains
|
21
|
+
|
22
|
+
def filter(self, record: Record) -> bool:
|
23
|
+
"""Filter structures based on their resolution.
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
record : Record
|
28
|
+
The record to filter.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
bool
|
33
|
+
Whether the record should be filtered.
|
34
|
+
|
35
|
+
"""
|
36
|
+
num_chains = record.structure.num_chains
|
37
|
+
num_valid = sum(1 for chain in record.chains if chain.valid)
|
38
|
+
return num_chains <= self.max_chains and num_valid >= self.min_chains
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
from boltz.data.types import Record
|
4
|
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
5
|
+
|
6
|
+
|
7
|
+
class SubsetFilter(DynamicFilter):
|
8
|
+
"""Filter a data record based on a subset of the data."""
|
9
|
+
|
10
|
+
def __init__(self, subset: str, reverse: bool = False) -> None:
|
11
|
+
"""Initialize the filter.
|
12
|
+
|
13
|
+
Parameters
|
14
|
+
----------
|
15
|
+
subset : str
|
16
|
+
The subset of data to consider, one per line.
|
17
|
+
|
18
|
+
"""
|
19
|
+
with Path(subset).open("r") as f:
|
20
|
+
subset = f.read().splitlines()
|
21
|
+
|
22
|
+
self.subset = {s.lower() for s in subset}
|
23
|
+
self.reverse = reverse
|
24
|
+
|
25
|
+
def filter(self, record: Record) -> bool:
|
26
|
+
"""Filter a data record.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
record : Record
|
31
|
+
The object to consider filtering in / out.
|
32
|
+
|
33
|
+
Returns
|
34
|
+
-------
|
35
|
+
bool
|
36
|
+
True if the data passes the filter, False otherwise.
|
37
|
+
|
38
|
+
"""
|
39
|
+
if self.reverse:
|
40
|
+
return record.id.lower() not in self.subset
|
41
|
+
else: # noqa: RET505
|
42
|
+
return record.id.lower() in self.subset
|
File without changes
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from boltz.data.types import Structure
|
6
|
+
|
7
|
+
|
8
|
+
class StaticFilter(ABC):
|
9
|
+
"""Base class for structure filters."""
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def filter(self, structure: Structure) -> np.ndarray:
|
13
|
+
"""Filter chains in a structure.
|
14
|
+
|
15
|
+
Parameters
|
16
|
+
----------
|
17
|
+
structure : Structure
|
18
|
+
The structure to filter chains from.
|
19
|
+
|
20
|
+
Returns
|
21
|
+
-------
|
22
|
+
np.ndarray
|
23
|
+
The chains to keep, as a boolean mask.
|
24
|
+
|
25
|
+
"""
|
26
|
+
raise NotImplementedError
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from boltz.data import const
|
4
|
+
from boltz.data.filter.static.filter import StaticFilter
|
5
|
+
from boltz.data.types import Structure
|
6
|
+
|
7
|
+
|
8
|
+
class ExcludedLigands(StaticFilter):
|
9
|
+
"""Filter excluded ligands."""
|
10
|
+
|
11
|
+
def filter(self, structure: Structure) -> np.ndarray:
|
12
|
+
"""Filter excluded ligands.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
structure : Structure
|
17
|
+
The structure to filter chains from.
|
18
|
+
|
19
|
+
Returns
|
20
|
+
-------
|
21
|
+
np.ndarray
|
22
|
+
The chains to keep, as a boolean mask.
|
23
|
+
|
24
|
+
"""
|
25
|
+
valid = np.ones(len(structure.chains), dtype=bool)
|
26
|
+
|
27
|
+
for i, chain in enumerate(structure.chains):
|
28
|
+
if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
|
29
|
+
continue
|
30
|
+
|
31
|
+
res_start = chain["res_idx"]
|
32
|
+
res_end = res_start + chain["res_num"]
|
33
|
+
residues = structure.residues[res_start:res_end]
|
34
|
+
if any(res["name"] in const.ligand_exclusion for res in residues):
|
35
|
+
valid[i] = 0
|
36
|
+
|
37
|
+
return valid
|
@@ -0,0 +1,299 @@
|
|
1
|
+
import itertools
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from sklearn.neighbors import KDTree
|
6
|
+
|
7
|
+
from boltz.data import const
|
8
|
+
from boltz.data.filter.static.filter import StaticFilter
|
9
|
+
from boltz.data.types import Structure
|
10
|
+
|
11
|
+
|
12
|
+
class MinimumLengthFilter(StaticFilter):
|
13
|
+
"""Filter polymers based on their length.
|
14
|
+
|
15
|
+
We use the number of resolved residues when considering
|
16
|
+
the minimum, and the sequence length for the maximum.
|
17
|
+
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, min_len: int = 4, max_len: int = 5000) -> None:
|
21
|
+
"""Initialize the filter.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
min_len : float, optional
|
26
|
+
The minimum allowed length.
|
27
|
+
max_len : float, optional
|
28
|
+
The maximum allowed length.
|
29
|
+
|
30
|
+
"""
|
31
|
+
self._min = min_len
|
32
|
+
self._max = max_len
|
33
|
+
|
34
|
+
def filter(self, structure: Structure) -> np.ndarray:
|
35
|
+
"""Filter a chains based on their length.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
structure : Structure
|
40
|
+
The structure to filter chains from.
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
np.ndarray
|
45
|
+
The chains to keep, as a boolean mask.
|
46
|
+
|
47
|
+
"""
|
48
|
+
valid = np.ones(len(structure.chains), dtype=bool)
|
49
|
+
|
50
|
+
for i, chain in enumerate(structure.chains):
|
51
|
+
if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
|
52
|
+
continue
|
53
|
+
|
54
|
+
res_start = chain["res_idx"]
|
55
|
+
res_end = res_start + chain["res_num"]
|
56
|
+
residues = structure.residues[res_start:res_end]
|
57
|
+
resolved = residues["is_present"].sum()
|
58
|
+
|
59
|
+
if (resolved < self._min) or (resolved > self._max):
|
60
|
+
valid[i] = 0
|
61
|
+
|
62
|
+
return valid
|
63
|
+
|
64
|
+
|
65
|
+
class UnknownFilter(StaticFilter):
|
66
|
+
"""Filter proteins with all unknown residues."""
|
67
|
+
|
68
|
+
def filter(self, structure: Structure) -> np.ndarray:
|
69
|
+
"""Filter proteins with all unknown residues.
|
70
|
+
|
71
|
+
Parameters
|
72
|
+
----------
|
73
|
+
structure : Structure
|
74
|
+
The structure to filter chains from.
|
75
|
+
|
76
|
+
Returns
|
77
|
+
-------
|
78
|
+
np.ndarray
|
79
|
+
The chains to keep, as a boolean mask.
|
80
|
+
|
81
|
+
"""
|
82
|
+
valid = np.ones(len(structure.chains), dtype=bool)
|
83
|
+
unk_toks = {
|
84
|
+
const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"],
|
85
|
+
const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"],
|
86
|
+
const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"],
|
87
|
+
}
|
88
|
+
|
89
|
+
for i, chain in enumerate(structure.chains):
|
90
|
+
if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
|
91
|
+
continue
|
92
|
+
|
93
|
+
res_start = chain["res_idx"]
|
94
|
+
res_end = res_start + chain["res_num"]
|
95
|
+
residues = structure.residues[res_start:res_end]
|
96
|
+
|
97
|
+
unk_id = unk_toks[chain["mol_type"]]
|
98
|
+
if np.all(residues["res_type"] == unk_id):
|
99
|
+
valid[i] = 0
|
100
|
+
|
101
|
+
return valid
|
102
|
+
|
103
|
+
|
104
|
+
class ConsecutiveCA(StaticFilter):
|
105
|
+
"""Filter proteins with consecutive CA atoms above a threshold."""
|
106
|
+
|
107
|
+
def __init__(self, max_dist: int = 10.0) -> None:
|
108
|
+
"""Initialize the filter.
|
109
|
+
|
110
|
+
Parameters
|
111
|
+
----------
|
112
|
+
max_dist : float, optional
|
113
|
+
The maximum allowed distance.
|
114
|
+
|
115
|
+
"""
|
116
|
+
self._max_dist = max_dist
|
117
|
+
|
118
|
+
def filter(self, structure: Structure) -> np.ndarray:
|
119
|
+
"""Filter protein if consecutive CA atoms above a threshold.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
structure : Structure
|
124
|
+
The structure to filter chains from.
|
125
|
+
|
126
|
+
Returns
|
127
|
+
-------
|
128
|
+
np.ndarray
|
129
|
+
The chains to keep, as a boolean mask.
|
130
|
+
|
131
|
+
"""
|
132
|
+
valid = np.ones(len(structure.chains), dtype=bool)
|
133
|
+
|
134
|
+
# Remove chain if consecutive CA atoms are above threshold
|
135
|
+
for i, chain in enumerate(structure.chains):
|
136
|
+
# Skip non-protein chains
|
137
|
+
if chain["mol_type"] != const.chain_type_ids["PROTEIN"]:
|
138
|
+
continue
|
139
|
+
|
140
|
+
# Get residues
|
141
|
+
res_start = chain["res_idx"]
|
142
|
+
res_end = res_start + chain["res_num"]
|
143
|
+
residues = structure.residues[res_start:res_end]
|
144
|
+
|
145
|
+
# Get c-alphas
|
146
|
+
ca_ids = residues["atom_center"]
|
147
|
+
ca_atoms = structure.atoms[ca_ids]
|
148
|
+
|
149
|
+
res_valid = residues["is_present"]
|
150
|
+
ca_valid = ca_atoms["is_present"] & res_valid
|
151
|
+
ca_coords = ca_atoms["coords"]
|
152
|
+
|
153
|
+
# Compute distances between consecutive atoms
|
154
|
+
dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1)
|
155
|
+
dist = dist > self._max_dist
|
156
|
+
dist = dist[ca_valid[1:] & ca_valid[:-1]]
|
157
|
+
|
158
|
+
# Remove the chain if any valid pair is above threshold
|
159
|
+
if np.any(dist):
|
160
|
+
valid[i] = 0
|
161
|
+
|
162
|
+
return valid
|
163
|
+
|
164
|
+
|
165
|
+
@dataclass(frozen=True)
|
166
|
+
class Clash:
|
167
|
+
"""A clash between two chains."""
|
168
|
+
|
169
|
+
chain: int
|
170
|
+
other: int
|
171
|
+
num_atoms: int
|
172
|
+
num_clashes: int
|
173
|
+
|
174
|
+
|
175
|
+
class ClashingChainsFilter(StaticFilter):
|
176
|
+
"""A filter that filters clashing chains.
|
177
|
+
|
178
|
+
Clashing chains are defined as those with >30% of atoms
|
179
|
+
within 1.7 Å of an atom in another chain. If two chains
|
180
|
+
are clashing with each other, the chain with the greater
|
181
|
+
percentage of clashing atoms will be removed. If the same
|
182
|
+
fraction of atoms are clashing, the chain with fewer total
|
183
|
+
atoms is removed. If the chains have the same number of
|
184
|
+
atoms, then the chain with the larger chain id is removed.
|
185
|
+
|
186
|
+
"""
|
187
|
+
|
188
|
+
def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None:
|
189
|
+
"""Initialize the filter.
|
190
|
+
|
191
|
+
Parameters
|
192
|
+
----------
|
193
|
+
dist : float, optional
|
194
|
+
The maximum distance for a clash.
|
195
|
+
freq : float, optional
|
196
|
+
The maximum allowed frequency of clashes.
|
197
|
+
|
198
|
+
"""
|
199
|
+
self._dist = dist
|
200
|
+
self._freq = freq
|
201
|
+
|
202
|
+
def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901
|
203
|
+
"""Filter out clashing chains.
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
structure : Structure
|
208
|
+
The structure to filter chains from.
|
209
|
+
|
210
|
+
Returns
|
211
|
+
-------
|
212
|
+
np.ndarray
|
213
|
+
The chains to keep, as a boolean mask.
|
214
|
+
|
215
|
+
"""
|
216
|
+
num_chains = len(structure.chains)
|
217
|
+
if num_chains < 2: # noqa: PLR2004
|
218
|
+
return np.ones(num_chains, dtype=bool)
|
219
|
+
|
220
|
+
# Get unique chain pairs
|
221
|
+
pairs = itertools.combinations(range(num_chains), 2)
|
222
|
+
|
223
|
+
# Compute clashes
|
224
|
+
clashes: list[Clash] = []
|
225
|
+
for i, j in pairs:
|
226
|
+
# Get the chains
|
227
|
+
c1 = structure.chains[i]
|
228
|
+
c2 = structure.chains[j]
|
229
|
+
|
230
|
+
# Get the atoms from each chain
|
231
|
+
c1_start = c1["atom_idx"]
|
232
|
+
c2_start = c2["atom_idx"]
|
233
|
+
c1_end = c1_start + c1["atom_num"]
|
234
|
+
c2_end = c2_start + c2["atom_num"]
|
235
|
+
|
236
|
+
atoms1 = structure.atoms[c1_start:c1_end]
|
237
|
+
atoms2 = structure.atoms[c2_start:c2_end]
|
238
|
+
atoms1 = atoms1[atoms1["is_present"]]
|
239
|
+
atoms2 = atoms2[atoms2["is_present"]]
|
240
|
+
|
241
|
+
# Skip if either chain has no atoms
|
242
|
+
if len(atoms1) == 0 or len(atoms2) == 0:
|
243
|
+
continue
|
244
|
+
|
245
|
+
# Compute the number of clashes
|
246
|
+
# Compute the distance matrix
|
247
|
+
tree = KDTree(atoms1["coords"], metric="euclidean")
|
248
|
+
query = tree.query_radius(atoms2["coords"], self._dist)
|
249
|
+
|
250
|
+
c2_clashes = sum(len(neighbors) > 0 for neighbors in query)
|
251
|
+
c1_clashes = len(set(itertools.chain.from_iterable(query)))
|
252
|
+
|
253
|
+
# Save results
|
254
|
+
if (c1_clashes / len(atoms1)) > self._freq:
|
255
|
+
clashes.append(Clash(i, j, len(atoms1), c1_clashes))
|
256
|
+
if (c2_clashes / len(atoms2)) > self._freq:
|
257
|
+
clashes.append(Clash(j, i, len(atoms2), c2_clashes))
|
258
|
+
|
259
|
+
# Compute indices to clash map
|
260
|
+
removed = set()
|
261
|
+
ids_to_clash = {(c.chain, c.other): c for c in clashes}
|
262
|
+
|
263
|
+
# Filter out chains according to ruleset
|
264
|
+
for clash in clashes:
|
265
|
+
# If either is already removed, skip
|
266
|
+
if clash.chain in removed or clash.other in removed:
|
267
|
+
continue
|
268
|
+
|
269
|
+
# Check if the two chains clash with each other
|
270
|
+
other_clash = ids_to_clash.get((clash.other, clash.chain))
|
271
|
+
if other_clash is not None:
|
272
|
+
# Remove the chain with the most clashes
|
273
|
+
clash1_freq = clash.num_clashes / clash.num_atoms
|
274
|
+
clash2_freq = other_clash.num_clashes / other_clash.num_atoms
|
275
|
+
if clash1_freq > clash2_freq:
|
276
|
+
removed.add(clash.chain)
|
277
|
+
elif clash1_freq < clash2_freq:
|
278
|
+
removed.add(clash.other)
|
279
|
+
|
280
|
+
# If same, remove the chain with fewer atoms
|
281
|
+
elif clash.num_atoms < other_clash.num_atoms:
|
282
|
+
removed.add(clash.chain)
|
283
|
+
elif clash.num_atoms > other_clash.num_atoms:
|
284
|
+
removed.add(clash.other)
|
285
|
+
|
286
|
+
# If same, remove the chain with the larger chain id
|
287
|
+
else:
|
288
|
+
removed.add(max(clash.chain, clash.other))
|
289
|
+
|
290
|
+
# Otherwise, just remove the chain directly
|
291
|
+
else:
|
292
|
+
removed.add(clash.chain)
|
293
|
+
|
294
|
+
# Remove the chains
|
295
|
+
valid = np.ones(len(structure.chains), dtype=bool)
|
296
|
+
for i in removed:
|
297
|
+
valid[i] = 0
|
298
|
+
|
299
|
+
return valid
|
File without changes
|