boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,164 @@
1
+ from dataclasses import replace
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ from boltz.data import const
7
+ from boltz.data.crop.cropper import Cropper
8
+ from boltz.data.types import Tokenized
9
+
10
+
11
+ class AffinityCropper(Cropper):
12
+ """Interpolate between contiguous and spatial crops."""
13
+
14
+ def __init__(
15
+ self,
16
+ neighborhood_size: int = 10,
17
+ max_tokens_protein: int = 200,
18
+ ) -> None:
19
+ """Initialize the cropper.
20
+
21
+ Parameters
22
+ ----------
23
+ neighborhood_size : int
24
+ Modulates the type of cropping to be performed.
25
+ Smaller neighborhoods result in more spatial
26
+ cropping. Larger neighborhoods result in more
27
+ continuous cropping.
28
+
29
+ """
30
+ self.neighborhood_size = neighborhood_size
31
+ self.max_tokens_protein = max_tokens_protein
32
+
33
+ def crop(
34
+ self,
35
+ data: Tokenized,
36
+ max_tokens: int,
37
+ max_atoms: Optional[int] = None,
38
+ ) -> Tokenized:
39
+ """Crop the data to a maximum number of tokens.
40
+
41
+ Parameters
42
+ ----------
43
+ data : Tokenized
44
+ The tokenized data.
45
+ max_tokens : int
46
+ The maximum number of tokens to crop.
47
+ random : np.random.RandomState
48
+ The random state for reproducibility.
49
+ max_atoms : Optional[int]
50
+ The maximum number of atoms to consider.
51
+
52
+ Returns
53
+ -------
54
+ Tokenized
55
+ The cropped data.
56
+
57
+ """
58
+ # Get token data
59
+ token_data = data.tokens
60
+ token_bonds = data.bonds
61
+
62
+ # Filter to resolved tokens
63
+ valid_tokens = token_data[token_data["resolved_mask"]]
64
+
65
+ # Check if we have any valid tokens
66
+ if not valid_tokens.size:
67
+ msg = "No valid tokens in structure"
68
+ raise ValueError(msg)
69
+
70
+ # compute minimum distance to ligand
71
+ ligand_coords = valid_tokens[valid_tokens["affinity_mask"]]["center_coords"]
72
+ dists = np.min(
73
+ np.sum(
74
+ (valid_tokens["center_coords"][:, None] - ligand_coords[None]) ** 2,
75
+ axis=-1,
76
+ )
77
+ ** 0.5,
78
+ axis=1,
79
+ )
80
+
81
+ indices = np.argsort(dists)
82
+
83
+ # Select cropped indices
84
+ cropped: set[int] = set()
85
+ total_atoms = 0
86
+
87
+ # protein tokens
88
+ cropped_protein: set[int] = set()
89
+ ligand_ids = set(
90
+ valid_tokens[
91
+ valid_tokens["mol_type"] == const.chain_type_ids["NONPOLYMER"]
92
+ ]["token_idx"]
93
+ )
94
+
95
+ for idx in indices:
96
+ # Get the token
97
+ token = valid_tokens[idx]
98
+
99
+ # Get all tokens from this chain
100
+ chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
101
+
102
+ # Pick the whole chain if possible, otherwise select
103
+ # a contiguous subset centered at the query token
104
+ if len(chain_tokens) <= self.neighborhood_size:
105
+ new_tokens = chain_tokens
106
+ else:
107
+ # First limit to the maximum set of tokens, with the
108
+ # neighboorhood on both sides to handle edges. This
109
+ # is mostly for efficiency with the while loop below.
110
+ min_idx = token["res_idx"] - self.neighborhood_size
111
+ max_idx = token["res_idx"] + self.neighborhood_size
112
+
113
+ max_token_set = chain_tokens
114
+ max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
115
+ max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
116
+
117
+ # Start by adding just the query token
118
+ new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
119
+
120
+ # Expand the neighborhood until we have enough tokens, one
121
+ # by one to handle some edge cases with non-standard chains.
122
+ # We switch to the res_idx instead of the token_idx to always
123
+ # include all tokens from modified residues or from ligands.
124
+ min_idx = max_idx = token["res_idx"]
125
+ while new_tokens.size < self.neighborhood_size:
126
+ min_idx = min_idx - 1
127
+ max_idx = max_idx + 1
128
+ new_tokens = max_token_set
129
+ new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
130
+ new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
131
+
132
+ # Compute new tokens and new atoms
133
+ new_indices = set(new_tokens["token_idx"]) - cropped
134
+ new_tokens = token_data[list(new_indices)]
135
+ new_atoms = np.sum(new_tokens["atom_num"])
136
+
137
+ # Stop if we exceed the max number of tokens or atoms
138
+ if (
139
+ (len(new_indices) > (max_tokens - len(cropped)))
140
+ or ((max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms))
141
+ or (
142
+ len(cropped_protein | new_indices - ligand_ids)
143
+ > self.max_tokens_protein
144
+ )
145
+ ):
146
+ break
147
+
148
+ # Add new indices
149
+ cropped.update(new_indices)
150
+ total_atoms += new_atoms
151
+
152
+ # Add protein indices
153
+ cropped_protein.update(new_indices - ligand_ids)
154
+
155
+ # Get the cropped tokens sorted by index
156
+ token_data = token_data[sorted(cropped)]
157
+
158
+ # Only keep bonds within the cropped tokens
159
+ indices = token_data["token_idx"]
160
+ token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
161
+ token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
162
+
163
+ # Return the cropped tokens
164
+ return replace(data, tokens=token_data, bonds=token_bonds)
@@ -0,0 +1,296 @@
1
+ from dataclasses import replace
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ from scipy.spatial.distance import cdist
6
+
7
+ from boltz.data import const
8
+ from boltz.data.crop.cropper import Cropper
9
+ from boltz.data.types import Tokenized
10
+
11
+
12
+ def pick_random_token(
13
+ tokens: np.ndarray,
14
+ random: np.random.RandomState,
15
+ ) -> np.ndarray:
16
+ """Pick a random token from the data.
17
+
18
+ Parameters
19
+ ----------
20
+ tokens : np.ndarray
21
+ The token data.
22
+ random : np.ndarray
23
+ The random state for reproducibility.
24
+
25
+ Returns
26
+ -------
27
+ np.ndarray
28
+ The selected token.
29
+
30
+ """
31
+ return tokens[random.randint(len(tokens))]
32
+
33
+
34
+ def pick_chain_token(
35
+ tokens: np.ndarray,
36
+ chain_id: int,
37
+ random: np.random.RandomState,
38
+ ) -> np.ndarray:
39
+ """Pick a random token from a chain.
40
+
41
+ Parameters
42
+ ----------
43
+ tokens : np.ndarray
44
+ The token data.
45
+ chain_id : int
46
+ The chain ID.
47
+ random : np.ndarray
48
+ The random state for reproducibility.
49
+
50
+ Returns
51
+ -------
52
+ np.ndarray
53
+ The selected token.
54
+
55
+ """
56
+ # Filter to chain
57
+ chain_tokens = tokens[tokens["asym_id"] == chain_id]
58
+
59
+ # Pick from chain, fallback to all tokens
60
+ if chain_tokens.size:
61
+ query = pick_random_token(chain_tokens, random)
62
+ else:
63
+ query = pick_random_token(tokens, random)
64
+
65
+ return query
66
+
67
+
68
+ def pick_interface_token(
69
+ tokens: np.ndarray,
70
+ interface: np.ndarray,
71
+ random: np.random.RandomState,
72
+ ) -> np.ndarray:
73
+ """Pick a random token from an interface.
74
+
75
+ Parameters
76
+ ----------
77
+ tokens : np.ndarray
78
+ The token data.
79
+ interface : int
80
+ The interface ID.
81
+ random : np.ndarray
82
+ The random state for reproducibility.
83
+
84
+ Returns
85
+ -------
86
+ np.ndarray
87
+ The selected token.
88
+
89
+ """
90
+ # Sample random interface
91
+ chain_1 = int(interface["chain_1"])
92
+ chain_2 = int(interface["chain_2"])
93
+
94
+ tokens_1 = tokens[tokens["asym_id"] == chain_1]
95
+ tokens_2 = tokens[tokens["asym_id"] == chain_2]
96
+
97
+ # If no interface, pick from the chains
98
+ if tokens_1.size and (not tokens_2.size):
99
+ query = pick_random_token(tokens_1, random)
100
+ elif tokens_2.size and (not tokens_1.size):
101
+ query = pick_random_token(tokens_2, random)
102
+ elif (not tokens_1.size) and (not tokens_2.size):
103
+ query = pick_random_token(tokens, random)
104
+ else:
105
+ # If we have tokens, compute distances
106
+ tokens_1_coords = tokens_1["center_coords"]
107
+ tokens_2_coords = tokens_2["center_coords"]
108
+
109
+ dists = cdist(tokens_1_coords, tokens_2_coords)
110
+ cuttoff = dists < const.interface_cutoff
111
+
112
+ # In rare cases, the interface cuttoff is slightly
113
+ # too small, then we slightly expand it if it happens
114
+ if not np.any(cuttoff):
115
+ cuttoff = dists < (const.interface_cutoff + 5.0)
116
+
117
+ tokens_1 = tokens_1[np.any(cuttoff, axis=1)]
118
+ tokens_2 = tokens_2[np.any(cuttoff, axis=0)]
119
+
120
+ # Select random token
121
+ candidates = np.concatenate([tokens_1, tokens_2])
122
+ query = pick_random_token(candidates, random)
123
+
124
+ return query
125
+
126
+
127
+ class BoltzCropper(Cropper):
128
+ """Interpolate between contiguous and spatial crops."""
129
+
130
+ def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None:
131
+ """Initialize the cropper.
132
+
133
+ Modulates the type of cropping to be performed.
134
+ Smaller neighborhoods result in more spatial
135
+ cropping. Larger neighborhoods result in more
136
+ continuous cropping. A mix can be achieved by
137
+ providing a range over which to sample.
138
+
139
+ Parameters
140
+ ----------
141
+ min_neighborhood : int
142
+ The minimum neighborhood size, by default 0.
143
+ max_neighborhood : int
144
+ The maximum neighborhood size, by default 40.
145
+
146
+ """
147
+ sizes = list(range(min_neighborhood, max_neighborhood + 1, 2))
148
+ self.neighborhood_sizes = sizes
149
+
150
+ def crop( # noqa: PLR0915
151
+ self,
152
+ data: Tokenized,
153
+ max_tokens: int,
154
+ random: np.random.RandomState,
155
+ max_atoms: Optional[int] = None,
156
+ chain_id: Optional[int] = None,
157
+ interface_id: Optional[int] = None,
158
+ ) -> Tokenized:
159
+ """Crop the data to a maximum number of tokens.
160
+
161
+ Parameters
162
+ ----------
163
+ data : Tokenized
164
+ The tokenized data.
165
+ max_tokens : int
166
+ The maximum number of tokens to crop.
167
+ random : np.random.RandomState
168
+ The random state for reproducibility.
169
+ max_atoms : int, optional
170
+ The maximum number of atoms to consider.
171
+ chain_id : int, optional
172
+ The chain ID to crop.
173
+ interface_id : int, optional
174
+ The interface ID to crop.
175
+
176
+ Returns
177
+ -------
178
+ Tokenized
179
+ The cropped data.
180
+
181
+ """
182
+ # Check inputs
183
+ if chain_id is not None and interface_id is not None:
184
+ msg = "Only one of chain_id or interface_id can be provided."
185
+ raise ValueError(msg)
186
+
187
+ # Randomly select a neighborhood size
188
+ neighborhood_size = random.choice(self.neighborhood_sizes)
189
+
190
+ # Get token data
191
+ token_data = data.tokens
192
+ token_bonds = data.bonds
193
+ mask = data.structure.mask
194
+ chains = data.structure.chains
195
+ interfaces = data.structure.interfaces
196
+
197
+ # Filter to valid chains
198
+ valid_chains = chains[mask]
199
+
200
+ # Filter to valid interfaces
201
+ valid_interfaces = interfaces
202
+ valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]]
203
+ valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]]
204
+
205
+ # Filter to resolved tokens
206
+ valid_tokens = token_data[token_data["resolved_mask"]]
207
+
208
+ # Check if we have any valid tokens
209
+ if not valid_tokens.size:
210
+ msg = "No valid tokens in structure"
211
+ raise ValueError(msg)
212
+
213
+ # Pick a random token, chain, or interface
214
+ if chain_id is not None:
215
+ query = pick_chain_token(valid_tokens, chain_id, random)
216
+ elif interface_id is not None:
217
+ interface = interfaces[interface_id]
218
+ query = pick_interface_token(valid_tokens, interface, random)
219
+ elif valid_interfaces.size:
220
+ idx = random.randint(len(valid_interfaces))
221
+ interface = valid_interfaces[idx]
222
+ query = pick_interface_token(valid_tokens, interface, random)
223
+ else:
224
+ idx = random.randint(len(valid_chains))
225
+ chain_id = valid_chains[idx]["asym_id"]
226
+ query = pick_chain_token(valid_tokens, chain_id, random)
227
+
228
+ # Sort all tokens by distance to query_coords
229
+ dists = valid_tokens["center_coords"] - query["center_coords"]
230
+ indices = np.argsort(np.linalg.norm(dists, axis=1))
231
+
232
+ # Select cropped indices
233
+ cropped: set[int] = set()
234
+ total_atoms = 0
235
+ for idx in indices:
236
+ # Get the token
237
+ token = valid_tokens[idx]
238
+
239
+ # Get all tokens from this chain
240
+ chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
241
+
242
+ # Pick the whole chain if possible, otherwise select
243
+ # a contiguous subset centered at the query token
244
+ if len(chain_tokens) <= neighborhood_size:
245
+ new_tokens = chain_tokens
246
+ else:
247
+ # First limit to the maximum set of tokens, with the
248
+ # neighboorhood on both sides to handle edges. This
249
+ # is mostly for efficiency with the while loop below.
250
+ min_idx = token["res_idx"] - neighborhood_size
251
+ max_idx = token["res_idx"] + neighborhood_size
252
+
253
+ max_token_set = chain_tokens
254
+ max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
255
+ max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
256
+
257
+ # Start by adding just the query token
258
+ new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
259
+
260
+ # Expand the neighborhood until we have enough tokens, one
261
+ # by one to handle some edge cases with non-standard chains.
262
+ # We switch to the res_idx instead of the token_idx to always
263
+ # include all tokens from modified residues or from ligands.
264
+ min_idx = max_idx = token["res_idx"]
265
+ while new_tokens.size < neighborhood_size:
266
+ min_idx = min_idx - 1
267
+ max_idx = max_idx + 1
268
+ new_tokens = max_token_set
269
+ new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
270
+ new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
271
+
272
+ # Compute new tokens and new atoms
273
+ new_indices = set(new_tokens["token_idx"]) - cropped
274
+ new_tokens = token_data[list(new_indices)]
275
+ new_atoms = np.sum(new_tokens["atom_num"])
276
+
277
+ # Stop if we exceed the max number of tokens or atoms
278
+ if (len(new_indices) > (max_tokens - len(cropped))) or (
279
+ (max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms)
280
+ ):
281
+ break
282
+
283
+ # Add new indices
284
+ cropped.update(new_indices)
285
+ total_atoms += new_atoms
286
+
287
+ # Get the cropped tokens sorted by index
288
+ token_data = token_data[sorted(cropped)]
289
+
290
+ # Only keep bonds within the cropped tokens
291
+ indices = token_data["token_idx"]
292
+ token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
293
+ token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
294
+
295
+ # Return the cropped tokens
296
+ return replace(data, tokens=token_data, bonds=token_bonds)
@@ -0,0 +1,45 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ from boltz.data.types import Tokenized
7
+
8
+
9
+ class Cropper(ABC):
10
+ """Abstract base class for cropper."""
11
+
12
+ @abstractmethod
13
+ def crop(
14
+ self,
15
+ data: Tokenized,
16
+ max_tokens: int,
17
+ random: np.random.RandomState,
18
+ max_atoms: Optional[int] = None,
19
+ chain_id: Optional[int] = None,
20
+ interface_id: Optional[int] = None,
21
+ ) -> Tokenized:
22
+ """Crop the data to a maximum number of tokens.
23
+
24
+ Parameters
25
+ ----------
26
+ data : Tokenized
27
+ The tokenized data.
28
+ max_tokens : int
29
+ The maximum number of tokens to crop.
30
+ random : np.random.RandomState
31
+ The random state for reproducibility.
32
+ max_atoms : Optional[int]
33
+ The maximum number of atoms to consider.
34
+ chain_id : Optional[int]
35
+ The chain ID to crop.
36
+ interface_id : Optional[int]
37
+ The interface ID to crop.
38
+
39
+ Returns
40
+ -------
41
+ Tokenized
42
+ The cropped data.
43
+
44
+ """
45
+ raise NotImplementedError
File without changes