stcrpy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. examples/__init__.py +0 -0
  2. examples/egnn.py +425 -0
  3. stcrpy/__init__.py +5 -0
  4. stcrpy/tcr_datasets/__init__.py +0 -0
  5. stcrpy/tcr_datasets/tcr_graph_dataset.py +499 -0
  6. stcrpy/tcr_datasets/tcr_selector.py +0 -0
  7. stcrpy/tcr_datasets/tcr_structure_dataset.py +0 -0
  8. stcrpy/tcr_datasets/utils.py +350 -0
  9. stcrpy/tcr_formats/__init__.py +0 -0
  10. stcrpy/tcr_formats/tcr_formats.py +114 -0
  11. stcrpy/tcr_formats/tcr_haddock.py +556 -0
  12. stcrpy/tcr_geometry/TCRCoM.py +350 -0
  13. stcrpy/tcr_geometry/TCRCoM_LICENCE +168 -0
  14. stcrpy/tcr_geometry/TCRDock.py +261 -0
  15. stcrpy/tcr_geometry/TCRGeom.py +450 -0
  16. stcrpy/tcr_geometry/TCRGeomFiltering.py +273 -0
  17. stcrpy/tcr_geometry/__init__.py +0 -0
  18. stcrpy/tcr_geometry/reference_data/__init__.py +0 -0
  19. stcrpy/tcr_geometry/reference_data/dock_reference_1_imgt_numbered.pdb +6549 -0
  20. stcrpy/tcr_geometry/reference_data/dock_reference_2_imgt_numbered.pdb +6495 -0
  21. stcrpy/tcr_geometry/reference_data/reference_A.pdb +31 -0
  22. stcrpy/tcr_geometry/reference_data/reference_B.pdb +31 -0
  23. stcrpy/tcr_geometry/reference_data/reference_D.pdb +31 -0
  24. stcrpy/tcr_geometry/reference_data/reference_G.pdb +31 -0
  25. stcrpy/tcr_geometry/reference_data/reference_data.py +104 -0
  26. stcrpy/tcr_interactions/PLIPParser.py +147 -0
  27. stcrpy/tcr_interactions/TCRInteractionProfiler.py +433 -0
  28. stcrpy/tcr_interactions/TCRpMHC_PLIP_Model_Parser.py +133 -0
  29. stcrpy/tcr_interactions/__init__.py +0 -0
  30. stcrpy/tcr_interactions/utils.py +170 -0
  31. stcrpy/tcr_methods/__init__.py +0 -0
  32. stcrpy/tcr_methods/tcr_batch_operations.py +223 -0
  33. stcrpy/tcr_methods/tcr_methods.py +150 -0
  34. stcrpy/tcr_methods/tcr_reformatting.py +18 -0
  35. stcrpy/tcr_metrics/__init__.py +2 -0
  36. stcrpy/tcr_metrics/constants.py +39 -0
  37. stcrpy/tcr_metrics/tcr_interface_rmsd.py +237 -0
  38. stcrpy/tcr_metrics/tcr_rmsd.py +179 -0
  39. stcrpy/tcr_ml/__init__.py +0 -0
  40. stcrpy/tcr_ml/geometry_predictor.py +3 -0
  41. stcrpy/tcr_processing/AGchain.py +89 -0
  42. stcrpy/tcr_processing/Chemical_components.py +48915 -0
  43. stcrpy/tcr_processing/Entity.py +301 -0
  44. stcrpy/tcr_processing/Fragment.py +58 -0
  45. stcrpy/tcr_processing/Holder.py +24 -0
  46. stcrpy/tcr_processing/MHC.py +449 -0
  47. stcrpy/tcr_processing/MHCchain.py +149 -0
  48. stcrpy/tcr_processing/Model.py +37 -0
  49. stcrpy/tcr_processing/Select.py +145 -0
  50. stcrpy/tcr_processing/TCR.py +532 -0
  51. stcrpy/tcr_processing/TCRIO.py +47 -0
  52. stcrpy/tcr_processing/TCRParser.py +1230 -0
  53. stcrpy/tcr_processing/TCRStructure.py +148 -0
  54. stcrpy/tcr_processing/TCRchain.py +160 -0
  55. stcrpy/tcr_processing/__init__.py +3 -0
  56. stcrpy/tcr_processing/annotate.py +480 -0
  57. stcrpy/tcr_processing/utils/__init__.py +0 -0
  58. stcrpy/tcr_processing/utils/common.py +67 -0
  59. stcrpy/tcr_processing/utils/constants.py +367 -0
  60. stcrpy/tcr_processing/utils/region_definitions.py +782 -0
  61. stcrpy/utils/__init__.py +0 -0
  62. stcrpy/utils/error_stream.py +12 -0
  63. stcrpy-1.0.0.dist-info/METADATA +173 -0
  64. stcrpy-1.0.0.dist-info/RECORD +68 -0
  65. stcrpy-1.0.0.dist-info/WHEEL +5 -0
  66. stcrpy-1.0.0.dist-info/licenses/LICENCE +28 -0
  67. stcrpy-1.0.0.dist-info/licenses/stcrpy/tcr_geometry/TCRCoM_LICENCE +168 -0
  68. stcrpy-1.0.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,499 @@
1
+ import warnings
2
+ import itertools
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+
8
+ from ..tcr_processing import TCR, TCRParser
9
+ from . import utils
10
+
11
+ try:
12
+ from torch_geometric.data import Data, Dataset
13
+ import torch
14
+ import torch.nn.functional as F
15
+ except ImportError:
16
+ pass
17
+
18
+
19
+ class TCRGraphConstructor:
20
+
21
+ def __init__(self, config=None, **kwargs):
22
+ if config is None:
23
+ config = {
24
+ "node_level": "residue",
25
+ "residue_coord": ["CA"],
26
+ "node_features": "one_hot",
27
+ "edge_features": "distance",
28
+ "tcr_regions": ["all"],
29
+ "include_antigen": True,
30
+ "include_mhc": True,
31
+ "mhc_distance_threshold": 15.0,
32
+ }
33
+
34
+ for kw in kwargs:
35
+ config[kw] = kwargs[kw]
36
+
37
+ # assert that minimum amount of configuration is set
38
+ assert (
39
+ len(
40
+ set(["node_level", "node_features", "edge_features"])
41
+ - set(config.keys())
42
+ )
43
+ == 0
44
+ )
45
+
46
+ self.config = config
47
+
48
+ self.node_selector = self._get_node_selector()
49
+ self.node_featuriser = self._get_node_featuriser()
50
+ self.edge_featuriser = self._get_edge_featuriser()
51
+
52
+ def set_node_selector(self, node_selector_function):
53
+ import Bio
54
+
55
+ test_res = Bio.PDB.Residue.Residue(id=(" ", 3, " "), resname="GLY", segid=" ")
56
+ atom_N = Bio.PDB.Atom.Atom(
57
+ name="N", coord=np.array([23.399, -5.842, 19.395]), bfactor=67.02
58
+ )
59
+ atom_O = Bio.PDB.Atom.Atom(
60
+ name="O", coord=np.array([24.17, -8.195, 21.998]), bfactor=67.02
61
+ )
62
+ atom_C = Bio.PDB.Atom.Atom(
63
+ name="C", coord=np.array([23.617, -7.263, 21.414]), bfactor=67.02
64
+ )
65
+ atom_CA = Bio.PDB.Atom.Atom(
66
+ name="CA", coord=np.array([24.316, -6.528, 20.288]), bfactor=67.02
67
+ )
68
+ for a in [atom_N, atom_O, atom_C, atom_CA]:
69
+ test_res.add(a)
70
+ try:
71
+ node_selector_function(test_res)
72
+ except Exception as e:
73
+ raise ValueError(
74
+ f"Node selector function should generate node from Bio.PDB.Residue instance. Raised error {e}"
75
+ )
76
+ self.node_selector = node_selector_function
77
+
78
+ def set_node_featuriser(self, node_featuriser_function, test_input=None):
79
+ if test_input is None:
80
+ warnings.warn(
81
+ "No test input provided for new node featuriser, using Bio.PDB.Atom instance"
82
+ )
83
+ import Bio
84
+
85
+ test_input = Bio.PDB.Atom.Atom(
86
+ name="CA", coord=np.array([24.316, -6.528, 20.288]), bfactor=67.02
87
+ )
88
+ try:
89
+ features = node_featuriser_function(test_input)
90
+ except Exception as e:
91
+ raise ValueError(
92
+ f"Node featuriser function could not featurise node {test_input}. Raised error {e}"
93
+ )
94
+ assert (
95
+ isinstance(features, torch.tensor) or features is None
96
+ ), "Node featuriser should generate torch tensor"
97
+ self.node_featuriser = node_featuriser_function
98
+
99
+ def set_edge_featuriser(self, edge_featuriser_function, test_input=None):
100
+ if test_input is None:
101
+ warnings.warn(
102
+ "No test input provided for new edge featuriser, using Bio.PDB.Atom instance"
103
+ )
104
+ import Bio
105
+
106
+ test_input = [
107
+ Bio.PDB.Atom.Atom(
108
+ name="CA", coord=np.array([24.316, -6.528, 20.288]), bfactor=67.02
109
+ ),
110
+ Bio.PDB.Atom.Atom(
111
+ name="CA", coord=np.array([27.623, -12.28, 23.288]), bfactor=67.02
112
+ ),
113
+ Bio.PDB.Atom.Atom(
114
+ name="CA", coord=np.array([16.36, 8.58, 30.288]), bfactor=67.02
115
+ ),
116
+ ]
117
+ try:
118
+ edges, edge_features, edge_weights = edge_featuriser_function(test_input)
119
+ except Exception as e:
120
+ raise ValueError(
121
+ f"Edge featuriser function could not featurise edge {test_input}. Raised error {e}"
122
+ )
123
+ assert (
124
+ edges.shape[0] == 2
125
+ ), f"Edge indices must be 2D to define connected nodes. Edge shape was {edges.shape}"
126
+ if edge_features is not None:
127
+ assert edges.shape[1] == len(edge_features) and isinstance(
128
+ edge_features, torch.tensor
129
+ ), "Edge features configuration invalid"
130
+ if edge_weights is not None:
131
+ assert edges.shape[1] == len(edge_weights) and isinstance(
132
+ edge_weights, torch.tensor
133
+ ), "Edge weights configuration invalid"
134
+
135
+ self.edge_featuriser = edge_featuriser_function
136
+
137
+ def _calculate_distance_matrix(self, coord_1, coord_2):
138
+ assert coord_1.shape[-1] == coord_2.shape[-1] == 3
139
+ coord_1_matrix = np.tile(coord_1, (len(coord_2), 1, 1))
140
+ coord_2_matrix = np.moveaxis(np.tile(coord_2, (len(coord_1), 1, 1)), 0, 1)
141
+ assert coord_1_matrix.shape == coord_2_matrix.shape
142
+
143
+ euclidian_dist_mat = np.sqrt(
144
+ np.sum((coord_1_matrix - coord_2_matrix) ** 2, axis=-1)
145
+ )
146
+ return euclidian_dist_mat.squeeze()
147
+
148
+ def _get_node_selector(self):
149
+ if self.config["node_level"] == "residue":
150
+ if "residue_coord" not in self.config or self.config["residue_coord"] == [
151
+ "CA"
152
+ ]:
153
+ # generate single node per residue with coordinate of CA atom.
154
+ def node_selector(residue):
155
+ if "CA" in residue.child_dict:
156
+ return [residue["CA"]]
157
+ else:
158
+ return [None]
159
+
160
+ return node_selector
161
+ else:
162
+ NotImplementedError
163
+
164
+ def _get_node_featuriser(self):
165
+ if self.config["node_features"] == "one_hot":
166
+
167
+ def one_hot_encoding(atom_node):
168
+ """one_hot_encoding consists of ...
169
+ 4 dims for chain type: [TCR alpha, TCR beta, peptide, MHC]
170
+ 7 dims for CDR loop: [Not CDR, CDRA1, CDRA2, CDRA3, CDRB1, CDRB2, CDRB2]
171
+ 20 dims for residue encoding:
172
+ ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE',
173
+ 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'XXX']
174
+ 37 dims for atom encoding
175
+ Args:
176
+ atom_node (_type_): _description_
177
+ """
178
+ if hasattr(atom_node.parent.parent, "chain_type"):
179
+ chain_type = atom_node.parent.parent.chain_type
180
+ else:
181
+ if hasattr(atom_node.parent.parent, "type"):
182
+ chain_type = atom_node.parent.parent.type
183
+ if chain_type == "peptide":
184
+ chain_type = "antigen"
185
+ else:
186
+ chain_type = atom_node.parent.parent.MHC_type
187
+ chain_type = "MHC" # if calling MHC_type doesn't raise an error, chain is MHC
188
+ chain_type = (
189
+ "MHC" if chain_type in utils.MHC_CHAIN_TYPES else chain_type
190
+ )
191
+
192
+ if chain_type in utils.CHAIN_TYPE_ONE_HOT_ENCODING:
193
+ chain_type_onehot_encoding = utils.CHAIN_TYPE_ONE_HOT_ENCODING[
194
+ chain_type
195
+ ] # one hot as integer
196
+ else:
197
+ warnings.warn(
198
+ f"""
199
+ Could not resolve chain type: {chain_type} of node:
200
+ {atom_node}|{atom_node.parent.parent}|{atom_node.parent.parent.parent.id}"""
201
+ )
202
+ return None
203
+
204
+ if chain_type in ["A", "B", "G", "D"]:
205
+ region = atom_node.parent.region.capitalize()
206
+ else:
207
+ region = "NOT_CDR"
208
+ if region in utils.TCR_REGION_ONE_HOT_ENCODING:
209
+ region_onehot_encoding = utils.TCR_REGION_ONE_HOT_ENCODING[region]
210
+ else:
211
+ region_onehot_encoding = utils.TCR_REGION_ONE_HOT_ENCODING[
212
+ "NOT_CDR"
213
+ ]
214
+
215
+ residue_onehot_encoding = utils.AMINO_ACID_ONEHOT_ENCODING[
216
+ atom_node.parent.resname.strip()
217
+ ]
218
+
219
+ atom37_onehot_encoding = utils.ATOM37_ATOM_ONEHOT_ENCODING[
220
+ atom_node.fullname.strip()
221
+ ]
222
+ atom_onehot_encoding = torch.concat(
223
+ [
224
+ F.one_hot(
225
+ torch.tensor(chain_type_onehot_encoding), num_classes=4
226
+ ),
227
+ F.one_hot(torch.tensor(region_onehot_encoding), num_classes=7),
228
+ F.one_hot(
229
+ torch.tensor(residue_onehot_encoding), num_classes=21
230
+ ),
231
+ F.one_hot(torch.tensor(atom37_onehot_encoding), num_classes=37),
232
+ ]
233
+ )
234
+ return atom_onehot_encoding
235
+
236
+ return one_hot_encoding
237
+ else:
238
+ raise NotImplementedError("Node featurisation method not recognised")
239
+
240
+ def _get_edge_featuriser(self):
241
+ if self.config["edge_features"] == "distance":
242
+ import scipy
243
+
244
+ def distance_edges(nodes, distance_cutoff=15.0, **kwargs):
245
+ dist_mat = np.triu(np.zeros((len(nodes), len(nodes))))
246
+ coords = np.asarray([a.get_coord() for a in nodes])
247
+ dist_mat[np.arange(len(nodes))[:, None] < np.arange(len(nodes))] = (
248
+ scipy.spatial.distance.pdist(coords)
249
+ )
250
+ dist_mat = (
251
+ dist_mat + dist_mat.T + (distance_cutoff * np.eye(len(nodes)))
252
+ ) # add diagonal to remove self edges
253
+ edges = np.argwhere(dist_mat < distance_cutoff)
254
+ edge_features = dist_mat[edges[:, 0], edges[:, 1]]
255
+ return torch.from_numpy(edges), torch.from_numpy(edge_features), None
256
+
257
+ return distance_edges
258
+
259
+ if self.config["edge_features"] == "fully_connected":
260
+
261
+ def fully_connected(nodes, **kwargs):
262
+ edges = np.argwhere(np.ones((len(nodes), len(nodes)), dtype=bool))
263
+ return torch.from_numpy(edges), None, None
264
+
265
+ return fully_connected
266
+
267
+ if self.config["edge_features"] == "interactions":
268
+
269
+ def get_interactions(nodes, tcr, **kwargs):
270
+ edges = {}
271
+ interactions_dict = {
272
+ "hydrophobic": 0,
273
+ "hbond": 1,
274
+ "pistack": 2,
275
+ "saltbridge": 3,
276
+ }
277
+ interactions_df = tcr.profile_peptide_interactions()
278
+ for i, r in interactions_df.iterrows():
279
+ if r.ligand_residue == "HOH":
280
+ # print(f"Skipping row nr {i}: {r}")
281
+ continue
282
+ n1 = tcr.parent[r.protein_chain][(" ", r.protein_number, " ")]
283
+ assert n1.resname == r.protein_residue
284
+ n2 = tcr.get_antigen()[0][(" ", r.ligand_number, " ")]
285
+
286
+ try:
287
+ edge_index = (
288
+ [
289
+ i
290
+ for i, n_i in enumerate(nodes)
291
+ if (n1["CA"].get_coord() == n_i.get_coord()).all()
292
+ ][0],
293
+ [
294
+ j
295
+ for j, n_j in enumerate(nodes)
296
+ if (n2["CA"].get_coord() == n_j.get_coord()).all()
297
+ ][0],
298
+ )
299
+ except IndexError:
300
+ # print(f"Skipping row nr {i}: {r}")
301
+ continue
302
+ edges[edge_index] = interactions_dict[r.type]
303
+
304
+ edge_indices = torch.tensor(list(edges.keys()))
305
+ edge_features = F.one_hot(
306
+ torch.tensor(list(edges.values())), num_classes=4
307
+ )
308
+ assert len(edge_indices) == len(edge_features)
309
+ return edge_indices, edge_features, None
310
+
311
+ return get_interactions
312
+
313
+ else:
314
+ raise NotImplementedError("Edge featurisation method not recognised")
315
+
316
+ def build_graph(self, tcr: TCR, label=None):
317
+ nodes = []
318
+ coordinates = []
319
+
320
+ if (
321
+ "tcr_regions" not in self.config
322
+ or self.config["tcr_regions"] == ["all"]
323
+ or self.config["tcr_regions"] is None
324
+ ):
325
+ tcr_nodes = [
326
+ a
327
+ for res in tcr.get_residues()
328
+ for a in self.node_selector(res)
329
+ if res.id[0].strip() == ""
330
+ ] # filters out waters and other
331
+ tcr_coords = np.array([a.get_coord() for a in tcr_nodes])
332
+ nodes.extend(tcr_nodes)
333
+ coordinates.extend(tcr_coords)
334
+
335
+ if "include_antigen" in self.config and self.config["include_antigen"]:
336
+ if len(tcr.get_antigen()) == 0:
337
+ warnings.warn(
338
+ f"No antigen found for TCR {tcr}. Antigen not included in graph."
339
+ )
340
+ else:
341
+ antigen_nodes = [
342
+ a
343
+ for res in tcr.get_antigen()[0].get_residues()
344
+ for a in self.node_selector(res)
345
+ if res.id[0].strip() == ""
346
+ ]
347
+ antigen_coords = np.array([a.get_coord() for a in antigen_nodes])
348
+ nodes.extend(antigen_nodes)
349
+ coordinates.extend(antigen_coords)
350
+
351
+ if "include_mhc" in self.config and self.config["include_mhc"]:
352
+ if len(tcr.get_MHC()) == 0:
353
+ warnings.warn(f"No MHC found for TCR {tcr}. MHC not included in graph.")
354
+ else:
355
+ mhc_nodes = [
356
+ a
357
+ for res in tcr.get_MHC()[0].get_residues()
358
+ for a in self.node_selector(res)
359
+ if res.id[0].strip() == ""
360
+ ]
361
+ mhc_coords = np.array([a.get_coord() for a in mhc_nodes])
362
+ if "mhc_distance_threshold" in self.config:
363
+ dist_mat = self._calculate_distance_matrix(tcr_coords, mhc_coords)
364
+ mhc_node_mask = (
365
+ np.sum(
366
+ dist_mat < self.config["mhc_distance_threshold"], axis=-1
367
+ )
368
+ > 0
369
+ ) # shape is (len(mhc_nodes), len(tcr_nodes))
370
+ mhc_nodes = list(itertools.compress(mhc_nodes, mhc_node_mask))
371
+ mhc_coords = np.array(
372
+ list(itertools.compress(mhc_coords, mhc_node_mask))
373
+ )
374
+ nodes.extend(mhc_nodes)
375
+ coordinates.extend(mhc_coords)
376
+
377
+ node_features = [self.node_featuriser(n) for n in nodes]
378
+
379
+ # remove nodes that could not be featurised
380
+ indices_to_remove = [idx for idx, n in enumerate(node_features) if n is None]
381
+ for idx in indices_to_remove:
382
+ nodes.pop(idx)
383
+ node_features.pop(idx)
384
+ if len(indices_to_remove) > 0:
385
+ warnings.warn(
386
+ f"{len(indices_to_remove)} nodes removed from original node list of TCR {tcr.parent.parent.id}_{tcr.id}"
387
+ )
388
+ node_features = torch.stack(node_features) # from list of tensors to tensor
389
+ assert len(nodes) == len(node_features)
390
+
391
+ edge_index, edge_features, edge_weight = self.edge_featuriser(nodes, tcr=tcr)
392
+
393
+ assert len(node_features) == len(coordinates)
394
+
395
+ graph = Data(
396
+ x=node_features,
397
+ edge_index=edge_index.T,
398
+ edge_attr=edge_features,
399
+ edge_weight=edge_weight,
400
+ pos=torch.from_numpy(np.array(coordinates)),
401
+ y=label,
402
+ tcr_id=f"{tcr.parent.parent.id}_{tcr.id}",
403
+ )
404
+ return graph
405
+
406
+
407
+ class TCRGraphDataset(Dataset):
408
+
409
+ def __init__(self, root, data_paths, graph_config=None, *args, **kwargs):
410
+
411
+ self.graph_constructor = TCRGraphConstructor(
412
+ config=graph_config, *args, **kwargs
413
+ )
414
+
415
+ if isinstance(data_paths, str):
416
+ if data_paths.endswith(".csv"):
417
+ data_files = pd.read_csv(data_paths)
418
+ elif os.path.isdir(data_paths):
419
+ data_files = pd.DataFrame(
420
+ [
421
+ os.path.join(data_paths, p)
422
+ for p in os.listdir(data_paths)
423
+ if p.endswith(".pdb")
424
+ or p.endswith(".cif")
425
+ or p.endswith(".mmcif")
426
+ ],
427
+ columns=["path"],
428
+ )
429
+ else:
430
+ data_files = pd.DataFrame(data_paths, columns=["path"])
431
+
432
+ self._ids, self._raw_file_names = zip(
433
+ *[
434
+ (data.name, data.path)
435
+ for _, data in data_files.iterrows()
436
+ if (data.path.endswith(".pdb") or data.path.endswith(".cif"))
437
+ ]
438
+ )
439
+
440
+ self._processed_file_names = []
441
+
442
+ super(TCRGraphDataset, self).__init__(root=root)
443
+
444
+ @property
445
+ def raw_file_names(self):
446
+ return self._raw_file_names
447
+
448
+ @property
449
+ def processed_file_names(self):
450
+ return self._processed_file_names
451
+
452
+ @staticmethod
453
+ def _tcr_generator(tcr_parser, tcr_pdb_iter):
454
+ for tcr in tcr_pdb_iter:
455
+ tcr_id = tcr.split("/")[-1].split(".")[0]
456
+ yield tcr_parser.get_tcr_structure(tcr_id, tcr).get_TCRs()
457
+
458
+ def process(self):
459
+ tcr_parser = TCRParser.TCRParser()
460
+ try:
461
+ for tcr_object in self._tcr_generator(tcr_parser, self.raw_file_names):
462
+ for tcr in tcr_object:
463
+ try:
464
+ tcr_graph = self.graph_constructor.build_graph(tcr)
465
+ processed_file_path = os.path.join(
466
+ self.root, "processed", f"{tcr_graph.tcr_id}.pt"
467
+ )
468
+
469
+ torch.save(tcr_graph, processed_file_path)
470
+ self._processed_file_names.append(processed_file_path)
471
+ except Exception as e:
472
+ warnings.warn(f"Dataset parsing error: {str(e)} for TCR: {tcr}")
473
+ except Exception as e:
474
+ warnings.warn(f"Dataset parsing error: {str(e)}")
475
+
476
+ def len(self):
477
+ return len(self._processed_file_names)
478
+
479
+ def get(self, idx):
480
+ graph = torch.load(self._processed_file_names[idx], weights_only=False)
481
+ return graph
482
+
483
+ def pop(self, idx):
484
+ graph = torch.load(self._processed_file_names.pop(idx), weights_only=False)
485
+ return graph
486
+
487
+ def set_y(self, idx, label):
488
+ processed_path = self._processed_file_names[idx]
489
+ graph = torch.load(processed_path, weights_only=False)
490
+ new_graph = Data(
491
+ x=graph.x,
492
+ edge_index=graph.edge_index,
493
+ edge_attr=graph.edge_attr,
494
+ edge_weight=graph.edge_weight,
495
+ pos=graph.pos,
496
+ y=label,
497
+ tcr_id=graph.tcr_id,
498
+ )
499
+ torch.save(new_graph, processed_path)
File without changes
File without changes