boltz-vsynthes 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/data/mol.py +0 -4
- boltz/data/parse/__init__.py +21 -0
- boltz/data/parse/pdb.py +71 -0
- boltz/data/parse/pdb_download.py +114 -0
- boltz/data/parse/schema.py +802 -141
- boltz/data/parse/sdf.py +60 -0
- boltz/main.py +176 -208
- {boltz_vsynthes-1.0.8.dist-info → boltz_vsynthes-1.0.10.dist-info}/METADATA +2 -2
- {boltz_vsynthes-1.0.8.dist-info → boltz_vsynthes-1.0.10.dist-info}/RECORD +13 -10
- {boltz_vsynthes-1.0.8.dist-info → boltz_vsynthes-1.0.10.dist-info}/WHEEL +0 -0
- {boltz_vsynthes-1.0.8.dist-info → boltz_vsynthes-1.0.10.dist-info}/entry_points.txt +0 -0
- {boltz_vsynthes-1.0.8.dist-info → boltz_vsynthes-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {boltz_vsynthes-1.0.8.dist-info → boltz_vsynthes-1.0.10.dist-info}/top_level.txt +0 -0
boltz/data/parse/schema.py
CHANGED
@@ -621,9 +621,6 @@ def get_mol(ccd: str, mols: dict, moldir: str) -> Mol:
|
|
621
621
|
Return mol with ccd from mols if it is in mols. Otherwise load it from moldir,
|
622
622
|
add it to mols, and return the mol.
|
623
623
|
"""
|
624
|
-
# Skip if it's a SMILES string (starts with LIG)
|
625
|
-
if ccd.startswith("LIG"):
|
626
|
-
return None
|
627
624
|
mol = mols.get(ccd)
|
628
625
|
if mol is None:
|
629
626
|
mol = load_molecules(moldir, [ccd])[ccd]
|
@@ -658,10 +655,6 @@ def parse_ccd_residue(
|
|
658
655
|
The output ParsedResidue, if successful.
|
659
656
|
|
660
657
|
"""
|
661
|
-
# Skip if it's a SMILES string (starts with LIG)
|
662
|
-
if name.startswith("LIG"):
|
663
|
-
return None
|
664
|
-
|
665
658
|
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
666
659
|
|
667
660
|
# Check if this is a single heavy atom CCD residue
|
@@ -936,111 +929,100 @@ def token_spec_to_ids(
|
|
936
929
|
contacts.append((chain_to_idx[chain_name], residue_index_or_atom_name - 1))
|
937
930
|
|
938
931
|
|
939
|
-
def parse_boltz_schema(
|
940
|
-
|
932
|
+
def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
|
933
|
+
name: str,
|
934
|
+
schema: dict,
|
935
|
+
ccd: Mapping[str, Mol],
|
936
|
+
mol_dir: Optional[Path] = None,
|
937
|
+
boltz_2: bool = False,
|
938
|
+
) -> Target:
|
939
|
+
"""Parse a Boltz input yaml / json.
|
940
|
+
|
941
|
+
The input file should be a dictionary with the following format:
|
942
|
+
|
943
|
+
version: 1
|
944
|
+
sequences:
|
945
|
+
- protein:
|
946
|
+
id: A
|
947
|
+
sequence: "MADQLTEEQIAEFKEAFSLF"
|
948
|
+
msa: path/to/msa1.a3m
|
949
|
+
- protein:
|
950
|
+
id: [B, C]
|
951
|
+
sequence: "AKLSILPWGHC"
|
952
|
+
msa: path/to/msa2.a3m
|
953
|
+
- rna:
|
954
|
+
id: D
|
955
|
+
sequence: "GCAUAGC"
|
956
|
+
- ligand:
|
957
|
+
id: E
|
958
|
+
smiles: "CC1=CC=CC=C1"
|
959
|
+
constraints:
|
960
|
+
- bond:
|
961
|
+
atom1: [A, 1, CA]
|
962
|
+
atom2: [A, 2, N]
|
963
|
+
- pocket:
|
964
|
+
binder: E
|
965
|
+
contacts: [[B, 1], [B, 2]]
|
966
|
+
max_distance: 6
|
967
|
+
- contact:
|
968
|
+
token1: [A, 1]
|
969
|
+
token2: [B, 1]
|
970
|
+
max_distance: 6
|
971
|
+
templates:
|
972
|
+
- cif: path/to/template.cif
|
973
|
+
properties:
|
974
|
+
- affinity:
|
975
|
+
binder: E
|
941
976
|
|
942
977
|
Parameters
|
943
978
|
----------
|
979
|
+
name : str
|
980
|
+
A name for the input.
|
944
981
|
schema : dict
|
945
982
|
The input schema.
|
983
|
+
components : dict
|
984
|
+
Dictionary of CCD components.
|
985
|
+
mol_dir: Path
|
986
|
+
Path to the directory containing the molecules.
|
987
|
+
boltz2: bool
|
988
|
+
Whether to parse the input for Boltz2.
|
946
989
|
|
947
990
|
Returns
|
948
991
|
-------
|
949
|
-
|
950
|
-
The parsed
|
992
|
+
Target
|
993
|
+
The parsed target.
|
951
994
|
|
952
995
|
"""
|
953
|
-
#
|
954
|
-
|
955
|
-
|
996
|
+
# Assert version 1
|
997
|
+
version = schema.get("version", 1)
|
998
|
+
if version != 1:
|
999
|
+
msg = f"Invalid version {version} in input!"
|
956
1000
|
raise ValueError(msg)
|
957
1001
|
|
958
|
-
#
|
1002
|
+
# Disable rdkit warnings
|
1003
|
+
blocker = rdBase.BlockLogs() # noqa: F841
|
1004
|
+
|
1005
|
+
# First group items that have the same type, sequence and modifications
|
959
1006
|
items_to_group = {}
|
960
1007
|
chain_name_to_entity_type = {}
|
961
1008
|
|
962
|
-
# Keep track of ligand IDs
|
963
|
-
ligand_id = 1
|
964
|
-
ligand_id_map = {}
|
965
|
-
|
966
|
-
# Parse sequences
|
967
1009
|
for item in schema["sequences"]:
|
968
|
-
|
969
|
-
|
970
|
-
|
1010
|
+
# Get entity type
|
1011
|
+
entity_type = next(iter(item.keys())).lower()
|
1012
|
+
if entity_type not in {"protein", "dna", "rna", "ligand"}:
|
1013
|
+
msg = f"Invalid entity type: {entity_type}"
|
1014
|
+
raise ValueError(msg)
|
971
1015
|
|
972
1016
|
# Get sequence
|
973
|
-
if entity_type
|
974
|
-
|
975
|
-
seq = item[entity_type]["sequence"]
|
976
|
-
elif "pdb" in item[entity_type]:
|
977
|
-
pdb_input = item[entity_type]["pdb"]
|
978
|
-
if pdb_input.startswith(("http://", "https://")):
|
979
|
-
# It's a PDB ID
|
980
|
-
import requests
|
981
|
-
response = requests.get(f"https://files.rcsb.org/download/{pdb_input}.pdb")
|
982
|
-
if response.status_code != 200:
|
983
|
-
msg = f"Failed to download PDB file: {pdb_input}"
|
984
|
-
raise FileNotFoundError(msg)
|
985
|
-
pdb_data = response.text
|
986
|
-
else:
|
987
|
-
# It's a file path
|
988
|
-
pdb_path = Path(pdb_input)
|
989
|
-
if not pdb_path.exists():
|
990
|
-
msg = f"PDB file not found: {pdb_path}"
|
991
|
-
raise FileNotFoundError(msg)
|
992
|
-
with pdb_path.open("r") as f:
|
993
|
-
pdb_data = f.read()
|
994
|
-
|
995
|
-
# Parse PDB data
|
996
|
-
from Bio.PDB import PDBParser
|
997
|
-
from io import StringIO
|
998
|
-
parser = PDBParser()
|
999
|
-
structure = parser.get_structure("protein", StringIO(pdb_data))
|
1000
|
-
|
1001
|
-
# Extract sequence
|
1002
|
-
seq = ""
|
1003
|
-
for model in structure:
|
1004
|
-
for chain in model:
|
1005
|
-
for residue in chain:
|
1006
|
-
if residue.id[0] == " ": # Only standard residues
|
1007
|
-
seq += residue.resname
|
1008
|
-
seq = "".join(seq)
|
1009
|
-
else:
|
1010
|
-
msg = "Protein must have either 'sequence' or 'pdb' field"
|
1011
|
-
raise ValueError(msg)
|
1017
|
+
if entity_type in {"protein", "dna", "rna"}:
|
1018
|
+
seq = str(item[entity_type]["sequence"])
|
1012
1019
|
elif entity_type == "ligand":
|
1013
|
-
|
1020
|
+
assert "smiles" in item[entity_type] or "ccd" in item[entity_type]
|
1021
|
+
assert "smiles" not in item[entity_type] or "ccd" not in item[entity_type]
|
1014
1022
|
if "smiles" in item[entity_type]:
|
1015
1023
|
seq = str(item[entity_type]["smiles"])
|
1016
|
-
# Map user-provided ID to internal LIG1, LIG2, etc.
|
1017
|
-
for id in entity_id:
|
1018
|
-
ligand_id_map[id] = f"LIG{ligand_id}"
|
1019
|
-
ligand_id += 1
|
1020
|
-
elif "ccd" in item[entity_type]:
|
1021
|
-
seq = str(item[entity_type]["ccd"])
|
1022
|
-
# For CCD ligands, use the CCD code as the internal ID
|
1023
|
-
for id in entity_id:
|
1024
|
-
ligand_id_map[id] = seq
|
1025
|
-
elif "sdf" in item[entity_type]:
|
1026
|
-
sdf_path = Path(item[entity_type]["sdf"])
|
1027
|
-
if not sdf_path.exists():
|
1028
|
-
msg = f"SDF file not found: {sdf_path}"
|
1029
|
-
raise FileNotFoundError(msg)
|
1030
|
-
# Read SDF and convert to SMILES
|
1031
|
-
from rdkit import Chem
|
1032
|
-
mol = Chem.SDMolSupplier(str(sdf_path))[0]
|
1033
|
-
if mol is None:
|
1034
|
-
msg = f"Failed to read SDF file: {sdf_path}"
|
1035
|
-
raise ValueError(msg)
|
1036
|
-
seq = Chem.MolToSmiles(mol)
|
1037
|
-
# Map user-provided ID to internal LIG1, LIG2, etc.
|
1038
|
-
for id in entity_id:
|
1039
|
-
ligand_id_map[id] = f"LIG{ligand_id}"
|
1040
|
-
ligand_id += 1
|
1041
1024
|
else:
|
1042
|
-
|
1043
|
-
raise ValueError(msg)
|
1025
|
+
seq = str(item[entity_type]["ccd"])
|
1044
1026
|
|
1045
1027
|
# Group items by entity
|
1046
1028
|
items_to_group.setdefault((entity_type, seq), []).append(item)
|
@@ -1051,60 +1033,739 @@ def parse_boltz_schema(schema: dict) -> dict:
|
|
1051
1033
|
for chain_name in chain_names:
|
1052
1034
|
chain_name_to_entity_type[chain_name] = entity_type
|
1053
1035
|
|
1054
|
-
#
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1036
|
+
# Check if any affinity ligand is present
|
1037
|
+
affinity_ligands = set()
|
1038
|
+
properties = schema.get("properties", [])
|
1039
|
+
if properties and not boltz_2:
|
1040
|
+
msg = "Affinity prediction is only supported for Boltz2!"
|
1041
|
+
raise ValueError(msg)
|
1042
|
+
|
1043
|
+
for prop in properties:
|
1044
|
+
prop_type = next(iter(prop.keys())).lower()
|
1045
|
+
if prop_type == "affinity":
|
1046
|
+
binder = prop["affinity"]["binder"]
|
1047
|
+
if not isinstance(binder, str):
|
1048
|
+
# TODO: support multi residue ligands and ccd's
|
1049
|
+
msg = "Binder must be a single chain."
|
1050
|
+
raise ValueError(msg)
|
1051
|
+
|
1052
|
+
if binder not in chain_name_to_entity_type:
|
1053
|
+
msg = f"Could not find binder with name {binder} in the input!"
|
1054
|
+
raise ValueError(msg)
|
1055
|
+
|
1056
|
+
if chain_name_to_entity_type[binder] != "ligand":
|
1057
|
+
msg = (
|
1058
|
+
f"Chain {binder} is not a ligand! "
|
1059
|
+
"Affinity is currently only supported for ligands."
|
1060
|
+
)
|
1061
|
+
raise ValueError(msg)
|
1062
|
+
|
1063
|
+
affinity_ligands.add(binder)
|
1064
|
+
|
1065
|
+
# Check only one affinity ligand is present
|
1066
|
+
if len(affinity_ligands) > 1:
|
1067
|
+
msg = "Only one affinity ligand is currently supported!"
|
1068
|
+
raise ValueError(msg)
|
1069
|
+
|
1070
|
+
# Go through entities and parse them
|
1071
|
+
extra_mols: dict[str, Mol] = {}
|
1072
|
+
chains: dict[str, ParsedChain] = {}
|
1073
|
+
chain_to_msa: dict[str, str] = {}
|
1074
|
+
entity_to_seq: dict[str, str] = {}
|
1075
|
+
is_msa_custom = False
|
1076
|
+
is_msa_auto = False
|
1077
|
+
ligand_id = 1
|
1078
|
+
for entity_id, items in enumerate(items_to_group.values()):
|
1079
|
+
# Get entity type and sequence
|
1080
|
+
entity_type = next(iter(items[0].keys())).lower()
|
1081
|
+
|
1082
|
+
# Get ids
|
1083
|
+
ids = []
|
1084
|
+
for item in items:
|
1085
|
+
if isinstance(item[entity_type]["id"], str):
|
1086
|
+
ids.append(item[entity_type]["id"])
|
1087
|
+
elif isinstance(item[entity_type]["id"], list):
|
1088
|
+
ids.extend(item[entity_type]["id"])
|
1089
|
+
|
1090
|
+
# Check if any affinity ligand is present
|
1091
|
+
if len(ids) == 1:
|
1092
|
+
affinity = ids[0] in affinity_ligands
|
1093
|
+
elif (len(ids) > 1) and any(x in affinity_ligands for x in ids):
|
1094
|
+
msg = "Cannot compute affinity for a ligand that has multiple copies!"
|
1095
|
+
raise ValueError(msg)
|
1096
|
+
else:
|
1097
|
+
affinity = False
|
1098
|
+
|
1099
|
+
# Ensure all the items share the same msa
|
1100
|
+
msa = -1
|
1061
1101
|
if entity_type == "protein":
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
msg = f"Protein {binder} not found in sequences"
|
1102
|
+
# Get the msa, default to 0, meaning auto-generated
|
1103
|
+
msa = items[0][entity_type].get("msa", 0)
|
1104
|
+
if (msa is None) or (msa == ""):
|
1105
|
+
msa = 0
|
1106
|
+
|
1107
|
+
# Check if all MSAs are the same within the same entity
|
1108
|
+
for item in items:
|
1109
|
+
item_msa = item[entity_type].get("msa", 0)
|
1110
|
+
if (item_msa is None) or (item_msa == ""):
|
1111
|
+
item_msa = 0
|
1112
|
+
|
1113
|
+
if item_msa != msa:
|
1114
|
+
msg = "All proteins with the same sequence must share the same MSA!"
|
1076
1115
|
raise ValueError(msg)
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1116
|
+
|
1117
|
+
# Set the MSA, warn if passed in single-sequence mode
|
1118
|
+
if msa == "empty":
|
1119
|
+
msa = -1
|
1120
|
+
msg = (
|
1121
|
+
"Found explicit empty MSA for some proteins, will run "
|
1122
|
+
"these in single sequence mode. Keep in mind that the "
|
1123
|
+
"model predictions will be suboptimal without an MSA."
|
1124
|
+
)
|
1125
|
+
click.echo(msg)
|
1126
|
+
|
1127
|
+
if msa not in (0, -1):
|
1128
|
+
is_msa_custom = True
|
1129
|
+
elif msa == 0:
|
1130
|
+
is_msa_auto = True
|
1131
|
+
|
1132
|
+
# Parse a polymer
|
1133
|
+
if entity_type in {"protein", "dna", "rna"}:
|
1134
|
+
# Get token map
|
1135
|
+
if entity_type == "rna":
|
1136
|
+
token_map = const.rna_letter_to_token
|
1137
|
+
elif entity_type == "dna":
|
1138
|
+
token_map = const.dna_letter_to_token
|
1139
|
+
elif entity_type == "protein":
|
1140
|
+
token_map = const.prot_letter_to_token
|
1141
|
+
else:
|
1142
|
+
msg = f"Unknown polymer type: {entity_type}"
|
1143
|
+
raise ValueError(msg)
|
1144
|
+
|
1145
|
+
# Get polymer info
|
1146
|
+
chain_type = const.chain_type_ids[entity_type.upper()]
|
1147
|
+
unk_token = const.unk_token[entity_type.upper()]
|
1148
|
+
|
1149
|
+
# Extract sequence
|
1150
|
+
raw_seq = items[0][entity_type]["sequence"]
|
1151
|
+
entity_to_seq[entity_id] = raw_seq
|
1152
|
+
|
1153
|
+
# Convert sequence to tokens
|
1154
|
+
seq = [token_map.get(c, unk_token) for c in list(raw_seq)]
|
1155
|
+
|
1156
|
+
# Apply modifications
|
1157
|
+
for mod in items[0][entity_type].get("modifications", []):
|
1158
|
+
code = mod["ccd"]
|
1159
|
+
idx = mod["position"] - 1 # 1-indexed
|
1160
|
+
seq[idx] = code
|
1161
|
+
|
1162
|
+
cyclic = items[0][entity_type].get("cyclic", False)
|
1163
|
+
|
1164
|
+
# Parse a polymer
|
1165
|
+
parsed_chain = parse_polymer(
|
1166
|
+
sequence=seq,
|
1167
|
+
raw_sequence=raw_seq,
|
1168
|
+
entity=entity_id,
|
1169
|
+
chain_type=chain_type,
|
1170
|
+
components=ccd,
|
1171
|
+
cyclic=cyclic,
|
1172
|
+
mol_dir=mol_dir,
|
1173
|
+
)
|
1174
|
+
|
1175
|
+
# Parse a non-polymer
|
1176
|
+
elif (entity_type == "ligand") and "ccd" in (items[0][entity_type]):
|
1177
|
+
seq = items[0][entity_type]["ccd"]
|
1178
|
+
|
1179
|
+
if isinstance(seq, str):
|
1180
|
+
seq = [seq]
|
1181
|
+
|
1182
|
+
if affinity and len(seq) > 1:
|
1183
|
+
msg = "Cannot compute affinity for multi residue ligands!"
|
1184
|
+
raise ValueError(msg)
|
1185
|
+
|
1186
|
+
residues = []
|
1187
|
+
affinity_mw = None
|
1188
|
+
for res_idx, code in enumerate(seq):
|
1189
|
+
# Get mol
|
1190
|
+
ref_mol = get_mol(code, ccd, mol_dir)
|
1191
|
+
|
1192
|
+
if affinity:
|
1193
|
+
affinity_mw = AllChem.Descriptors.MolWt(ref_mol)
|
1194
|
+
|
1195
|
+
# Parse residue
|
1196
|
+
residue = parse_ccd_residue(
|
1197
|
+
name=code,
|
1198
|
+
ref_mol=ref_mol,
|
1199
|
+
res_idx=res_idx,
|
1200
|
+
)
|
1201
|
+
residues.append(residue)
|
1202
|
+
|
1203
|
+
# Create multi ligand chain
|
1204
|
+
parsed_chain = ParsedChain(
|
1205
|
+
entity=entity_id,
|
1206
|
+
residues=residues,
|
1207
|
+
type=const.chain_type_ids["NONPOLYMER"],
|
1208
|
+
cyclic_period=0,
|
1209
|
+
sequence=None,
|
1210
|
+
affinity=affinity,
|
1211
|
+
affinity_mw=affinity_mw,
|
1212
|
+
)
|
1213
|
+
|
1214
|
+
assert not items[0][entity_type].get(
|
1215
|
+
"cyclic", False
|
1216
|
+
), "Cyclic flag is not supported for ligands"
|
1217
|
+
|
1218
|
+
elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
|
1219
|
+
seq = items[0][entity_type]["smiles"]
|
1220
|
+
|
1221
|
+
if affinity:
|
1222
|
+
seq = standardize(seq)
|
1223
|
+
|
1224
|
+
mol = AllChem.MolFromSmiles(seq)
|
1225
|
+
mol = AllChem.AddHs(mol)
|
1226
|
+
|
1227
|
+
# Set atom names
|
1228
|
+
canonical_order = AllChem.CanonicalRankAtoms(mol)
|
1229
|
+
for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
|
1230
|
+
atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
|
1231
|
+
if len(atom_name) > 4:
|
1232
|
+
msg = (
|
1233
|
+
f"{seq} has an atom with a name longer than "
|
1234
|
+
f"4 characters: {atom_name}."
|
1235
|
+
)
|
1092
1236
|
raise ValueError(msg)
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1237
|
+
atom.SetProp("name", atom_name)
|
1238
|
+
|
1239
|
+
success = compute_3d_conformer(mol)
|
1240
|
+
if not success:
|
1241
|
+
msg = f"Failed to compute 3D conformer for {seq}"
|
1242
|
+
raise ValueError(msg)
|
1243
|
+
|
1244
|
+
mol_no_h = AllChem.RemoveHs(mol, sanitize=False)
|
1245
|
+
affinity_mw = AllChem.Descriptors.MolWt(mol_no_h) if affinity else None
|
1246
|
+
extra_mols[f"LIG{ligand_id}"] = mol_no_h
|
1247
|
+
residue = parse_ccd_residue(
|
1248
|
+
name=f"LIG{ligand_id}",
|
1249
|
+
ref_mol=mol,
|
1250
|
+
res_idx=0,
|
1251
|
+
)
|
1252
|
+
|
1253
|
+
ligand_id += 1
|
1254
|
+
parsed_chain = ParsedChain(
|
1255
|
+
entity=entity_id,
|
1256
|
+
residues=[residue],
|
1257
|
+
type=const.chain_type_ids["NONPOLYMER"],
|
1258
|
+
cyclic_period=0,
|
1259
|
+
sequence=None,
|
1260
|
+
affinity=affinity,
|
1261
|
+
affinity_mw=affinity_mw,
|
1262
|
+
)
|
1263
|
+
|
1264
|
+
assert not items[0][entity_type].get(
|
1265
|
+
"cyclic", False
|
1266
|
+
), "Cyclic flag is not supported for ligands"
|
1267
|
+
|
1268
|
+
else:
|
1269
|
+
msg = f"Invalid entity type: {entity_type}"
|
1270
|
+
raise ValueError(msg)
|
1271
|
+
|
1272
|
+
# Add as many chains as provided ids
|
1273
|
+
for item in items:
|
1274
|
+
ids = item[entity_type]["id"]
|
1275
|
+
if isinstance(ids, str):
|
1276
|
+
ids = [ids]
|
1277
|
+
for chain_name in ids:
|
1278
|
+
chains[chain_name] = parsed_chain
|
1279
|
+
chain_to_msa[chain_name] = msa
|
1280
|
+
|
1281
|
+
# Check if msa is custom or auto
|
1282
|
+
if is_msa_custom and is_msa_auto:
|
1283
|
+
msg = "Cannot mix custom and auto-generated MSAs in the same input!"
|
1284
|
+
raise ValueError(msg)
|
1285
|
+
|
1286
|
+
# If no chains parsed fail
|
1287
|
+
if not chains:
|
1288
|
+
msg = "No chains parsed!"
|
1289
|
+
raise ValueError(msg)
|
1290
|
+
|
1291
|
+
# Create tables
|
1292
|
+
atom_data = []
|
1293
|
+
bond_data = []
|
1294
|
+
res_data = []
|
1295
|
+
chain_data = []
|
1296
|
+
protein_chains = set()
|
1297
|
+
affinity_info = None
|
1298
|
+
|
1299
|
+
rdkit_bounds_constraint_data = []
|
1300
|
+
chiral_atom_constraint_data = []
|
1301
|
+
stereo_bond_constraint_data = []
|
1302
|
+
planar_bond_constraint_data = []
|
1303
|
+
planar_ring_5_constraint_data = []
|
1304
|
+
planar_ring_6_constraint_data = []
|
1305
|
+
|
1306
|
+
# Convert parsed chains to tables
|
1307
|
+
atom_idx = 0
|
1308
|
+
res_idx = 0
|
1309
|
+
asym_id = 0
|
1310
|
+
sym_count = {}
|
1311
|
+
chain_to_idx = {}
|
1312
|
+
|
1313
|
+
# Keep a mapping of (chain_name, residue_idx, atom_name) to atom_idx
|
1314
|
+
atom_idx_map = {}
|
1315
|
+
|
1316
|
+
for asym_id, (chain_name, chain) in enumerate(chains.items()):
|
1317
|
+
# Compute number of atoms and residues
|
1318
|
+
res_num = len(chain.residues)
|
1319
|
+
atom_num = sum(len(res.atoms) for res in chain.residues)
|
1320
|
+
|
1321
|
+
# Save protein chains for later
|
1322
|
+
if chain.type == const.chain_type_ids["PROTEIN"]:
|
1323
|
+
protein_chains.add(chain_name)
|
1324
|
+
|
1325
|
+
# Add affinity info
|
1326
|
+
if chain.affinity and affinity_info is not None:
|
1327
|
+
msg = "Cannot compute affinity for multiple ligands!"
|
1328
|
+
raise ValueError(msg)
|
1329
|
+
|
1330
|
+
if chain.affinity:
|
1331
|
+
affinity_info = AffinityInfo(
|
1332
|
+
chain_id=asym_id,
|
1333
|
+
mw=chain.affinity_mw,
|
1334
|
+
)
|
1335
|
+
|
1336
|
+
# Find all copies of this chain in the assembly
|
1337
|
+
entity_id = int(chain.entity)
|
1338
|
+
sym_id = sym_count.get(entity_id, 0)
|
1339
|
+
chain_data.append(
|
1340
|
+
(
|
1341
|
+
chain_name,
|
1342
|
+
chain.type,
|
1343
|
+
entity_id,
|
1344
|
+
sym_id,
|
1345
|
+
asym_id,
|
1346
|
+
atom_idx,
|
1347
|
+
atom_num,
|
1348
|
+
res_idx,
|
1349
|
+
res_num,
|
1350
|
+
chain.cyclic_period,
|
1351
|
+
)
|
1352
|
+
)
|
1353
|
+
chain_to_idx[chain_name] = asym_id
|
1354
|
+
sym_count[entity_id] = sym_id + 1
|
1355
|
+
|
1356
|
+
# Add residue, atom, bond, data
|
1357
|
+
for res in chain.residues:
|
1358
|
+
atom_center = atom_idx + res.atom_center
|
1359
|
+
atom_disto = atom_idx + res.atom_disto
|
1360
|
+
res_data.append(
|
1361
|
+
(
|
1362
|
+
res.name,
|
1363
|
+
res.type,
|
1364
|
+
res.idx,
|
1365
|
+
atom_idx,
|
1366
|
+
len(res.atoms),
|
1367
|
+
atom_center,
|
1368
|
+
atom_disto,
|
1369
|
+
res.is_standard,
|
1370
|
+
res.is_present,
|
1371
|
+
)
|
1372
|
+
)
|
1373
|
+
|
1374
|
+
if res.rdkit_bounds_constraints is not None:
|
1375
|
+
for constraint in res.rdkit_bounds_constraints:
|
1376
|
+
rdkit_bounds_constraint_data.append( # noqa: PERF401
|
1377
|
+
(
|
1378
|
+
tuple(
|
1379
|
+
c_atom_idx + atom_idx
|
1380
|
+
for c_atom_idx in constraint.atom_idxs
|
1381
|
+
),
|
1382
|
+
constraint.is_bond,
|
1383
|
+
constraint.is_angle,
|
1384
|
+
constraint.upper_bound,
|
1385
|
+
constraint.lower_bound,
|
1386
|
+
)
|
1387
|
+
)
|
1388
|
+
if res.chiral_atom_constraints is not None:
|
1389
|
+
for constraint in res.chiral_atom_constraints:
|
1390
|
+
chiral_atom_constraint_data.append( # noqa: PERF401
|
1391
|
+
(
|
1392
|
+
tuple(
|
1393
|
+
c_atom_idx + atom_idx
|
1394
|
+
for c_atom_idx in constraint.atom_idxs
|
1395
|
+
),
|
1396
|
+
constraint.is_reference,
|
1397
|
+
constraint.is_r,
|
1398
|
+
)
|
1399
|
+
)
|
1400
|
+
if res.stereo_bond_constraints is not None:
|
1401
|
+
for constraint in res.stereo_bond_constraints:
|
1402
|
+
stereo_bond_constraint_data.append( # noqa: PERF401
|
1403
|
+
(
|
1404
|
+
tuple(
|
1405
|
+
c_atom_idx + atom_idx
|
1406
|
+
for c_atom_idx in constraint.atom_idxs
|
1407
|
+
),
|
1408
|
+
constraint.is_check,
|
1409
|
+
constraint.is_e,
|
1410
|
+
)
|
1411
|
+
)
|
1412
|
+
if res.planar_bond_constraints is not None:
|
1413
|
+
for constraint in res.planar_bond_constraints:
|
1414
|
+
planar_bond_constraint_data.append( # noqa: PERF401
|
1415
|
+
(
|
1416
|
+
tuple(
|
1417
|
+
c_atom_idx + atom_idx
|
1418
|
+
for c_atom_idx in constraint.atom_idxs
|
1419
|
+
),
|
1420
|
+
)
|
1421
|
+
)
|
1422
|
+
if res.planar_ring_5_constraints is not None:
|
1423
|
+
for constraint in res.planar_ring_5_constraints:
|
1424
|
+
planar_ring_5_constraint_data.append( # noqa: PERF401
|
1425
|
+
(
|
1426
|
+
tuple(
|
1427
|
+
c_atom_idx + atom_idx
|
1428
|
+
for c_atom_idx in constraint.atom_idxs
|
1429
|
+
),
|
1430
|
+
)
|
1431
|
+
)
|
1432
|
+
if res.planar_ring_6_constraints is not None:
|
1433
|
+
for constraint in res.planar_ring_6_constraints:
|
1434
|
+
planar_ring_6_constraint_data.append( # noqa: PERF401
|
1435
|
+
(
|
1436
|
+
tuple(
|
1437
|
+
c_atom_idx + atom_idx
|
1438
|
+
for c_atom_idx in constraint.atom_idxs
|
1439
|
+
),
|
1440
|
+
)
|
1441
|
+
)
|
1442
|
+
|
1443
|
+
for bond in res.bonds:
|
1444
|
+
atom_1 = atom_idx + bond.atom_1
|
1445
|
+
atom_2 = atom_idx + bond.atom_2
|
1446
|
+
bond_data.append(
|
1447
|
+
(
|
1448
|
+
asym_id,
|
1449
|
+
asym_id,
|
1450
|
+
res_idx,
|
1451
|
+
res_idx,
|
1452
|
+
atom_1,
|
1453
|
+
atom_2,
|
1454
|
+
bond.type,
|
1455
|
+
)
|
1456
|
+
)
|
1457
|
+
|
1458
|
+
for atom in res.atoms:
|
1459
|
+
# Add atom to map
|
1460
|
+
atom_idx_map[(chain_name, res.idx, atom.name)] = (
|
1461
|
+
asym_id,
|
1462
|
+
res_idx,
|
1463
|
+
atom_idx,
|
1464
|
+
)
|
1465
|
+
|
1466
|
+
# Add atom to data
|
1467
|
+
atom_data.append(
|
1468
|
+
(
|
1469
|
+
atom.name,
|
1470
|
+
atom.element,
|
1471
|
+
atom.charge,
|
1472
|
+
atom.coords,
|
1473
|
+
atom.conformer,
|
1474
|
+
atom.is_present,
|
1475
|
+
atom.chirality,
|
1476
|
+
)
|
1477
|
+
)
|
1478
|
+
atom_idx += 1
|
1479
|
+
|
1480
|
+
res_idx += 1
|
1481
|
+
|
1482
|
+
# Parse constraints
|
1483
|
+
connections = []
|
1484
|
+
pocket_constraints = []
|
1485
|
+
contact_constraints = []
|
1486
|
+
constraints = schema.get("constraints", [])
|
1487
|
+
for constraint in constraints:
|
1488
|
+
if "bond" in constraint:
|
1489
|
+
if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
|
1490
|
+
msg = f"Bond constraint was not properly specified"
|
1491
|
+
raise ValueError(msg)
|
1492
|
+
|
1493
|
+
c1, r1, a1 = tuple(constraint["bond"]["atom1"])
|
1494
|
+
c2, r2, a2 = tuple(constraint["bond"]["atom2"])
|
1495
|
+
c1, r1, a1 = atom_idx_map[(c1, r1 - 1, a1)] # 1-indexed
|
1496
|
+
c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)] # 1-indexed
|
1497
|
+
connections.append((c1, c2, r1, r2, a1, a2))
|
1498
|
+
elif "pocket" in constraint:
|
1499
|
+
if (
|
1500
|
+
"binder" not in constraint["pocket"]
|
1501
|
+
or "contacts" not in constraint["pocket"]
|
1502
|
+
):
|
1503
|
+
msg = f"Pocket constraint was not properly specified"
|
1504
|
+
raise ValueError(msg)
|
1505
|
+
|
1506
|
+
if len(pocket_constraints) > 0 and not boltz_2:
|
1507
|
+
msg = f"Only one pocket binders is supported in Boltz-1!"
|
1508
|
+
raise ValueError(msg)
|
1509
|
+
|
1510
|
+
max_distance = constraint["pocket"].get("max_distance", 6.0)
|
1511
|
+
if max_distance != 6.0 and not boltz_2:
|
1512
|
+
msg = f"Max distance != 6.0 is not supported in Boltz-1!"
|
1513
|
+
raise ValueError(msg)
|
1514
|
+
|
1515
|
+
binder = constraint["pocket"]["binder"]
|
1516
|
+
binder = chain_to_idx[binder]
|
1517
|
+
|
1518
|
+
contacts = []
|
1519
|
+
for chain_name, residue_index_or_atom_name in constraint["pocket"][
|
1520
|
+
"contacts"
|
1521
|
+
]:
|
1522
|
+
if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]:
|
1523
|
+
# Non-polymer chains are indexed by atom name
|
1524
|
+
_, _, atom_idx = atom_idx_map[
|
1525
|
+
(chain_name, 0, residue_index_or_atom_name)
|
1526
|
+
]
|
1527
|
+
contact = (chain_to_idx[chain_name], atom_idx)
|
1528
|
+
else:
|
1529
|
+
# Polymer chains are indexed by residue index
|
1530
|
+
contact = (chain_to_idx[chain_name], residue_index_or_atom_name - 1)
|
1531
|
+
contacts.append(contact)
|
1532
|
+
|
1533
|
+
pocket_constraints.append((binder, contacts, max_distance))
|
1534
|
+
elif "contact" in constraint:
|
1535
|
+
if (
|
1536
|
+
"token1" not in constraint["contact"]
|
1537
|
+
or "token2" not in constraint["contact"]
|
1538
|
+
):
|
1539
|
+
msg = f"Contact constraint was not properly specified"
|
1540
|
+
raise ValueError(msg)
|
1541
|
+
|
1542
|
+
if not boltz_2:
|
1543
|
+
msg = f"Contact constraint is not supported in Boltz-1!"
|
1544
|
+
raise ValueError(msg)
|
1545
|
+
|
1546
|
+
max_distance = constraint["contact"].get("max_distance", 6.0)
|
1547
|
+
|
1548
|
+
chain_name1, residue_index_or_atom_name1 = constraint["contact"]["token1"]
|
1549
|
+
if chains[chain_name1].type == const.chain_type_ids["NONPOLYMER"]:
|
1550
|
+
# Non-polymer chains are indexed by atom name
|
1551
|
+
_, _, atom_idx = atom_idx_map[
|
1552
|
+
(chain_name1, 0, residue_index_or_atom_name1)
|
1553
|
+
]
|
1554
|
+
token1 = (chain_to_idx[chain_name1], atom_idx)
|
1555
|
+
else:
|
1556
|
+
# Polymer chains are indexed by residue index
|
1557
|
+
token1 = (chain_to_idx[chain_name1], residue_index_or_atom_name1 - 1)
|
1558
|
+
|
1559
|
+
pocket_constraints.append((binder, contacts, max_distance))
|
1560
|
+
else:
|
1561
|
+
msg = f"Invalid constraint: {constraint}"
|
1562
|
+
raise ValueError(msg)
|
1563
|
+
|
1564
|
+
# Get protein sequences in this YAML
|
1565
|
+
protein_seqs = {name: chains[name].sequence for name in protein_chains}
|
1566
|
+
|
1567
|
+
# Parse templates
|
1568
|
+
template_schema = schema.get("templates", [])
|
1569
|
+
if template_schema and not boltz_2:
|
1570
|
+
msg = "Templates are not supported in Boltz 1.0!"
|
1571
|
+
raise ValueError(msg)
|
1572
|
+
|
1573
|
+
templates = {}
|
1574
|
+
template_records = []
|
1575
|
+
for template in template_schema:
|
1576
|
+
if "cif" not in template:
|
1577
|
+
msg = "Template was not properly specified, missing CIF path!"
|
1578
|
+
raise ValueError(msg)
|
1579
|
+
|
1580
|
+
path = template["cif"]
|
1581
|
+
template_id = Path(path).stem
|
1582
|
+
chain_ids = template.get("chain_id", None)
|
1583
|
+
template_chain_ids = template.get("template_id", None)
|
1584
|
+
|
1585
|
+
# Check validity of input
|
1586
|
+
matched = False
|
1587
|
+
|
1588
|
+
if chain_ids is not None and not isinstance(chain_ids, list):
|
1589
|
+
chain_ids = [chain_ids]
|
1590
|
+
if template_chain_ids is not None and not isinstance(template_chain_ids, list):
|
1591
|
+
template_chain_ids = [template_chain_ids]
|
1592
|
+
|
1593
|
+
if (
|
1594
|
+
template_chain_ids is not None
|
1595
|
+
and chain_ids is not None
|
1596
|
+
and len(template_chain_ids) != len(chain_ids)
|
1597
|
+
):
|
1598
|
+
matched = True
|
1599
|
+
if len(template_chain_ids) != len(chain_ids):
|
1600
|
+
msg = (
|
1601
|
+
"When providing both the chain_id and template_id, the number of"
|
1602
|
+
"template_ids provided must match the number of chain_ids!"
|
1603
|
+
)
|
1604
|
+
raise ValueError(msg)
|
1605
|
+
|
1606
|
+
# Get relevant chains ids
|
1607
|
+
if chain_ids is None:
|
1608
|
+
chain_ids = list(protein_chains)
|
1609
|
+
|
1610
|
+
for chain_id in chain_ids:
|
1611
|
+
if chain_id not in protein_chains:
|
1612
|
+
msg = (
|
1613
|
+
f"Chain {chain_id} assigned for template"
|
1614
|
+
f"{template_id} is not one of the protein chains!"
|
1615
|
+
)
|
1616
|
+
raise ValueError(msg)
|
1617
|
+
|
1618
|
+
# Get relevant template chain ids
|
1619
|
+
parsed_template = parse_mmcif(
|
1620
|
+
path,
|
1621
|
+
mols=ccd,
|
1622
|
+
moldir=mol_dir,
|
1623
|
+
use_assembly=False,
|
1624
|
+
compute_interfaces=False,
|
1625
|
+
)
|
1626
|
+
template_proteins = {
|
1627
|
+
str(c["name"])
|
1628
|
+
for c in parsed_template.data.chains
|
1629
|
+
if c["mol_type"] == const.chain_type_ids["PROTEIN"]
|
1630
|
+
}
|
1631
|
+
if template_chain_ids is None:
|
1632
|
+
template_chain_ids = list(template_proteins)
|
1633
|
+
|
1634
|
+
for chain_id in template_chain_ids:
|
1635
|
+
if chain_id not in template_proteins:
|
1636
|
+
msg = (
|
1637
|
+
f"Template chain {chain_id} assigned for template"
|
1638
|
+
f"{template_id} is not one of the protein chains!"
|
1639
|
+
)
|
1640
|
+
raise ValueError(msg)
|
1641
|
+
|
1642
|
+
# Compute template records
|
1643
|
+
if matched:
|
1644
|
+
template_records.extend(
|
1645
|
+
get_template_records_from_matching(
|
1646
|
+
template_id=template_id,
|
1647
|
+
chain_ids=chain_ids,
|
1648
|
+
sequences=protein_seqs,
|
1649
|
+
template_chain_ids=template_chain_ids,
|
1650
|
+
template_sequences=parsed_template.sequences,
|
1651
|
+
)
|
1652
|
+
)
|
1653
|
+
else:
|
1654
|
+
template_records.extend(
|
1655
|
+
get_template_records_from_search(
|
1656
|
+
template_id=template_id,
|
1657
|
+
chain_ids=chain_ids,
|
1658
|
+
sequences=protein_seqs,
|
1659
|
+
template_chain_ids=template_chain_ids,
|
1660
|
+
template_sequences=parsed_template.sequences,
|
1661
|
+
)
|
1662
|
+
)
|
1663
|
+
# Save template
|
1664
|
+
templates[template_id] = parsed_template.data
|
1665
|
+
|
1666
|
+
# Convert into datatypes
|
1667
|
+
residues = np.array(res_data, dtype=Residue)
|
1668
|
+
chains = np.array(chain_data, dtype=Chain)
|
1669
|
+
interfaces = np.array([], dtype=Interface)
|
1670
|
+
mask = np.ones(len(chain_data), dtype=bool)
|
1671
|
+
rdkit_bounds_constraints = np.array(
|
1672
|
+
rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint
|
1673
|
+
)
|
1674
|
+
chiral_atom_constraints = np.array(
|
1675
|
+
chiral_atom_constraint_data, dtype=ChiralAtomConstraint
|
1676
|
+
)
|
1677
|
+
stereo_bond_constraints = np.array(
|
1678
|
+
stereo_bond_constraint_data, dtype=StereoBondConstraint
|
1679
|
+
)
|
1680
|
+
planar_bond_constraints = np.array(
|
1681
|
+
planar_bond_constraint_data, dtype=PlanarBondConstraint
|
1682
|
+
)
|
1683
|
+
planar_ring_5_constraints = np.array(
|
1684
|
+
planar_ring_5_constraint_data, dtype=PlanarRing5Constraint
|
1685
|
+
)
|
1686
|
+
planar_ring_6_constraints = np.array(
|
1687
|
+
planar_ring_6_constraint_data, dtype=PlanarRing6Constraint
|
1688
|
+
)
|
1689
|
+
|
1690
|
+
if boltz_2:
|
1691
|
+
atom_data = [(a[0], a[3], a[5], 0.0, 1.0) for a in atom_data]
|
1692
|
+
connections = [(*c, const.bond_type_ids["COVALENT"]) for c in connections]
|
1693
|
+
bond_data = bond_data + connections
|
1694
|
+
atoms = np.array(atom_data, dtype=AtomV2)
|
1695
|
+
bonds = np.array(bond_data, dtype=BondV2)
|
1696
|
+
coords = [(x,) for x in atoms["coords"]]
|
1697
|
+
coords = np.array(coords, Coords)
|
1698
|
+
ensemble = np.array([(0, len(coords))], dtype=Ensemble)
|
1699
|
+
data = StructureV2(
|
1700
|
+
atoms=atoms,
|
1701
|
+
bonds=bonds,
|
1702
|
+
residues=residues,
|
1703
|
+
chains=chains,
|
1704
|
+
interfaces=interfaces,
|
1705
|
+
mask=mask,
|
1706
|
+
coords=coords,
|
1707
|
+
ensemble=ensemble,
|
1708
|
+
)
|
1709
|
+
else:
|
1710
|
+
bond_data = [(b[4], b[5], b[6]) for b in bond_data]
|
1711
|
+
atom_data = [(convert_atom_name(a[0]), *a[1:]) for a in atom_data]
|
1712
|
+
atoms = np.array(atom_data, dtype=Atom)
|
1713
|
+
bonds = np.array(bond_data, dtype=Bond)
|
1714
|
+
connections = np.array(connections, dtype=Connection)
|
1715
|
+
data = Structure(
|
1716
|
+
atoms=atoms,
|
1717
|
+
bonds=bonds,
|
1718
|
+
residues=residues,
|
1719
|
+
chains=chains,
|
1720
|
+
connections=connections,
|
1721
|
+
interfaces=interfaces,
|
1722
|
+
mask=mask,
|
1723
|
+
)
|
1724
|
+
|
1725
|
+
# Create metadata
|
1726
|
+
struct_info = StructureInfo(num_chains=len(chains))
|
1727
|
+
chain_infos = []
|
1728
|
+
for chain in chains:
|
1729
|
+
chain_info = ChainInfo(
|
1730
|
+
chain_id=int(chain["asym_id"]),
|
1731
|
+
chain_name=chain["name"],
|
1732
|
+
mol_type=int(chain["mol_type"]),
|
1733
|
+
cluster_id=-1,
|
1734
|
+
msa_id=chain_to_msa[chain["name"]],
|
1735
|
+
num_residues=int(chain["res_num"]),
|
1736
|
+
valid=True,
|
1737
|
+
entity_id=int(chain["entity_id"]),
|
1738
|
+
)
|
1739
|
+
chain_infos.append(chain_info)
|
1740
|
+
|
1741
|
+
options = InferenceOptions(pocket_constraints=pocket_constraints)
|
1742
|
+
record = Record(
|
1743
|
+
id=name,
|
1744
|
+
structure=struct_info,
|
1745
|
+
chains=chain_infos,
|
1746
|
+
interfaces=[],
|
1747
|
+
inference_options=options,
|
1748
|
+
templates=template_records,
|
1749
|
+
affinity=affinity_info,
|
1750
|
+
)
|
1751
|
+
|
1752
|
+
residue_constraints = ResidueConstraints(
|
1753
|
+
rdkit_bounds_constraints=rdkit_bounds_constraints,
|
1754
|
+
chiral_atom_constraints=chiral_atom_constraints,
|
1755
|
+
stereo_bond_constraints=stereo_bond_constraints,
|
1756
|
+
planar_bond_constraints=planar_bond_constraints,
|
1757
|
+
planar_ring_5_constraints=planar_ring_5_constraints,
|
1758
|
+
planar_ring_6_constraints=planar_ring_6_constraints,
|
1759
|
+
)
|
1760
|
+
|
1761
|
+
return Target(
|
1762
|
+
record=record,
|
1763
|
+
structure=data,
|
1764
|
+
sequences=entity_to_seq,
|
1765
|
+
residue_constraints=residue_constraints,
|
1766
|
+
templates=templates,
|
1767
|
+
extra_mols=extra_mols,
|
1768
|
+
)
|
1108
1769
|
|
1109
1770
|
|
1110
1771
|
def standardize(smiles: str) -> Optional[str]:
|