gemmi-protools 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,911 @@
1
+ import gzip
2
+ import io
3
+ import pathlib
4
+ import random
5
+ import re
6
+ import string
7
+ from collections import defaultdict
8
+ from datetime import datetime
9
+ from typing import Dict, Optional, List
10
+
11
+ import gemmi
12
+ import numpy as np
13
+ import pandas as pd
14
+ from scipy.spatial import cKDTree
15
+
16
+ ATOM = [("chain_name", "U5"),
17
+ ("residue_num", "i4"),
18
+ ("residue_icode", "U3"),
19
+ ("residue_name", "U5"),
20
+ ("atom_name", "U5"),
21
+ ("element", "U3"),
22
+ ("charge", "i1"),
23
+ ("b_factor", "f4"),
24
+ ("occupancy", "f4"),
25
+ ("coordinate", ("f4", (3,)))
26
+ ]
27
+
28
+
29
+ def is_pdb(path: str) -> bool:
30
+ """
31
+ Check if input file is .pdb or .pdb.gz format
32
+ :param path:
33
+ :return:
34
+ bool
35
+ """
36
+ path = pathlib.Path(path)
37
+
38
+ if path.suffixes:
39
+ if path.suffixes[-1] == ".pdb":
40
+ return True
41
+ elif "".join(path.suffixes[-2:]) == ".pdb.gz":
42
+ return True
43
+ else:
44
+ return False
45
+ else:
46
+ return False
47
+
48
+
49
+ def is_cif(path: str) -> bool:
50
+ """
51
+ Check if input file is .cif or .cif.gz
52
+ :param path:
53
+ :return:
54
+ bool
55
+ """
56
+
57
+ path = pathlib.Path(path)
58
+ if path.suffixes:
59
+ if path.suffixes[-1] == ".cif":
60
+ return True
61
+ elif "".join(path.suffixes[-2:]) == ".cif.gz":
62
+ return True
63
+ else:
64
+ return False
65
+ else:
66
+ return False
67
+
68
+
69
+ def get_release_date(block: gemmi.cif.Block):
70
+ val = pd.DataFrame(block.get_mmcif_category(name="_pdbx_audit_revision_history", raw=False))
71
+ if len(val) > 0:
72
+ val_s = val.sort_values(["major_revision", "minor_revision"]).reset_index(drop=True)
73
+ return dict(val_s.iloc[0])["revision_date"]
74
+ else:
75
+ return ""
76
+
77
+
78
+ def parse_cif(path: str) -> dict:
79
+ """
80
+ Parse CIF structure and info
81
+ :param path: str
82
+ :return:
83
+ dict
84
+ """
85
+ if not is_cif(path):
86
+ raise TypeError("Input file is not a cif file [.cif or .cif.gz]: %s" % path)
87
+
88
+ doc = gemmi.cif.Document()
89
+ st = gemmi.read_structure(path, save_doc=doc)
90
+ st.setup_entities()
91
+ st.assign_serial_numbers()
92
+ block = doc.sole_block()
93
+
94
+ def _read_src(query_block, category, name_col, taxid_col):
95
+ dk = pd.DataFrame(query_block.get_mmcif_category(name=category, raw=False))
96
+ dk[dk.isna()] = ""
97
+
98
+ if dk.shape[0] > 0 and np.all(np.isin(["entity_id", name_col, taxid_col], dk.columns)):
99
+ return {eid: [name, taxid]
100
+ for eid, name, taxid in dk[["entity_id", name_col, taxid_col]].to_numpy()
101
+ }
102
+ else:
103
+ return dict()
104
+
105
+ desc = pd.DataFrame(block.get_mmcif_category(name="_entity", raw=False))
106
+ desc[desc.isna()] = ""
107
+
108
+ entityid2description = dict()
109
+ if desc.shape[0] > 0 and np.all(np.isin(["id", "pdbx_description"], desc.columns)):
110
+ entityid2description = dict(zip(desc["id"], desc["pdbx_description"]))
111
+
112
+ entityid2src = dict()
113
+ src_1 = _read_src(block, "_entity_src_gen.",
114
+ "pdbx_gene_src_scientific_name",
115
+ "pdbx_gene_src_ncbi_taxonomy_id")
116
+ src_2 = _read_src(block, "_pdbx_entity_src_syn.",
117
+ "organism_scientific",
118
+ "ncbi_taxonomy_id")
119
+ src_3 = _read_src(block, "_entity_src_nat.",
120
+ "pdbx_organism_scientific",
121
+ "pdbx_ncbi_taxonomy_id")
122
+ entityid2src.update(src_1)
123
+
124
+ for k, v in src_2.items():
125
+ if k not in entityid2src:
126
+ entityid2src[k] = v
127
+
128
+ for k, v in src_3.items():
129
+ if k not in entityid2src:
130
+ entityid2src[k] = v
131
+
132
+ info_map = dict(st.info)
133
+ pdb_code = info_map.get("_entry.id", "").lower()
134
+
135
+ v1 = block.find_value("_refine.ls_d_res_high")
136
+ v2 = block.find_value("_em_3d_reconstruction.resolution")
137
+
138
+ resolution = 0.0
139
+ if v1 not in [".", "?", None]:
140
+ resolution = v1
141
+ elif v2 not in [".", "?", None]:
142
+ resolution = v2
143
+
144
+ try:
145
+ resolution = float(resolution)
146
+ except:
147
+ resolution = 0.0
148
+
149
+ st.resolution = resolution
150
+
151
+ info = dict(description={k: v for k, v in entityid2description.items() if v and v != "?"},
152
+ source=entityid2src,
153
+ resolution=st.resolution,
154
+ pdb_id=pdb_code if gemmi.is_pdb_code(pdb_code) else "",
155
+ method=info_map.get("_exptl.method", "").lower(),
156
+ deposition_date=info_map.get("_pdbx_database_status.recvd_initial_deposition_date", ""),
157
+ release_date=get_release_date(block),
158
+ title=info_map.get("_struct.title", "")
159
+ )
160
+ return dict(structure=st, info=info)
161
+
162
+
163
+ def _get_pdb_header(path: str):
164
+ """
165
+ Molecule description from PDB (.pdb or .pdb.gz)
166
+ :param path:
167
+ :return:
168
+ """
169
+ if is_pdb(path):
170
+ cur_path = pathlib.Path(path)
171
+ if cur_path.suffixes[-1] == ".pdb":
172
+ with open(path, "r") as text_io:
173
+ lines = text_io.readlines()
174
+ else:
175
+ with gzip.open(path, "rb") as gz_handle:
176
+ with io.TextIOWrapper(gz_handle, encoding="utf-8") as text_io:
177
+ lines = text_io.readlines()
178
+ else:
179
+ raise ValueError("Only support .pdb or .pdb.gz file, but got %s" % path)
180
+
181
+ values = {"COMPND": defaultdict(dict),
182
+ "SOURCE": defaultdict(dict),
183
+ }
184
+
185
+ comp_molid = ""
186
+ last_comp_key = ""
187
+
188
+ release_date = None
189
+
190
+ for hh in lines:
191
+ h = hh.strip()
192
+ key = h[:6].strip()
193
+ tt = h[10:].strip().strip(";")
194
+
195
+ if key in ["COMPND", "SOURCE"]:
196
+ tok = tt.split(":")
197
+ if len(tok) >= 2:
198
+ ckey = tok[0].lower().strip()
199
+ cval = tok[1].strip()
200
+ if ckey == "mol_id":
201
+ comp_molid = cval
202
+ values[key][comp_molid] = dict()
203
+ else:
204
+ values[key][comp_molid][ckey] = cval
205
+ last_comp_key = ckey
206
+ else:
207
+ if last_comp_key != "":
208
+ values[key][comp_molid][last_comp_key] += " " + tok[0].strip()
209
+ elif key == "REVDAT":
210
+ rr = re.search(r"\d\d-\w\w\w-\d\d", tt)
211
+ if rr is not None:
212
+ src_fmt = "%d-%b-%y"
213
+ date_obj = datetime.strptime(rr.group(), src_fmt)
214
+ if isinstance(release_date, datetime) and date_obj < release_date:
215
+ # update the early one
216
+ release_date = date_obj
217
+ elif release_date is None:
218
+ release_date = date_obj
219
+
220
+ outputs = dict(description=dict(),
221
+ source=dict())
222
+
223
+ ch_id2mol_id = dict()
224
+ for mol_id, val in values["COMPND"].items():
225
+ chain_str = val.get("chain", "").strip()
226
+ if chain_str != "":
227
+ chains = chain_str.split(",")
228
+ for ch in chains:
229
+ ch_id2mol_id[ch.strip()] = mol_id
230
+
231
+ for mol_id, val in values["COMPND"].items():
232
+ m = val.get("molecule", "").strip()
233
+ if m != "":
234
+ outputs["description"][mol_id] = m
235
+
236
+ for mol_id, val in values["SOURCE"].items():
237
+ name = val.get("organism_scientific", "").strip()
238
+ taxid = val.get("organism_taxid", "").strip()
239
+ if name not in ["", "?", "."] or taxid not in ["", "?", "."]:
240
+ outputs["source"][mol_id] = [name, taxid]
241
+ outputs["ch_id2mol_id"] = ch_id2mol_id
242
+
243
+ # release date
244
+ if isinstance(release_date, datetime):
245
+ outputs["release_date"] = release_date.strftime("%Y-%m-%d")
246
+ else:
247
+ outputs["release_date"] = ""
248
+ return outputs
249
+
250
+
251
+ def parse_pdb(path: str) -> dict:
252
+ if not is_pdb(path):
253
+ raise TypeError("Input file is not a pdb file [.pdb or .pdb.gz]: %s" % path)
254
+
255
+ st = gemmi.read_structure(path)
256
+ st.setup_entities()
257
+ st.assign_serial_numbers()
258
+
259
+ values = _get_pdb_header(path)
260
+
261
+ mol_id2entity_name = dict()
262
+ for ent in st.entities:
263
+ if ent.name in values["ch_id2mol_id"]:
264
+ mol_id = values["ch_id2mol_id"][ent.name]
265
+ mol_id2entity_name[mol_id] = ent.name
266
+
267
+ # replace mod_id to entity.name
268
+ description = {mol_id2entity_name[mol_id]: v for mol_id, v in values["description"].items()
269
+ if mol_id in mol_id2entity_name}
270
+ # add ligand and water entity description
271
+ # gemmi use ligand name or water as entity name, take this as description
272
+ for ent in st.entities:
273
+ if (ent.name not in description
274
+ and ent.polymer_type.name == "Unknown"
275
+ and ent.name != ""
276
+ and len(ent.name) > 1):
277
+ description[ent.name] = ent.name
278
+
279
+ source = {mol_id2entity_name[mol_id]: v for mol_id, v in values["source"].items()
280
+ if mol_id in mol_id2entity_name}
281
+
282
+ # assign digital entity names
283
+ mapper = assign_digital_entity_names(st)
284
+
285
+ info_map = dict(st.info)
286
+ pdb_code = info_map.get("_entry.id", "").lower()
287
+ info = dict(description={mapper.get(k, k): v for k, v in description.items()},
288
+ source={mapper.get(k, k): v for k, v in source.items()},
289
+ resolution=st.resolution,
290
+ pdb_id=pdb_code if gemmi.is_pdb_code(pdb_code) else "",
291
+ method=info_map.get("_exptl.method", "").lower(),
292
+ deposition_date=info_map.get("_pdbx_database_status.recvd_initial_deposition_date", ""),
293
+ release_date=values["release_date"],
294
+ title=info_map.get("_struct.title", ""),
295
+ )
296
+ return dict(structure=st, info=info)
297
+
298
+
299
+ def assign_digital_entity_names(structure: gemmi.Structure) -> Optional[Dict[str, str]]:
300
+ """
301
+ :param structure:
302
+ :return:
303
+ dict, original entity name to new digital entity name
304
+ """
305
+ all_digit_name = np.all([ent.name.isdigit() for ent in structure.entities])
306
+
307
+ mapper = dict()
308
+ if not all_digit_name:
309
+ for ix, ent in enumerate(structure.entities):
310
+ new_name = str(ix + 1)
311
+ mapper[ent.name] = new_name
312
+ ent.name = new_name
313
+ return mapper
314
+
315
+
316
+ class StructureParser(object):
317
+ """
318
+ Structure reader for .cif, .cif.gz, .pdb or .pdb.gz
319
+
320
+ Read the first model
321
+ """
322
+
323
+ def __init__(self, structure: Optional[gemmi.Structure] = None):
324
+ if not isinstance(structure, (type(None), gemmi.Structure)):
325
+ raise ValueError("structure must be gemmi.Structure or None")
326
+ if structure is None:
327
+ # init with an empty model
328
+ self.STRUCT = gemmi.Structure()
329
+ self.MODEL = gemmi.Model(1)
330
+ self.STRUCT.add_model(self.MODEL)
331
+ elif isinstance(structure, gemmi.Structure):
332
+ self.STRUCT = structure.clone()
333
+ else:
334
+ raise ValueError("structure must be gemmi.Structure or None")
335
+
336
+ self._init_struct()
337
+
338
+ info_map = dict(self.STRUCT.info)
339
+ pdb_code = info_map.get("_entry.id", "").lower()
340
+ self.INFO = dict(description=dict(),
341
+ source=dict(),
342
+ resolution=self.STRUCT.resolution,
343
+ pdb_id=pdb_code if gemmi.is_pdb_code(pdb_code) else "",
344
+ method=info_map.get("_exptl.method", "").lower(),
345
+ deposition_date=info_map.get("_pdbx_database_status.recvd_initial_deposition_date", ""),
346
+ release_date="",
347
+ title=info_map.get("_struct.title", ""),
348
+ )
349
+ self.update_entity()
350
+
351
+ def _init_struct(self):
352
+ self.STRUCT.setup_entities()
353
+ self.STRUCT.assign_serial_numbers()
354
+ self.STRUCT.renumber_models()
355
+
356
+ # keep the first model
357
+ if len(self.STRUCT) > 1:
358
+ for idx in reversed(list(range(1, len(self.STRUCT)))):
359
+ del self.STRUCT[idx]
360
+
361
+ self.MODEL = self.STRUCT[0]
362
+ self.STRUCT.remove_empty_chains()
363
+ self._update_full_sequences()
364
+
365
+ def load_from_file(self, path: str):
366
+ """
367
+ Load model from file, default use the first model.
368
+ :param path:
369
+ :return:
370
+ """
371
+ if is_pdb(path):
372
+ val = parse_pdb(path)
373
+ self.STRUCT, self.INFO = val["structure"], val["info"]
374
+ elif is_cif(path):
375
+ val = parse_cif(path)
376
+ self.STRUCT, self.INFO = val["structure"], val["info"]
377
+ else:
378
+ raise ValueError("path must be files with suffixes [ .cif, .cif.gz, .pdb or .pdb.gz]")
379
+
380
+ self._init_struct()
381
+ self.update_entity()
382
+
383
+ def _update_full_sequences(self):
384
+ for idx, ent in enumerate(self.STRUCT.entities):
385
+ if ent.entity_type.name == "Polymer":
386
+ self.STRUCT.entities[idx].full_sequence = [gemmi.Entity.first_mon(item) for item in ent.full_sequence]
387
+
388
+ if len(ent.full_sequence) == 0:
389
+ sc = self.get_subchain(ent.subchains[0])
390
+ self.STRUCT.entities[idx].full_sequence = sc.extract_sequence()
391
+
392
+ @property
393
+ def chain_ids(self):
394
+ return [ch.name for ch in self.MODEL]
395
+
396
+ @property
397
+ def subchain_ids(self):
398
+ return [ch.subchain_id() for ch in self.MODEL.subchains()]
399
+
400
+ @property
401
+ def assembly_names(self):
402
+ return [assem.name for assem in self.STRUCT.assemblies]
403
+
404
+ @property
405
+ def polymer_types(self):
406
+ subchain_id2polymer = dict()
407
+ for ent in self.STRUCT.entities:
408
+ if ent.entity_type.name == "Polymer":
409
+ for ch in ent.subchains:
410
+ subchain_id2polymer[ch] = ent.polymer_type
411
+
412
+ out = dict()
413
+ for chain in self.MODEL:
414
+ polymer_ch = chain.get_polymer()
415
+ seq = polymer_ch.extract_sequence()
416
+ if seq:
417
+ subchain_id = polymer_ch.subchain_id()
418
+ if subchain_id in subchain_id2polymer:
419
+ out[chain.name] = subchain_id2polymer[subchain_id]
420
+ return out
421
+
422
+ def polymer_sequences(self, pdbx: bool = False):
423
+ """
424
+ entity sequences for polymers
425
+ :param pdbx:
426
+ :return:
427
+ """
428
+ out = dict()
429
+ subchain_id2entity_id = self.subchain_id_to_entity_id
430
+ entity_dict = {ent.name: ent for ent in self.STRUCT.entities}
431
+
432
+ for ch, polymer_type in self.polymer_types.items():
433
+ polymer = self.get_chain(ch).get_polymer()
434
+ entity_id = subchain_id2entity_id[polymer.subchain_id()]
435
+ ent = entity_dict[entity_id]
436
+
437
+ if pdbx:
438
+ s = gemmi.pdbx_one_letter_code(ent.full_sequence, gemmi.sequence_kind(polymer_type))
439
+ else:
440
+ s = self.one_letter_code(ent.full_sequence)
441
+ out[ch] = s
442
+ return out
443
+
444
+ @staticmethod
445
+ def one_letter_code(sequences: List[str]):
446
+ s = "".join([gemmi.find_tabulated_residue(r).one_letter_code for r in sequences]).upper().replace(" ", "X")
447
+ return s
448
+
449
+ def get_subchain(self, subchain_id: str):
450
+ out = None
451
+ for ch in self.MODEL.subchains():
452
+ if ch.subchain_id() == subchain_id:
453
+ out = ch
454
+ break
455
+
456
+ if out is None:
457
+ raise ValueError("Sub-Chain %s not found (only [%s])" % (subchain_id, " ".join(self.subchain_ids)))
458
+
459
+ return out
460
+
461
+ @property
462
+ def subchain_id_to_entity_id(self):
463
+ return {ch: ent.name for ent in self.STRUCT.entities for ch in ent.subchains}
464
+
465
+ @property
466
+ def subchain_id_to_chain_id(self):
467
+ return {sch.subchain_id(): chain.name for chain in self.MODEL for sch in chain.subchains()}
468
+
469
+ def get_chain(self, chain_id: str):
470
+ return self.MODEL[chain_id]
471
+
472
+ def pick_chains(self, chain_names: List[str]):
473
+ struct = gemmi.Structure()
474
+ struct.name = self.STRUCT.name
475
+ model = gemmi.Model(1)
476
+ for ch_id in chain_names:
477
+ model.add_chain(self.get_chain(ch_id))
478
+
479
+ struct.add_model(model)
480
+
481
+ # add basic information
482
+ struct.resolution = self.STRUCT.resolution
483
+
484
+ vals = {"_exptl.method": self.INFO["method"],
485
+ "_struct.title": "(Chains %s): " % " ".join(chain_names) + self.INFO["title"],
486
+ "_pdbx_database_status.recvd_initial_deposition_date": self.INFO["deposition_date"],
487
+ }
488
+ if self.INFO["pdb_id"] != "":
489
+ vals["_entry.id"] = self.INFO["pdb_id"]
490
+
491
+ struct.info = gemmi.InfoMap(vals)
492
+ new_struct = StructureParser(struct)
493
+
494
+ new_struct.INFO["description"] = {ent.name: self.INFO["description"][ent.name]
495
+ for ent in new_struct.STRUCT.entities
496
+ if ent.name in self.INFO["description"]
497
+ }
498
+ new_struct.INFO["source"] = {ent.name: self.INFO["source"][ent.name]
499
+ for ent in new_struct.STRUCT.entities
500
+ if ent.name in self.INFO["source"]
501
+ }
502
+ return new_struct
503
+
504
+ def _raw_marks(self):
505
+ subchain2chain = dict()
506
+ for chain in self.MODEL:
507
+ for sub_chain in chain.subchains():
508
+ subchain_id = sub_chain.subchain_id()
509
+ subchain2chain[subchain_id] = chain.name
510
+
511
+ entity2chains = dict()
512
+ for ent in self.STRUCT.entities:
513
+ val = [subchain2chain[sub_ch] for sub_ch in ent.subchains if sub_ch in subchain2chain]
514
+ if len(val) > 0:
515
+ entity2chains[ent.name] = val
516
+
517
+ mol_id = 1
518
+ n_line = 1
519
+ compound_mol = "COMPND {n_line:>3} MOL_ID: {mol_id};"
520
+ compound_molecule = "COMPND {n_line:>3} MOLECULE: {molecule};"
521
+ compound_chain = "COMPND {n_line:>3} CHAIN: {chain};"
522
+
523
+ outputs = []
524
+
525
+ for ent in self.STRUCT.entities:
526
+ if ent.entity_type.name == "Polymer":
527
+ chain = ", ".join(entity2chains[ent.name])
528
+
529
+ molecule = self.INFO["description"].get(ent.name, "")
530
+ if n_line == 1:
531
+ outputs.append("COMPND MOL_ID: {mol_id};".format(mol_id=mol_id))
532
+ else:
533
+ outputs.append(compound_mol.format(n_line=n_line, mol_id=mol_id))
534
+ n_line += 1
535
+
536
+ outputs.append(compound_molecule.format(n_line=n_line, molecule=molecule))
537
+ n_line += 1
538
+
539
+ outputs.append(compound_chain.format(n_line=n_line, chain=chain))
540
+ n_line += 1
541
+
542
+ mol_id += 1
543
+
544
+ mol_id = 1
545
+ n_line = 1
546
+ source_mol = "SOURCE {n_line:>3} MOL_ID: {mol_id};"
547
+ source_scientific = "SOURCE {n_line:>3} ORGANISM_SCIENTIFIC: {organism_scientific};"
548
+ source_taxid = "SOURCE {n_line:>3} ORGANISM_TAXID: {organism_taxid};"
549
+
550
+ for ent in self.STRUCT.entities:
551
+ if ent.entity_type.name == "Polymer":
552
+ src = self.INFO["source"].get(ent.name)
553
+ if src is None:
554
+ organism_scientific, organism_taxid = "", ""
555
+ else:
556
+ organism_scientific, organism_taxid = src
557
+
558
+ if n_line == 1:
559
+ outputs.append("SOURCE MOL_ID: {mol_id};".format(mol_id=mol_id))
560
+ else:
561
+ outputs.append(source_mol.format(n_line=n_line, mol_id=mol_id))
562
+ n_line += 1
563
+
564
+ outputs.append(source_scientific.format(n_line=n_line, organism_scientific=organism_scientific))
565
+ n_line += 1
566
+
567
+ outputs.append(source_taxid.format(n_line=n_line, organism_taxid=organism_taxid))
568
+ n_line += 1
569
+
570
+ mol_id += 1
571
+
572
+ resolution_remarks = ["REMARK 2",
573
+ "REMARK 2 RESOLUTION. %.2f ANGSTROMS." % self.STRUCT.resolution
574
+ ]
575
+ outputs.extend(resolution_remarks)
576
+ return outputs
577
+
578
+ def to_pdb(self, outfile: str, write_minimal_pdb=False):
579
+ struct = self.STRUCT.clone()
580
+ if write_minimal_pdb:
581
+ struct.write_minimal_pdb(outfile)
582
+ else:
583
+ struct.raw_remarks = self._raw_marks()
584
+ struct.write_pdb(outfile)
585
+
586
+ @staticmethod
587
+ def _item_index(block: gemmi.cif.Block, tag: str):
588
+ mapper = dict()
589
+ for idx, item in enumerate(block):
590
+ if item.loop is not None:
591
+ keys = item.loop.tags
592
+ for k in keys:
593
+ mapper[k] = idx
594
+ elif item.pair is not None:
595
+ key = item.pair[0]
596
+ mapper[key] = idx
597
+ return mapper.get(tag)
598
+
599
+ def to_cif(self, outfile: str):
600
+ block = self.STRUCT.make_mmcif_block()
601
+ #### add resolution
602
+ # block.set_pair(tag="_refine.entry_id", value=gemmi.cif.quote(self.INFO["pdb_id"].upper()))
603
+ # block.set_pair(tag="_refine.pdbx_refine_id", value=gemmi.cif.quote(self.INFO["method"].upper()))
604
+ block.set_pair(tag="_refine.ls_d_res_high", value=gemmi.cif.quote(str(self.INFO["resolution"])))
605
+
606
+ # tag_names = ["_exptl.entry_id",
607
+ # "_refine.entry_id", "_refine.pdbx_refine_id",
608
+ # "_refine.ls_d_res_high"]
609
+ # for i in range(1, len(tag_names)):
610
+ # idx_1a = self._item_index(block, tag=tag_names[i])
611
+ # idx_2a = self._item_index(block, tag=tag_names[i - 1])
612
+ # block.move_item(idx_1a, idx_2a + 1)
613
+
614
+ #### add entity description
615
+ ta = block.find_mmcif_category(category="_entity.")
616
+ da = pd.DataFrame(list(ta), columns=list(ta.tags))
617
+ da["_entity.pdbx_description"] = da["_entity.id"].apply(
618
+ lambda i: gemmi.cif.quote(self.INFO["description"].get(i, "?")))
619
+
620
+ rows_1 = da.to_numpy().tolist()
621
+ tags_1 = [s.replace("_entity.", "") for s in da.columns.tolist()]
622
+
623
+ # erase
624
+ qitem = block.find_loop_item("_entity.id")
625
+ if isinstance(qitem, gemmi.cif.Item):
626
+ qitem.erase()
627
+
628
+ # add
629
+ loop_1 = block.init_loop(prefix="_entity.", tags=tags_1)
630
+ for r in rows_1:
631
+ loop_1.add_row(r)
632
+
633
+ idx_1b = self._item_index(block, tag="_entity.id")
634
+ idx_2b = self._item_index(block, tag="_entity_poly.entity_id")
635
+
636
+ # place _entity. before _entity_poly.
637
+ if isinstance(idx_1b, int) and isinstance(idx_2b, int):
638
+ block.move_item(idx_1b, idx_2b - 1)
639
+
640
+ #### add source name and taxid
641
+ loop_2 = block.init_loop(prefix="_entity_src_gen.", tags=["entity_id",
642
+ "pdbx_gene_src_scientific_name",
643
+ "pdbx_gene_src_ncbi_taxonomy_id"])
644
+
645
+ for k, (name, taxid) in self.INFO["source"].items():
646
+ name = name if name != "" else "?"
647
+ taxid = taxid if taxid != "" else "?"
648
+
649
+ loop_2.add_row([gemmi.cif.quote(k),
650
+ gemmi.cif.quote(name),
651
+ gemmi.cif.quote(taxid)]
652
+ )
653
+
654
+ idx_1c = self._item_index(block, tag="_entity_src_gen.entity_id")
655
+ idx_2c = self._item_index(block, tag="_entity_poly_seq.entity_id")
656
+ # place _entity_src_gen. after _entity_poly_seq.
657
+ if isinstance(idx_1c, int) and isinstance(idx_2c, int):
658
+ block.move_item(idx_1c, idx_2c + 1)
659
+
660
+ block.write_file(outfile)
661
+
662
+ def update_entity(self):
663
+ """
664
+ Update ENTITY, .entities .assemblies according to subchains
665
+ :return:
666
+ """
667
+ subchains = self.subchain_ids
668
+
669
+ # update .entities
670
+ new_entities = gemmi.EntityList()
671
+ ent_names = [] # keep
672
+ for ent in self.STRUCT.entities:
673
+ tmp = [i for i in ent.subchains if i in subchains]
674
+ if tmp:
675
+ ent.subchains = tmp
676
+ new_entities.append(ent)
677
+ ent_names.append(ent.name)
678
+ self.STRUCT.entities = new_entities
679
+
680
+ # update INFO
681
+ self.INFO["description"] = {k: v for k, v in self.INFO["description"].items() if k in ent_names}
682
+ self.INFO["source"] = {k: v for k, v in self.INFO["source"].items() if k in ent_names}
683
+
684
+ # update .assemblies
685
+ all_cid = self.chain_ids
686
+ del_assembly_indexes = []
687
+
688
+ for a_i, assembly in enumerate(self.STRUCT.assemblies):
689
+ del_gen_indexes = []
690
+ for g_i, gen in enumerate(assembly.generators):
691
+ # chains
692
+ tmp1 = [i for i in gen.chains if i in all_cid]
693
+ gen.chains = tmp1
694
+
695
+ tmp2 = [i for i in gen.subchains if i in subchains]
696
+ gen.subchains = tmp2
697
+ # empty gen
698
+ if gen.chains == [] and gen.subchains == []:
699
+ del_gen_indexes.append(g_i)
700
+
701
+ del_gen_indexes.sort(reverse=True)
702
+ for dgi in del_gen_indexes:
703
+ del assembly.generators[dgi]
704
+
705
+ if len(del_gen_indexes) == len(assembly.generators):
706
+ del_assembly_indexes.append(a_i)
707
+
708
+ del_assembly_indexes.sort(reverse=True)
709
+ for dai in del_assembly_indexes:
710
+ del self.STRUCT.assemblies[dai]
711
+
712
+ def rename_chain(self, origin_name: str, target_name: str):
713
+ if origin_name not in self.chain_ids:
714
+ raise ValueError("Chain %s not found" % origin_name)
715
+
716
+ other_chain_names = set(self.chain_ids) - {origin_name}
717
+
718
+ if target_name in other_chain_names:
719
+ raise ValueError("Chain %s has existed, please set a different target_name." % target_name)
720
+
721
+ self.STRUCT.rename_chain(origin_name, target_name)
722
+
723
+ for assembly in self.STRUCT.assemblies:
724
+ for gen in assembly.generators:
725
+ tmp = [target_name if c == origin_name else c for c in gen.chains]
726
+ gen.chains = tmp
727
+
728
+ def swap_chain_names(self, chain_name_1: str, chain_name_2: str):
729
+ if chain_name_1 not in self.chain_ids:
730
+ raise ValueError("Chain %s not found" % chain_name_1)
731
+ if chain_name_2 not in self.chain_ids:
732
+ raise ValueError("Chain %s not in found" % chain_name_2)
733
+
734
+ flag = True
735
+ sw_name = ""
736
+
737
+ while flag:
738
+ characters = string.ascii_letters + string.digits
739
+ sw_name = ''.join(random.choices(characters, k=4))
740
+ if sw_name not in self.chain_ids:
741
+ flag = False
742
+
743
+ if sw_name != "":
744
+ self.rename_chain(chain_name_1, sw_name)
745
+ self.rename_chain(chain_name_2, chain_name_1)
746
+ self.rename_chain(sw_name, chain_name_2)
747
+
748
+ def make_one_letter_chain(self, only_uppercase: bool = True):
749
+ uppercase_letters = list(string.ascii_uppercase)
750
+ uppercase_letters.sort(reverse=True)
751
+
752
+ lowercase_letters = list(string.ascii_lowercase)
753
+ lowercase_letters.sort(reverse=True)
754
+
755
+ digit_letters = list(string.digits)
756
+ digit_letters.sort(reverse=True)
757
+
758
+ if only_uppercase:
759
+ letters = uppercase_letters
760
+ else:
761
+ letters = digit_letters + lowercase_letters + uppercase_letters
762
+
763
+ if only_uppercase:
764
+ msg = "The number of chains exceed the number of uppercase letters: %d > %d"
765
+ else:
766
+ msg = "The number of chains exceed the number of one-letter characters: %d > %d"
767
+
768
+ if len(self.chain_ids) > len(letters):
769
+ raise RuntimeError(msg % (len(self.chain_ids), len(letters)))
770
+
771
+ # not use yet
772
+ letters_valid = [l for l in letters if l not in self.chain_ids]
773
+ mapper = {ch: letters_valid.pop() for ch in self.chain_ids if ch not in letters}
774
+
775
+ for origin_name, target_name in mapper.items():
776
+ self.rename_chain(origin_name, target_name)
777
+ return mapper
778
+
779
+ def get_assembly(self, assembly_name: str,
780
+ how: gemmi.HowToNameCopiedChain = gemmi.HowToNameCopiedChain.AddNumber):
781
+ if assembly_name not in self.assembly_names:
782
+ raise ValueError("Assembly %s not found (only [%s])" % (assembly_name, ", ".join(self.assembly_names)))
783
+
784
+ struct = self.STRUCT.clone()
785
+ struct.transform_to_assembly(assembly_name, how)
786
+ struct.info["_struct.title"] = "(Assembly %s): " % assembly_name + struct.info["_struct.title"]
787
+
788
+ new_struct = StructureParser(struct)
789
+
790
+ # find perfect match entities
791
+ entity_mapper = dict()
792
+ for new_ent in new_struct.STRUCT.entities:
793
+ for ent in self.STRUCT.entities:
794
+ if new_ent.entity_type == ent.entity_type:
795
+ if ent.entity_type.name == "Polymer":
796
+ if new_ent.full_sequence == ent.full_sequence:
797
+ entity_mapper[new_ent.name] = ent.name
798
+ break
799
+ else:
800
+ new_s = new_struct.get_subchain(new_ent.subchains[0]).extract_sequence()
801
+ s = self.get_subchain(ent.subchains[0]).extract_sequence()
802
+ if new_s == s:
803
+ entity_mapper[new_ent.name] = ent.name
804
+ break
805
+
806
+ # update Info
807
+ desc = dict()
808
+ src = dict()
809
+
810
+ for ent in new_struct.STRUCT.entities:
811
+ if ent.name in entity_mapper and entity_mapper[ent.name] in self.INFO["description"]:
812
+ desc[ent.name] = self.INFO["description"][entity_mapper[ent.name]]
813
+
814
+ if ent.name in entity_mapper and entity_mapper[ent.name] in self.INFO["source"]:
815
+ src[ent.name] = self.INFO["source"][entity_mapper[ent.name]]
816
+
817
+ new_struct.INFO["description"] = desc
818
+ new_struct.INFO["source"] = src
819
+ return new_struct
820
+
821
+ def clean_structure(self, remove_ligand=False, remove_hydrogen=True):
822
+ """
823
+ Remove water by default
824
+
825
+ :param remove_ligand: bool, default False
826
+ :param remove_hydrogen: bool, default True
827
+ :return:
828
+ """
829
+ self.STRUCT.remove_alternative_conformations()
830
+
831
+ if remove_hydrogen:
832
+ self.STRUCT.remove_hydrogens()
833
+
834
+ if remove_ligand:
835
+ self.STRUCT.remove_ligands_and_waters()
836
+ else:
837
+ self.STRUCT.remove_waters()
838
+
839
+ self.STRUCT.remove_empty_chains()
840
+ self.update_entity()
841
+
842
+ def met_to_mse(self):
843
+ for chain in self.MODEL:
844
+ for residue in chain:
845
+ if residue.name == 'MET':
846
+ residue.name = 'MSE'
847
+ for atom in residue:
848
+ if atom.name == 'SD':
849
+ atom.name = 'SE'
850
+ atom.element = gemmi.Element('Se')
851
+
852
+ def get_atoms(self, arg: str = "*", exclude_hydrogen=False):
853
+ """
854
+
855
+ :param arg: str, "*", "/1/*//N,CA,C,O", "/1/*"
856
+ see gemmi.Selection
857
+ :param exclude_hydrogen: bool, default False
858
+ :return:
859
+ np.ndarray
860
+ """
861
+ sel = gemmi.Selection(arg)
862
+ res = []
863
+
864
+ for model in sel.models(self.STRUCT):
865
+ for chain in sel.chains(model):
866
+ for residue in sel.residues(chain):
867
+ for atom in sel.atoms(residue):
868
+ if exclude_hydrogen and atom.is_hydrogen():
869
+ continue
870
+
871
+ val = (chain.name,
872
+ residue.seqid.num,
873
+ residue.seqid.icode,
874
+ residue.name,
875
+ atom.name,
876
+ atom.element.name,
877
+ atom.charge,
878
+ atom.b_iso,
879
+ atom.occ,
880
+ tuple(atom.pos.tolist()),
881
+ )
882
+ res.append(val)
883
+
884
+ return np.array(res, dtype=ATOM)
885
+
886
+ def compute_interface(self,
887
+ chains_x: List[str],
888
+ chains_y: List[str],
889
+ threshold: float = 5.0):
890
+ """
891
+ :param chains_x:
892
+ :param chains_y:
893
+ :param threshold:
894
+ :return:
895
+ PPI residues of chains_x, PPI residues of chains_y
896
+ """
897
+ for ch in chains_x + chains_y:
898
+ if ch not in self.chain_ids:
899
+ raise ValueError("Chain %s not found (only [%s])" % (ch, " ".join(self.chain_ids)))
900
+
901
+ atom_x = self.get_atoms("/1/%s" % ",".join(chains_x), exclude_hydrogen=True)
902
+ atom_y = self.get_atoms("/1/%s" % ",".join(chains_y), exclude_hydrogen=True)
903
+
904
+ kd_tree_x = cKDTree(atom_x["coordinate"])
905
+ kd_tree_y = cKDTree(atom_y["coordinate"])
906
+
907
+ pairs = kd_tree_x.sparse_distance_matrix(kd_tree_y, threshold, output_type='coo_matrix')
908
+ x_res = np.unique(atom_x[pairs.row][["chain_name", "residue_num", "residue_icode", "residue_name"]])
909
+ y_res = np.unique(atom_y[pairs.col][["chain_name", "residue_num", "residue_icode", "residue_name"]])
910
+
911
+ return x_res, y_res