gemmi-protools 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,862 @@
1
+ """
2
+ @Author: Luo Jiejian
3
+ """
4
+ import hashlib
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import uuid
10
+ from collections import defaultdict
11
+ from importlib.resources import files
12
+
13
+ import numpy as np
14
+ from Bio import SeqIO
15
+ from anarci import run_anarci
16
+ from anarci.germlines import all_germlines
17
+ from gemmi_protools import StructureParser
18
+ from gemmi_protools.tools.align import align_sequences
19
+
20
+
21
+ def hash_sequence(seq: str) -> str:
22
+ """Hash a sequence."""
23
+ return hashlib.sha256(seq.encode()).hexdigest()
24
+
25
+
26
+ def get_fv_region(in_sequence: str):
27
+ # IMGT number, include start and end
28
+ # https://www.imgt.org/IMGTScientificChart/Nomenclature/IMGT-FRCDRdefinition.html
29
+ # αβTCR:Light chain α, heavy chain β
30
+ # γδTCR:Light chain γ, heavy chain δ
31
+ imgt_scheme = dict(
32
+ fr1=(1, 26),
33
+ cdr1=(27, 38),
34
+ fr2=(39, 55),
35
+ cdr2=(56, 65),
36
+ fr3=(66, 104),
37
+ cdr3=(105, 117),
38
+ fr4=(118, 128),
39
+ )
40
+
41
+ mapper = dict()
42
+ num_mapper = dict()
43
+ for k, v in imgt_scheme.items():
44
+ for i in range(v[0], v[1] + 1):
45
+ mapper[i] = k
46
+
47
+ if k == "cdr1":
48
+ ki = 1
49
+ elif k == "cdr2":
50
+ ki = 2
51
+ elif k == "cdr3":
52
+ ki = 3
53
+ else:
54
+ ki = 0
55
+ num_mapper[i] = ki
56
+
57
+ inputs = [("input", in_sequence)]
58
+ _, numbered, alignment_details, _ = run_anarci(inputs, scheme="imgt", assign_germline=True)
59
+ if numbered[0] is None:
60
+ return []
61
+
62
+ outputs = []
63
+ for cur_numbered, cur_details in zip(numbered[0], alignment_details[0]):
64
+ aligned_sites, start, end = cur_numbered
65
+ if aligned_sites == []:
66
+ continue
67
+ # add mask
68
+ # 9 for not Fv region
69
+ # 0 for non-CDR region, 1, 2, 3 for CDR region for the current Fv
70
+ mask = np.full(len(in_sequence), fill_value=9, dtype=np.int8)
71
+ mask[start: end + 1] = 0
72
+ i = 0
73
+ for (site_num, _), site_aa in aligned_sites:
74
+ if site_aa != "-":
75
+ mask[i + start] = num_mapper[site_num]
76
+ i += 1
77
+
78
+ # region_seq
79
+ regions = defaultdict(list)
80
+ for site in aligned_sites:
81
+ region_name = mapper[site[0][0]]
82
+ regions[region_name].append(site[1])
83
+
84
+ max_index = aligned_sites[-1][0][0]
85
+ if max_index < 128:
86
+ for idx in range(max_index + 1, 129):
87
+ region_name = mapper[idx]
88
+ regions[region_name].append("-")
89
+
90
+ cdr1_seq = "".join([aa for aa in regions["cdr1"] if aa != "-"])
91
+ cdr2_seq = "".join([aa for aa in regions["cdr2"] if aa != "-"])
92
+ cdr3_seq = "".join([aa for aa in regions["cdr3"] if aa != "-"])
93
+
94
+ # germ line V gene [fr1], germ line J gene [fr4]
95
+ chain_type = cur_details["chain_type"]
96
+ v_gene_specie, v_gene = cur_details["germlines"]["v_gene"][0]
97
+ j_gene_specie, j_gene = cur_details["germlines"]["j_gene"][0]
98
+
99
+ gl_fr1 = list(
100
+ all_germlines["V"][chain_type][v_gene_specie][v_gene][imgt_scheme["fr1"][0] - 1:imgt_scheme["fr1"][1]])
101
+ gl_fr1_mapper = dict(zip(range(imgt_scheme["fr1"][0], imgt_scheme["fr1"][1] + 1), gl_fr1))
102
+
103
+ gl_fr4 = list(
104
+ all_germlines["J"][chain_type][j_gene_specie][j_gene][imgt_scheme["fr4"][0] - 1:imgt_scheme["fr4"][1]])
105
+ gl_fr4_mapper = dict(zip(range(imgt_scheme["fr4"][0], imgt_scheme["fr4"][1] + 1), gl_fr4))
106
+
107
+ # repair the gap with gl_fr1 and gl_fr4
108
+ # For FR1
109
+ fixed_fr1 = []
110
+ for site in aligned_sites:
111
+ idx, ins = site[0]
112
+ if imgt_scheme["fr1"][0] <= idx <= imgt_scheme["fr1"][1]:
113
+ if ins == ' ' and site[1] == "-" and gl_fr1_mapper[idx] != "-":
114
+ fixed_fr1.append(gl_fr1_mapper[idx])
115
+ else:
116
+ fixed_fr1.append(site[1])
117
+
118
+ # For FR4
119
+ fixed_fr4 = []
120
+ for site in aligned_sites:
121
+ idx, ins = site[0]
122
+ if imgt_scheme["fr4"][0] <= idx <= imgt_scheme["fr4"][1]:
123
+ if ins == ' ' and site[1] == "-" and gl_fr4_mapper[idx] != "-":
124
+ fixed_fr4.append(gl_fr4_mapper[idx])
125
+ else:
126
+ fixed_fr4.append(site[1])
127
+
128
+ # update regions
129
+ regions["fr1"] = fixed_fr1
130
+ regions["fr4"] = fixed_fr4
131
+
132
+ fixed_fv_seq = []
133
+ for r_name in ["fr1", "cdr1", "fr2", "cdr2", "fr3", "cdr3", "fr4"]:
134
+ for aa in regions[r_name]:
135
+ if aa != "-":
136
+ fixed_fv_seq.append(aa)
137
+ fixed_fv_seq = "".join(fixed_fv_seq)
138
+
139
+ outputs.append(dict(Fv_aa=fixed_fv_seq,
140
+ classification=v_gene[0:2],
141
+ chain_type=chain_type,
142
+ v_gene=v_gene_specie + "/" + v_gene,
143
+ j_gene=j_gene_specie + "/" + j_gene,
144
+ cdr1_aa=cdr1_seq,
145
+ cdr2_aa=cdr2_seq,
146
+ cdr3_aa=cdr3_seq,
147
+ mask="".join([str(i) for i in mask.tolist()])
148
+ )
149
+ )
150
+ return outputs
151
+
152
+
153
+ def fv_region_type(inputs: list[dict]):
154
+ n = len(inputs)
155
+ if n == 0:
156
+ return "not-Fv"
157
+ elif n == 1:
158
+ clf = inputs[0]["classification"]
159
+ ct = inputs[0]["chain_type"]
160
+
161
+ v = "%s%s" % (clf, ct)
162
+ if v in ["IGH", "TRB", "TRD"]:
163
+ return "%s/VH" % clf
164
+ elif v in ["IGK", "IGL", "TRA", "TRG"]:
165
+ return "%s/VL" % clf
166
+ else:
167
+ return "other"
168
+ elif n == 2:
169
+ p = {"%s%s" % (item["classification"], item["chain_type"]) for item in inputs}
170
+ if p in [{"IGH", "IGL"}, {"IGH", "IGK"}, {"TRA", "TRB"}, {"TRG", "TRD"}]:
171
+ clf = p.pop()[0:2]
172
+ return "%s/scFv" % clf
173
+ else:
174
+ return "other"
175
+ else:
176
+ return "other"
177
+
178
+
179
+ def annotate_mhc(seq_dict: dict):
180
+ """
181
+
182
+ Args:
183
+ seq_dict: dict,
184
+ key: ch_id
185
+ val: protein seq
186
+
187
+ Returns:
188
+
189
+ """
190
+ hmm_model = str(files("gemmi_protools.data") / "MHC" / "MHC_combined.hmm")
191
+ # save sequences to fasta
192
+ # all chains of biomolecule
193
+ home_dir = os.path.expanduser("~")
194
+ tmp_dir = os.path.join(home_dir, str(uuid.uuid4()))
195
+ os.makedirs(tmp_dir)
196
+
197
+ fasta_file = os.path.join(tmp_dir, "input.fasta")
198
+ with open(fasta_file, "w") as fo:
199
+ for ch_id, seq in seq_dict.items():
200
+ print(">%s" % ch_id, file=fo)
201
+ print(seq, file=fo)
202
+
203
+ result_file = os.path.join(tmp_dir, "result.txt")
204
+ _path = shutil.which("hmmscan")
205
+
206
+ if _path is None:
207
+ raise RuntimeError("hmmscan is not found.")
208
+
209
+ cmd = "%s --tblout %s --cut_ga %s %s" % (_path, result_file, hmm_model, fasta_file)
210
+
211
+ try:
212
+ _ = subprocess.run(cmd, shell=True, check=True,
213
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
214
+ except subprocess.CalledProcessError as ce:
215
+ raise Exception(ce)
216
+ else:
217
+ out = dict()
218
+ with open(result_file, "r") as fi:
219
+ for li in fi:
220
+ if not re.match("#", li.strip()):
221
+ tmp = re.split(r"\s+", li.strip())[0:3]
222
+ out[tmp[2]] = tmp[0]
223
+ finally:
224
+ if os.path.isdir(tmp_dir):
225
+ shutil.rmtree(tmp_dir)
226
+ return out
227
+
228
+
229
+ def annotate_cd1(seq_dict: dict):
230
+ ref_fa = str(files("gemmi_protools.data") / "CD1" / "CD21029_review.fasta")
231
+ identity_thres = 0.8
232
+ coverage_thres = 0.8
233
+
234
+ # load reference sequences
235
+ recorders = SeqIO.parse(ref_fa, "fasta")
236
+ keys = ["CD1a", "CD1b", "CD1c", "CD1d", "CD1e"]
237
+
238
+ ref_sequences = []
239
+ ref_tags = []
240
+ for seq in recorders:
241
+ for k in keys:
242
+ if k in seq.description:
243
+ ref_tags.append(k)
244
+ ref_sequences.append(str(seq.seq))
245
+ break
246
+
247
+ outputs = dict()
248
+ for query_ch, query_seq in seq_dict.items():
249
+ for i, target_seq in enumerate(ref_sequences):
250
+ v = align_sequences(query_seq, target_seq)
251
+ if v["identity"] > identity_thres and v["coverage_1"] > coverage_thres:
252
+ # print(query_ch, v["identity"], v["coverage_1"], i)
253
+ outputs[query_ch] = ref_tags[i]
254
+ break
255
+ return outputs
256
+
257
+
258
+ class ImmuneComplex(object):
259
+ MAX_MHC_I_PEPTIDE_LEN = 13
260
+ MAX_MHC_II_PEPTIDE_LEN = 25
261
+
262
+ def __init__(self, struct_file: str,
263
+ min_fv_ppi_residues: int = 25,
264
+ min_cdr_ppi_residues: int = 5,
265
+ min_b2globulin_ppi_residues: int = 40,
266
+ min_mhc2ab_ppi_residues: int = 40
267
+ ):
268
+ self.struct_file = struct_file
269
+ self.min_fv_ppi_residues = min_fv_ppi_residues
270
+ self.min_cdr_ppi_residues = min_cdr_ppi_residues
271
+ self.min_b2globulin_ppi_residues = min_b2globulin_ppi_residues
272
+ self.min_mhc2ab_ppi_residues = min_mhc2ab_ppi_residues
273
+
274
+ self.st = StructureParser()
275
+ self.st.load_from_file(struct_file)
276
+ self.st.clean_structure(remove_ligand=False)
277
+ self.renumber_structure(self.st)
278
+
279
+ self.model_chains = dict()
280
+
281
+ # Consider protein chains with non-standard residues (X) less than 0.1
282
+ self.protein_chains = []
283
+ self.ligand_chains = []
284
+ self.nucl_chains = []
285
+ self.other_polymer_chains = []
286
+ self.pro = dict()
287
+
288
+ for ch_id in self.st.chain_ids:
289
+ seq = [r.name for r in self.st.get_chain(ch_id)]
290
+ one_letter_seq = self.st.one_letter_code(seq)
291
+ is_good = one_letter_seq.count("X") / len(one_letter_seq) < 0.1
292
+
293
+ if ch_id in self.st.polymer_types and is_good:
294
+ ch_type = self.st.polymer_types[ch_id]
295
+
296
+ if ch_type.name == 'PeptideL':
297
+ self.protein_chains.append(ch_id)
298
+ self.pro[ch_id] = one_letter_seq
299
+ elif ch_type.name in ['Dna', 'Rna']:
300
+ self.nucl_chains.append(ch_id)
301
+ else:
302
+ self.other_polymer_chains.append(ch_id)
303
+ else:
304
+ self.ligand_chains.append(ch_id)
305
+
306
+ self.model_chains[ch_id] = seq
307
+
308
+ self.anarci_ann = self._annotate_proteins_anarci()
309
+ self.mhc_ann = self._annotate_proteins_mhc()
310
+ self.cd1_ann = self._annotate_proteins_cd1()
311
+
312
+ self.ig_chains = []
313
+ self.ig_H = []
314
+ self.ig_L = []
315
+ self.ig_scfv_chains = []
316
+
317
+ self.tr_chains = []
318
+ self.tr_H = []
319
+ self.tr_L = []
320
+ self.tr_scfv_chains = []
321
+
322
+ self.other_ig_tr_chains = []
323
+
324
+ for ch, val in self.anarci_ann.items():
325
+ fv_type = val["fv_type"]
326
+ if fv_type == "TR/scFv":
327
+ self.tr_scfv_chains.append(ch)
328
+ elif fv_type == "IG/scFv":
329
+ self.ig_scfv_chains.append(ch)
330
+ elif fv_type in ["TR/VH", "TR/VL"]:
331
+ self.tr_chains.append(ch)
332
+ if fv_type == "TR/VH":
333
+ self.tr_H.append(ch)
334
+ else:
335
+ self.tr_L.append(ch)
336
+ elif fv_type in ["IG/VH", "IG/VL"]:
337
+ self.ig_chains.append(ch)
338
+ if fv_type == "IG/VH":
339
+ self.ig_H.append(ch)
340
+ else:
341
+ self.ig_L.append(ch)
342
+ else:
343
+ self.other_ig_tr_chains.append(ch)
344
+ print("Warning: fv_type %s for chain %s: %s" % (fv_type, ch, struct_file))
345
+
346
+ self.cd1_chains = list(self.cd1_ann.keys())
347
+
348
+ # exclude cd1 chains with MHC annotations, due to annotation accuracy
349
+ # CD1 annotations with higher accuracy
350
+ self.mhc_chains = [ch for ch in self.mhc_ann.keys() if ch not in self.cd1_chains]
351
+
352
+ self.ig_pairs = self.get_ig_pairs()
353
+ self.tr_pairs = self.get_tr_pairs()
354
+ self.vhh_chains = self.get_vhh()
355
+
356
+ self.ch_types = dict()
357
+
358
+ for ch in self.st.chain_ids:
359
+ if ch in self.st.polymer_types:
360
+ self.ch_types[ch] = self.st.polymer_types[ch].name
361
+ else:
362
+ self.ch_types[ch] = "SmallMol"
363
+
364
+ def _annotate_proteins_anarci(self):
365
+ outputs = dict()
366
+
367
+ for ch, seq in self.pro.items():
368
+ anarci_info = get_fv_region(seq)
369
+ fv_type = fv_region_type(anarci_info)
370
+
371
+ if fv_type != "not-Fv":
372
+ mask = np.array([list(ann["mask"]) for ann in anarci_info], dtype="int")
373
+ cdr_mask = np.any(np.logical_and(mask > 0, mask < 9), axis=0)
374
+ fv_mask = np.any(np.logical_and(mask >= 0, mask < 9), axis=0)
375
+ outputs[ch] = dict(fv_type=fv_type,
376
+ cdr_mask=cdr_mask,
377
+ fv_mask=fv_mask)
378
+ return outputs
379
+
380
+ def _annotate_proteins_mhc(self):
381
+ if len(self.pro) > 0:
382
+ return annotate_mhc(self.pro)
383
+ else:
384
+ return dict()
385
+
386
+ def _annotate_proteins_cd1(self):
387
+ aa_mapping = {aa: aa for aa in "ACDEFGHIKLMNPQRSTVWXY"}
388
+
389
+ if len(self.pro) > 0:
390
+ std_pro = {k: ''.join(aa_mapping.get(r, 'X') for r in v)
391
+ for k, v in self.pro.items()}
392
+ return annotate_cd1(std_pro)
393
+ else:
394
+ return dict()
395
+
396
+ @staticmethod
397
+ def renumber_structure(struct: StructureParser):
398
+ for chain in struct.MODEL:
399
+ count = 1
400
+ for residue in chain:
401
+ residue.seqid.num = count
402
+ residue.seqid.icode = " "
403
+ count += 1
404
+
405
+ def get_interface_mask(self, ch_x: str, ch_y: str):
406
+ res_x, res_y = self.st.compute_interface([ch_x], [ch_y])
407
+ num_x = res_x["residue_num"]
408
+ num_y = res_y["residue_num"]
409
+
410
+ xm = np.zeros(len(self.model_chains[ch_x]), dtype="bool")
411
+ ym = np.zeros(len(self.model_chains[ch_y]), dtype="bool")
412
+
413
+ xm[num_x - 1] = True
414
+ ym[num_y - 1] = True
415
+ return xm, ym
416
+
417
+ def show_chains(self):
418
+ for key in ['protein_chains', 'ligand_chains', 'nucl_chains', 'other_polymer_chains',
419
+ 'ig_chains', 'ig_H', 'ig_L', 'ig_scfv_chains',
420
+ 'tr_chains', 'tr_H', 'tr_L', 'tr_scfv_chains', 'mhc_chains', 'cd1_chains',
421
+ 'ig_pairs', 'tr_pairs', "vhh_chains", "other_ig_tr_chains"
422
+ ]:
423
+ print("%s: %s" % (key, str(self.__dict__[key])))
424
+
425
+ def _hl_pairs(self, h_chains: list, l_chains: list):
426
+ if len(h_chains) == 0 or len(l_chains) == 0:
427
+ return []
428
+
429
+ candidate_pairs = []
430
+ for ch_h in h_chains:
431
+ fv_mask_h = self.anarci_ann[ch_h]["fv_mask"]
432
+
433
+ for ch_l in l_chains:
434
+ fv_mask_l = self.anarci_ann[ch_l]["fv_mask"]
435
+
436
+ ppi_h, ppi_l = self.get_interface_mask(ch_h, ch_l)
437
+ n_ppi_h = np.logical_and(fv_mask_h, ppi_h).sum()
438
+ n_ppi_l = np.logical_and(fv_mask_l, ppi_l).sum()
439
+
440
+ n_ppi = n_ppi_h + n_ppi_l
441
+ if n_ppi >= self.min_fv_ppi_residues:
442
+ candidate_pairs.append((ch_h, ch_l, n_ppi))
443
+ return candidate_pairs
444
+
445
+ def _pairs(self, chains: list):
446
+ # For double H chains or L chains
447
+ # anarci not always right
448
+ chains_sort = chains.copy()
449
+ chains_sort.sort()
450
+
451
+ n_chains = len(chains_sort)
452
+
453
+ if n_chains < 2:
454
+ return []
455
+
456
+ candidate_pairs = []
457
+ for i in range(n_chains - 1):
458
+ ch_i = chains_sort[i]
459
+ fv_mask_i = self.anarci_ann[ch_i]["fv_mask"]
460
+
461
+ for j in range(i + 1, n_chains):
462
+ ch_j = chains_sort[j]
463
+ fv_mask_j = self.anarci_ann[ch_j]["fv_mask"]
464
+
465
+ ppi_i, ppi_j = self.get_interface_mask(ch_i, ch_j)
466
+ n_ppi_i = np.logical_and(fv_mask_i, ppi_i).sum()
467
+ n_ppi_j = np.logical_and(fv_mask_j, ppi_j).sum()
468
+
469
+ n_ppi = n_ppi_i + n_ppi_j
470
+ if n_ppi >= self.min_fv_ppi_residues:
471
+ candidate_pairs.append((ch_i, ch_j, n_ppi))
472
+ return candidate_pairs
473
+
474
+ def _search_pairs(self, h_chains: list, l_chains: list):
475
+ candidate_pairs = (self._hl_pairs(h_chains=h_chains, l_chains=l_chains)
476
+ + self._pairs(chains=h_chains)
477
+ + self._pairs(chains=l_chains)
478
+ )
479
+ candidate_pairs.sort(reverse=True, key=lambda x: x[2])
480
+ # print(candidate_pairs)
481
+
482
+ outputs = []
483
+ _status = {ch: 0 for ch in h_chains + l_chains}
484
+
485
+ for ch_1, ch_2, _ in candidate_pairs:
486
+ if _status[ch_1] == 0 and _status[ch_2] == 0:
487
+ outputs.append((ch_1, ch_2))
488
+ _status[ch_1] = 1
489
+ _status[ch_2] = 1
490
+ return outputs
491
+
492
+ def get_ig_pairs(self):
493
+ return self._search_pairs(h_chains=self.ig_H, l_chains=self.ig_L)
494
+
495
+ def get_tr_pairs(self):
496
+ return self._search_pairs(h_chains=self.tr_H, l_chains=self.tr_L)
497
+
498
+ def get_vhh(self):
499
+ query_chains = []
500
+ paired_chains = []
501
+ for pair in self.ig_pairs:
502
+ paired_chains.extend(list(pair))
503
+
504
+ for ch in self.ig_chains:
505
+ if ch not in paired_chains and self.anarci_ann[ch]["fv_type"] == "IG/VH":
506
+ query_chains.append(ch)
507
+ return query_chains
508
+
509
+ def _search_target_chains(self, query_chains: list, query_type: str):
510
+ """
511
+ Not Consider IG-IG, TR-TR complexes, if exist
512
+ """
513
+ assert query_type in ["IG", "TR"]
514
+
515
+ candidates = []
516
+
517
+ if query_type == "IG":
518
+ target_chains = list(
519
+ set(self.st.chain_ids).difference(set(self.ig_chains + self.ig_scfv_chains + self.other_ig_tr_chains)))
520
+ else:
521
+ target_chains = list(
522
+ set(self.st.chain_ids).difference(set(self.tr_chains + self.tr_scfv_chains + self.other_ig_tr_chains)))
523
+ target_chains.sort()
524
+
525
+ for ch_t in target_chains:
526
+
527
+ n_cdr = 0
528
+ for ch_q in query_chains:
529
+ cdr_mask_q = self.anarci_ann[ch_q]["cdr_mask"]
530
+
531
+ ppi_q, ppi_t = self.get_interface_mask(ch_q, ch_t)
532
+
533
+ # cdr interactions
534
+ n_cdr_q = np.logical_and(cdr_mask_q, ppi_q).sum()
535
+ n_cdr += n_cdr_q
536
+
537
+ if n_cdr >= self.min_cdr_ppi_residues:
538
+ candidates.append(ch_t)
539
+ return candidates
540
+
541
+ def get_ig_complexes(self):
542
+ qt = "IG"
543
+ outputs = []
544
+ for query in self.ig_pairs:
545
+ tmp = self._search_target_chains(query_chains=list(query), query_type=qt)
546
+ if tmp:
547
+ outputs.append(dict(query_chains=list(query),
548
+ target_chains=tmp,
549
+ target_chains_types=[self.ch_types[ch] for ch in tmp],
550
+ complex_type="IG_Ag"
551
+ )
552
+ )
553
+ return outputs
554
+
555
+ def get_vhh_complexes(self):
556
+ qt = "IG"
557
+ pairs = [(ch,) for ch in self.vhh_chains]
558
+
559
+ outputs = []
560
+ for query in pairs:
561
+ tmp = self._search_target_chains(query_chains=list(query), query_type=qt)
562
+ if tmp:
563
+ outputs.append(dict(query_chains=list(query),
564
+ target_chains=tmp,
565
+ target_chains_types=[self.ch_types[ch] for ch in tmp],
566
+ complex_type="VHH_Ag"
567
+ )
568
+ )
569
+ return outputs
570
+
571
+ def get_scfv_complexes(self):
572
+ pairs = [(ch,) for ch in self.ig_scfv_chains]
573
+ qt = "IG"
574
+
575
+ outputs = []
576
+ for query in pairs:
577
+ tmp = self._search_target_chains(query_chains=list(query), query_type=qt)
578
+ if tmp:
579
+ outputs.append(dict(query_chains=list(query),
580
+ target_chains=tmp,
581
+ target_chains_types=[self.ch_types[ch] for ch in tmp],
582
+ complex_type="scFv_Ag"
583
+ )
584
+ )
585
+ return outputs
586
+
587
+ def find_b2mg(self, query_ch: str):
588
+ # only right for MHC I chain or CD1 chain
589
+ assert (query_ch in self.mhc_ann and self.mhc_ann[
590
+ query_ch] == "MHC_I") or query_ch in self.cd1_chains, "Not MHC_I chain or CD1 chain: %s" % query_ch
591
+
592
+ exclude_chains = (self.ig_chains
593
+ + self.ig_scfv_chains
594
+ + self.tr_chains
595
+ + self.tr_scfv_chains
596
+ + self.other_ig_tr_chains
597
+ + self.mhc_chains
598
+ + self.cd1_chains
599
+ )
600
+
601
+ # not peptide chains
602
+ candidates = []
603
+ for cur_ch in self.protein_chains:
604
+ seq_n = len(self.st.polymer_sequences()[cur_ch])
605
+ if seq_n > self.MAX_MHC_II_PEPTIDE_LEN and cur_ch not in exclude_chains:
606
+ m1, m2 = self.get_interface_mask(cur_ch, query_ch)
607
+ n_ppi = m1.sum() + m2.sum()
608
+
609
+ if n_ppi >= self.min_b2globulin_ppi_residues:
610
+ candidates.append((cur_ch, n_ppi))
611
+
612
+ candidates.sort(reverse=True, key=lambda s: s[1])
613
+
614
+ if len(candidates) > 0:
615
+ return candidates[0][0]
616
+ else:
617
+ return ""
618
+
619
+ def find_pair_mhc2ab(self, query_ch: str):
620
+ if query_ch not in self.self.mhc_ann or self.mhc_ann[query_ch] in ["MHC_II_alpha", "MHC_II_beta"]:
621
+ raise RuntimeError("Not MHC_II chain: %s" % query_ch)
622
+
623
+ # query_ch must be MHC2 alpha or beta chain
624
+ # not peptide chains
625
+ candidates = []
626
+ for cur_ch in self.mhc_chains:
627
+ if cur_ch != query_ch:
628
+ m1, _ = self.get_interface_mask(cur_ch, query_ch)
629
+ n_ppi = m1.sum()
630
+
631
+ if n_ppi >= self.min_mhc2ab_ppi_residues:
632
+ candidates.append((cur_ch, n_ppi))
633
+
634
+ candidates.sort(reverse=True, key=lambda s: s[1])
635
+
636
+ if len(candidates) > 0:
637
+ return candidates[0][0]
638
+ else:
639
+ return ""
640
+
641
+ def find_best_chain(self, query_chains: list, ref_chains: list):
642
+ """
643
+ find best chain from query_chains, which owns most interactions with ref_chains
644
+
645
+ Return str, chain id
646
+ """
647
+ tmp = []
648
+ for i, q_ch in enumerate(query_chains):
649
+
650
+ q_masks = []
651
+ for r_ch in ref_chains:
652
+ qm, _ = self.get_interface_mask(q_ch, r_ch)
653
+ q_masks.append(qm)
654
+
655
+ n_contact_residues = np.any(np.array(q_masks), axis=0).sum()
656
+ if n_contact_residues > 0:
657
+ tmp.append((q_ch, n_contact_residues))
658
+
659
+ tmp.sort(reverse=True, key=lambda s: s[1])
660
+
661
+ if tmp:
662
+ return tmp[0][0]
663
+ else:
664
+ return ""
665
+
666
+ def check_mhc(self, query_chains: list, target_chains: list):
667
+ status = defaultdict(list)
668
+ ligand_chains = []
669
+
670
+ for ch in target_chains:
671
+ if ch in self.mhc_chains:
672
+ status[self.mhc_ann[ch]].append(ch)
673
+ else:
674
+ ligand_chains.append(ch)
675
+
676
+ n_status = len(status)
677
+
678
+ complex_type = ""
679
+ msg = ""
680
+ output_chains = target_chains.copy()
681
+
682
+ if n_status == 1:
683
+ if "MHC_I" in status:
684
+ mhc_chain = self.find_best_chain(query_chains=status["MHC_I"],
685
+ ref_chains=query_chains)
686
+
687
+ b2 = self.find_b2mg(mhc_chain)
688
+ # reset output_chains
689
+ if b2 != "":
690
+ output_chains = [mhc_chain, b2] + ligand_chains
691
+ else:
692
+ output_chains = [mhc_chain] + ligand_chains
693
+
694
+ complex_type = "TR_MHC1"
695
+ msg = "Success"
696
+ elif "MHC_II_alpha" in status:
697
+ alpha_chain = self.find_best_chain(query_chains=status["MHC_II_alpha"],
698
+ ref_chains=query_chains)
699
+
700
+ candidate_betas = [ch for ch, t in self.mhc_ann.items() if t == "MHC_II_beta"]
701
+
702
+ beta_chain = ""
703
+ if candidate_betas:
704
+ beta_chain = self.find_best_chain(query_chains=candidate_betas, ref_chains=[alpha_chain])
705
+
706
+ if beta_chain == "":
707
+ msg = "Missing MHC_II_beta chain"
708
+ else:
709
+ msg = "Success"
710
+ complex_type = "TR_MHC2"
711
+
712
+ # place MHC chains first
713
+ output_chains = [alpha_chain, beta_chain] + ligand_chains
714
+ elif "MHC_II_beta" in status:
715
+ beta_chain = self.find_best_chain(query_chains=status["MHC_II_beta"],
716
+ ref_chains=query_chains)
717
+
718
+ candidate_alphas = [ch for ch, t in self.mhc_ann.items() if t == "MHC_II_alpha"]
719
+
720
+ alpha_chain = ""
721
+ if candidate_alphas:
722
+ alpha_chain = self.find_best_chain(query_chains=candidate_alphas, ref_chains=[beta_chain])
723
+
724
+ if alpha_chain == "":
725
+ msg = "Missing MHC_II_alpha chain"
726
+ else:
727
+ msg = "Success"
728
+ complex_type = "TR_MHC2"
729
+
730
+ output_chains = [alpha_chain, beta_chain] + ligand_chains
731
+
732
+ elif n_status == 2:
733
+ if "MHC_II_alpha" in status and "MHC_II_beta" in status:
734
+ if len(status["MHC_II_alpha"]) == 1 and len(status["MHC_II_beta"]) == 1:
735
+ msg = "Success"
736
+ complex_type = "TR_MHC2"
737
+
738
+ # place MHC chains first
739
+ output_chains = [status["MHC_II_alpha"][0], status["MHC_II_beta"][0]] + ligand_chains
740
+ else:
741
+ msg = "Multiple MHC_II_alpha or MHC_II_beta"
742
+ else:
743
+ msg = "Confusing MHC"
744
+ elif n_status > 2:
745
+ msg = "Confusing MHC"
746
+
747
+ return msg, complex_type, output_chains
748
+
749
+ def check_cd1(self, query_chains: list, target_chains: list):
750
+ status = defaultdict(list)
751
+ ligand_chains = []
752
+
753
+ for ch in target_chains:
754
+ if ch in self.cd1_chains:
755
+ status[self.cd1_ann[ch]].append(ch)
756
+ else:
757
+ ligand_chains.append(ch)
758
+
759
+ n_status = len(status)
760
+
761
+ complex_type = ""
762
+ msg = ""
763
+ output_chains = target_chains.copy()
764
+
765
+ if n_status == 1:
766
+ cd1_type = list(status.keys())[0]
767
+
768
+ if len(status[cd1_type]) == 1:
769
+ cd1_chain = status[cd1_type][0]
770
+ else:
771
+ # multiple CD1 chains, pick the one with most interactions
772
+ tmp = []
773
+ for i, t_ch in enumerate(status[cd1_type]):
774
+
775
+ t_masks = []
776
+ for q_ch in query_chains:
777
+ tm, _ = self.get_interface_mask(t_ch, q_ch)
778
+ t_masks.append(tm)
779
+
780
+ n_contact_residues = np.any(np.array(t_masks), axis=0).sum()
781
+ tmp.append((t_ch, n_contact_residues))
782
+
783
+ tmp.sort(reverse=True, key=lambda s: s[1])
784
+ cd1_chain = tmp[0][0]
785
+
786
+ b2 = self.find_b2mg(cd1_chain)
787
+
788
+ # reset output_chains
789
+ if b2 != "":
790
+ output_chains = [cd1_chain, b2] + ligand_chains
791
+ else:
792
+ output_chains = [cd1_chain] + ligand_chains
793
+
794
+ complex_type = "TR_CD1"
795
+ msg = "Success"
796
+
797
+ elif n_status > 1:
798
+ msg = "Multiple CD1 chains"
799
+ return msg, complex_type, output_chains
800
+
801
+ def get_tr_complexes(self):
802
+ qt = "TR"
803
+ outputs = []
804
+ wrongs = []
805
+
806
+ scfv_chains = [(ch,) for ch in self.tr_scfv_chains]
807
+
808
+ for query in self.tr_pairs + scfv_chains:
809
+ tmp = self._search_target_chains(query_chains=list(query), query_type=qt)
810
+
811
+ if tmp:
812
+ msg_b, complex_type_b, target_chains_b = self.check_cd1(query_chains=list(query),
813
+ target_chains=tmp)
814
+
815
+ msg_a, complex_type_a, target_chains_a = self.check_mhc(query_chains=list(query),
816
+ target_chains=tmp)
817
+ if msg_b == "Success":
818
+ outputs.append(dict(query_chains=list(query),
819
+ target_chains=target_chains_b,
820
+ target_chains_types=[self.ch_types[ch] for ch in target_chains_b],
821
+ complex_type=complex_type_b
822
+ )
823
+ )
824
+ elif msg_a == "Success":
825
+ outputs.append(dict(query_chains=list(query),
826
+ target_chains=target_chains_a,
827
+ target_chains_types=[self.ch_types[ch] for ch in target_chains_a],
828
+ complex_type=complex_type_a
829
+ )
830
+ )
831
+ elif msg_a == "" and msg_b == "":
832
+ # TR-Ag
833
+ outputs.append(dict(query_chains=list(query),
834
+ target_chains=tmp,
835
+ target_chains_types=[self.ch_types[ch] for ch in tmp],
836
+ complex_type="TR_Ag"
837
+ )
838
+ )
839
+ else:
840
+ wrongs.append(dict(query_chains=list(query),
841
+ target_chains=tmp,
842
+ target_chains_types=[self.ch_types[ch] for ch in tmp],
843
+ msg_MHC=msg_a,
844
+ msg_CD1=msg_b
845
+ )
846
+ )
847
+ return outputs, wrongs
848
+
849
+ def run(self):
850
+ outputs = []
851
+ wrongs = []
852
+ if self.ig_pairs:
853
+ outputs.extend(self.get_ig_complexes())
854
+ elif self.vhh_chains:
855
+ outputs.extend(self.get_vhh_complexes())
856
+ elif self.ig_scfv_chains:
857
+ outputs.extend(self.get_scfv_complexes())
858
+ elif self.tr_pairs or self.tr_scfv_chains:
859
+ val, wrong = self.get_tr_complexes()
860
+ outputs.extend(val)
861
+ wrongs.extend(wrong)
862
+ return outputs, wrongs