boltz-vsynthes 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -621,9 +621,6 @@ def get_mol(ccd: str, mols: dict, moldir: str) -> Mol:
621
621
  Return mol with ccd from mols if it is in mols. Otherwise load it from moldir,
622
622
  add it to mols, and return the mol.
623
623
  """
624
- # Skip if it's a SMILES string (starts with LIG)
625
- if ccd.startswith("LIG"):
626
- return None
627
624
  mol = mols.get(ccd)
628
625
  if mol is None:
629
626
  mol = load_molecules(moldir, [ccd])[ccd]
@@ -658,10 +655,6 @@ def parse_ccd_residue(
658
655
  The output ParsedResidue, if successful.
659
656
 
660
657
  """
661
- # Skip if it's a SMILES string (starts with LIG)
662
- if name.startswith("LIG"):
663
- return None
664
-
665
658
  unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
666
659
 
667
660
  # Check if this is a single heavy atom CCD residue
@@ -936,111 +929,100 @@ def token_spec_to_ids(
936
929
  contacts.append((chain_to_idx[chain_name], residue_index_or_atom_name - 1))
937
930
 
938
931
 
939
- def parse_boltz_schema(schema: dict) -> dict:
940
- """Parse the Boltz input schema.
932
+ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
933
+ name: str,
934
+ schema: dict,
935
+ ccd: Mapping[str, Mol],
936
+ mol_dir: Optional[Path] = None,
937
+ boltz_2: bool = False,
938
+ ) -> Target:
939
+ """Parse a Boltz input yaml / json.
940
+
941
+ The input file should be a dictionary with the following format:
942
+
943
+ version: 1
944
+ sequences:
945
+ - protein:
946
+ id: A
947
+ sequence: "MADQLTEEQIAEFKEAFSLF"
948
+ msa: path/to/msa1.a3m
949
+ - protein:
950
+ id: [B, C]
951
+ sequence: "AKLSILPWGHC"
952
+ msa: path/to/msa2.a3m
953
+ - rna:
954
+ id: D
955
+ sequence: "GCAUAGC"
956
+ - ligand:
957
+ id: E
958
+ smiles: "CC1=CC=CC=C1"
959
+ constraints:
960
+ - bond:
961
+ atom1: [A, 1, CA]
962
+ atom2: [A, 2, N]
963
+ - pocket:
964
+ binder: E
965
+ contacts: [[B, 1], [B, 2]]
966
+ max_distance: 6
967
+ - contact:
968
+ token1: [A, 1]
969
+ token2: [B, 1]
970
+ max_distance: 6
971
+ templates:
972
+ - cif: path/to/template.cif
973
+ properties:
974
+ - affinity:
975
+ binder: E
941
976
 
942
977
  Parameters
943
978
  ----------
979
+ name : str
980
+ A name for the input.
944
981
  schema : dict
945
982
  The input schema.
983
+ components : dict
984
+ Dictionary of CCD components.
985
+ mol_dir: Path
986
+ Path to the directory containing the molecules.
987
+ boltz2: bool
988
+ Whether to parse the input for Boltz2.
946
989
 
947
990
  Returns
948
991
  -------
949
- dict
950
- The parsed schema.
992
+ Target
993
+ The parsed target.
951
994
 
952
995
  """
953
- # Check version
954
- if "version" not in schema:
955
- msg = "Schema must have a version field"
996
+ # Assert version 1
997
+ version = schema.get("version", 1)
998
+ if version != 1:
999
+ msg = f"Invalid version {version} in input!"
956
1000
  raise ValueError(msg)
957
1001
 
958
- # Group items by entity type and sequence
1002
+ # Disable rdkit warnings
1003
+ blocker = rdBase.BlockLogs() # noqa: F841
1004
+
1005
+ # First group items that have the same type, sequence and modifications
959
1006
  items_to_group = {}
960
1007
  chain_name_to_entity_type = {}
961
1008
 
962
- # Keep track of ligand IDs
963
- ligand_id = 1
964
- ligand_id_map = {}
965
-
966
- # Parse sequences
967
1009
  for item in schema["sequences"]:
968
- entity_type = list(item.keys())[0]
969
- entity_id = item[entity_type]["id"]
970
- entity_id = [entity_id] if isinstance(entity_id, str) else entity_id
1010
+ # Get entity type
1011
+ entity_type = next(iter(item.keys())).lower()
1012
+ if entity_type not in {"protein", "dna", "rna", "ligand"}:
1013
+ msg = f"Invalid entity type: {entity_type}"
1014
+ raise ValueError(msg)
971
1015
 
972
1016
  # Get sequence
973
- if entity_type == "protein":
974
- if "sequence" in item[entity_type]:
975
- seq = item[entity_type]["sequence"]
976
- elif "pdb" in item[entity_type]:
977
- pdb_input = item[entity_type]["pdb"]
978
- if pdb_input.startswith(("http://", "https://")):
979
- # It's a PDB ID
980
- import requests
981
- response = requests.get(f"https://files.rcsb.org/download/{pdb_input}.pdb")
982
- if response.status_code != 200:
983
- msg = f"Failed to download PDB file: {pdb_input}"
984
- raise FileNotFoundError(msg)
985
- pdb_data = response.text
986
- else:
987
- # It's a file path
988
- pdb_path = Path(pdb_input)
989
- if not pdb_path.exists():
990
- msg = f"PDB file not found: {pdb_path}"
991
- raise FileNotFoundError(msg)
992
- with pdb_path.open("r") as f:
993
- pdb_data = f.read()
994
-
995
- # Parse PDB data
996
- from Bio.PDB import PDBParser
997
- from io import StringIO
998
- parser = PDBParser()
999
- structure = parser.get_structure("protein", StringIO(pdb_data))
1000
-
1001
- # Extract sequence
1002
- seq = ""
1003
- for model in structure:
1004
- for chain in model:
1005
- for residue in chain:
1006
- if residue.id[0] == " ": # Only standard residues
1007
- seq += residue.resname
1008
- seq = "".join(seq)
1009
- else:
1010
- msg = "Protein must have either 'sequence' or 'pdb' field"
1011
- raise ValueError(msg)
1017
+ if entity_type in {"protein", "dna", "rna"}:
1018
+ seq = str(item[entity_type]["sequence"])
1012
1019
  elif entity_type == "ligand":
1013
- # Support for SMILES, CCD, and SDF
1020
+ assert "smiles" in item[entity_type] or "ccd" in item[entity_type]
1021
+ assert "smiles" not in item[entity_type] or "ccd" not in item[entity_type]
1014
1022
  if "smiles" in item[entity_type]:
1015
1023
  seq = str(item[entity_type]["smiles"])
1016
- # Map user-provided ID to internal LIG1, LIG2, etc.
1017
- for id in entity_id:
1018
- ligand_id_map[id] = f"LIG{ligand_id}"
1019
- ligand_id += 1
1020
- elif "ccd" in item[entity_type]:
1021
- seq = str(item[entity_type]["ccd"])
1022
- # For CCD ligands, use the CCD code as the internal ID
1023
- for id in entity_id:
1024
- ligand_id_map[id] = seq
1025
- elif "sdf" in item[entity_type]:
1026
- sdf_path = Path(item[entity_type]["sdf"])
1027
- if not sdf_path.exists():
1028
- msg = f"SDF file not found: {sdf_path}"
1029
- raise FileNotFoundError(msg)
1030
- # Read SDF and convert to SMILES
1031
- from rdkit import Chem
1032
- mol = Chem.SDMolSupplier(str(sdf_path))[0]
1033
- if mol is None:
1034
- msg = f"Failed to read SDF file: {sdf_path}"
1035
- raise ValueError(msg)
1036
- seq = Chem.MolToSmiles(mol)
1037
- # Map user-provided ID to internal LIG1, LIG2, etc.
1038
- for id in entity_id:
1039
- ligand_id_map[id] = f"LIG{ligand_id}"
1040
- ligand_id += 1
1041
1024
  else:
1042
- msg = "Ligand must have either 'smiles', 'ccd', or 'sdf' field"
1043
- raise ValueError(msg)
1025
+ seq = str(item[entity_type]["ccd"])
1044
1026
 
1045
1027
  # Group items by entity
1046
1028
  items_to_group.setdefault((entity_type, seq), []).append(item)
@@ -1051,60 +1033,739 @@ def parse_boltz_schema(schema: dict) -> dict:
1051
1033
  for chain_name in chain_names:
1052
1034
  chain_name_to_entity_type[chain_name] = entity_type
1053
1035
 
1054
- # Get all proteins and ligands
1055
- proteins = []
1056
- ligands = []
1057
- for item in schema["sequences"]:
1058
- entity_type = list(item.keys())[0]
1059
- entity_id = item[entity_type]["id"]
1060
- entity_id = [entity_id] if isinstance(entity_id, str) else entity_id
1036
+ # Check if any affinity ligand is present
1037
+ affinity_ligands = set()
1038
+ properties = schema.get("properties", [])
1039
+ if properties and not boltz_2:
1040
+ msg = "Affinity prediction is only supported for Boltz2!"
1041
+ raise ValueError(msg)
1042
+
1043
+ for prop in properties:
1044
+ prop_type = next(iter(prop.keys())).lower()
1045
+ if prop_type == "affinity":
1046
+ binder = prop["affinity"]["binder"]
1047
+ if not isinstance(binder, str):
1048
+ # TODO: support multi residue ligands and ccd's
1049
+ msg = "Binder must be a single chain."
1050
+ raise ValueError(msg)
1051
+
1052
+ if binder not in chain_name_to_entity_type:
1053
+ msg = f"Could not find binder with name {binder} in the input!"
1054
+ raise ValueError(msg)
1055
+
1056
+ if chain_name_to_entity_type[binder] != "ligand":
1057
+ msg = (
1058
+ f"Chain {binder} is not a ligand! "
1059
+ "Affinity is currently only supported for ligands."
1060
+ )
1061
+ raise ValueError(msg)
1062
+
1063
+ affinity_ligands.add(binder)
1064
+
1065
+ # Check only one affinity ligand is present
1066
+ if len(affinity_ligands) > 1:
1067
+ msg = "Only one affinity ligand is currently supported!"
1068
+ raise ValueError(msg)
1069
+
1070
+ # Go through entities and parse them
1071
+ extra_mols: dict[str, Mol] = {}
1072
+ chains: dict[str, ParsedChain] = {}
1073
+ chain_to_msa: dict[str, str] = {}
1074
+ entity_to_seq: dict[str, str] = {}
1075
+ is_msa_custom = False
1076
+ is_msa_auto = False
1077
+ ligand_id = 1
1078
+ for entity_id, items in enumerate(items_to_group.values()):
1079
+ # Get entity type and sequence
1080
+ entity_type = next(iter(items[0].keys())).lower()
1081
+
1082
+ # Get ids
1083
+ ids = []
1084
+ for item in items:
1085
+ if isinstance(item[entity_type]["id"], str):
1086
+ ids.append(item[entity_type]["id"])
1087
+ elif isinstance(item[entity_type]["id"], list):
1088
+ ids.extend(item[entity_type]["id"])
1089
+
1090
+ # Check if any affinity ligand is present
1091
+ if len(ids) == 1:
1092
+ affinity = ids[0] in affinity_ligands
1093
+ elif (len(ids) > 1) and any(x in affinity_ligands for x in ids):
1094
+ msg = "Cannot compute affinity for a ligand that has multiple copies!"
1095
+ raise ValueError(msg)
1096
+ else:
1097
+ affinity = False
1098
+
1099
+ # Ensure all the items share the same msa
1100
+ msa = -1
1061
1101
  if entity_type == "protein":
1062
- proteins.extend(entity_id)
1063
- elif entity_type == "ligand":
1064
- ligands.extend(entity_id)
1065
-
1066
- # Generate properties for each protein-ligand pair
1067
- new_properties = []
1068
- for prop in schema.get("properties", []):
1069
- if "affinity" in prop:
1070
- affinity = prop["affinity"]
1071
- # Handle protein as binder
1072
- if "protein" in affinity:
1073
- binder = affinity["protein"]
1074
- if binder not in proteins:
1075
- msg = f"Protein {binder} not found in sequences"
1102
+ # Get the msa, default to 0, meaning auto-generated
1103
+ msa = items[0][entity_type].get("msa", 0)
1104
+ if (msa is None) or (msa == ""):
1105
+ msa = 0
1106
+
1107
+ # Check if all MSAs are the same within the same entity
1108
+ for item in items:
1109
+ item_msa = item[entity_type].get("msa", 0)
1110
+ if (item_msa is None) or (item_msa == ""):
1111
+ item_msa = 0
1112
+
1113
+ if item_msa != msa:
1114
+ msg = "All proteins with the same sequence must share the same MSA!"
1076
1115
  raise ValueError(msg)
1077
- # Generate pairs with all ligands
1078
- for ligand in ligands:
1079
- if ligand in ligand_id_map:
1080
- ligand = ligand_id_map[ligand] # Convert to internal LIG1, LIG2, etc.
1081
- new_properties.append({
1082
- "affinity": {
1083
- "binder": binder,
1084
- "ligand": ligand
1085
- }
1086
- })
1087
- # Handle ligand as binder (backward compatibility)
1088
- elif "binder" in affinity:
1089
- binder = affinity["binder"]
1090
- if binder not in proteins:
1091
- msg = f"Protein {binder} not found in sequences"
1116
+
1117
+ # Set the MSA, warn if passed in single-sequence mode
1118
+ if msa == "empty":
1119
+ msa = -1
1120
+ msg = (
1121
+ "Found explicit empty MSA for some proteins, will run "
1122
+ "these in single sequence mode. Keep in mind that the "
1123
+ "model predictions will be suboptimal without an MSA."
1124
+ )
1125
+ click.echo(msg)
1126
+
1127
+ if msa not in (0, -1):
1128
+ is_msa_custom = True
1129
+ elif msa == 0:
1130
+ is_msa_auto = True
1131
+
1132
+ # Parse a polymer
1133
+ if entity_type in {"protein", "dna", "rna"}:
1134
+ # Get token map
1135
+ if entity_type == "rna":
1136
+ token_map = const.rna_letter_to_token
1137
+ elif entity_type == "dna":
1138
+ token_map = const.dna_letter_to_token
1139
+ elif entity_type == "protein":
1140
+ token_map = const.prot_letter_to_token
1141
+ else:
1142
+ msg = f"Unknown polymer type: {entity_type}"
1143
+ raise ValueError(msg)
1144
+
1145
+ # Get polymer info
1146
+ chain_type = const.chain_type_ids[entity_type.upper()]
1147
+ unk_token = const.unk_token[entity_type.upper()]
1148
+
1149
+ # Extract sequence
1150
+ raw_seq = items[0][entity_type]["sequence"]
1151
+ entity_to_seq[entity_id] = raw_seq
1152
+
1153
+ # Convert sequence to tokens
1154
+ seq = [token_map.get(c, unk_token) for c in list(raw_seq)]
1155
+
1156
+ # Apply modifications
1157
+ for mod in items[0][entity_type].get("modifications", []):
1158
+ code = mod["ccd"]
1159
+ idx = mod["position"] - 1 # 1-indexed
1160
+ seq[idx] = code
1161
+
1162
+ cyclic = items[0][entity_type].get("cyclic", False)
1163
+
1164
+ # Parse a polymer
1165
+ parsed_chain = parse_polymer(
1166
+ sequence=seq,
1167
+ raw_sequence=raw_seq,
1168
+ entity=entity_id,
1169
+ chain_type=chain_type,
1170
+ components=ccd,
1171
+ cyclic=cyclic,
1172
+ mol_dir=mol_dir,
1173
+ )
1174
+
1175
+ # Parse a non-polymer
1176
+ elif (entity_type == "ligand") and "ccd" in (items[0][entity_type]):
1177
+ seq = items[0][entity_type]["ccd"]
1178
+
1179
+ if isinstance(seq, str):
1180
+ seq = [seq]
1181
+
1182
+ if affinity and len(seq) > 1:
1183
+ msg = "Cannot compute affinity for multi residue ligands!"
1184
+ raise ValueError(msg)
1185
+
1186
+ residues = []
1187
+ affinity_mw = None
1188
+ for res_idx, code in enumerate(seq):
1189
+ # Get mol
1190
+ ref_mol = get_mol(code, ccd, mol_dir)
1191
+
1192
+ if affinity:
1193
+ affinity_mw = AllChem.Descriptors.MolWt(ref_mol)
1194
+
1195
+ # Parse residue
1196
+ residue = parse_ccd_residue(
1197
+ name=code,
1198
+ ref_mol=ref_mol,
1199
+ res_idx=res_idx,
1200
+ )
1201
+ residues.append(residue)
1202
+
1203
+ # Create multi ligand chain
1204
+ parsed_chain = ParsedChain(
1205
+ entity=entity_id,
1206
+ residues=residues,
1207
+ type=const.chain_type_ids["NONPOLYMER"],
1208
+ cyclic_period=0,
1209
+ sequence=None,
1210
+ affinity=affinity,
1211
+ affinity_mw=affinity_mw,
1212
+ )
1213
+
1214
+ assert not items[0][entity_type].get(
1215
+ "cyclic", False
1216
+ ), "Cyclic flag is not supported for ligands"
1217
+
1218
+ elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
1219
+ seq = items[0][entity_type]["smiles"]
1220
+
1221
+ if affinity:
1222
+ seq = standardize(seq)
1223
+
1224
+ mol = AllChem.MolFromSmiles(seq)
1225
+ mol = AllChem.AddHs(mol)
1226
+
1227
+ # Set atom names
1228
+ canonical_order = AllChem.CanonicalRankAtoms(mol)
1229
+ for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
1230
+ atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
1231
+ if len(atom_name) > 4:
1232
+ msg = (
1233
+ f"{seq} has an atom with a name longer than "
1234
+ f"4 characters: {atom_name}."
1235
+ )
1092
1236
  raise ValueError(msg)
1093
- # Generate pairs with all ligands
1094
- for ligand in ligands:
1095
- if ligand in ligand_id_map:
1096
- ligand = ligand_id_map[ligand] # Convert to internal LIG1, LIG2, etc.
1097
- new_properties.append({
1098
- "affinity": {
1099
- "binder": binder,
1100
- "ligand": ligand
1101
- }
1102
- })
1103
-
1104
- # Update schema with generated properties
1105
- schema["properties"] = new_properties
1106
-
1107
- return schema
1237
+ atom.SetProp("name", atom_name)
1238
+
1239
+ success = compute_3d_conformer(mol)
1240
+ if not success:
1241
+ msg = f"Failed to compute 3D conformer for {seq}"
1242
+ raise ValueError(msg)
1243
+
1244
+ mol_no_h = AllChem.RemoveHs(mol, sanitize=False)
1245
+ affinity_mw = AllChem.Descriptors.MolWt(mol_no_h) if affinity else None
1246
+ extra_mols[f"LIG{ligand_id}"] = mol_no_h
1247
+ residue = parse_ccd_residue(
1248
+ name=f"LIG{ligand_id}",
1249
+ ref_mol=mol,
1250
+ res_idx=0,
1251
+ )
1252
+
1253
+ ligand_id += 1
1254
+ parsed_chain = ParsedChain(
1255
+ entity=entity_id,
1256
+ residues=[residue],
1257
+ type=const.chain_type_ids["NONPOLYMER"],
1258
+ cyclic_period=0,
1259
+ sequence=None,
1260
+ affinity=affinity,
1261
+ affinity_mw=affinity_mw,
1262
+ )
1263
+
1264
+ assert not items[0][entity_type].get(
1265
+ "cyclic", False
1266
+ ), "Cyclic flag is not supported for ligands"
1267
+
1268
+ else:
1269
+ msg = f"Invalid entity type: {entity_type}"
1270
+ raise ValueError(msg)
1271
+
1272
+ # Add as many chains as provided ids
1273
+ for item in items:
1274
+ ids = item[entity_type]["id"]
1275
+ if isinstance(ids, str):
1276
+ ids = [ids]
1277
+ for chain_name in ids:
1278
+ chains[chain_name] = parsed_chain
1279
+ chain_to_msa[chain_name] = msa
1280
+
1281
+ # Check if msa is custom or auto
1282
+ if is_msa_custom and is_msa_auto:
1283
+ msg = "Cannot mix custom and auto-generated MSAs in the same input!"
1284
+ raise ValueError(msg)
1285
+
1286
+ # If no chains parsed fail
1287
+ if not chains:
1288
+ msg = "No chains parsed!"
1289
+ raise ValueError(msg)
1290
+
1291
+ # Create tables
1292
+ atom_data = []
1293
+ bond_data = []
1294
+ res_data = []
1295
+ chain_data = []
1296
+ protein_chains = set()
1297
+ affinity_info = None
1298
+
1299
+ rdkit_bounds_constraint_data = []
1300
+ chiral_atom_constraint_data = []
1301
+ stereo_bond_constraint_data = []
1302
+ planar_bond_constraint_data = []
1303
+ planar_ring_5_constraint_data = []
1304
+ planar_ring_6_constraint_data = []
1305
+
1306
+ # Convert parsed chains to tables
1307
+ atom_idx = 0
1308
+ res_idx = 0
1309
+ asym_id = 0
1310
+ sym_count = {}
1311
+ chain_to_idx = {}
1312
+
1313
+ # Keep a mapping of (chain_name, residue_idx, atom_name) to atom_idx
1314
+ atom_idx_map = {}
1315
+
1316
+ for asym_id, (chain_name, chain) in enumerate(chains.items()):
1317
+ # Compute number of atoms and residues
1318
+ res_num = len(chain.residues)
1319
+ atom_num = sum(len(res.atoms) for res in chain.residues)
1320
+
1321
+ # Save protein chains for later
1322
+ if chain.type == const.chain_type_ids["PROTEIN"]:
1323
+ protein_chains.add(chain_name)
1324
+
1325
+ # Add affinity info
1326
+ if chain.affinity and affinity_info is not None:
1327
+ msg = "Cannot compute affinity for multiple ligands!"
1328
+ raise ValueError(msg)
1329
+
1330
+ if chain.affinity:
1331
+ affinity_info = AffinityInfo(
1332
+ chain_id=asym_id,
1333
+ mw=chain.affinity_mw,
1334
+ )
1335
+
1336
+ # Find all copies of this chain in the assembly
1337
+ entity_id = int(chain.entity)
1338
+ sym_id = sym_count.get(entity_id, 0)
1339
+ chain_data.append(
1340
+ (
1341
+ chain_name,
1342
+ chain.type,
1343
+ entity_id,
1344
+ sym_id,
1345
+ asym_id,
1346
+ atom_idx,
1347
+ atom_num,
1348
+ res_idx,
1349
+ res_num,
1350
+ chain.cyclic_period,
1351
+ )
1352
+ )
1353
+ chain_to_idx[chain_name] = asym_id
1354
+ sym_count[entity_id] = sym_id + 1
1355
+
1356
+ # Add residue, atom, bond, data
1357
+ for res in chain.residues:
1358
+ atom_center = atom_idx + res.atom_center
1359
+ atom_disto = atom_idx + res.atom_disto
1360
+ res_data.append(
1361
+ (
1362
+ res.name,
1363
+ res.type,
1364
+ res.idx,
1365
+ atom_idx,
1366
+ len(res.atoms),
1367
+ atom_center,
1368
+ atom_disto,
1369
+ res.is_standard,
1370
+ res.is_present,
1371
+ )
1372
+ )
1373
+
1374
+ if res.rdkit_bounds_constraints is not None:
1375
+ for constraint in res.rdkit_bounds_constraints:
1376
+ rdkit_bounds_constraint_data.append( # noqa: PERF401
1377
+ (
1378
+ tuple(
1379
+ c_atom_idx + atom_idx
1380
+ for c_atom_idx in constraint.atom_idxs
1381
+ ),
1382
+ constraint.is_bond,
1383
+ constraint.is_angle,
1384
+ constraint.upper_bound,
1385
+ constraint.lower_bound,
1386
+ )
1387
+ )
1388
+ if res.chiral_atom_constraints is not None:
1389
+ for constraint in res.chiral_atom_constraints:
1390
+ chiral_atom_constraint_data.append( # noqa: PERF401
1391
+ (
1392
+ tuple(
1393
+ c_atom_idx + atom_idx
1394
+ for c_atom_idx in constraint.atom_idxs
1395
+ ),
1396
+ constraint.is_reference,
1397
+ constraint.is_r,
1398
+ )
1399
+ )
1400
+ if res.stereo_bond_constraints is not None:
1401
+ for constraint in res.stereo_bond_constraints:
1402
+ stereo_bond_constraint_data.append( # noqa: PERF401
1403
+ (
1404
+ tuple(
1405
+ c_atom_idx + atom_idx
1406
+ for c_atom_idx in constraint.atom_idxs
1407
+ ),
1408
+ constraint.is_check,
1409
+ constraint.is_e,
1410
+ )
1411
+ )
1412
+ if res.planar_bond_constraints is not None:
1413
+ for constraint in res.planar_bond_constraints:
1414
+ planar_bond_constraint_data.append( # noqa: PERF401
1415
+ (
1416
+ tuple(
1417
+ c_atom_idx + atom_idx
1418
+ for c_atom_idx in constraint.atom_idxs
1419
+ ),
1420
+ )
1421
+ )
1422
+ if res.planar_ring_5_constraints is not None:
1423
+ for constraint in res.planar_ring_5_constraints:
1424
+ planar_ring_5_constraint_data.append( # noqa: PERF401
1425
+ (
1426
+ tuple(
1427
+ c_atom_idx + atom_idx
1428
+ for c_atom_idx in constraint.atom_idxs
1429
+ ),
1430
+ )
1431
+ )
1432
+ if res.planar_ring_6_constraints is not None:
1433
+ for constraint in res.planar_ring_6_constraints:
1434
+ planar_ring_6_constraint_data.append( # noqa: PERF401
1435
+ (
1436
+ tuple(
1437
+ c_atom_idx + atom_idx
1438
+ for c_atom_idx in constraint.atom_idxs
1439
+ ),
1440
+ )
1441
+ )
1442
+
1443
+ for bond in res.bonds:
1444
+ atom_1 = atom_idx + bond.atom_1
1445
+ atom_2 = atom_idx + bond.atom_2
1446
+ bond_data.append(
1447
+ (
1448
+ asym_id,
1449
+ asym_id,
1450
+ res_idx,
1451
+ res_idx,
1452
+ atom_1,
1453
+ atom_2,
1454
+ bond.type,
1455
+ )
1456
+ )
1457
+
1458
+ for atom in res.atoms:
1459
+ # Add atom to map
1460
+ atom_idx_map[(chain_name, res.idx, atom.name)] = (
1461
+ asym_id,
1462
+ res_idx,
1463
+ atom_idx,
1464
+ )
1465
+
1466
+ # Add atom to data
1467
+ atom_data.append(
1468
+ (
1469
+ atom.name,
1470
+ atom.element,
1471
+ atom.charge,
1472
+ atom.coords,
1473
+ atom.conformer,
1474
+ atom.is_present,
1475
+ atom.chirality,
1476
+ )
1477
+ )
1478
+ atom_idx += 1
1479
+
1480
+ res_idx += 1
1481
+
1482
+ # Parse constraints
1483
+ connections = []
1484
+ pocket_constraints = []
1485
+ contact_constraints = []
1486
+ constraints = schema.get("constraints", [])
1487
+ for constraint in constraints:
1488
+ if "bond" in constraint:
1489
+ if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
1490
+ msg = f"Bond constraint was not properly specified"
1491
+ raise ValueError(msg)
1492
+
1493
+ c1, r1, a1 = tuple(constraint["bond"]["atom1"])
1494
+ c2, r2, a2 = tuple(constraint["bond"]["atom2"])
1495
+ c1, r1, a1 = atom_idx_map[(c1, r1 - 1, a1)] # 1-indexed
1496
+ c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)] # 1-indexed
1497
+ connections.append((c1, c2, r1, r2, a1, a2))
1498
+ elif "pocket" in constraint:
1499
+ if (
1500
+ "binder" not in constraint["pocket"]
1501
+ or "contacts" not in constraint["pocket"]
1502
+ ):
1503
+ msg = f"Pocket constraint was not properly specified"
1504
+ raise ValueError(msg)
1505
+
1506
+ if len(pocket_constraints) > 0 and not boltz_2:
1507
+ msg = f"Only one pocket binders is supported in Boltz-1!"
1508
+ raise ValueError(msg)
1509
+
1510
+ max_distance = constraint["pocket"].get("max_distance", 6.0)
1511
+ if max_distance != 6.0 and not boltz_2:
1512
+ msg = f"Max distance != 6.0 is not supported in Boltz-1!"
1513
+ raise ValueError(msg)
1514
+
1515
+ binder = constraint["pocket"]["binder"]
1516
+ binder = chain_to_idx[binder]
1517
+
1518
+ contacts = []
1519
+ for chain_name, residue_index_or_atom_name in constraint["pocket"][
1520
+ "contacts"
1521
+ ]:
1522
+ if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]:
1523
+ # Non-polymer chains are indexed by atom name
1524
+ _, _, atom_idx = atom_idx_map[
1525
+ (chain_name, 0, residue_index_or_atom_name)
1526
+ ]
1527
+ contact = (chain_to_idx[chain_name], atom_idx)
1528
+ else:
1529
+ # Polymer chains are indexed by residue index
1530
+ contact = (chain_to_idx[chain_name], residue_index_or_atom_name - 1)
1531
+ contacts.append(contact)
1532
+
1533
+ pocket_constraints.append((binder, contacts, max_distance))
1534
+ elif "contact" in constraint:
1535
+ if (
1536
+ "token1" not in constraint["contact"]
1537
+ or "token2" not in constraint["contact"]
1538
+ ):
1539
+ msg = f"Contact constraint was not properly specified"
1540
+ raise ValueError(msg)
1541
+
1542
+ if not boltz_2:
1543
+ msg = f"Contact constraint is not supported in Boltz-1!"
1544
+ raise ValueError(msg)
1545
+
1546
+ max_distance = constraint["contact"].get("max_distance", 6.0)
1547
+
1548
+ chain_name1, residue_index_or_atom_name1 = constraint["contact"]["token1"]
1549
+ if chains[chain_name1].type == const.chain_type_ids["NONPOLYMER"]:
1550
+ # Non-polymer chains are indexed by atom name
1551
+ _, _, atom_idx = atom_idx_map[
1552
+ (chain_name1, 0, residue_index_or_atom_name1)
1553
+ ]
1554
+ token1 = (chain_to_idx[chain_name1], atom_idx)
1555
+ else:
1556
+ # Polymer chains are indexed by residue index
1557
+ token1 = (chain_to_idx[chain_name1], residue_index_or_atom_name1 - 1)
1558
+
1559
+ pocket_constraints.append((binder, contacts, max_distance))
1560
+ else:
1561
+ msg = f"Invalid constraint: {constraint}"
1562
+ raise ValueError(msg)
1563
+
1564
+ # Get protein sequences in this YAML
1565
+ protein_seqs = {name: chains[name].sequence for name in protein_chains}
1566
+
1567
+ # Parse templates
1568
+ template_schema = schema.get("templates", [])
1569
+ if template_schema and not boltz_2:
1570
+ msg = "Templates are not supported in Boltz 1.0!"
1571
+ raise ValueError(msg)
1572
+
1573
+ templates = {}
1574
+ template_records = []
1575
+ for template in template_schema:
1576
+ if "cif" not in template:
1577
+ msg = "Template was not properly specified, missing CIF path!"
1578
+ raise ValueError(msg)
1579
+
1580
+ path = template["cif"]
1581
+ template_id = Path(path).stem
1582
+ chain_ids = template.get("chain_id", None)
1583
+ template_chain_ids = template.get("template_id", None)
1584
+
1585
+ # Check validity of input
1586
+ matched = False
1587
+
1588
+ if chain_ids is not None and not isinstance(chain_ids, list):
1589
+ chain_ids = [chain_ids]
1590
+ if template_chain_ids is not None and not isinstance(template_chain_ids, list):
1591
+ template_chain_ids = [template_chain_ids]
1592
+
1593
+ if (
1594
+ template_chain_ids is not None
1595
+ and chain_ids is not None
1596
+ and len(template_chain_ids) != len(chain_ids)
1597
+ ):
1598
+ matched = True
1599
+ if len(template_chain_ids) != len(chain_ids):
1600
+ msg = (
1601
+ "When providing both the chain_id and template_id, the number of"
1602
+ "template_ids provided must match the number of chain_ids!"
1603
+ )
1604
+ raise ValueError(msg)
1605
+
1606
+ # Get relevant chains ids
1607
+ if chain_ids is None:
1608
+ chain_ids = list(protein_chains)
1609
+
1610
+ for chain_id in chain_ids:
1611
+ if chain_id not in protein_chains:
1612
+ msg = (
1613
+ f"Chain {chain_id} assigned for template"
1614
+ f"{template_id} is not one of the protein chains!"
1615
+ )
1616
+ raise ValueError(msg)
1617
+
1618
+ # Get relevant template chain ids
1619
+ parsed_template = parse_mmcif(
1620
+ path,
1621
+ mols=ccd,
1622
+ moldir=mol_dir,
1623
+ use_assembly=False,
1624
+ compute_interfaces=False,
1625
+ )
1626
+ template_proteins = {
1627
+ str(c["name"])
1628
+ for c in parsed_template.data.chains
1629
+ if c["mol_type"] == const.chain_type_ids["PROTEIN"]
1630
+ }
1631
+ if template_chain_ids is None:
1632
+ template_chain_ids = list(template_proteins)
1633
+
1634
+ for chain_id in template_chain_ids:
1635
+ if chain_id not in template_proteins:
1636
+ msg = (
1637
+ f"Template chain {chain_id} assigned for template"
1638
+ f"{template_id} is not one of the protein chains!"
1639
+ )
1640
+ raise ValueError(msg)
1641
+
1642
+ # Compute template records
1643
+ if matched:
1644
+ template_records.extend(
1645
+ get_template_records_from_matching(
1646
+ template_id=template_id,
1647
+ chain_ids=chain_ids,
1648
+ sequences=protein_seqs,
1649
+ template_chain_ids=template_chain_ids,
1650
+ template_sequences=parsed_template.sequences,
1651
+ )
1652
+ )
1653
+ else:
1654
+ template_records.extend(
1655
+ get_template_records_from_search(
1656
+ template_id=template_id,
1657
+ chain_ids=chain_ids,
1658
+ sequences=protein_seqs,
1659
+ template_chain_ids=template_chain_ids,
1660
+ template_sequences=parsed_template.sequences,
1661
+ )
1662
+ )
1663
+ # Save template
1664
+ templates[template_id] = parsed_template.data
1665
+
1666
+ # Convert into datatypes
1667
+ residues = np.array(res_data, dtype=Residue)
1668
+ chains = np.array(chain_data, dtype=Chain)
1669
+ interfaces = np.array([], dtype=Interface)
1670
+ mask = np.ones(len(chain_data), dtype=bool)
1671
+ rdkit_bounds_constraints = np.array(
1672
+ rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint
1673
+ )
1674
+ chiral_atom_constraints = np.array(
1675
+ chiral_atom_constraint_data, dtype=ChiralAtomConstraint
1676
+ )
1677
+ stereo_bond_constraints = np.array(
1678
+ stereo_bond_constraint_data, dtype=StereoBondConstraint
1679
+ )
1680
+ planar_bond_constraints = np.array(
1681
+ planar_bond_constraint_data, dtype=PlanarBondConstraint
1682
+ )
1683
+ planar_ring_5_constraints = np.array(
1684
+ planar_ring_5_constraint_data, dtype=PlanarRing5Constraint
1685
+ )
1686
+ planar_ring_6_constraints = np.array(
1687
+ planar_ring_6_constraint_data, dtype=PlanarRing6Constraint
1688
+ )
1689
+
1690
+ if boltz_2:
1691
+ atom_data = [(a[0], a[3], a[5], 0.0, 1.0) for a in atom_data]
1692
+ connections = [(*c, const.bond_type_ids["COVALENT"]) for c in connections]
1693
+ bond_data = bond_data + connections
1694
+ atoms = np.array(atom_data, dtype=AtomV2)
1695
+ bonds = np.array(bond_data, dtype=BondV2)
1696
+ coords = [(x,) for x in atoms["coords"]]
1697
+ coords = np.array(coords, Coords)
1698
+ ensemble = np.array([(0, len(coords))], dtype=Ensemble)
1699
+ data = StructureV2(
1700
+ atoms=atoms,
1701
+ bonds=bonds,
1702
+ residues=residues,
1703
+ chains=chains,
1704
+ interfaces=interfaces,
1705
+ mask=mask,
1706
+ coords=coords,
1707
+ ensemble=ensemble,
1708
+ )
1709
+ else:
1710
+ bond_data = [(b[4], b[5], b[6]) for b in bond_data]
1711
+ atom_data = [(convert_atom_name(a[0]), *a[1:]) for a in atom_data]
1712
+ atoms = np.array(atom_data, dtype=Atom)
1713
+ bonds = np.array(bond_data, dtype=Bond)
1714
+ connections = np.array(connections, dtype=Connection)
1715
+ data = Structure(
1716
+ atoms=atoms,
1717
+ bonds=bonds,
1718
+ residues=residues,
1719
+ chains=chains,
1720
+ connections=connections,
1721
+ interfaces=interfaces,
1722
+ mask=mask,
1723
+ )
1724
+
1725
+ # Create metadata
1726
+ struct_info = StructureInfo(num_chains=len(chains))
1727
+ chain_infos = []
1728
+ for chain in chains:
1729
+ chain_info = ChainInfo(
1730
+ chain_id=int(chain["asym_id"]),
1731
+ chain_name=chain["name"],
1732
+ mol_type=int(chain["mol_type"]),
1733
+ cluster_id=-1,
1734
+ msa_id=chain_to_msa[chain["name"]],
1735
+ num_residues=int(chain["res_num"]),
1736
+ valid=True,
1737
+ entity_id=int(chain["entity_id"]),
1738
+ )
1739
+ chain_infos.append(chain_info)
1740
+
1741
+ options = InferenceOptions(pocket_constraints=pocket_constraints)
1742
+ record = Record(
1743
+ id=name,
1744
+ structure=struct_info,
1745
+ chains=chain_infos,
1746
+ interfaces=[],
1747
+ inference_options=options,
1748
+ templates=template_records,
1749
+ affinity=affinity_info,
1750
+ )
1751
+
1752
+ residue_constraints = ResidueConstraints(
1753
+ rdkit_bounds_constraints=rdkit_bounds_constraints,
1754
+ chiral_atom_constraints=chiral_atom_constraints,
1755
+ stereo_bond_constraints=stereo_bond_constraints,
1756
+ planar_bond_constraints=planar_bond_constraints,
1757
+ planar_ring_5_constraints=planar_ring_5_constraints,
1758
+ planar_ring_6_constraints=planar_ring_6_constraints,
1759
+ )
1760
+
1761
+ return Target(
1762
+ record=record,
1763
+ structure=data,
1764
+ sequences=entity_to_seq,
1765
+ residue_constraints=residue_constraints,
1766
+ templates=templates,
1767
+ extra_mols=extra_mols,
1768
+ )
1108
1769
 
1109
1770
 
1110
1771
  def standardize(smiles: str) -> Optional[str]: