boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,68 @@
1
+ from pathlib import Path
2
+
3
+ import yaml
4
+ from rdkit.Chem.rdchem import Mol
5
+
6
+ from boltz.data.parse.schema import parse_boltz_schema
7
+ from boltz.data.types import Target
8
+
9
+
10
+ def parse_yaml(
11
+ path: Path,
12
+ ccd: dict[str, Mol],
13
+ mol_dir: Path,
14
+ boltz2: bool = False,
15
+ ) -> Target:
16
+ """Parse a Boltz input yaml / json.
17
+
18
+ The input file should be a yaml file with the following format:
19
+
20
+ sequences:
21
+ - protein:
22
+ id: A
23
+ sequence: "MADQLTEEQIAEFKEAFSLF"
24
+ - protein:
25
+ id: [B, C]
26
+ sequence: "AKLSILPWGHC"
27
+ - rna:
28
+ id: D
29
+ sequence: "GCAUAGC"
30
+ - ligand:
31
+ id: E
32
+ smiles: "CC1=CC=CC=C1"
33
+ - ligand:
34
+ id: [F, G]
35
+ ccd: []
36
+ constraints:
37
+ - bond:
38
+ atom1: [A, 1, CA]
39
+ atom2: [A, 2, N]
40
+ - pocket:
41
+ binder: E
42
+ contacts: [[B, 1], [B, 2]]
43
+ templates:
44
+ - path: /path/to/template.pdb
45
+ ids: [A] # optional, specify which chains to template
46
+
47
+ version: 1
48
+
49
+ Parameters
50
+ ----------
51
+ path : Path
52
+ Path to the YAML input format.
53
+ components : Dict
54
+ Dictionary of CCD components.
55
+ boltz2 : bool
56
+ Whether to parse the input for Boltz2.
57
+
58
+ Returns
59
+ -------
60
+ Target
61
+ The parsed target.
62
+
63
+ """
64
+ with path.open("r") as file:
65
+ data = yaml.safe_load(file)
66
+
67
+ name = path.stem
68
+ return parse_boltz_schema(name, data, ccd, mol_dir, boltz2)
File without changes
@@ -0,0 +1,283 @@
1
+ from typing import Dict, Iterator, List
2
+
3
+ import numpy as np
4
+ from numpy.random import RandomState
5
+
6
+ from boltz.data import const
7
+ from boltz.data.types import ChainInfo, InterfaceInfo, Record
8
+ from boltz.data.sample.sampler import Sample, Sampler
9
+
10
+
11
+ def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001
12
+ """Get the cluster id for a chain.
13
+
14
+ Parameters
15
+ ----------
16
+ chain : ChainInfo
17
+ The chain id to get the cluster id for.
18
+ record : Record
19
+ The record the interface is part of.
20
+
21
+ Returns
22
+ -------
23
+ str
24
+ The cluster id of the chain.
25
+
26
+ """
27
+ return chain.cluster_id
28
+
29
+
30
+ def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str:
31
+ """Get the cluster id for an interface.
32
+
33
+ Parameters
34
+ ----------
35
+ interface : InterfaceInfo
36
+ The interface to get the cluster id for.
37
+ record : Record
38
+ The record the interface is part of.
39
+
40
+ Returns
41
+ -------
42
+ str
43
+ The cluster id of the interface.
44
+
45
+ """
46
+ chain1 = record.chains[interface.chain_1]
47
+ chain2 = record.chains[interface.chain_2]
48
+
49
+ cluster_1 = str(chain1.cluster_id)
50
+ cluster_2 = str(chain2.cluster_id)
51
+
52
+ cluster_id = (cluster_1, cluster_2)
53
+ cluster_id = tuple(sorted(cluster_id))
54
+
55
+ return cluster_id
56
+
57
+
58
+ def get_chain_weight(
59
+ chain: ChainInfo,
60
+ record: Record, # noqa: ARG001
61
+ clusters: Dict[str, int],
62
+ beta_chain: float,
63
+ alpha_prot: float,
64
+ alpha_nucl: float,
65
+ alpha_ligand: float,
66
+ ) -> float:
67
+ """Get the weight of a chain.
68
+
69
+ Parameters
70
+ ----------
71
+ chain : ChainInfo
72
+ The chain to get the weight for.
73
+ record : Record
74
+ The record the chain is part of.
75
+ clusters : Dict[str, int]
76
+ The cluster sizes.
77
+ beta_chain : float
78
+ The beta value for chains.
79
+ alpha_prot : float
80
+ The alpha value for proteins.
81
+ alpha_nucl : float
82
+ The alpha value for nucleic acids.
83
+ alpha_ligand : float
84
+ The alpha value for ligands.
85
+
86
+ Returns
87
+ -------
88
+ float
89
+ The weight of the chain.
90
+
91
+ """
92
+ prot_id = const.chain_type_ids["PROTEIN"]
93
+ rna_id = const.chain_type_ids["RNA"]
94
+ dna_id = const.chain_type_ids["DNA"]
95
+ ligand_id = const.chain_type_ids["NONPOLYMER"]
96
+
97
+ weight = beta_chain / clusters[chain.cluster_id]
98
+ if chain.mol_type == prot_id:
99
+ weight *= alpha_prot
100
+ elif chain.mol_type in [rna_id, dna_id]:
101
+ weight *= alpha_nucl
102
+ elif chain.mol_type == ligand_id:
103
+ weight *= alpha_ligand
104
+
105
+ return weight
106
+
107
+
108
+ def get_interface_weight(
109
+ interface: InterfaceInfo,
110
+ record: Record,
111
+ clusters: Dict[str, int],
112
+ beta_interface: float,
113
+ alpha_prot: float,
114
+ alpha_nucl: float,
115
+ alpha_ligand: float,
116
+ ) -> float:
117
+ """Get the weight of an interface.
118
+
119
+ Parameters
120
+ ----------
121
+ interface : InterfaceInfo
122
+ The interface to get the weight for.
123
+ record : Record
124
+ The record the interface is part of.
125
+ clusters : Dict[str, int]
126
+ The cluster sizes.
127
+ beta_interface : float
128
+ The beta value for interfaces.
129
+ alpha_prot : float
130
+ The alpha value for proteins.
131
+ alpha_nucl : float
132
+ The alpha value for nucleic acids.
133
+ alpha_ligand : float
134
+ The alpha value for ligands.
135
+
136
+ Returns
137
+ -------
138
+ float
139
+ The weight of the interface.
140
+
141
+ """
142
+ prot_id = const.chain_type_ids["PROTEIN"]
143
+ rna_id = const.chain_type_ids["RNA"]
144
+ dna_id = const.chain_type_ids["DNA"]
145
+ ligand_id = const.chain_type_ids["NONPOLYMER"]
146
+
147
+ chain1 = record.chains[interface.chain_1]
148
+ chain2 = record.chains[interface.chain_2]
149
+
150
+ n_prot = (chain1.mol_type) == prot_id
151
+ n_nuc = chain1.mol_type in [rna_id, dna_id]
152
+ n_ligand = chain1.mol_type == ligand_id
153
+
154
+ n_prot += chain2.mol_type == prot_id
155
+ n_nuc += chain2.mol_type in [rna_id, dna_id]
156
+ n_ligand += chain2.mol_type == ligand_id
157
+
158
+ weight = beta_interface / clusters[get_interface_cluster(interface, record)]
159
+ weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand
160
+ return weight
161
+
162
+
163
+ class ClusterSampler(Sampler):
164
+ """The weighted sampling approach, as described in AF3.
165
+
166
+ Each chain / interface is given a weight according
167
+ to the following formula, and sampled accordingly:
168
+
169
+ w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc
170
+ + a_ligand * n_ligand)
171
+
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ alpha_prot: float = 3.0,
177
+ alpha_nucl: float = 3.0,
178
+ alpha_ligand: float = 1.0,
179
+ beta_chain: float = 0.5,
180
+ beta_interface: float = 1.0,
181
+ ) -> None:
182
+ """Initialize the sampler.
183
+
184
+ Parameters
185
+ ----------
186
+ alpha_prot : float, optional
187
+ The alpha value for proteins.
188
+ alpha_nucl : float, optional
189
+ The alpha value for nucleic acids.
190
+ alpha_ligand : float, optional
191
+ The alpha value for ligands.
192
+ beta_chain : float, optional
193
+ The beta value for chains.
194
+ beta_interface : float, optional
195
+ The beta value for interfaces.
196
+
197
+ """
198
+ self.alpha_prot = alpha_prot
199
+ self.alpha_nucl = alpha_nucl
200
+ self.alpha_ligand = alpha_ligand
201
+ self.beta_chain = beta_chain
202
+ self.beta_interface = beta_interface
203
+
204
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912
205
+ """Sample a structure from the dataset infinitely.
206
+
207
+ Parameters
208
+ ----------
209
+ records : List[Record]
210
+ The records to sample from.
211
+ random : RandomState
212
+ The random state for reproducibility.
213
+
214
+ Yields
215
+ ------
216
+ Sample
217
+ A data sample.
218
+
219
+ """
220
+ # Compute chain cluster sizes
221
+ chain_clusters: Dict[str, int] = {}
222
+ for record in records:
223
+ for chain in record.chains:
224
+ if not chain.valid:
225
+ continue
226
+ cluster_id = get_chain_cluster(chain, record)
227
+ if cluster_id not in chain_clusters:
228
+ chain_clusters[cluster_id] = 0
229
+ chain_clusters[cluster_id] += 1
230
+
231
+ # Compute interface clusters sizes
232
+ interface_clusters: Dict[str, int] = {}
233
+ for record in records:
234
+ for interface in record.interfaces:
235
+ if not interface.valid:
236
+ continue
237
+ cluster_id = get_interface_cluster(interface, record)
238
+ if cluster_id not in interface_clusters:
239
+ interface_clusters[cluster_id] = 0
240
+ interface_clusters[cluster_id] += 1
241
+
242
+ # Compute weights
243
+ items, weights = [], []
244
+ for record in records:
245
+ for chain_id, chain in enumerate(record.chains):
246
+ if not chain.valid:
247
+ continue
248
+ weight = get_chain_weight(
249
+ chain,
250
+ record,
251
+ chain_clusters,
252
+ self.beta_chain,
253
+ self.alpha_prot,
254
+ self.alpha_nucl,
255
+ self.alpha_ligand,
256
+ )
257
+ items.append((record, 0, chain_id))
258
+ weights.append(weight)
259
+
260
+ for int_id, interface in enumerate(record.interfaces):
261
+ if not interface.valid:
262
+ continue
263
+ weight = get_interface_weight(
264
+ interface,
265
+ record,
266
+ interface_clusters,
267
+ self.beta_interface,
268
+ self.alpha_prot,
269
+ self.alpha_nucl,
270
+ self.alpha_ligand,
271
+ )
272
+ items.append((record, 1, int_id))
273
+ weights.append(weight)
274
+
275
+ # Sample infinitely
276
+ weights = np.array(weights) / np.sum(weights)
277
+ while True:
278
+ item_idx = random.choice(len(items), p=weights)
279
+ record, kind, index = items[item_idx]
280
+ if kind == 0:
281
+ yield Sample(record=record, chain_id=index)
282
+ else:
283
+ yield Sample(record=record, interface_id=index)
@@ -0,0 +1,57 @@
1
+ from typing import Iterator, List
2
+
3
+ from numpy.random import RandomState
4
+
5
+ from boltz.data.types import Record
6
+ from boltz.data.sample.sampler import Sample, Sampler
7
+
8
+
9
+ class DistillationSampler(Sampler):
10
+ """A sampler for monomer distillation data."""
11
+
12
+ def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None:
13
+ """Initialize the sampler.
14
+
15
+ Parameters
16
+ ----------
17
+ small_size : int, optional
18
+ The maximum size to be considered small.
19
+ small_prob : float, optional
20
+ The probability of sampling a small item.
21
+
22
+ """
23
+ self._size = small_size
24
+ self._prob = small_prob
25
+
26
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
27
+ """Sample a structure from the dataset infinitely.
28
+
29
+ Parameters
30
+ ----------
31
+ records : List[Record]
32
+ The records to sample from.
33
+ random : RandomState
34
+ The random state for reproducibility.
35
+
36
+ Yields
37
+ ------
38
+ Sample
39
+ A data sample.
40
+
41
+ """
42
+ # Remove records with invalid chains
43
+ records = [r for r in records if r.chains[0].valid]
44
+
45
+ # Split in small and large proteins. We assume that there is only
46
+ # one chain per record, as is the case for monomer distillation
47
+ small = [r for r in records if r.chains[0].num_residues <= self._size]
48
+ large = [r for r in records if r.chains[0].num_residues > self._size]
49
+
50
+ # Sample infinitely
51
+ while True:
52
+ # Sample small or large
53
+ samples = small if random.rand() < self._prob else large
54
+
55
+ # Sample item from the list
56
+ index = random.randint(0, len(samples))
57
+ yield Sample(record=samples[index])
@@ -0,0 +1,39 @@
1
+ from dataclasses import replace
2
+ from typing import Iterator, List
3
+
4
+ from numpy.random import RandomState
5
+
6
+ from boltz.data.types import Record
7
+ from boltz.data.sample.sampler import Sample, Sampler
8
+
9
+
10
+ class RandomSampler(Sampler):
11
+ """A simple random sampler with replacement."""
12
+
13
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
14
+ """Sample a structure from the dataset infinitely.
15
+
16
+ Parameters
17
+ ----------
18
+ records : List[Record]
19
+ The records to sample from.
20
+ random : RandomState
21
+ The random state for reproducibility.
22
+
23
+ Yields
24
+ ------
25
+ Sample
26
+ A data sample.
27
+
28
+ """
29
+ while True:
30
+ # Sample item from the list
31
+ index = random.randint(0, len(records))
32
+ record = records[index]
33
+
34
+ # Remove invalid chains and interfaces
35
+ chains = [c for c in record.chains if c.valid]
36
+ interfaces = [i for i in record.interfaces if i.valid]
37
+ record = replace(record, chains=chains, interfaces=interfaces)
38
+
39
+ yield Sample(record=record)
@@ -0,0 +1,49 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Iterator, List, Optional
4
+
5
+ from numpy.random import RandomState
6
+
7
+ from boltz.data.types import Record
8
+
9
+
10
+ @dataclass
11
+ class Sample:
12
+ """A sample with optional chain and interface IDs.
13
+
14
+ Attributes
15
+ ----------
16
+ record : Record
17
+ The record.
18
+ chain_id : Optional[int]
19
+ The chain ID.
20
+ interface_id : Optional[int]
21
+ The interface ID.
22
+ """
23
+
24
+ record: Record
25
+ chain_id: Optional[int] = None
26
+ interface_id: Optional[int] = None
27
+
28
+
29
+ class Sampler(ABC):
30
+ """Abstract base class for samplers."""
31
+
32
+ @abstractmethod
33
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
34
+ """Sample a structure from the dataset infinitely.
35
+
36
+ Parameters
37
+ ----------
38
+ records : List[Record]
39
+ The records to sample from.
40
+ random : RandomState
41
+ The random state for reproducibility.
42
+
43
+ Yields
44
+ ------
45
+ Sample
46
+ A data sample.
47
+
48
+ """
49
+ raise NotImplementedError
File without changes
@@ -0,0 +1,195 @@
1
+ from dataclasses import astuple, dataclass
2
+
3
+ import numpy as np
4
+
5
+ from boltz.data import const
6
+ from boltz.data.tokenize.tokenizer import Tokenizer
7
+ from boltz.data.types import Input, Token, TokenBond, Tokenized
8
+
9
+
10
+ @dataclass
11
+ class TokenData:
12
+ """TokenData datatype."""
13
+
14
+ token_idx: int
15
+ atom_idx: int
16
+ atom_num: int
17
+ res_idx: int
18
+ res_type: int
19
+ sym_id: int
20
+ asym_id: int
21
+ entity_id: int
22
+ mol_type: int
23
+ center_idx: int
24
+ disto_idx: int
25
+ center_coords: np.ndarray
26
+ disto_coords: np.ndarray
27
+ resolved_mask: bool
28
+ disto_mask: bool
29
+ cyclic_period: int
30
+
31
+
32
+ class BoltzTokenizer(Tokenizer):
33
+ """Tokenize an input structure for training."""
34
+
35
+ def tokenize(self, data: Input) -> Tokenized:
36
+ """Tokenize the input data.
37
+
38
+ Parameters
39
+ ----------
40
+ data : Input
41
+ The input data.
42
+
43
+ Returns
44
+ -------
45
+ Tokenized
46
+ The tokenized data.
47
+
48
+ """
49
+ # Get structure data
50
+ struct = data.structure
51
+
52
+ # Create token data
53
+ token_data = []
54
+
55
+ # Keep track of atom_idx to token_idx
56
+ token_idx = 0
57
+ atom_to_token = {}
58
+
59
+ # Filter to valid chains only
60
+ chains = struct.chains[struct.mask]
61
+
62
+ for chain in chains:
63
+ # Get residue indices
64
+ res_start = chain["res_idx"]
65
+ res_end = chain["res_idx"] + chain["res_num"]
66
+
67
+ for res in struct.residues[res_start:res_end]:
68
+ # Get atom indices
69
+ atom_start = res["atom_idx"]
70
+ atom_end = res["atom_idx"] + res["atom_num"]
71
+
72
+ # Standard residues are tokens
73
+ if res["is_standard"]:
74
+ # Get center and disto atoms
75
+ center = struct.atoms[res["atom_center"]]
76
+ disto = struct.atoms[res["atom_disto"]]
77
+
78
+ # Token is present if centers are
79
+ is_present = res["is_present"] & center["is_present"]
80
+ is_disto_present = res["is_present"] & disto["is_present"]
81
+
82
+ # Apply chain transformation
83
+ c_coords = center["coords"]
84
+ d_coords = disto["coords"]
85
+
86
+ # Create token
87
+ token = TokenData(
88
+ token_idx=token_idx,
89
+ atom_idx=res["atom_idx"],
90
+ atom_num=res["atom_num"],
91
+ res_idx=res["res_idx"],
92
+ res_type=res["res_type"],
93
+ sym_id=chain["sym_id"],
94
+ asym_id=chain["asym_id"],
95
+ entity_id=chain["entity_id"],
96
+ mol_type=chain["mol_type"],
97
+ center_idx=res["atom_center"],
98
+ disto_idx=res["atom_disto"],
99
+ center_coords=c_coords,
100
+ disto_coords=d_coords,
101
+ resolved_mask=is_present,
102
+ disto_mask=is_disto_present,
103
+ cyclic_period=chain["cyclic_period"],
104
+ )
105
+ token_data.append(astuple(token))
106
+
107
+ # Update atom_idx to token_idx
108
+ for atom_idx in range(atom_start, atom_end):
109
+ atom_to_token[atom_idx] = token_idx
110
+
111
+ token_idx += 1
112
+
113
+ # Non-standard are tokenized per atom
114
+ else:
115
+ # We use the unk protein token as res_type
116
+ unk_token = const.unk_token["PROTEIN"]
117
+ unk_id = const.token_ids[unk_token]
118
+
119
+ # Get atom coordinates
120
+ atom_data = struct.atoms[atom_start:atom_end]
121
+ atom_coords = atom_data["coords"]
122
+
123
+ # Tokenize each atom
124
+ for i, atom in enumerate(atom_data):
125
+ # Token is present if atom is
126
+ is_present = res["is_present"] & atom["is_present"]
127
+ index = atom_start + i
128
+
129
+ # Create token
130
+ token = TokenData(
131
+ token_idx=token_idx,
132
+ atom_idx=index,
133
+ atom_num=1,
134
+ res_idx=res["res_idx"],
135
+ res_type=unk_id,
136
+ sym_id=chain["sym_id"],
137
+ asym_id=chain["asym_id"],
138
+ entity_id=chain["entity_id"],
139
+ mol_type=chain["mol_type"],
140
+ center_idx=index,
141
+ disto_idx=index,
142
+ center_coords=atom_coords[i],
143
+ disto_coords=atom_coords[i],
144
+ resolved_mask=is_present,
145
+ disto_mask=is_present,
146
+ cyclic_period=chain[
147
+ "cyclic_period"
148
+ ], # Enforced to be False in chain parser
149
+ )
150
+ token_data.append(astuple(token))
151
+
152
+ # Update atom_idx to token_idx
153
+ atom_to_token[index] = token_idx
154
+ token_idx += 1
155
+
156
+ # Create token bonds
157
+ token_bonds = []
158
+
159
+ # Add atom-atom bonds from ligands
160
+ for bond in struct.bonds:
161
+ if (
162
+ bond["atom_1"] not in atom_to_token
163
+ or bond["atom_2"] not in atom_to_token
164
+ ):
165
+ continue
166
+ token_bond = (
167
+ atom_to_token[bond["atom_1"]],
168
+ atom_to_token[bond["atom_2"]],
169
+ )
170
+ token_bonds.append(token_bond)
171
+
172
+ # Add connection bonds (covalent)
173
+ for conn in struct.connections:
174
+ if (
175
+ conn["atom_1"] not in atom_to_token
176
+ or conn["atom_2"] not in atom_to_token
177
+ ):
178
+ continue
179
+ token_bond = (
180
+ atom_to_token[conn["atom_1"]],
181
+ atom_to_token[conn["atom_2"]],
182
+ )
183
+ token_bonds.append(token_bond)
184
+
185
+ token_data = np.array(token_data, dtype=Token)
186
+ token_bonds = np.array(token_bonds, dtype=TokenBond)
187
+ tokenized = Tokenized(
188
+ tokens=token_data,
189
+ bonds=token_bonds,
190
+ structure=data.structure,
191
+ msa=data.msa,
192
+ record=data.record,
193
+ residue_constraints=data.residue_constraints,
194
+ )
195
+ return tokenized