boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
boltz/data/parse/yaml.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import yaml
|
4
|
+
from rdkit.Chem.rdchem import Mol
|
5
|
+
|
6
|
+
from boltz.data.parse.schema import parse_boltz_schema
|
7
|
+
from boltz.data.types import Target
|
8
|
+
|
9
|
+
|
10
|
+
def parse_yaml(
|
11
|
+
path: Path,
|
12
|
+
ccd: dict[str, Mol],
|
13
|
+
mol_dir: Path,
|
14
|
+
boltz2: bool = False,
|
15
|
+
) -> Target:
|
16
|
+
"""Parse a Boltz input yaml / json.
|
17
|
+
|
18
|
+
The input file should be a yaml file with the following format:
|
19
|
+
|
20
|
+
sequences:
|
21
|
+
- protein:
|
22
|
+
id: A
|
23
|
+
sequence: "MADQLTEEQIAEFKEAFSLF"
|
24
|
+
- protein:
|
25
|
+
id: [B, C]
|
26
|
+
sequence: "AKLSILPWGHC"
|
27
|
+
- rna:
|
28
|
+
id: D
|
29
|
+
sequence: "GCAUAGC"
|
30
|
+
- ligand:
|
31
|
+
id: E
|
32
|
+
smiles: "CC1=CC=CC=C1"
|
33
|
+
- ligand:
|
34
|
+
id: [F, G]
|
35
|
+
ccd: []
|
36
|
+
constraints:
|
37
|
+
- bond:
|
38
|
+
atom1: [A, 1, CA]
|
39
|
+
atom2: [A, 2, N]
|
40
|
+
- pocket:
|
41
|
+
binder: E
|
42
|
+
contacts: [[B, 1], [B, 2]]
|
43
|
+
templates:
|
44
|
+
- path: /path/to/template.pdb
|
45
|
+
ids: [A] # optional, specify which chains to template
|
46
|
+
|
47
|
+
version: 1
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
path : Path
|
52
|
+
Path to the YAML input format.
|
53
|
+
components : Dict
|
54
|
+
Dictionary of CCD components.
|
55
|
+
boltz2 : bool
|
56
|
+
Whether to parse the input for Boltz2.
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
Target
|
61
|
+
The parsed target.
|
62
|
+
|
63
|
+
"""
|
64
|
+
with path.open("r") as file:
|
65
|
+
data = yaml.safe_load(file)
|
66
|
+
|
67
|
+
name = path.stem
|
68
|
+
return parse_boltz_schema(name, data, ccd, mol_dir, boltz2)
|
File without changes
|
@@ -0,0 +1,283 @@
|
|
1
|
+
from typing import Dict, Iterator, List
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from numpy.random import RandomState
|
5
|
+
|
6
|
+
from boltz.data import const
|
7
|
+
from boltz.data.types import ChainInfo, InterfaceInfo, Record
|
8
|
+
from boltz.data.sample.sampler import Sample, Sampler
|
9
|
+
|
10
|
+
|
11
|
+
def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001
|
12
|
+
"""Get the cluster id for a chain.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
chain : ChainInfo
|
17
|
+
The chain id to get the cluster id for.
|
18
|
+
record : Record
|
19
|
+
The record the interface is part of.
|
20
|
+
|
21
|
+
Returns
|
22
|
+
-------
|
23
|
+
str
|
24
|
+
The cluster id of the chain.
|
25
|
+
|
26
|
+
"""
|
27
|
+
return chain.cluster_id
|
28
|
+
|
29
|
+
|
30
|
+
def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str:
|
31
|
+
"""Get the cluster id for an interface.
|
32
|
+
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
interface : InterfaceInfo
|
36
|
+
The interface to get the cluster id for.
|
37
|
+
record : Record
|
38
|
+
The record the interface is part of.
|
39
|
+
|
40
|
+
Returns
|
41
|
+
-------
|
42
|
+
str
|
43
|
+
The cluster id of the interface.
|
44
|
+
|
45
|
+
"""
|
46
|
+
chain1 = record.chains[interface.chain_1]
|
47
|
+
chain2 = record.chains[interface.chain_2]
|
48
|
+
|
49
|
+
cluster_1 = str(chain1.cluster_id)
|
50
|
+
cluster_2 = str(chain2.cluster_id)
|
51
|
+
|
52
|
+
cluster_id = (cluster_1, cluster_2)
|
53
|
+
cluster_id = tuple(sorted(cluster_id))
|
54
|
+
|
55
|
+
return cluster_id
|
56
|
+
|
57
|
+
|
58
|
+
def get_chain_weight(
|
59
|
+
chain: ChainInfo,
|
60
|
+
record: Record, # noqa: ARG001
|
61
|
+
clusters: Dict[str, int],
|
62
|
+
beta_chain: float,
|
63
|
+
alpha_prot: float,
|
64
|
+
alpha_nucl: float,
|
65
|
+
alpha_ligand: float,
|
66
|
+
) -> float:
|
67
|
+
"""Get the weight of a chain.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
chain : ChainInfo
|
72
|
+
The chain to get the weight for.
|
73
|
+
record : Record
|
74
|
+
The record the chain is part of.
|
75
|
+
clusters : Dict[str, int]
|
76
|
+
The cluster sizes.
|
77
|
+
beta_chain : float
|
78
|
+
The beta value for chains.
|
79
|
+
alpha_prot : float
|
80
|
+
The alpha value for proteins.
|
81
|
+
alpha_nucl : float
|
82
|
+
The alpha value for nucleic acids.
|
83
|
+
alpha_ligand : float
|
84
|
+
The alpha value for ligands.
|
85
|
+
|
86
|
+
Returns
|
87
|
+
-------
|
88
|
+
float
|
89
|
+
The weight of the chain.
|
90
|
+
|
91
|
+
"""
|
92
|
+
prot_id = const.chain_type_ids["PROTEIN"]
|
93
|
+
rna_id = const.chain_type_ids["RNA"]
|
94
|
+
dna_id = const.chain_type_ids["DNA"]
|
95
|
+
ligand_id = const.chain_type_ids["NONPOLYMER"]
|
96
|
+
|
97
|
+
weight = beta_chain / clusters[chain.cluster_id]
|
98
|
+
if chain.mol_type == prot_id:
|
99
|
+
weight *= alpha_prot
|
100
|
+
elif chain.mol_type in [rna_id, dna_id]:
|
101
|
+
weight *= alpha_nucl
|
102
|
+
elif chain.mol_type == ligand_id:
|
103
|
+
weight *= alpha_ligand
|
104
|
+
|
105
|
+
return weight
|
106
|
+
|
107
|
+
|
108
|
+
def get_interface_weight(
|
109
|
+
interface: InterfaceInfo,
|
110
|
+
record: Record,
|
111
|
+
clusters: Dict[str, int],
|
112
|
+
beta_interface: float,
|
113
|
+
alpha_prot: float,
|
114
|
+
alpha_nucl: float,
|
115
|
+
alpha_ligand: float,
|
116
|
+
) -> float:
|
117
|
+
"""Get the weight of an interface.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
interface : InterfaceInfo
|
122
|
+
The interface to get the weight for.
|
123
|
+
record : Record
|
124
|
+
The record the interface is part of.
|
125
|
+
clusters : Dict[str, int]
|
126
|
+
The cluster sizes.
|
127
|
+
beta_interface : float
|
128
|
+
The beta value for interfaces.
|
129
|
+
alpha_prot : float
|
130
|
+
The alpha value for proteins.
|
131
|
+
alpha_nucl : float
|
132
|
+
The alpha value for nucleic acids.
|
133
|
+
alpha_ligand : float
|
134
|
+
The alpha value for ligands.
|
135
|
+
|
136
|
+
Returns
|
137
|
+
-------
|
138
|
+
float
|
139
|
+
The weight of the interface.
|
140
|
+
|
141
|
+
"""
|
142
|
+
prot_id = const.chain_type_ids["PROTEIN"]
|
143
|
+
rna_id = const.chain_type_ids["RNA"]
|
144
|
+
dna_id = const.chain_type_ids["DNA"]
|
145
|
+
ligand_id = const.chain_type_ids["NONPOLYMER"]
|
146
|
+
|
147
|
+
chain1 = record.chains[interface.chain_1]
|
148
|
+
chain2 = record.chains[interface.chain_2]
|
149
|
+
|
150
|
+
n_prot = (chain1.mol_type) == prot_id
|
151
|
+
n_nuc = chain1.mol_type in [rna_id, dna_id]
|
152
|
+
n_ligand = chain1.mol_type == ligand_id
|
153
|
+
|
154
|
+
n_prot += chain2.mol_type == prot_id
|
155
|
+
n_nuc += chain2.mol_type in [rna_id, dna_id]
|
156
|
+
n_ligand += chain2.mol_type == ligand_id
|
157
|
+
|
158
|
+
weight = beta_interface / clusters[get_interface_cluster(interface, record)]
|
159
|
+
weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand
|
160
|
+
return weight
|
161
|
+
|
162
|
+
|
163
|
+
class ClusterSampler(Sampler):
|
164
|
+
"""The weighted sampling approach, as described in AF3.
|
165
|
+
|
166
|
+
Each chain / interface is given a weight according
|
167
|
+
to the following formula, and sampled accordingly:
|
168
|
+
|
169
|
+
w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc
|
170
|
+
+ a_ligand * n_ligand)
|
171
|
+
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
alpha_prot: float = 3.0,
|
177
|
+
alpha_nucl: float = 3.0,
|
178
|
+
alpha_ligand: float = 1.0,
|
179
|
+
beta_chain: float = 0.5,
|
180
|
+
beta_interface: float = 1.0,
|
181
|
+
) -> None:
|
182
|
+
"""Initialize the sampler.
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
alpha_prot : float, optional
|
187
|
+
The alpha value for proteins.
|
188
|
+
alpha_nucl : float, optional
|
189
|
+
The alpha value for nucleic acids.
|
190
|
+
alpha_ligand : float, optional
|
191
|
+
The alpha value for ligands.
|
192
|
+
beta_chain : float, optional
|
193
|
+
The beta value for chains.
|
194
|
+
beta_interface : float, optional
|
195
|
+
The beta value for interfaces.
|
196
|
+
|
197
|
+
"""
|
198
|
+
self.alpha_prot = alpha_prot
|
199
|
+
self.alpha_nucl = alpha_nucl
|
200
|
+
self.alpha_ligand = alpha_ligand
|
201
|
+
self.beta_chain = beta_chain
|
202
|
+
self.beta_interface = beta_interface
|
203
|
+
|
204
|
+
def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912
|
205
|
+
"""Sample a structure from the dataset infinitely.
|
206
|
+
|
207
|
+
Parameters
|
208
|
+
----------
|
209
|
+
records : List[Record]
|
210
|
+
The records to sample from.
|
211
|
+
random : RandomState
|
212
|
+
The random state for reproducibility.
|
213
|
+
|
214
|
+
Yields
|
215
|
+
------
|
216
|
+
Sample
|
217
|
+
A data sample.
|
218
|
+
|
219
|
+
"""
|
220
|
+
# Compute chain cluster sizes
|
221
|
+
chain_clusters: Dict[str, int] = {}
|
222
|
+
for record in records:
|
223
|
+
for chain in record.chains:
|
224
|
+
if not chain.valid:
|
225
|
+
continue
|
226
|
+
cluster_id = get_chain_cluster(chain, record)
|
227
|
+
if cluster_id not in chain_clusters:
|
228
|
+
chain_clusters[cluster_id] = 0
|
229
|
+
chain_clusters[cluster_id] += 1
|
230
|
+
|
231
|
+
# Compute interface clusters sizes
|
232
|
+
interface_clusters: Dict[str, int] = {}
|
233
|
+
for record in records:
|
234
|
+
for interface in record.interfaces:
|
235
|
+
if not interface.valid:
|
236
|
+
continue
|
237
|
+
cluster_id = get_interface_cluster(interface, record)
|
238
|
+
if cluster_id not in interface_clusters:
|
239
|
+
interface_clusters[cluster_id] = 0
|
240
|
+
interface_clusters[cluster_id] += 1
|
241
|
+
|
242
|
+
# Compute weights
|
243
|
+
items, weights = [], []
|
244
|
+
for record in records:
|
245
|
+
for chain_id, chain in enumerate(record.chains):
|
246
|
+
if not chain.valid:
|
247
|
+
continue
|
248
|
+
weight = get_chain_weight(
|
249
|
+
chain,
|
250
|
+
record,
|
251
|
+
chain_clusters,
|
252
|
+
self.beta_chain,
|
253
|
+
self.alpha_prot,
|
254
|
+
self.alpha_nucl,
|
255
|
+
self.alpha_ligand,
|
256
|
+
)
|
257
|
+
items.append((record, 0, chain_id))
|
258
|
+
weights.append(weight)
|
259
|
+
|
260
|
+
for int_id, interface in enumerate(record.interfaces):
|
261
|
+
if not interface.valid:
|
262
|
+
continue
|
263
|
+
weight = get_interface_weight(
|
264
|
+
interface,
|
265
|
+
record,
|
266
|
+
interface_clusters,
|
267
|
+
self.beta_interface,
|
268
|
+
self.alpha_prot,
|
269
|
+
self.alpha_nucl,
|
270
|
+
self.alpha_ligand,
|
271
|
+
)
|
272
|
+
items.append((record, 1, int_id))
|
273
|
+
weights.append(weight)
|
274
|
+
|
275
|
+
# Sample infinitely
|
276
|
+
weights = np.array(weights) / np.sum(weights)
|
277
|
+
while True:
|
278
|
+
item_idx = random.choice(len(items), p=weights)
|
279
|
+
record, kind, index = items[item_idx]
|
280
|
+
if kind == 0:
|
281
|
+
yield Sample(record=record, chain_id=index)
|
282
|
+
else:
|
283
|
+
yield Sample(record=record, interface_id=index)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from typing import Iterator, List
|
2
|
+
|
3
|
+
from numpy.random import RandomState
|
4
|
+
|
5
|
+
from boltz.data.types import Record
|
6
|
+
from boltz.data.sample.sampler import Sample, Sampler
|
7
|
+
|
8
|
+
|
9
|
+
class DistillationSampler(Sampler):
|
10
|
+
"""A sampler for monomer distillation data."""
|
11
|
+
|
12
|
+
def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None:
|
13
|
+
"""Initialize the sampler.
|
14
|
+
|
15
|
+
Parameters
|
16
|
+
----------
|
17
|
+
small_size : int, optional
|
18
|
+
The maximum size to be considered small.
|
19
|
+
small_prob : float, optional
|
20
|
+
The probability of sampling a small item.
|
21
|
+
|
22
|
+
"""
|
23
|
+
self._size = small_size
|
24
|
+
self._prob = small_prob
|
25
|
+
|
26
|
+
def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
|
27
|
+
"""Sample a structure from the dataset infinitely.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
records : List[Record]
|
32
|
+
The records to sample from.
|
33
|
+
random : RandomState
|
34
|
+
The random state for reproducibility.
|
35
|
+
|
36
|
+
Yields
|
37
|
+
------
|
38
|
+
Sample
|
39
|
+
A data sample.
|
40
|
+
|
41
|
+
"""
|
42
|
+
# Remove records with invalid chains
|
43
|
+
records = [r for r in records if r.chains[0].valid]
|
44
|
+
|
45
|
+
# Split in small and large proteins. We assume that there is only
|
46
|
+
# one chain per record, as is the case for monomer distillation
|
47
|
+
small = [r for r in records if r.chains[0].num_residues <= self._size]
|
48
|
+
large = [r for r in records if r.chains[0].num_residues > self._size]
|
49
|
+
|
50
|
+
# Sample infinitely
|
51
|
+
while True:
|
52
|
+
# Sample small or large
|
53
|
+
samples = small if random.rand() < self._prob else large
|
54
|
+
|
55
|
+
# Sample item from the list
|
56
|
+
index = random.randint(0, len(samples))
|
57
|
+
yield Sample(record=samples[index])
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from dataclasses import replace
|
2
|
+
from typing import Iterator, List
|
3
|
+
|
4
|
+
from numpy.random import RandomState
|
5
|
+
|
6
|
+
from boltz.data.types import Record
|
7
|
+
from boltz.data.sample.sampler import Sample, Sampler
|
8
|
+
|
9
|
+
|
10
|
+
class RandomSampler(Sampler):
|
11
|
+
"""A simple random sampler with replacement."""
|
12
|
+
|
13
|
+
def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
|
14
|
+
"""Sample a structure from the dataset infinitely.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
records : List[Record]
|
19
|
+
The records to sample from.
|
20
|
+
random : RandomState
|
21
|
+
The random state for reproducibility.
|
22
|
+
|
23
|
+
Yields
|
24
|
+
------
|
25
|
+
Sample
|
26
|
+
A data sample.
|
27
|
+
|
28
|
+
"""
|
29
|
+
while True:
|
30
|
+
# Sample item from the list
|
31
|
+
index = random.randint(0, len(records))
|
32
|
+
record = records[index]
|
33
|
+
|
34
|
+
# Remove invalid chains and interfaces
|
35
|
+
chains = [c for c in record.chains if c.valid]
|
36
|
+
interfaces = [i for i in record.interfaces if i.valid]
|
37
|
+
record = replace(record, chains=chains, interfaces=interfaces)
|
38
|
+
|
39
|
+
yield Sample(record=record)
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Iterator, List, Optional
|
4
|
+
|
5
|
+
from numpy.random import RandomState
|
6
|
+
|
7
|
+
from boltz.data.types import Record
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class Sample:
|
12
|
+
"""A sample with optional chain and interface IDs.
|
13
|
+
|
14
|
+
Attributes
|
15
|
+
----------
|
16
|
+
record : Record
|
17
|
+
The record.
|
18
|
+
chain_id : Optional[int]
|
19
|
+
The chain ID.
|
20
|
+
interface_id : Optional[int]
|
21
|
+
The interface ID.
|
22
|
+
"""
|
23
|
+
|
24
|
+
record: Record
|
25
|
+
chain_id: Optional[int] = None
|
26
|
+
interface_id: Optional[int] = None
|
27
|
+
|
28
|
+
|
29
|
+
class Sampler(ABC):
|
30
|
+
"""Abstract base class for samplers."""
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
|
34
|
+
"""Sample a structure from the dataset infinitely.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
records : List[Record]
|
39
|
+
The records to sample from.
|
40
|
+
random : RandomState
|
41
|
+
The random state for reproducibility.
|
42
|
+
|
43
|
+
Yields
|
44
|
+
------
|
45
|
+
Sample
|
46
|
+
A data sample.
|
47
|
+
|
48
|
+
"""
|
49
|
+
raise NotImplementedError
|
File without changes
|
@@ -0,0 +1,195 @@
|
|
1
|
+
from dataclasses import astuple, dataclass
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from boltz.data import const
|
6
|
+
from boltz.data.tokenize.tokenizer import Tokenizer
|
7
|
+
from boltz.data.types import Input, Token, TokenBond, Tokenized
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class TokenData:
|
12
|
+
"""TokenData datatype."""
|
13
|
+
|
14
|
+
token_idx: int
|
15
|
+
atom_idx: int
|
16
|
+
atom_num: int
|
17
|
+
res_idx: int
|
18
|
+
res_type: int
|
19
|
+
sym_id: int
|
20
|
+
asym_id: int
|
21
|
+
entity_id: int
|
22
|
+
mol_type: int
|
23
|
+
center_idx: int
|
24
|
+
disto_idx: int
|
25
|
+
center_coords: np.ndarray
|
26
|
+
disto_coords: np.ndarray
|
27
|
+
resolved_mask: bool
|
28
|
+
disto_mask: bool
|
29
|
+
cyclic_period: int
|
30
|
+
|
31
|
+
|
32
|
+
class BoltzTokenizer(Tokenizer):
|
33
|
+
"""Tokenize an input structure for training."""
|
34
|
+
|
35
|
+
def tokenize(self, data: Input) -> Tokenized:
|
36
|
+
"""Tokenize the input data.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
data : Input
|
41
|
+
The input data.
|
42
|
+
|
43
|
+
Returns
|
44
|
+
-------
|
45
|
+
Tokenized
|
46
|
+
The tokenized data.
|
47
|
+
|
48
|
+
"""
|
49
|
+
# Get structure data
|
50
|
+
struct = data.structure
|
51
|
+
|
52
|
+
# Create token data
|
53
|
+
token_data = []
|
54
|
+
|
55
|
+
# Keep track of atom_idx to token_idx
|
56
|
+
token_idx = 0
|
57
|
+
atom_to_token = {}
|
58
|
+
|
59
|
+
# Filter to valid chains only
|
60
|
+
chains = struct.chains[struct.mask]
|
61
|
+
|
62
|
+
for chain in chains:
|
63
|
+
# Get residue indices
|
64
|
+
res_start = chain["res_idx"]
|
65
|
+
res_end = chain["res_idx"] + chain["res_num"]
|
66
|
+
|
67
|
+
for res in struct.residues[res_start:res_end]:
|
68
|
+
# Get atom indices
|
69
|
+
atom_start = res["atom_idx"]
|
70
|
+
atom_end = res["atom_idx"] + res["atom_num"]
|
71
|
+
|
72
|
+
# Standard residues are tokens
|
73
|
+
if res["is_standard"]:
|
74
|
+
# Get center and disto atoms
|
75
|
+
center = struct.atoms[res["atom_center"]]
|
76
|
+
disto = struct.atoms[res["atom_disto"]]
|
77
|
+
|
78
|
+
# Token is present if centers are
|
79
|
+
is_present = res["is_present"] & center["is_present"]
|
80
|
+
is_disto_present = res["is_present"] & disto["is_present"]
|
81
|
+
|
82
|
+
# Apply chain transformation
|
83
|
+
c_coords = center["coords"]
|
84
|
+
d_coords = disto["coords"]
|
85
|
+
|
86
|
+
# Create token
|
87
|
+
token = TokenData(
|
88
|
+
token_idx=token_idx,
|
89
|
+
atom_idx=res["atom_idx"],
|
90
|
+
atom_num=res["atom_num"],
|
91
|
+
res_idx=res["res_idx"],
|
92
|
+
res_type=res["res_type"],
|
93
|
+
sym_id=chain["sym_id"],
|
94
|
+
asym_id=chain["asym_id"],
|
95
|
+
entity_id=chain["entity_id"],
|
96
|
+
mol_type=chain["mol_type"],
|
97
|
+
center_idx=res["atom_center"],
|
98
|
+
disto_idx=res["atom_disto"],
|
99
|
+
center_coords=c_coords,
|
100
|
+
disto_coords=d_coords,
|
101
|
+
resolved_mask=is_present,
|
102
|
+
disto_mask=is_disto_present,
|
103
|
+
cyclic_period=chain["cyclic_period"],
|
104
|
+
)
|
105
|
+
token_data.append(astuple(token))
|
106
|
+
|
107
|
+
# Update atom_idx to token_idx
|
108
|
+
for atom_idx in range(atom_start, atom_end):
|
109
|
+
atom_to_token[atom_idx] = token_idx
|
110
|
+
|
111
|
+
token_idx += 1
|
112
|
+
|
113
|
+
# Non-standard are tokenized per atom
|
114
|
+
else:
|
115
|
+
# We use the unk protein token as res_type
|
116
|
+
unk_token = const.unk_token["PROTEIN"]
|
117
|
+
unk_id = const.token_ids[unk_token]
|
118
|
+
|
119
|
+
# Get atom coordinates
|
120
|
+
atom_data = struct.atoms[atom_start:atom_end]
|
121
|
+
atom_coords = atom_data["coords"]
|
122
|
+
|
123
|
+
# Tokenize each atom
|
124
|
+
for i, atom in enumerate(atom_data):
|
125
|
+
# Token is present if atom is
|
126
|
+
is_present = res["is_present"] & atom["is_present"]
|
127
|
+
index = atom_start + i
|
128
|
+
|
129
|
+
# Create token
|
130
|
+
token = TokenData(
|
131
|
+
token_idx=token_idx,
|
132
|
+
atom_idx=index,
|
133
|
+
atom_num=1,
|
134
|
+
res_idx=res["res_idx"],
|
135
|
+
res_type=unk_id,
|
136
|
+
sym_id=chain["sym_id"],
|
137
|
+
asym_id=chain["asym_id"],
|
138
|
+
entity_id=chain["entity_id"],
|
139
|
+
mol_type=chain["mol_type"],
|
140
|
+
center_idx=index,
|
141
|
+
disto_idx=index,
|
142
|
+
center_coords=atom_coords[i],
|
143
|
+
disto_coords=atom_coords[i],
|
144
|
+
resolved_mask=is_present,
|
145
|
+
disto_mask=is_present,
|
146
|
+
cyclic_period=chain[
|
147
|
+
"cyclic_period"
|
148
|
+
], # Enforced to be False in chain parser
|
149
|
+
)
|
150
|
+
token_data.append(astuple(token))
|
151
|
+
|
152
|
+
# Update atom_idx to token_idx
|
153
|
+
atom_to_token[index] = token_idx
|
154
|
+
token_idx += 1
|
155
|
+
|
156
|
+
# Create token bonds
|
157
|
+
token_bonds = []
|
158
|
+
|
159
|
+
# Add atom-atom bonds from ligands
|
160
|
+
for bond in struct.bonds:
|
161
|
+
if (
|
162
|
+
bond["atom_1"] not in atom_to_token
|
163
|
+
or bond["atom_2"] not in atom_to_token
|
164
|
+
):
|
165
|
+
continue
|
166
|
+
token_bond = (
|
167
|
+
atom_to_token[bond["atom_1"]],
|
168
|
+
atom_to_token[bond["atom_2"]],
|
169
|
+
)
|
170
|
+
token_bonds.append(token_bond)
|
171
|
+
|
172
|
+
# Add connection bonds (covalent)
|
173
|
+
for conn in struct.connections:
|
174
|
+
if (
|
175
|
+
conn["atom_1"] not in atom_to_token
|
176
|
+
or conn["atom_2"] not in atom_to_token
|
177
|
+
):
|
178
|
+
continue
|
179
|
+
token_bond = (
|
180
|
+
atom_to_token[conn["atom_1"]],
|
181
|
+
atom_to_token[conn["atom_2"]],
|
182
|
+
)
|
183
|
+
token_bonds.append(token_bond)
|
184
|
+
|
185
|
+
token_data = np.array(token_data, dtype=Token)
|
186
|
+
token_bonds = np.array(token_bonds, dtype=TokenBond)
|
187
|
+
tokenized = Tokenized(
|
188
|
+
tokens=token_data,
|
189
|
+
bonds=token_bonds,
|
190
|
+
structure=data.structure,
|
191
|
+
msa=data.msa,
|
192
|
+
record=data.record,
|
193
|
+
residue_constraints=data.residue_constraints,
|
194
|
+
)
|
195
|
+
return tokenized
|