boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
|
@@ -0,0 +1,164 @@
|
|
1
|
+
from dataclasses import replace
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from boltz.data import const
|
7
|
+
from boltz.data.crop.cropper import Cropper
|
8
|
+
from boltz.data.types import Tokenized
|
9
|
+
|
10
|
+
|
11
|
+
class AffinityCropper(Cropper):
|
12
|
+
"""Interpolate between contiguous and spatial crops."""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
neighborhood_size: int = 10,
|
17
|
+
max_tokens_protein: int = 200,
|
18
|
+
) -> None:
|
19
|
+
"""Initialize the cropper.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
neighborhood_size : int
|
24
|
+
Modulates the type of cropping to be performed.
|
25
|
+
Smaller neighborhoods result in more spatial
|
26
|
+
cropping. Larger neighborhoods result in more
|
27
|
+
continuous cropping.
|
28
|
+
|
29
|
+
"""
|
30
|
+
self.neighborhood_size = neighborhood_size
|
31
|
+
self.max_tokens_protein = max_tokens_protein
|
32
|
+
|
33
|
+
def crop(
|
34
|
+
self,
|
35
|
+
data: Tokenized,
|
36
|
+
max_tokens: int,
|
37
|
+
max_atoms: Optional[int] = None,
|
38
|
+
) -> Tokenized:
|
39
|
+
"""Crop the data to a maximum number of tokens.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
data : Tokenized
|
44
|
+
The tokenized data.
|
45
|
+
max_tokens : int
|
46
|
+
The maximum number of tokens to crop.
|
47
|
+
random : np.random.RandomState
|
48
|
+
The random state for reproducibility.
|
49
|
+
max_atoms : Optional[int]
|
50
|
+
The maximum number of atoms to consider.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
Tokenized
|
55
|
+
The cropped data.
|
56
|
+
|
57
|
+
"""
|
58
|
+
# Get token data
|
59
|
+
token_data = data.tokens
|
60
|
+
token_bonds = data.bonds
|
61
|
+
|
62
|
+
# Filter to resolved tokens
|
63
|
+
valid_tokens = token_data[token_data["resolved_mask"]]
|
64
|
+
|
65
|
+
# Check if we have any valid tokens
|
66
|
+
if not valid_tokens.size:
|
67
|
+
msg = "No valid tokens in structure"
|
68
|
+
raise ValueError(msg)
|
69
|
+
|
70
|
+
# compute minimum distance to ligand
|
71
|
+
ligand_coords = valid_tokens[valid_tokens["affinity_mask"]]["center_coords"]
|
72
|
+
dists = np.min(
|
73
|
+
np.sum(
|
74
|
+
(valid_tokens["center_coords"][:, None] - ligand_coords[None]) ** 2,
|
75
|
+
axis=-1,
|
76
|
+
)
|
77
|
+
** 0.5,
|
78
|
+
axis=1,
|
79
|
+
)
|
80
|
+
|
81
|
+
indices = np.argsort(dists)
|
82
|
+
|
83
|
+
# Select cropped indices
|
84
|
+
cropped: set[int] = set()
|
85
|
+
total_atoms = 0
|
86
|
+
|
87
|
+
# protein tokens
|
88
|
+
cropped_protein: set[int] = set()
|
89
|
+
ligand_ids = set(
|
90
|
+
valid_tokens[
|
91
|
+
valid_tokens["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
92
|
+
]["token_idx"]
|
93
|
+
)
|
94
|
+
|
95
|
+
for idx in indices:
|
96
|
+
# Get the token
|
97
|
+
token = valid_tokens[idx]
|
98
|
+
|
99
|
+
# Get all tokens from this chain
|
100
|
+
chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
|
101
|
+
|
102
|
+
# Pick the whole chain if possible, otherwise select
|
103
|
+
# a contiguous subset centered at the query token
|
104
|
+
if len(chain_tokens) <= self.neighborhood_size:
|
105
|
+
new_tokens = chain_tokens
|
106
|
+
else:
|
107
|
+
# First limit to the maximum set of tokens, with the
|
108
|
+
# neighboorhood on both sides to handle edges. This
|
109
|
+
# is mostly for efficiency with the while loop below.
|
110
|
+
min_idx = token["res_idx"] - self.neighborhood_size
|
111
|
+
max_idx = token["res_idx"] + self.neighborhood_size
|
112
|
+
|
113
|
+
max_token_set = chain_tokens
|
114
|
+
max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
|
115
|
+
max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
|
116
|
+
|
117
|
+
# Start by adding just the query token
|
118
|
+
new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
|
119
|
+
|
120
|
+
# Expand the neighborhood until we have enough tokens, one
|
121
|
+
# by one to handle some edge cases with non-standard chains.
|
122
|
+
# We switch to the res_idx instead of the token_idx to always
|
123
|
+
# include all tokens from modified residues or from ligands.
|
124
|
+
min_idx = max_idx = token["res_idx"]
|
125
|
+
while new_tokens.size < self.neighborhood_size:
|
126
|
+
min_idx = min_idx - 1
|
127
|
+
max_idx = max_idx + 1
|
128
|
+
new_tokens = max_token_set
|
129
|
+
new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
|
130
|
+
new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
|
131
|
+
|
132
|
+
# Compute new tokens and new atoms
|
133
|
+
new_indices = set(new_tokens["token_idx"]) - cropped
|
134
|
+
new_tokens = token_data[list(new_indices)]
|
135
|
+
new_atoms = np.sum(new_tokens["atom_num"])
|
136
|
+
|
137
|
+
# Stop if we exceed the max number of tokens or atoms
|
138
|
+
if (
|
139
|
+
(len(new_indices) > (max_tokens - len(cropped)))
|
140
|
+
or ((max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms))
|
141
|
+
or (
|
142
|
+
len(cropped_protein | new_indices - ligand_ids)
|
143
|
+
> self.max_tokens_protein
|
144
|
+
)
|
145
|
+
):
|
146
|
+
break
|
147
|
+
|
148
|
+
# Add new indices
|
149
|
+
cropped.update(new_indices)
|
150
|
+
total_atoms += new_atoms
|
151
|
+
|
152
|
+
# Add protein indices
|
153
|
+
cropped_protein.update(new_indices - ligand_ids)
|
154
|
+
|
155
|
+
# Get the cropped tokens sorted by index
|
156
|
+
token_data = token_data[sorted(cropped)]
|
157
|
+
|
158
|
+
# Only keep bonds within the cropped tokens
|
159
|
+
indices = token_data["token_idx"]
|
160
|
+
token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
|
161
|
+
token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
|
162
|
+
|
163
|
+
# Return the cropped tokens
|
164
|
+
return replace(data, tokens=token_data, bonds=token_bonds)
|
boltz/data/crop/boltz.py
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
from dataclasses import replace
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from scipy.spatial.distance import cdist
|
6
|
+
|
7
|
+
from boltz.data import const
|
8
|
+
from boltz.data.crop.cropper import Cropper
|
9
|
+
from boltz.data.types import Tokenized
|
10
|
+
|
11
|
+
|
12
|
+
def pick_random_token(
|
13
|
+
tokens: np.ndarray,
|
14
|
+
random: np.random.RandomState,
|
15
|
+
) -> np.ndarray:
|
16
|
+
"""Pick a random token from the data.
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
tokens : np.ndarray
|
21
|
+
The token data.
|
22
|
+
random : np.ndarray
|
23
|
+
The random state for reproducibility.
|
24
|
+
|
25
|
+
Returns
|
26
|
+
-------
|
27
|
+
np.ndarray
|
28
|
+
The selected token.
|
29
|
+
|
30
|
+
"""
|
31
|
+
return tokens[random.randint(len(tokens))]
|
32
|
+
|
33
|
+
|
34
|
+
def pick_chain_token(
|
35
|
+
tokens: np.ndarray,
|
36
|
+
chain_id: int,
|
37
|
+
random: np.random.RandomState,
|
38
|
+
) -> np.ndarray:
|
39
|
+
"""Pick a random token from a chain.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
tokens : np.ndarray
|
44
|
+
The token data.
|
45
|
+
chain_id : int
|
46
|
+
The chain ID.
|
47
|
+
random : np.ndarray
|
48
|
+
The random state for reproducibility.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
np.ndarray
|
53
|
+
The selected token.
|
54
|
+
|
55
|
+
"""
|
56
|
+
# Filter to chain
|
57
|
+
chain_tokens = tokens[tokens["asym_id"] == chain_id]
|
58
|
+
|
59
|
+
# Pick from chain, fallback to all tokens
|
60
|
+
if chain_tokens.size:
|
61
|
+
query = pick_random_token(chain_tokens, random)
|
62
|
+
else:
|
63
|
+
query = pick_random_token(tokens, random)
|
64
|
+
|
65
|
+
return query
|
66
|
+
|
67
|
+
|
68
|
+
def pick_interface_token(
|
69
|
+
tokens: np.ndarray,
|
70
|
+
interface: np.ndarray,
|
71
|
+
random: np.random.RandomState,
|
72
|
+
) -> np.ndarray:
|
73
|
+
"""Pick a random token from an interface.
|
74
|
+
|
75
|
+
Parameters
|
76
|
+
----------
|
77
|
+
tokens : np.ndarray
|
78
|
+
The token data.
|
79
|
+
interface : int
|
80
|
+
The interface ID.
|
81
|
+
random : np.ndarray
|
82
|
+
The random state for reproducibility.
|
83
|
+
|
84
|
+
Returns
|
85
|
+
-------
|
86
|
+
np.ndarray
|
87
|
+
The selected token.
|
88
|
+
|
89
|
+
"""
|
90
|
+
# Sample random interface
|
91
|
+
chain_1 = int(interface["chain_1"])
|
92
|
+
chain_2 = int(interface["chain_2"])
|
93
|
+
|
94
|
+
tokens_1 = tokens[tokens["asym_id"] == chain_1]
|
95
|
+
tokens_2 = tokens[tokens["asym_id"] == chain_2]
|
96
|
+
|
97
|
+
# If no interface, pick from the chains
|
98
|
+
if tokens_1.size and (not tokens_2.size):
|
99
|
+
query = pick_random_token(tokens_1, random)
|
100
|
+
elif tokens_2.size and (not tokens_1.size):
|
101
|
+
query = pick_random_token(tokens_2, random)
|
102
|
+
elif (not tokens_1.size) and (not tokens_2.size):
|
103
|
+
query = pick_random_token(tokens, random)
|
104
|
+
else:
|
105
|
+
# If we have tokens, compute distances
|
106
|
+
tokens_1_coords = tokens_1["center_coords"]
|
107
|
+
tokens_2_coords = tokens_2["center_coords"]
|
108
|
+
|
109
|
+
dists = cdist(tokens_1_coords, tokens_2_coords)
|
110
|
+
cuttoff = dists < const.interface_cutoff
|
111
|
+
|
112
|
+
# In rare cases, the interface cuttoff is slightly
|
113
|
+
# too small, then we slightly expand it if it happens
|
114
|
+
if not np.any(cuttoff):
|
115
|
+
cuttoff = dists < (const.interface_cutoff + 5.0)
|
116
|
+
|
117
|
+
tokens_1 = tokens_1[np.any(cuttoff, axis=1)]
|
118
|
+
tokens_2 = tokens_2[np.any(cuttoff, axis=0)]
|
119
|
+
|
120
|
+
# Select random token
|
121
|
+
candidates = np.concatenate([tokens_1, tokens_2])
|
122
|
+
query = pick_random_token(candidates, random)
|
123
|
+
|
124
|
+
return query
|
125
|
+
|
126
|
+
|
127
|
+
class BoltzCropper(Cropper):
|
128
|
+
"""Interpolate between contiguous and spatial crops."""
|
129
|
+
|
130
|
+
def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None:
|
131
|
+
"""Initialize the cropper.
|
132
|
+
|
133
|
+
Modulates the type of cropping to be performed.
|
134
|
+
Smaller neighborhoods result in more spatial
|
135
|
+
cropping. Larger neighborhoods result in more
|
136
|
+
continuous cropping. A mix can be achieved by
|
137
|
+
providing a range over which to sample.
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
----------
|
141
|
+
min_neighborhood : int
|
142
|
+
The minimum neighborhood size, by default 0.
|
143
|
+
max_neighborhood : int
|
144
|
+
The maximum neighborhood size, by default 40.
|
145
|
+
|
146
|
+
"""
|
147
|
+
sizes = list(range(min_neighborhood, max_neighborhood + 1, 2))
|
148
|
+
self.neighborhood_sizes = sizes
|
149
|
+
|
150
|
+
def crop( # noqa: PLR0915
|
151
|
+
self,
|
152
|
+
data: Tokenized,
|
153
|
+
max_tokens: int,
|
154
|
+
random: np.random.RandomState,
|
155
|
+
max_atoms: Optional[int] = None,
|
156
|
+
chain_id: Optional[int] = None,
|
157
|
+
interface_id: Optional[int] = None,
|
158
|
+
) -> Tokenized:
|
159
|
+
"""Crop the data to a maximum number of tokens.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
data : Tokenized
|
164
|
+
The tokenized data.
|
165
|
+
max_tokens : int
|
166
|
+
The maximum number of tokens to crop.
|
167
|
+
random : np.random.RandomState
|
168
|
+
The random state for reproducibility.
|
169
|
+
max_atoms : int, optional
|
170
|
+
The maximum number of atoms to consider.
|
171
|
+
chain_id : int, optional
|
172
|
+
The chain ID to crop.
|
173
|
+
interface_id : int, optional
|
174
|
+
The interface ID to crop.
|
175
|
+
|
176
|
+
Returns
|
177
|
+
-------
|
178
|
+
Tokenized
|
179
|
+
The cropped data.
|
180
|
+
|
181
|
+
"""
|
182
|
+
# Check inputs
|
183
|
+
if chain_id is not None and interface_id is not None:
|
184
|
+
msg = "Only one of chain_id or interface_id can be provided."
|
185
|
+
raise ValueError(msg)
|
186
|
+
|
187
|
+
# Randomly select a neighborhood size
|
188
|
+
neighborhood_size = random.choice(self.neighborhood_sizes)
|
189
|
+
|
190
|
+
# Get token data
|
191
|
+
token_data = data.tokens
|
192
|
+
token_bonds = data.bonds
|
193
|
+
mask = data.structure.mask
|
194
|
+
chains = data.structure.chains
|
195
|
+
interfaces = data.structure.interfaces
|
196
|
+
|
197
|
+
# Filter to valid chains
|
198
|
+
valid_chains = chains[mask]
|
199
|
+
|
200
|
+
# Filter to valid interfaces
|
201
|
+
valid_interfaces = interfaces
|
202
|
+
valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]]
|
203
|
+
valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]]
|
204
|
+
|
205
|
+
# Filter to resolved tokens
|
206
|
+
valid_tokens = token_data[token_data["resolved_mask"]]
|
207
|
+
|
208
|
+
# Check if we have any valid tokens
|
209
|
+
if not valid_tokens.size:
|
210
|
+
msg = "No valid tokens in structure"
|
211
|
+
raise ValueError(msg)
|
212
|
+
|
213
|
+
# Pick a random token, chain, or interface
|
214
|
+
if chain_id is not None:
|
215
|
+
query = pick_chain_token(valid_tokens, chain_id, random)
|
216
|
+
elif interface_id is not None:
|
217
|
+
interface = interfaces[interface_id]
|
218
|
+
query = pick_interface_token(valid_tokens, interface, random)
|
219
|
+
elif valid_interfaces.size:
|
220
|
+
idx = random.randint(len(valid_interfaces))
|
221
|
+
interface = valid_interfaces[idx]
|
222
|
+
query = pick_interface_token(valid_tokens, interface, random)
|
223
|
+
else:
|
224
|
+
idx = random.randint(len(valid_chains))
|
225
|
+
chain_id = valid_chains[idx]["asym_id"]
|
226
|
+
query = pick_chain_token(valid_tokens, chain_id, random)
|
227
|
+
|
228
|
+
# Sort all tokens by distance to query_coords
|
229
|
+
dists = valid_tokens["center_coords"] - query["center_coords"]
|
230
|
+
indices = np.argsort(np.linalg.norm(dists, axis=1))
|
231
|
+
|
232
|
+
# Select cropped indices
|
233
|
+
cropped: set[int] = set()
|
234
|
+
total_atoms = 0
|
235
|
+
for idx in indices:
|
236
|
+
# Get the token
|
237
|
+
token = valid_tokens[idx]
|
238
|
+
|
239
|
+
# Get all tokens from this chain
|
240
|
+
chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
|
241
|
+
|
242
|
+
# Pick the whole chain if possible, otherwise select
|
243
|
+
# a contiguous subset centered at the query token
|
244
|
+
if len(chain_tokens) <= neighborhood_size:
|
245
|
+
new_tokens = chain_tokens
|
246
|
+
else:
|
247
|
+
# First limit to the maximum set of tokens, with the
|
248
|
+
# neighboorhood on both sides to handle edges. This
|
249
|
+
# is mostly for efficiency with the while loop below.
|
250
|
+
min_idx = token["res_idx"] - neighborhood_size
|
251
|
+
max_idx = token["res_idx"] + neighborhood_size
|
252
|
+
|
253
|
+
max_token_set = chain_tokens
|
254
|
+
max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
|
255
|
+
max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
|
256
|
+
|
257
|
+
# Start by adding just the query token
|
258
|
+
new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
|
259
|
+
|
260
|
+
# Expand the neighborhood until we have enough tokens, one
|
261
|
+
# by one to handle some edge cases with non-standard chains.
|
262
|
+
# We switch to the res_idx instead of the token_idx to always
|
263
|
+
# include all tokens from modified residues or from ligands.
|
264
|
+
min_idx = max_idx = token["res_idx"]
|
265
|
+
while new_tokens.size < neighborhood_size:
|
266
|
+
min_idx = min_idx - 1
|
267
|
+
max_idx = max_idx + 1
|
268
|
+
new_tokens = max_token_set
|
269
|
+
new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
|
270
|
+
new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
|
271
|
+
|
272
|
+
# Compute new tokens and new atoms
|
273
|
+
new_indices = set(new_tokens["token_idx"]) - cropped
|
274
|
+
new_tokens = token_data[list(new_indices)]
|
275
|
+
new_atoms = np.sum(new_tokens["atom_num"])
|
276
|
+
|
277
|
+
# Stop if we exceed the max number of tokens or atoms
|
278
|
+
if (len(new_indices) > (max_tokens - len(cropped))) or (
|
279
|
+
(max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms)
|
280
|
+
):
|
281
|
+
break
|
282
|
+
|
283
|
+
# Add new indices
|
284
|
+
cropped.update(new_indices)
|
285
|
+
total_atoms += new_atoms
|
286
|
+
|
287
|
+
# Get the cropped tokens sorted by index
|
288
|
+
token_data = token_data[sorted(cropped)]
|
289
|
+
|
290
|
+
# Only keep bonds within the cropped tokens
|
291
|
+
indices = token_data["token_idx"]
|
292
|
+
token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
|
293
|
+
token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
|
294
|
+
|
295
|
+
# Return the cropped tokens
|
296
|
+
return replace(data, tokens=token_data, bonds=token_bonds)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from boltz.data.types import Tokenized
|
7
|
+
|
8
|
+
|
9
|
+
class Cropper(ABC):
|
10
|
+
"""Abstract base class for cropper."""
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def crop(
|
14
|
+
self,
|
15
|
+
data: Tokenized,
|
16
|
+
max_tokens: int,
|
17
|
+
random: np.random.RandomState,
|
18
|
+
max_atoms: Optional[int] = None,
|
19
|
+
chain_id: Optional[int] = None,
|
20
|
+
interface_id: Optional[int] = None,
|
21
|
+
) -> Tokenized:
|
22
|
+
"""Crop the data to a maximum number of tokens.
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
data : Tokenized
|
27
|
+
The tokenized data.
|
28
|
+
max_tokens : int
|
29
|
+
The maximum number of tokens to crop.
|
30
|
+
random : np.random.RandomState
|
31
|
+
The random state for reproducibility.
|
32
|
+
max_atoms : Optional[int]
|
33
|
+
The maximum number of atoms to consider.
|
34
|
+
chain_id : Optional[int]
|
35
|
+
The chain ID to crop.
|
36
|
+
interface_id : Optional[int]
|
37
|
+
The interface ID to crop.
|
38
|
+
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
Tokenized
|
42
|
+
The cropped data.
|
43
|
+
|
44
|
+
"""
|
45
|
+
raise NotImplementedError
|
File without changes
|