stcrpy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +0 -0
- examples/egnn.py +425 -0
- stcrpy/__init__.py +5 -0
- stcrpy/tcr_datasets/__init__.py +0 -0
- stcrpy/tcr_datasets/tcr_graph_dataset.py +499 -0
- stcrpy/tcr_datasets/tcr_selector.py +0 -0
- stcrpy/tcr_datasets/tcr_structure_dataset.py +0 -0
- stcrpy/tcr_datasets/utils.py +350 -0
- stcrpy/tcr_formats/__init__.py +0 -0
- stcrpy/tcr_formats/tcr_formats.py +114 -0
- stcrpy/tcr_formats/tcr_haddock.py +556 -0
- stcrpy/tcr_geometry/TCRCoM.py +350 -0
- stcrpy/tcr_geometry/TCRCoM_LICENCE +168 -0
- stcrpy/tcr_geometry/TCRDock.py +261 -0
- stcrpy/tcr_geometry/TCRGeom.py +450 -0
- stcrpy/tcr_geometry/TCRGeomFiltering.py +273 -0
- stcrpy/tcr_geometry/__init__.py +0 -0
- stcrpy/tcr_geometry/reference_data/__init__.py +0 -0
- stcrpy/tcr_geometry/reference_data/dock_reference_1_imgt_numbered.pdb +6549 -0
- stcrpy/tcr_geometry/reference_data/dock_reference_2_imgt_numbered.pdb +6495 -0
- stcrpy/tcr_geometry/reference_data/reference_A.pdb +31 -0
- stcrpy/tcr_geometry/reference_data/reference_B.pdb +31 -0
- stcrpy/tcr_geometry/reference_data/reference_D.pdb +31 -0
- stcrpy/tcr_geometry/reference_data/reference_G.pdb +31 -0
- stcrpy/tcr_geometry/reference_data/reference_data.py +104 -0
- stcrpy/tcr_interactions/PLIPParser.py +147 -0
- stcrpy/tcr_interactions/TCRInteractionProfiler.py +433 -0
- stcrpy/tcr_interactions/TCRpMHC_PLIP_Model_Parser.py +133 -0
- stcrpy/tcr_interactions/__init__.py +0 -0
- stcrpy/tcr_interactions/utils.py +170 -0
- stcrpy/tcr_methods/__init__.py +0 -0
- stcrpy/tcr_methods/tcr_batch_operations.py +223 -0
- stcrpy/tcr_methods/tcr_methods.py +150 -0
- stcrpy/tcr_methods/tcr_reformatting.py +18 -0
- stcrpy/tcr_metrics/__init__.py +2 -0
- stcrpy/tcr_metrics/constants.py +39 -0
- stcrpy/tcr_metrics/tcr_interface_rmsd.py +237 -0
- stcrpy/tcr_metrics/tcr_rmsd.py +179 -0
- stcrpy/tcr_ml/__init__.py +0 -0
- stcrpy/tcr_ml/geometry_predictor.py +3 -0
- stcrpy/tcr_processing/AGchain.py +89 -0
- stcrpy/tcr_processing/Chemical_components.py +48915 -0
- stcrpy/tcr_processing/Entity.py +301 -0
- stcrpy/tcr_processing/Fragment.py +58 -0
- stcrpy/tcr_processing/Holder.py +24 -0
- stcrpy/tcr_processing/MHC.py +449 -0
- stcrpy/tcr_processing/MHCchain.py +149 -0
- stcrpy/tcr_processing/Model.py +37 -0
- stcrpy/tcr_processing/Select.py +145 -0
- stcrpy/tcr_processing/TCR.py +532 -0
- stcrpy/tcr_processing/TCRIO.py +47 -0
- stcrpy/tcr_processing/TCRParser.py +1230 -0
- stcrpy/tcr_processing/TCRStructure.py +148 -0
- stcrpy/tcr_processing/TCRchain.py +160 -0
- stcrpy/tcr_processing/__init__.py +3 -0
- stcrpy/tcr_processing/annotate.py +480 -0
- stcrpy/tcr_processing/utils/__init__.py +0 -0
- stcrpy/tcr_processing/utils/common.py +67 -0
- stcrpy/tcr_processing/utils/constants.py +367 -0
- stcrpy/tcr_processing/utils/region_definitions.py +782 -0
- stcrpy/utils/__init__.py +0 -0
- stcrpy/utils/error_stream.py +12 -0
- stcrpy-1.0.0.dist-info/METADATA +173 -0
- stcrpy-1.0.0.dist-info/RECORD +68 -0
- stcrpy-1.0.0.dist-info/WHEEL +5 -0
- stcrpy-1.0.0.dist-info/licenses/LICENCE +28 -0
- stcrpy-1.0.0.dist-info/licenses/stcrpy/tcr_geometry/TCRCoM_LICENCE +168 -0
- stcrpy-1.0.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import itertools
|
|
3
|
+
import os
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from ..tcr_processing import TCR, TCRParser
|
|
9
|
+
from . import utils
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from torch_geometric.data import Data, Dataset
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
except ImportError:
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TCRGraphConstructor:
|
|
20
|
+
|
|
21
|
+
def __init__(self, config=None, **kwargs):
|
|
22
|
+
if config is None:
|
|
23
|
+
config = {
|
|
24
|
+
"node_level": "residue",
|
|
25
|
+
"residue_coord": ["CA"],
|
|
26
|
+
"node_features": "one_hot",
|
|
27
|
+
"edge_features": "distance",
|
|
28
|
+
"tcr_regions": ["all"],
|
|
29
|
+
"include_antigen": True,
|
|
30
|
+
"include_mhc": True,
|
|
31
|
+
"mhc_distance_threshold": 15.0,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
for kw in kwargs:
|
|
35
|
+
config[kw] = kwargs[kw]
|
|
36
|
+
|
|
37
|
+
# assert that minimum amount of configuration is set
|
|
38
|
+
assert (
|
|
39
|
+
len(
|
|
40
|
+
set(["node_level", "node_features", "edge_features"])
|
|
41
|
+
- set(config.keys())
|
|
42
|
+
)
|
|
43
|
+
== 0
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
self.config = config
|
|
47
|
+
|
|
48
|
+
self.node_selector = self._get_node_selector()
|
|
49
|
+
self.node_featuriser = self._get_node_featuriser()
|
|
50
|
+
self.edge_featuriser = self._get_edge_featuriser()
|
|
51
|
+
|
|
52
|
+
def set_node_selector(self, node_selector_function):
|
|
53
|
+
import Bio
|
|
54
|
+
|
|
55
|
+
test_res = Bio.PDB.Residue.Residue(id=(" ", 3, " "), resname="GLY", segid=" ")
|
|
56
|
+
atom_N = Bio.PDB.Atom.Atom(
|
|
57
|
+
name="N", coord=np.array([23.399, -5.842, 19.395]), bfactor=67.02
|
|
58
|
+
)
|
|
59
|
+
atom_O = Bio.PDB.Atom.Atom(
|
|
60
|
+
name="O", coord=np.array([24.17, -8.195, 21.998]), bfactor=67.02
|
|
61
|
+
)
|
|
62
|
+
atom_C = Bio.PDB.Atom.Atom(
|
|
63
|
+
name="C", coord=np.array([23.617, -7.263, 21.414]), bfactor=67.02
|
|
64
|
+
)
|
|
65
|
+
atom_CA = Bio.PDB.Atom.Atom(
|
|
66
|
+
name="CA", coord=np.array([24.316, -6.528, 20.288]), bfactor=67.02
|
|
67
|
+
)
|
|
68
|
+
for a in [atom_N, atom_O, atom_C, atom_CA]:
|
|
69
|
+
test_res.add(a)
|
|
70
|
+
try:
|
|
71
|
+
node_selector_function(test_res)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Node selector function should generate node from Bio.PDB.Residue instance. Raised error {e}"
|
|
75
|
+
)
|
|
76
|
+
self.node_selector = node_selector_function
|
|
77
|
+
|
|
78
|
+
def set_node_featuriser(self, node_featuriser_function, test_input=None):
|
|
79
|
+
if test_input is None:
|
|
80
|
+
warnings.warn(
|
|
81
|
+
"No test input provided for new node featuriser, using Bio.PDB.Atom instance"
|
|
82
|
+
)
|
|
83
|
+
import Bio
|
|
84
|
+
|
|
85
|
+
test_input = Bio.PDB.Atom.Atom(
|
|
86
|
+
name="CA", coord=np.array([24.316, -6.528, 20.288]), bfactor=67.02
|
|
87
|
+
)
|
|
88
|
+
try:
|
|
89
|
+
features = node_featuriser_function(test_input)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Node featuriser function could not featurise node {test_input}. Raised error {e}"
|
|
93
|
+
)
|
|
94
|
+
assert (
|
|
95
|
+
isinstance(features, torch.tensor) or features is None
|
|
96
|
+
), "Node featuriser should generate torch tensor"
|
|
97
|
+
self.node_featuriser = node_featuriser_function
|
|
98
|
+
|
|
99
|
+
def set_edge_featuriser(self, edge_featuriser_function, test_input=None):
|
|
100
|
+
if test_input is None:
|
|
101
|
+
warnings.warn(
|
|
102
|
+
"No test input provided for new edge featuriser, using Bio.PDB.Atom instance"
|
|
103
|
+
)
|
|
104
|
+
import Bio
|
|
105
|
+
|
|
106
|
+
test_input = [
|
|
107
|
+
Bio.PDB.Atom.Atom(
|
|
108
|
+
name="CA", coord=np.array([24.316, -6.528, 20.288]), bfactor=67.02
|
|
109
|
+
),
|
|
110
|
+
Bio.PDB.Atom.Atom(
|
|
111
|
+
name="CA", coord=np.array([27.623, -12.28, 23.288]), bfactor=67.02
|
|
112
|
+
),
|
|
113
|
+
Bio.PDB.Atom.Atom(
|
|
114
|
+
name="CA", coord=np.array([16.36, 8.58, 30.288]), bfactor=67.02
|
|
115
|
+
),
|
|
116
|
+
]
|
|
117
|
+
try:
|
|
118
|
+
edges, edge_features, edge_weights = edge_featuriser_function(test_input)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Edge featuriser function could not featurise edge {test_input}. Raised error {e}"
|
|
122
|
+
)
|
|
123
|
+
assert (
|
|
124
|
+
edges.shape[0] == 2
|
|
125
|
+
), f"Edge indices must be 2D to define connected nodes. Edge shape was {edges.shape}"
|
|
126
|
+
if edge_features is not None:
|
|
127
|
+
assert edges.shape[1] == len(edge_features) and isinstance(
|
|
128
|
+
edge_features, torch.tensor
|
|
129
|
+
), "Edge features configuration invalid"
|
|
130
|
+
if edge_weights is not None:
|
|
131
|
+
assert edges.shape[1] == len(edge_weights) and isinstance(
|
|
132
|
+
edge_weights, torch.tensor
|
|
133
|
+
), "Edge weights configuration invalid"
|
|
134
|
+
|
|
135
|
+
self.edge_featuriser = edge_featuriser_function
|
|
136
|
+
|
|
137
|
+
def _calculate_distance_matrix(self, coord_1, coord_2):
|
|
138
|
+
assert coord_1.shape[-1] == coord_2.shape[-1] == 3
|
|
139
|
+
coord_1_matrix = np.tile(coord_1, (len(coord_2), 1, 1))
|
|
140
|
+
coord_2_matrix = np.moveaxis(np.tile(coord_2, (len(coord_1), 1, 1)), 0, 1)
|
|
141
|
+
assert coord_1_matrix.shape == coord_2_matrix.shape
|
|
142
|
+
|
|
143
|
+
euclidian_dist_mat = np.sqrt(
|
|
144
|
+
np.sum((coord_1_matrix - coord_2_matrix) ** 2, axis=-1)
|
|
145
|
+
)
|
|
146
|
+
return euclidian_dist_mat.squeeze()
|
|
147
|
+
|
|
148
|
+
def _get_node_selector(self):
|
|
149
|
+
if self.config["node_level"] == "residue":
|
|
150
|
+
if "residue_coord" not in self.config or self.config["residue_coord"] == [
|
|
151
|
+
"CA"
|
|
152
|
+
]:
|
|
153
|
+
# generate single node per residue with coordinate of CA atom.
|
|
154
|
+
def node_selector(residue):
|
|
155
|
+
if "CA" in residue.child_dict:
|
|
156
|
+
return [residue["CA"]]
|
|
157
|
+
else:
|
|
158
|
+
return [None]
|
|
159
|
+
|
|
160
|
+
return node_selector
|
|
161
|
+
else:
|
|
162
|
+
NotImplementedError
|
|
163
|
+
|
|
164
|
+
def _get_node_featuriser(self):
|
|
165
|
+
if self.config["node_features"] == "one_hot":
|
|
166
|
+
|
|
167
|
+
def one_hot_encoding(atom_node):
|
|
168
|
+
"""one_hot_encoding consists of ...
|
|
169
|
+
4 dims for chain type: [TCR alpha, TCR beta, peptide, MHC]
|
|
170
|
+
7 dims for CDR loop: [Not CDR, CDRA1, CDRA2, CDRA3, CDRB1, CDRB2, CDRB2]
|
|
171
|
+
20 dims for residue encoding:
|
|
172
|
+
['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE',
|
|
173
|
+
'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'XXX']
|
|
174
|
+
37 dims for atom encoding
|
|
175
|
+
Args:
|
|
176
|
+
atom_node (_type_): _description_
|
|
177
|
+
"""
|
|
178
|
+
if hasattr(atom_node.parent.parent, "chain_type"):
|
|
179
|
+
chain_type = atom_node.parent.parent.chain_type
|
|
180
|
+
else:
|
|
181
|
+
if hasattr(atom_node.parent.parent, "type"):
|
|
182
|
+
chain_type = atom_node.parent.parent.type
|
|
183
|
+
if chain_type == "peptide":
|
|
184
|
+
chain_type = "antigen"
|
|
185
|
+
else:
|
|
186
|
+
chain_type = atom_node.parent.parent.MHC_type
|
|
187
|
+
chain_type = "MHC" # if calling MHC_type doesn't raise an error, chain is MHC
|
|
188
|
+
chain_type = (
|
|
189
|
+
"MHC" if chain_type in utils.MHC_CHAIN_TYPES else chain_type
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if chain_type in utils.CHAIN_TYPE_ONE_HOT_ENCODING:
|
|
193
|
+
chain_type_onehot_encoding = utils.CHAIN_TYPE_ONE_HOT_ENCODING[
|
|
194
|
+
chain_type
|
|
195
|
+
] # one hot as integer
|
|
196
|
+
else:
|
|
197
|
+
warnings.warn(
|
|
198
|
+
f"""
|
|
199
|
+
Could not resolve chain type: {chain_type} of node:
|
|
200
|
+
{atom_node}|{atom_node.parent.parent}|{atom_node.parent.parent.parent.id}"""
|
|
201
|
+
)
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
if chain_type in ["A", "B", "G", "D"]:
|
|
205
|
+
region = atom_node.parent.region.capitalize()
|
|
206
|
+
else:
|
|
207
|
+
region = "NOT_CDR"
|
|
208
|
+
if region in utils.TCR_REGION_ONE_HOT_ENCODING:
|
|
209
|
+
region_onehot_encoding = utils.TCR_REGION_ONE_HOT_ENCODING[region]
|
|
210
|
+
else:
|
|
211
|
+
region_onehot_encoding = utils.TCR_REGION_ONE_HOT_ENCODING[
|
|
212
|
+
"NOT_CDR"
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
residue_onehot_encoding = utils.AMINO_ACID_ONEHOT_ENCODING[
|
|
216
|
+
atom_node.parent.resname.strip()
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
atom37_onehot_encoding = utils.ATOM37_ATOM_ONEHOT_ENCODING[
|
|
220
|
+
atom_node.fullname.strip()
|
|
221
|
+
]
|
|
222
|
+
atom_onehot_encoding = torch.concat(
|
|
223
|
+
[
|
|
224
|
+
F.one_hot(
|
|
225
|
+
torch.tensor(chain_type_onehot_encoding), num_classes=4
|
|
226
|
+
),
|
|
227
|
+
F.one_hot(torch.tensor(region_onehot_encoding), num_classes=7),
|
|
228
|
+
F.one_hot(
|
|
229
|
+
torch.tensor(residue_onehot_encoding), num_classes=21
|
|
230
|
+
),
|
|
231
|
+
F.one_hot(torch.tensor(atom37_onehot_encoding), num_classes=37),
|
|
232
|
+
]
|
|
233
|
+
)
|
|
234
|
+
return atom_onehot_encoding
|
|
235
|
+
|
|
236
|
+
return one_hot_encoding
|
|
237
|
+
else:
|
|
238
|
+
raise NotImplementedError("Node featurisation method not recognised")
|
|
239
|
+
|
|
240
|
+
def _get_edge_featuriser(self):
|
|
241
|
+
if self.config["edge_features"] == "distance":
|
|
242
|
+
import scipy
|
|
243
|
+
|
|
244
|
+
def distance_edges(nodes, distance_cutoff=15.0, **kwargs):
|
|
245
|
+
dist_mat = np.triu(np.zeros((len(nodes), len(nodes))))
|
|
246
|
+
coords = np.asarray([a.get_coord() for a in nodes])
|
|
247
|
+
dist_mat[np.arange(len(nodes))[:, None] < np.arange(len(nodes))] = (
|
|
248
|
+
scipy.spatial.distance.pdist(coords)
|
|
249
|
+
)
|
|
250
|
+
dist_mat = (
|
|
251
|
+
dist_mat + dist_mat.T + (distance_cutoff * np.eye(len(nodes)))
|
|
252
|
+
) # add diagonal to remove self edges
|
|
253
|
+
edges = np.argwhere(dist_mat < distance_cutoff)
|
|
254
|
+
edge_features = dist_mat[edges[:, 0], edges[:, 1]]
|
|
255
|
+
return torch.from_numpy(edges), torch.from_numpy(edge_features), None
|
|
256
|
+
|
|
257
|
+
return distance_edges
|
|
258
|
+
|
|
259
|
+
if self.config["edge_features"] == "fully_connected":
|
|
260
|
+
|
|
261
|
+
def fully_connected(nodes, **kwargs):
|
|
262
|
+
edges = np.argwhere(np.ones((len(nodes), len(nodes)), dtype=bool))
|
|
263
|
+
return torch.from_numpy(edges), None, None
|
|
264
|
+
|
|
265
|
+
return fully_connected
|
|
266
|
+
|
|
267
|
+
if self.config["edge_features"] == "interactions":
|
|
268
|
+
|
|
269
|
+
def get_interactions(nodes, tcr, **kwargs):
|
|
270
|
+
edges = {}
|
|
271
|
+
interactions_dict = {
|
|
272
|
+
"hydrophobic": 0,
|
|
273
|
+
"hbond": 1,
|
|
274
|
+
"pistack": 2,
|
|
275
|
+
"saltbridge": 3,
|
|
276
|
+
}
|
|
277
|
+
interactions_df = tcr.profile_peptide_interactions()
|
|
278
|
+
for i, r in interactions_df.iterrows():
|
|
279
|
+
if r.ligand_residue == "HOH":
|
|
280
|
+
# print(f"Skipping row nr {i}: {r}")
|
|
281
|
+
continue
|
|
282
|
+
n1 = tcr.parent[r.protein_chain][(" ", r.protein_number, " ")]
|
|
283
|
+
assert n1.resname == r.protein_residue
|
|
284
|
+
n2 = tcr.get_antigen()[0][(" ", r.ligand_number, " ")]
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
edge_index = (
|
|
288
|
+
[
|
|
289
|
+
i
|
|
290
|
+
for i, n_i in enumerate(nodes)
|
|
291
|
+
if (n1["CA"].get_coord() == n_i.get_coord()).all()
|
|
292
|
+
][0],
|
|
293
|
+
[
|
|
294
|
+
j
|
|
295
|
+
for j, n_j in enumerate(nodes)
|
|
296
|
+
if (n2["CA"].get_coord() == n_j.get_coord()).all()
|
|
297
|
+
][0],
|
|
298
|
+
)
|
|
299
|
+
except IndexError:
|
|
300
|
+
# print(f"Skipping row nr {i}: {r}")
|
|
301
|
+
continue
|
|
302
|
+
edges[edge_index] = interactions_dict[r.type]
|
|
303
|
+
|
|
304
|
+
edge_indices = torch.tensor(list(edges.keys()))
|
|
305
|
+
edge_features = F.one_hot(
|
|
306
|
+
torch.tensor(list(edges.values())), num_classes=4
|
|
307
|
+
)
|
|
308
|
+
assert len(edge_indices) == len(edge_features)
|
|
309
|
+
return edge_indices, edge_features, None
|
|
310
|
+
|
|
311
|
+
return get_interactions
|
|
312
|
+
|
|
313
|
+
else:
|
|
314
|
+
raise NotImplementedError("Edge featurisation method not recognised")
|
|
315
|
+
|
|
316
|
+
def build_graph(self, tcr: TCR, label=None):
|
|
317
|
+
nodes = []
|
|
318
|
+
coordinates = []
|
|
319
|
+
|
|
320
|
+
if (
|
|
321
|
+
"tcr_regions" not in self.config
|
|
322
|
+
or self.config["tcr_regions"] == ["all"]
|
|
323
|
+
or self.config["tcr_regions"] is None
|
|
324
|
+
):
|
|
325
|
+
tcr_nodes = [
|
|
326
|
+
a
|
|
327
|
+
for res in tcr.get_residues()
|
|
328
|
+
for a in self.node_selector(res)
|
|
329
|
+
if res.id[0].strip() == ""
|
|
330
|
+
] # filters out waters and other
|
|
331
|
+
tcr_coords = np.array([a.get_coord() for a in tcr_nodes])
|
|
332
|
+
nodes.extend(tcr_nodes)
|
|
333
|
+
coordinates.extend(tcr_coords)
|
|
334
|
+
|
|
335
|
+
if "include_antigen" in self.config and self.config["include_antigen"]:
|
|
336
|
+
if len(tcr.get_antigen()) == 0:
|
|
337
|
+
warnings.warn(
|
|
338
|
+
f"No antigen found for TCR {tcr}. Antigen not included in graph."
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
antigen_nodes = [
|
|
342
|
+
a
|
|
343
|
+
for res in tcr.get_antigen()[0].get_residues()
|
|
344
|
+
for a in self.node_selector(res)
|
|
345
|
+
if res.id[0].strip() == ""
|
|
346
|
+
]
|
|
347
|
+
antigen_coords = np.array([a.get_coord() for a in antigen_nodes])
|
|
348
|
+
nodes.extend(antigen_nodes)
|
|
349
|
+
coordinates.extend(antigen_coords)
|
|
350
|
+
|
|
351
|
+
if "include_mhc" in self.config and self.config["include_mhc"]:
|
|
352
|
+
if len(tcr.get_MHC()) == 0:
|
|
353
|
+
warnings.warn(f"No MHC found for TCR {tcr}. MHC not included in graph.")
|
|
354
|
+
else:
|
|
355
|
+
mhc_nodes = [
|
|
356
|
+
a
|
|
357
|
+
for res in tcr.get_MHC()[0].get_residues()
|
|
358
|
+
for a in self.node_selector(res)
|
|
359
|
+
if res.id[0].strip() == ""
|
|
360
|
+
]
|
|
361
|
+
mhc_coords = np.array([a.get_coord() for a in mhc_nodes])
|
|
362
|
+
if "mhc_distance_threshold" in self.config:
|
|
363
|
+
dist_mat = self._calculate_distance_matrix(tcr_coords, mhc_coords)
|
|
364
|
+
mhc_node_mask = (
|
|
365
|
+
np.sum(
|
|
366
|
+
dist_mat < self.config["mhc_distance_threshold"], axis=-1
|
|
367
|
+
)
|
|
368
|
+
> 0
|
|
369
|
+
) # shape is (len(mhc_nodes), len(tcr_nodes))
|
|
370
|
+
mhc_nodes = list(itertools.compress(mhc_nodes, mhc_node_mask))
|
|
371
|
+
mhc_coords = np.array(
|
|
372
|
+
list(itertools.compress(mhc_coords, mhc_node_mask))
|
|
373
|
+
)
|
|
374
|
+
nodes.extend(mhc_nodes)
|
|
375
|
+
coordinates.extend(mhc_coords)
|
|
376
|
+
|
|
377
|
+
node_features = [self.node_featuriser(n) for n in nodes]
|
|
378
|
+
|
|
379
|
+
# remove nodes that could not be featurised
|
|
380
|
+
indices_to_remove = [idx for idx, n in enumerate(node_features) if n is None]
|
|
381
|
+
for idx in indices_to_remove:
|
|
382
|
+
nodes.pop(idx)
|
|
383
|
+
node_features.pop(idx)
|
|
384
|
+
if len(indices_to_remove) > 0:
|
|
385
|
+
warnings.warn(
|
|
386
|
+
f"{len(indices_to_remove)} nodes removed from original node list of TCR {tcr.parent.parent.id}_{tcr.id}"
|
|
387
|
+
)
|
|
388
|
+
node_features = torch.stack(node_features) # from list of tensors to tensor
|
|
389
|
+
assert len(nodes) == len(node_features)
|
|
390
|
+
|
|
391
|
+
edge_index, edge_features, edge_weight = self.edge_featuriser(nodes, tcr=tcr)
|
|
392
|
+
|
|
393
|
+
assert len(node_features) == len(coordinates)
|
|
394
|
+
|
|
395
|
+
graph = Data(
|
|
396
|
+
x=node_features,
|
|
397
|
+
edge_index=edge_index.T,
|
|
398
|
+
edge_attr=edge_features,
|
|
399
|
+
edge_weight=edge_weight,
|
|
400
|
+
pos=torch.from_numpy(np.array(coordinates)),
|
|
401
|
+
y=label,
|
|
402
|
+
tcr_id=f"{tcr.parent.parent.id}_{tcr.id}",
|
|
403
|
+
)
|
|
404
|
+
return graph
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class TCRGraphDataset(Dataset):
|
|
408
|
+
|
|
409
|
+
def __init__(self, root, data_paths, graph_config=None, *args, **kwargs):
|
|
410
|
+
|
|
411
|
+
self.graph_constructor = TCRGraphConstructor(
|
|
412
|
+
config=graph_config, *args, **kwargs
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
if isinstance(data_paths, str):
|
|
416
|
+
if data_paths.endswith(".csv"):
|
|
417
|
+
data_files = pd.read_csv(data_paths)
|
|
418
|
+
elif os.path.isdir(data_paths):
|
|
419
|
+
data_files = pd.DataFrame(
|
|
420
|
+
[
|
|
421
|
+
os.path.join(data_paths, p)
|
|
422
|
+
for p in os.listdir(data_paths)
|
|
423
|
+
if p.endswith(".pdb")
|
|
424
|
+
or p.endswith(".cif")
|
|
425
|
+
or p.endswith(".mmcif")
|
|
426
|
+
],
|
|
427
|
+
columns=["path"],
|
|
428
|
+
)
|
|
429
|
+
else:
|
|
430
|
+
data_files = pd.DataFrame(data_paths, columns=["path"])
|
|
431
|
+
|
|
432
|
+
self._ids, self._raw_file_names = zip(
|
|
433
|
+
*[
|
|
434
|
+
(data.name, data.path)
|
|
435
|
+
for _, data in data_files.iterrows()
|
|
436
|
+
if (data.path.endswith(".pdb") or data.path.endswith(".cif"))
|
|
437
|
+
]
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
self._processed_file_names = []
|
|
441
|
+
|
|
442
|
+
super(TCRGraphDataset, self).__init__(root=root)
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def raw_file_names(self):
|
|
446
|
+
return self._raw_file_names
|
|
447
|
+
|
|
448
|
+
@property
|
|
449
|
+
def processed_file_names(self):
|
|
450
|
+
return self._processed_file_names
|
|
451
|
+
|
|
452
|
+
@staticmethod
|
|
453
|
+
def _tcr_generator(tcr_parser, tcr_pdb_iter):
|
|
454
|
+
for tcr in tcr_pdb_iter:
|
|
455
|
+
tcr_id = tcr.split("/")[-1].split(".")[0]
|
|
456
|
+
yield tcr_parser.get_tcr_structure(tcr_id, tcr).get_TCRs()
|
|
457
|
+
|
|
458
|
+
def process(self):
|
|
459
|
+
tcr_parser = TCRParser.TCRParser()
|
|
460
|
+
try:
|
|
461
|
+
for tcr_object in self._tcr_generator(tcr_parser, self.raw_file_names):
|
|
462
|
+
for tcr in tcr_object:
|
|
463
|
+
try:
|
|
464
|
+
tcr_graph = self.graph_constructor.build_graph(tcr)
|
|
465
|
+
processed_file_path = os.path.join(
|
|
466
|
+
self.root, "processed", f"{tcr_graph.tcr_id}.pt"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
torch.save(tcr_graph, processed_file_path)
|
|
470
|
+
self._processed_file_names.append(processed_file_path)
|
|
471
|
+
except Exception as e:
|
|
472
|
+
warnings.warn(f"Dataset parsing error: {str(e)} for TCR: {tcr}")
|
|
473
|
+
except Exception as e:
|
|
474
|
+
warnings.warn(f"Dataset parsing error: {str(e)}")
|
|
475
|
+
|
|
476
|
+
def len(self):
|
|
477
|
+
return len(self._processed_file_names)
|
|
478
|
+
|
|
479
|
+
def get(self, idx):
|
|
480
|
+
graph = torch.load(self._processed_file_names[idx], weights_only=False)
|
|
481
|
+
return graph
|
|
482
|
+
|
|
483
|
+
def pop(self, idx):
|
|
484
|
+
graph = torch.load(self._processed_file_names.pop(idx), weights_only=False)
|
|
485
|
+
return graph
|
|
486
|
+
|
|
487
|
+
def set_y(self, idx, label):
|
|
488
|
+
processed_path = self._processed_file_names[idx]
|
|
489
|
+
graph = torch.load(processed_path, weights_only=False)
|
|
490
|
+
new_graph = Data(
|
|
491
|
+
x=graph.x,
|
|
492
|
+
edge_index=graph.edge_index,
|
|
493
|
+
edge_attr=graph.edge_attr,
|
|
494
|
+
edge_weight=graph.edge_weight,
|
|
495
|
+
pos=graph.pos,
|
|
496
|
+
y=label,
|
|
497
|
+
tcr_id=graph.tcr_id,
|
|
498
|
+
)
|
|
499
|
+
torch.save(new_graph, processed_path)
|
|
File without changes
|
|
File without changes
|