boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,235 @@
1
+ # From https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py
2
+
3
+ import logging
4
+ import os
5
+ import random
6
+ import tarfile
7
+ import time
8
+ from typing import Union
9
+
10
+ import requests
11
+ from tqdm import tqdm
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ TQDM_BAR_FORMAT = (
16
+ "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
17
+ )
18
+
19
+
20
+ def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915
21
+ x: Union[str, list[str]],
22
+ prefix: str = "tmp",
23
+ use_env: bool = True,
24
+ use_filter: bool = True,
25
+ use_pairing: bool = False,
26
+ pairing_strategy: str = "greedy",
27
+ host_url: str = "https://api.colabfold.com",
28
+ ) -> tuple[list[str], list[str]]:
29
+ submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
30
+
31
+ # Set header agent as boltz
32
+ headers = {}
33
+ headers["User-Agent"] = "boltz"
34
+
35
+ def submit(seqs, mode, N=101):
36
+ n, query = N, ""
37
+ for seq in seqs:
38
+ query += f">{n}\n{seq}\n"
39
+ n += 1
40
+
41
+ error_count = 0
42
+ while True:
43
+ try:
44
+ # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
45
+ # "good practice to set connect timeouts to slightly larger than a multiple of 3"
46
+ res = requests.post(
47
+ f"{host_url}/{submission_endpoint}",
48
+ data={"q": query, "mode": mode},
49
+ timeout=6.02,
50
+ headers=headers,
51
+ )
52
+ except Exception as e:
53
+ error_count += 1
54
+ logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
55
+ logger.warning(f"Error: {e}")
56
+ if error_count > 5:
57
+ raise Exception("Too many failed attempts for the MSA generation request.")
58
+ time.sleep(5)
59
+ else:
60
+ break
61
+
62
+ try:
63
+ out = res.json()
64
+ except ValueError:
65
+ logger.error(f"Server didn't reply with json: {res.text}")
66
+ out = {"status": "ERROR"}
67
+ return out
68
+
69
+ def status(ID):
70
+ error_count = 0
71
+ while True:
72
+ try:
73
+ res = requests.get(
74
+ f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers
75
+ )
76
+ except Exception as e:
77
+ error_count += 1
78
+ logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
79
+ logger.warning(f"Error: {e}")
80
+ if error_count > 5:
81
+ raise Exception("Too many failed attempts for the MSA generation request.")
82
+ time.sleep(5)
83
+ else:
84
+ break
85
+ try:
86
+ out = res.json()
87
+ except ValueError:
88
+ logger.error(f"Server didn't reply with json: {res.text}")
89
+ out = {"status": "ERROR"}
90
+ return out
91
+
92
+ def download(ID, path):
93
+ error_count = 0
94
+ while True:
95
+ try:
96
+ res = requests.get(
97
+ f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers
98
+ )
99
+ except Exception as e:
100
+ error_count += 1
101
+ logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
102
+ logger.warning(f"Error: {e}")
103
+ if error_count > 5:
104
+ raise Exception("Too many failed attempts for the MSA generation request.")
105
+ time.sleep(5)
106
+ else:
107
+ break
108
+ with open(path, "wb") as out:
109
+ out.write(res.content)
110
+
111
+ # process input x
112
+ seqs = [x] if isinstance(x, str) else x
113
+
114
+ # setup mode
115
+ if use_filter:
116
+ mode = "env" if use_env else "all"
117
+ else:
118
+ mode = "env-nofilter" if use_env else "nofilter"
119
+
120
+ if use_pairing:
121
+ mode = ""
122
+ # greedy is default, complete was the previous behavior
123
+ if pairing_strategy == "greedy":
124
+ mode = "pairgreedy"
125
+ elif pairing_strategy == "complete":
126
+ mode = "paircomplete"
127
+ if use_env:
128
+ mode = mode + "-env"
129
+
130
+ # define path
131
+ path = f"{prefix}_{mode}"
132
+ if not os.path.isdir(path):
133
+ os.mkdir(path)
134
+
135
+ # call mmseqs2 api
136
+ tar_gz_file = f"{path}/out.tar.gz"
137
+ N, REDO = 101, True
138
+
139
+ # deduplicate and keep track of order
140
+ seqs_unique = []
141
+ # TODO this might be slow for large sets
142
+ [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
143
+ Ms = [N + seqs_unique.index(seq) for seq in seqs]
144
+ # lets do it!
145
+ if not os.path.isfile(tar_gz_file):
146
+ TIME_ESTIMATE = 150 * len(seqs_unique)
147
+ with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
148
+ while REDO:
149
+ pbar.set_description("SUBMIT")
150
+
151
+ # Resubmit job until it goes through
152
+ out = submit(seqs_unique, mode, N)
153
+ while out["status"] in ["UNKNOWN", "RATELIMIT"]:
154
+ sleep_time = 5 + random.randint(0, 5)
155
+ logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
156
+ # resubmit
157
+ time.sleep(sleep_time)
158
+ out = submit(seqs_unique, mode, N)
159
+
160
+ if out["status"] == "ERROR":
161
+ msg = (
162
+ "MMseqs2 API is giving errors. Please confirm your "
163
+ " input is a valid protein sequence. If error persists, "
164
+ "please try again an hour later."
165
+ )
166
+ raise Exception(msg)
167
+
168
+ if out["status"] == "MAINTENANCE":
169
+ msg = (
170
+ "MMseqs2 API is undergoing maintenance. "
171
+ "Please try again in a few minutes."
172
+ )
173
+ raise Exception(msg)
174
+
175
+ # wait for job to finish
176
+ ID, TIME = out["id"], 0
177
+ pbar.set_description(out["status"])
178
+ while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
179
+ t = 5 + random.randint(0, 5)
180
+ logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
181
+ time.sleep(t)
182
+ out = status(ID)
183
+ pbar.set_description(out["status"])
184
+ if out["status"] == "RUNNING":
185
+ TIME += t
186
+ pbar.update(n=t)
187
+
188
+ if out["status"] == "COMPLETE":
189
+ if TIME < TIME_ESTIMATE:
190
+ pbar.update(n=(TIME_ESTIMATE - TIME))
191
+ REDO = False
192
+
193
+ if out["status"] == "ERROR":
194
+ REDO = False
195
+ msg = (
196
+ "MMseqs2 API is giving errors. Please confirm your "
197
+ " input is a valid protein sequence. If error persists, "
198
+ "please try again an hour later."
199
+ )
200
+ raise Exception(msg)
201
+
202
+ # Download results
203
+ download(ID, tar_gz_file)
204
+
205
+ # prep list of a3m files
206
+ if use_pairing:
207
+ a3m_files = [f"{path}/pair.a3m"]
208
+ else:
209
+ a3m_files = [f"{path}/uniref.a3m"]
210
+ if use_env:
211
+ a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
212
+
213
+ # extract a3m files
214
+ if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
215
+ with tarfile.open(tar_gz_file) as tar_gz:
216
+ tar_gz.extractall(path)
217
+
218
+ # gather a3m lines
219
+ a3m_lines = {}
220
+ for a3m_file in a3m_files:
221
+ update_M, M = True, None
222
+ for line in open(a3m_file, "r"):
223
+ if len(line) > 0:
224
+ if "\x00" in line:
225
+ line = line.replace("\x00", "")
226
+ update_M = True
227
+ if line.startswith(">") and update_M:
228
+ M = int(line[1:].rstrip())
229
+ update_M = False
230
+ if M not in a3m_lines:
231
+ a3m_lines[M] = []
232
+ a3m_lines[M].append(line)
233
+
234
+ a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
235
+ return a3m_lines
boltz/data/pad.py ADDED
@@ -0,0 +1,84 @@
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn.functional import pad
4
+
5
+
6
+ def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
7
+ """Pad a tensor along a given dimension.
8
+
9
+ Parameters
10
+ ----------
11
+ data : Tensor
12
+ The input tensor.
13
+ dim : int
14
+ The dimension to pad.
15
+ pad_len : float
16
+ The padding length.
17
+ value : int, optional
18
+ The value to pad with.
19
+
20
+ Returns
21
+ -------
22
+ Tensor
23
+ The padded tensor.
24
+
25
+ """
26
+ if pad_len == 0:
27
+ return data
28
+
29
+ total_dims = len(data.shape)
30
+ padding = [0] * (2 * (total_dims - dim))
31
+ padding[2 * (total_dims - 1 - dim) + 1] = pad_len
32
+ return pad(data, tuple(padding), value=value)
33
+
34
+
35
+ def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
36
+ """Pad the data in all dimensions to the maximum found.
37
+
38
+ Parameters
39
+ ----------
40
+ data : list[Tensor]
41
+ list of tensors to pad.
42
+ value : float
43
+ The value to use for padding.
44
+
45
+ Returns
46
+ -------
47
+ Tensor
48
+ The padded tensor.
49
+ Tensor
50
+ The padding mask.
51
+
52
+ """
53
+ if isinstance(data[0], str):
54
+ return data, 0
55
+
56
+ # Check if all have the same shape
57
+ if all(d.shape == data[0].shape for d in data):
58
+ return torch.stack(data, dim=0), 0
59
+
60
+ # Get the maximum in each dimension
61
+ num_dims = len(data[0].shape)
62
+ max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
63
+
64
+ # Get the padding lengths
65
+ pad_lengths = []
66
+ for d in data:
67
+ dims = []
68
+ for i in range(num_dims):
69
+ dims.append(0)
70
+ dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
71
+ pad_lengths.append(dims)
72
+
73
+ # Pad the data
74
+ padding = [
75
+ pad(torch.ones_like(d), pad_len, value=0)
76
+ for d, pad_len in zip(data, pad_lengths)
77
+ ]
78
+ data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
79
+
80
+ # Stack the data
81
+ padding = torch.stack(padding, dim=0)
82
+ data = torch.stack(data, dim=0)
83
+
84
+ return data, padding
File without changes
@@ -0,0 +1,134 @@
1
+ import gzip
2
+ from pathlib import Path
3
+ from typing import Optional, TextIO
4
+
5
+ import numpy as np
6
+
7
+ from boltz.data import const
8
+ from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9
+
10
+
11
+ def _parse_a3m( # noqa: C901
12
+ lines: TextIO,
13
+ taxonomy: Optional[dict[str, str]],
14
+ max_seqs: Optional[int] = None,
15
+ ) -> MSA:
16
+ """Process an MSA file.
17
+
18
+ Parameters
19
+ ----------
20
+ lines : TextIO
21
+ The lines of the MSA file.
22
+ taxonomy : dict[str, str]
23
+ The taxonomy database, if available.
24
+ max_seqs : int, optional
25
+ The maximum number of sequences.
26
+
27
+ Returns
28
+ -------
29
+ MSA
30
+ The MSA object.
31
+
32
+ """
33
+ visited = set()
34
+ sequences = []
35
+ deletions = []
36
+ residues = []
37
+
38
+ seq_idx = 0
39
+ for line in lines:
40
+ line: str
41
+ line = line.strip() # noqa: PLW2901
42
+ if not line or line.startswith("#"):
43
+ continue
44
+
45
+ # Get taxonomy, if annotated
46
+ if line.startswith(">"):
47
+ header = line.split()[0]
48
+ if taxonomy and header.startswith(">UniRef100"):
49
+ uniref_id = header.split("_")[1]
50
+ taxonomy_id = taxonomy.get(uniref_id)
51
+ if taxonomy_id is None:
52
+ taxonomy_id = -1
53
+ else:
54
+ taxonomy_id = -1
55
+ continue
56
+
57
+ # Skip if duplicate sequence
58
+ str_seq = line.replace("-", "").upper()
59
+ if str_seq not in visited:
60
+ visited.add(str_seq)
61
+ else:
62
+ continue
63
+
64
+ # Process sequence
65
+ residue = []
66
+ deletion = []
67
+ count = 0
68
+ res_idx = 0
69
+ for c in line:
70
+ if c != "-" and c.islower():
71
+ count += 1
72
+ continue
73
+ token = const.prot_letter_to_token[c]
74
+ token = const.token_ids[token]
75
+ residue.append(token)
76
+ if count > 0:
77
+ deletion.append((res_idx, count))
78
+ count = 0
79
+ res_idx += 1
80
+
81
+ res_start = len(residues)
82
+ res_end = res_start + len(residue)
83
+
84
+ del_start = len(deletions)
85
+ del_end = del_start + len(deletion)
86
+
87
+ sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
88
+ residues.extend(residue)
89
+ deletions.extend(deletion)
90
+
91
+ seq_idx += 1
92
+ if (max_seqs is not None) and (seq_idx >= max_seqs):
93
+ break
94
+
95
+ # Create MSA object
96
+ msa = MSA(
97
+ residues=np.array(residues, dtype=MSAResidue),
98
+ deletions=np.array(deletions, dtype=MSADeletion),
99
+ sequences=np.array(sequences, dtype=MSASequence),
100
+ )
101
+ return msa
102
+
103
+
104
+ def parse_a3m(
105
+ path: Path,
106
+ taxonomy: Optional[dict[str, str]],
107
+ max_seqs: Optional[int] = None,
108
+ ) -> MSA:
109
+ """Process an A3M file.
110
+
111
+ Parameters
112
+ ----------
113
+ path : Path
114
+ The path to the a3m(.gz) file.
115
+ taxonomy : Redis
116
+ The taxonomy database.
117
+ max_seqs : int, optional
118
+ The maximum number of sequences.
119
+
120
+ Returns
121
+ -------
122
+ MSA
123
+ The MSA object.
124
+
125
+ """
126
+ # Read the file
127
+ if path.suffix == ".gz":
128
+ with gzip.open(str(path), "rt") as f:
129
+ msa = _parse_a3m(f, taxonomy, max_seqs)
130
+ else:
131
+ with path.open("r") as f:
132
+ msa = _parse_a3m(f, taxonomy, max_seqs)
133
+
134
+ return msa
@@ -0,0 +1,100 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from boltz.data import const
8
+ from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9
+
10
+
11
+ def parse_csv(
12
+ path: Path,
13
+ max_seqs: Optional[int] = None,
14
+ ) -> MSA:
15
+ """Process an A3M file.
16
+
17
+ Parameters
18
+ ----------
19
+ path : Path
20
+ The path to the a3m(.gz) file.
21
+ max_seqs : int, optional
22
+ The maximum number of sequences.
23
+
24
+ Returns
25
+ -------
26
+ MSA
27
+ The MSA object.
28
+
29
+ """
30
+ # Read file
31
+ data = pd.read_csv(path)
32
+
33
+ # Check columns
34
+ if tuple(sorted(data.columns)) != ("key", "sequence"):
35
+ msg = "Invalid CSV format, expected columns: ['sequence', 'key']"
36
+ raise ValueError(msg)
37
+
38
+ # Create taxonomy mapping
39
+ visited = set()
40
+ sequences = []
41
+ deletions = []
42
+ residues = []
43
+
44
+ seq_idx = 0
45
+ for line, key in zip(data["sequence"], data["key"]):
46
+ line: str
47
+ line = line.strip() # noqa: PLW2901
48
+ if not line:
49
+ continue
50
+
51
+ # Get taxonomy, if annotated
52
+ taxonomy_id = -1
53
+ if (str(key) != "nan") and (key is not None) and (key != ""):
54
+ taxonomy_id = key
55
+
56
+ # Skip if duplicate sequence
57
+ str_seq = line.replace("-", "").upper()
58
+ if str_seq not in visited:
59
+ visited.add(str_seq)
60
+ else:
61
+ continue
62
+
63
+ # Process sequence
64
+ residue = []
65
+ deletion = []
66
+ count = 0
67
+ res_idx = 0
68
+ for c in line:
69
+ if c != "-" and c.islower():
70
+ count += 1
71
+ continue
72
+ token = const.prot_letter_to_token[c]
73
+ token = const.token_ids[token]
74
+ residue.append(token)
75
+ if count > 0:
76
+ deletion.append((res_idx, count))
77
+ count = 0
78
+ res_idx += 1
79
+
80
+ res_start = len(residues)
81
+ res_end = res_start + len(residue)
82
+
83
+ del_start = len(deletions)
84
+ del_end = del_start + len(deletion)
85
+
86
+ sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
87
+ residues.extend(residue)
88
+ deletions.extend(deletion)
89
+
90
+ seq_idx += 1
91
+ if (max_seqs is not None) and (seq_idx >= max_seqs):
92
+ break
93
+
94
+ # Create MSA object
95
+ msa = MSA(
96
+ residues=np.array(residues, dtype=MSAResidue),
97
+ deletions=np.array(deletions, dtype=MSADeletion),
98
+ sequences=np.array(sequences, dtype=MSASequence),
99
+ )
100
+ return msa
@@ -0,0 +1,138 @@
1
+ from collections.abc import Mapping
2
+ from pathlib import Path
3
+
4
+ from Bio import SeqIO
5
+ from rdkit.Chem.rdchem import Mol
6
+
7
+ from boltz.data.parse.yaml import parse_boltz_schema
8
+ from boltz.data.types import Target
9
+
10
+
11
+ def parse_fasta( # noqa: C901, PLR0912
12
+ path: Path,
13
+ ccd: Mapping[str, Mol],
14
+ mol_dir: Path,
15
+ boltz2: bool = False,
16
+ ) -> Target:
17
+ """Parse a fasta file.
18
+
19
+ The name of the fasta file is used as the name of this job.
20
+ We rely on the fasta record id to determine the entity type.
21
+
22
+ > CHAIN_ID|ENTITY_TYPE|MSA_ID
23
+ SEQUENCE
24
+ > CHAIN_ID|ENTITY_TYPE|MSA_ID
25
+ ...
26
+
27
+ Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
28
+ and CHAIN_ID is the chain identifier, which should be unique.
29
+ The MSA_ID is optional and should only be used on proteins.
30
+
31
+ Parameters
32
+ ----------
33
+ fasta_file : Path
34
+ Path to the fasta file.
35
+ ccd : Dict
36
+ Dictionary of CCD components.
37
+ mol_dir : Path
38
+ Path to the directory containing the molecules.
39
+ boltz2 : bool
40
+ Whether to parse the input for Boltz2.
41
+
42
+ Returns
43
+ -------
44
+ Target
45
+ The parsed target.
46
+
47
+ """
48
+ # Read fasta file
49
+ with path.open("r") as f:
50
+ records = list(SeqIO.parse(f, "fasta"))
51
+
52
+ # Make sure all records have a chain id and entity
53
+ for seq_record in records:
54
+ if "|" not in seq_record.id:
55
+ msg = f"Invalid record id: {seq_record.id}"
56
+ raise ValueError(msg)
57
+
58
+ header = seq_record.id.split("|")
59
+ assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
60
+
61
+ chain_id, entity_type = header[:2]
62
+ if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
63
+ msg = f"Invalid entity type: {entity_type}"
64
+ raise ValueError(msg)
65
+ if chain_id == "":
66
+ msg = "Empty chain id in input fasta!"
67
+ raise ValueError(msg)
68
+ if entity_type == "":
69
+ msg = "Empty entity type in input fasta!"
70
+ raise ValueError(msg)
71
+
72
+ # Convert to yaml format
73
+ sequences = []
74
+ for seq_record in records:
75
+ # Get chain id, entity type and sequence
76
+ header = seq_record.id.split("|")
77
+ chain_id, entity_type = header[:2]
78
+ if len(header) == 3 and header[2] != "":
79
+ assert (
80
+ entity_type.lower() == "protein"
81
+ ), "MSA_ID is only allowed for proteins"
82
+ msa_id = header[2]
83
+ else:
84
+ msa_id = None
85
+
86
+ entity_type = entity_type.upper()
87
+ seq = str(seq_record.seq)
88
+
89
+ if entity_type == "PROTEIN":
90
+ molecule = {
91
+ "protein": {
92
+ "id": chain_id,
93
+ "sequence": seq,
94
+ "modifications": [],
95
+ "msa": msa_id,
96
+ },
97
+ }
98
+ elif entity_type == "RNA":
99
+ molecule = {
100
+ "rna": {
101
+ "id": chain_id,
102
+ "sequence": seq,
103
+ "modifications": [],
104
+ },
105
+ }
106
+ elif entity_type == "DNA":
107
+ molecule = {
108
+ "dna": {
109
+ "id": chain_id,
110
+ "sequence": seq,
111
+ "modifications": [],
112
+ }
113
+ }
114
+ elif entity_type.upper() == "CCD":
115
+ molecule = {
116
+ "ligand": {
117
+ "id": chain_id,
118
+ "ccd": seq,
119
+ }
120
+ }
121
+ elif entity_type.upper() == "SMILES":
122
+ molecule = {
123
+ "ligand": {
124
+ "id": chain_id,
125
+ "smiles": seq,
126
+ }
127
+ }
128
+
129
+ sequences.append(molecule)
130
+
131
+ data = {
132
+ "sequences": sequences,
133
+ "bonds": [],
134
+ "version": 1,
135
+ }
136
+
137
+ name = path.stem
138
+ return parse_boltz_schema(name, data, ccd, mol_dir, boltz2)