stcrpy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. examples/__init__.py +0 -0
  2. examples/egnn.py +425 -0
  3. stcrpy/__init__.py +5 -0
  4. stcrpy/tcr_datasets/__init__.py +0 -0
  5. stcrpy/tcr_datasets/tcr_graph_dataset.py +499 -0
  6. stcrpy/tcr_datasets/tcr_selector.py +0 -0
  7. stcrpy/tcr_datasets/tcr_structure_dataset.py +0 -0
  8. stcrpy/tcr_datasets/utils.py +350 -0
  9. stcrpy/tcr_formats/__init__.py +0 -0
  10. stcrpy/tcr_formats/tcr_formats.py +114 -0
  11. stcrpy/tcr_formats/tcr_haddock.py +556 -0
  12. stcrpy/tcr_geometry/TCRCoM.py +350 -0
  13. stcrpy/tcr_geometry/TCRCoM_LICENCE +168 -0
  14. stcrpy/tcr_geometry/TCRDock.py +261 -0
  15. stcrpy/tcr_geometry/TCRGeom.py +450 -0
  16. stcrpy/tcr_geometry/TCRGeomFiltering.py +273 -0
  17. stcrpy/tcr_geometry/__init__.py +0 -0
  18. stcrpy/tcr_geometry/reference_data/__init__.py +0 -0
  19. stcrpy/tcr_geometry/reference_data/dock_reference_1_imgt_numbered.pdb +6549 -0
  20. stcrpy/tcr_geometry/reference_data/dock_reference_2_imgt_numbered.pdb +6495 -0
  21. stcrpy/tcr_geometry/reference_data/reference_A.pdb +31 -0
  22. stcrpy/tcr_geometry/reference_data/reference_B.pdb +31 -0
  23. stcrpy/tcr_geometry/reference_data/reference_D.pdb +31 -0
  24. stcrpy/tcr_geometry/reference_data/reference_G.pdb +31 -0
  25. stcrpy/tcr_geometry/reference_data/reference_data.py +104 -0
  26. stcrpy/tcr_interactions/PLIPParser.py +147 -0
  27. stcrpy/tcr_interactions/TCRInteractionProfiler.py +433 -0
  28. stcrpy/tcr_interactions/TCRpMHC_PLIP_Model_Parser.py +133 -0
  29. stcrpy/tcr_interactions/__init__.py +0 -0
  30. stcrpy/tcr_interactions/utils.py +170 -0
  31. stcrpy/tcr_methods/__init__.py +0 -0
  32. stcrpy/tcr_methods/tcr_batch_operations.py +223 -0
  33. stcrpy/tcr_methods/tcr_methods.py +150 -0
  34. stcrpy/tcr_methods/tcr_reformatting.py +18 -0
  35. stcrpy/tcr_metrics/__init__.py +2 -0
  36. stcrpy/tcr_metrics/constants.py +39 -0
  37. stcrpy/tcr_metrics/tcr_interface_rmsd.py +237 -0
  38. stcrpy/tcr_metrics/tcr_rmsd.py +179 -0
  39. stcrpy/tcr_ml/__init__.py +0 -0
  40. stcrpy/tcr_ml/geometry_predictor.py +3 -0
  41. stcrpy/tcr_processing/AGchain.py +89 -0
  42. stcrpy/tcr_processing/Chemical_components.py +48915 -0
  43. stcrpy/tcr_processing/Entity.py +301 -0
  44. stcrpy/tcr_processing/Fragment.py +58 -0
  45. stcrpy/tcr_processing/Holder.py +24 -0
  46. stcrpy/tcr_processing/MHC.py +449 -0
  47. stcrpy/tcr_processing/MHCchain.py +149 -0
  48. stcrpy/tcr_processing/Model.py +37 -0
  49. stcrpy/tcr_processing/Select.py +145 -0
  50. stcrpy/tcr_processing/TCR.py +532 -0
  51. stcrpy/tcr_processing/TCRIO.py +47 -0
  52. stcrpy/tcr_processing/TCRParser.py +1230 -0
  53. stcrpy/tcr_processing/TCRStructure.py +148 -0
  54. stcrpy/tcr_processing/TCRchain.py +160 -0
  55. stcrpy/tcr_processing/__init__.py +3 -0
  56. stcrpy/tcr_processing/annotate.py +480 -0
  57. stcrpy/tcr_processing/utils/__init__.py +0 -0
  58. stcrpy/tcr_processing/utils/common.py +67 -0
  59. stcrpy/tcr_processing/utils/constants.py +367 -0
  60. stcrpy/tcr_processing/utils/region_definitions.py +782 -0
  61. stcrpy/utils/__init__.py +0 -0
  62. stcrpy/utils/error_stream.py +12 -0
  63. stcrpy-1.0.0.dist-info/METADATA +173 -0
  64. stcrpy-1.0.0.dist-info/RECORD +68 -0
  65. stcrpy-1.0.0.dist-info/WHEEL +5 -0
  66. stcrpy-1.0.0.dist-info/licenses/LICENCE +28 -0
  67. stcrpy-1.0.0.dist-info/licenses/stcrpy/tcr_geometry/TCRCoM_LICENCE +168 -0
  68. stcrpy-1.0.0.dist-info/top_level.txt +2 -0
examples/__init__.py ADDED
File without changes
examples/egnn.py ADDED
@@ -0,0 +1,425 @@
1
+ # Code adapted from https://github.com/lucidrains/egnn-pytorch/blob/main/egnn_pytorch/egnn_pytorch_geometric.py
2
+
3
+
4
+ # MIT License
5
+
6
+ # Copyright (c) 2021 Phil Wang, Eric Alcaide
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ import torch
28
+ from torch import nn, einsum, broadcast_tensors
29
+ import torch.nn.functional as F
30
+
31
+ from einops import rearrange, repeat
32
+ from einops.layers.torch import Rearrange
33
+
34
+ # helper functions
35
+
36
+
37
+ def exists(val):
38
+ return val is not None
39
+
40
+
41
+ def safe_div(num, den, eps=1e-8):
42
+ res = num.div(den.clamp(min=eps))
43
+ res.masked_fill_(den == 0, 0.0)
44
+ return res
45
+
46
+
47
+ def batched_index_select(values, indices, dim=1):
48
+ value_dims = values.shape[(dim + 1) :]
49
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
50
+ indices = indices[(..., *((None,) * len(value_dims)))]
51
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
52
+ value_expand_len = len(indices_shape) - (dim + 1)
53
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
54
+
55
+ value_expand_shape = [-1] * len(values.shape)
56
+ expand_slice = slice(dim, (dim + value_expand_len))
57
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
58
+ values = values.expand(*value_expand_shape)
59
+
60
+ dim += value_expand_len
61
+ return values.gather(dim, indices)
62
+
63
+
64
+ def fourier_encode_dist(x, num_encodings=4, include_self=True):
65
+ x = x.unsqueeze(-1)
66
+ device, dtype, orig_x = x.device, x.dtype, x
67
+ scales = 2 ** torch.arange(num_encodings, device=device, dtype=dtype)
68
+ x = x / scales
69
+ x = torch.cat([x.sin(), x.cos()], dim=-1)
70
+ x = torch.cat((x, orig_x), dim=-1) if include_self else x
71
+ return x
72
+
73
+
74
+ def embedd_token(x, dims, layers):
75
+ stop_concat = -len(dims)
76
+ to_embedd = x[:, stop_concat:].long()
77
+ for i, emb_layer in enumerate(layers):
78
+ # the portion corresponding to `to_embedd` part gets dropped
79
+ x = torch.cat([x[:, :stop_concat], emb_layer(to_embedd[:, i])], dim=-1)
80
+ stop_concat = x.shape[-1]
81
+ return x
82
+
83
+
84
+ # swish activation fallback
85
+
86
+
87
+ class Swish_(nn.Module):
88
+ def forward(self, x):
89
+ return x * x.sigmoid()
90
+
91
+
92
+ SiLU = nn.SiLU if hasattr(nn, "SiLU") else Swish_
93
+
94
+ # helper classes
95
+
96
+ # this follows the same strategy for normalization as done in SE3 Transformers
97
+ # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95
98
+
99
+
100
+ class CoorsNorm(nn.Module):
101
+ def __init__(self, eps=1e-8, scale_init=1.0):
102
+ super().__init__()
103
+ self.eps = eps
104
+ scale = torch.zeros(1).fill_(scale_init)
105
+ self.scale = nn.Parameter(scale)
106
+
107
+ def forward(self, coors):
108
+ norm = coors.norm(dim=-1, keepdim=True)
109
+ normed_coors = coors / norm.clamp(min=self.eps)
110
+ return normed_coors * self.scale
111
+
112
+
113
+ # global linear attention
114
+
115
+
116
+ class Attention(nn.Module):
117
+ def __init__(self, dim, heads=8, dim_head=64):
118
+ super().__init__()
119
+ inner_dim = heads * dim_head
120
+ self.heads = heads
121
+ self.scale = dim_head**-0.5
122
+
123
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
124
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
125
+ self.to_out = nn.Linear(inner_dim, dim)
126
+
127
+ def forward(self, x, context, mask=None):
128
+ h = self.heads
129
+
130
+ q = self.to_q(x)
131
+ kv = self.to_kv(context).chunk(2, dim=-1)
132
+
133
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, *kv))
134
+ dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
135
+
136
+ if exists(mask):
137
+ mask_value = -torch.finfo(dots.dtype).max
138
+ mask = rearrange(mask, "b n -> b () () n")
139
+ dots.masked_fill_(~mask, mask_value)
140
+
141
+ attn = dots.softmax(dim=-1)
142
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
143
+
144
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
145
+ return self.to_out(out)
146
+
147
+
148
+ class GlobalLinearAttention(nn.Module):
149
+ def __init__(self, *, dim, heads=8, dim_head=64):
150
+ super().__init__()
151
+ self.norm_seq = nn.LayerNorm(dim)
152
+ self.norm_queries = nn.LayerNorm(dim)
153
+ self.attn1 = Attention(dim, heads, dim_head)
154
+ self.attn2 = Attention(dim, heads, dim_head)
155
+
156
+ self.ff = nn.Sequential(
157
+ nn.LayerNorm(dim),
158
+ nn.Linear(dim, dim * 4),
159
+ nn.GELU(),
160
+ nn.Linear(dim * 4, dim),
161
+ )
162
+
163
+ def forward(self, x, queries, mask=None):
164
+ res_x, res_queries = x, queries
165
+ x, queries = self.norm_seq(x), self.norm_queries(queries)
166
+
167
+ induced = self.attn1(queries, x, mask=mask)
168
+ out = self.attn2(x, induced)
169
+
170
+ x = out + res_x
171
+ queries = induced + res_queries
172
+
173
+ x = self.ff(x) + x
174
+ return x, queries
175
+
176
+
177
+ # classes
178
+
179
+
180
+ class EGNN(nn.Module):
181
+ def __init__(
182
+ self,
183
+ dim,
184
+ edge_dim=0,
185
+ m_dim=16,
186
+ fourier_features=0,
187
+ num_nearest_neighbors=0,
188
+ dropout=0.0,
189
+ init_eps=1e-3,
190
+ norm_feats=False,
191
+ norm_coors=False,
192
+ norm_coors_scale_init=1e-2,
193
+ update_feats=True,
194
+ update_coors=True,
195
+ only_sparse_neighbors=False,
196
+ valid_radius=float("inf"),
197
+ m_pool_method="sum",
198
+ soft_edges=False,
199
+ coor_weights_clamp_value=None,
200
+ return_edges=False,
201
+ ):
202
+ super().__init__()
203
+ assert m_pool_method in {
204
+ "sum",
205
+ "mean",
206
+ }, "pool method must be either sum or mean"
207
+ assert (
208
+ update_feats or update_coors
209
+ ), "you must update either features, coordinates, or both"
210
+
211
+ self.fourier_features = fourier_features
212
+
213
+ edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1
214
+ dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
215
+
216
+ self.edge_mlp = nn.Sequential(
217
+ nn.Linear(edge_input_dim, edge_input_dim * 2),
218
+ dropout,
219
+ SiLU(),
220
+ nn.Linear(edge_input_dim * 2, m_dim),
221
+ SiLU(),
222
+ )
223
+
224
+ self.edge_gate = (
225
+ nn.Sequential(nn.Linear(m_dim, 1), nn.Sigmoid()) if soft_edges else None
226
+ )
227
+
228
+ self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity()
229
+ self.coors_norm = (
230
+ CoorsNorm(scale_init=norm_coors_scale_init) if norm_coors else nn.Identity()
231
+ )
232
+
233
+ self.m_pool_method = m_pool_method
234
+
235
+ self.node_mlp = (
236
+ nn.Sequential(
237
+ nn.Linear(dim + m_dim, dim * 2),
238
+ dropout,
239
+ SiLU(),
240
+ nn.Linear(dim * 2, dim),
241
+ )
242
+ if update_feats
243
+ else None
244
+ )
245
+
246
+ self.coors_mlp = (
247
+ nn.Sequential(
248
+ nn.Linear(m_dim, m_dim * 4), dropout, SiLU(), nn.Linear(m_dim * 4, 1)
249
+ )
250
+ if update_coors
251
+ else None
252
+ )
253
+
254
+ self.num_nearest_neighbors = num_nearest_neighbors
255
+ self.only_sparse_neighbors = only_sparse_neighbors
256
+ self.valid_radius = valid_radius
257
+
258
+ self.coor_weights_clamp_value = coor_weights_clamp_value
259
+
260
+ self.init_eps = init_eps
261
+ self.apply(self.init_)
262
+
263
+ self.return_edges = return_edges
264
+
265
+ def init_(self, module):
266
+ if type(module) in {nn.Linear}:
267
+ # seems to be needed to keep the network from exploding to NaN with greater depths
268
+ nn.init.normal_(module.weight, std=self.init_eps)
269
+
270
+ def forward(self, feats, coors, edges=None, mask=None, adj_mat=None):
271
+ (
272
+ b,
273
+ n,
274
+ d,
275
+ device,
276
+ fourier_features,
277
+ num_nearest,
278
+ valid_radius,
279
+ only_sparse_neighbors,
280
+ ) = (
281
+ *feats.shape,
282
+ feats.device,
283
+ self.fourier_features,
284
+ self.num_nearest_neighbors,
285
+ self.valid_radius,
286
+ self.only_sparse_neighbors,
287
+ )
288
+
289
+ if exists(mask):
290
+ num_nodes = mask.sum(dim=-1)
291
+
292
+ use_nearest = num_nearest > 0 or only_sparse_neighbors
293
+
294
+ rel_coors = rearrange(coors, "b i d -> b i () d") - rearrange(
295
+ coors, "b j d -> b () j d"
296
+ )
297
+ rel_dist = (rel_coors**2).sum(dim=-1, keepdim=True)
298
+
299
+ i = j = n
300
+
301
+ if use_nearest:
302
+ ranking = rel_dist[..., 0].clone()
303
+
304
+ if exists(mask):
305
+ rank_mask = mask[:, :, None] * mask[:, None, :]
306
+ ranking.masked_fill_(~rank_mask, 1e5)
307
+
308
+ if exists(adj_mat):
309
+ if len(adj_mat.shape) == 2:
310
+ adj_mat = repeat(adj_mat.clone(), "i j -> b i j", b=b)
311
+
312
+ if only_sparse_neighbors:
313
+ num_nearest = int(adj_mat.float().sum(dim=-1).max().item())
314
+ valid_radius = 0
315
+
316
+ self_mask = rearrange(
317
+ torch.eye(n, device=device, dtype=torch.bool), "i j -> () i j"
318
+ )
319
+
320
+ adj_mat = adj_mat.masked_fill(self_mask, False)
321
+ ranking.masked_fill_(self_mask, -1.0)
322
+ ranking.masked_fill_(adj_mat, 0.0)
323
+
324
+ nbhd_ranking, nbhd_indices = ranking.topk(
325
+ num_nearest, dim=-1, largest=False
326
+ )
327
+
328
+ nbhd_mask = nbhd_ranking <= valid_radius
329
+
330
+ rel_coors = batched_index_select(rel_coors, nbhd_indices, dim=2)
331
+ rel_dist = batched_index_select(rel_dist, nbhd_indices, dim=2)
332
+
333
+ if exists(edges):
334
+ edges = batched_index_select(edges, nbhd_indices, dim=2)
335
+
336
+ j = num_nearest
337
+
338
+ if fourier_features > 0:
339
+ rel_dist = fourier_encode_dist(rel_dist, num_encodings=fourier_features)
340
+ rel_dist = rearrange(rel_dist, "b i j () d -> b i j d")
341
+
342
+ if use_nearest:
343
+ feats_j = batched_index_select(feats, nbhd_indices, dim=1)
344
+ else:
345
+ feats_j = rearrange(feats, "b j d -> b () j d")
346
+
347
+ feats_i = rearrange(feats, "b i d -> b i () d")
348
+ feats_i, feats_j = broadcast_tensors(feats_i, feats_j)
349
+
350
+ edge_input = torch.cat((feats_i, feats_j, rel_dist), dim=-1)
351
+
352
+ if exists(edges):
353
+ edge_input = torch.cat((edge_input, edges), dim=-1)
354
+
355
+ m_ij = self.edge_mlp(edge_input)
356
+
357
+ if exists(self.edge_gate):
358
+ m_ij = m_ij * self.edge_gate(m_ij)
359
+
360
+ if exists(mask):
361
+ mask_i = rearrange(mask, "b i -> b i ()")
362
+
363
+ if use_nearest:
364
+ mask_j = batched_index_select(mask, nbhd_indices, dim=1)
365
+ mask = (mask_i * mask_j) & nbhd_mask
366
+ else:
367
+ mask_j = rearrange(mask, "b j -> b () j")
368
+ mask = mask_i * mask_j
369
+
370
+ if exists(self.coors_mlp):
371
+ coor_weights = self.coors_mlp(m_ij)
372
+ coor_weights = rearrange(coor_weights, "b i j () -> b i j")
373
+
374
+ rel_coors = self.coors_norm(rel_coors)
375
+
376
+ if exists(mask):
377
+ coor_weights.masked_fill_(~mask, 0.0)
378
+
379
+ if exists(self.coor_weights_clamp_value):
380
+ clamp_value = self.coor_weights_clamp_value
381
+ coor_weights.clamp_(min=-clamp_value, max=clamp_value)
382
+
383
+ coors_out = (
384
+ einsum("b i j, b i j c -> b i c", coor_weights, rel_coors) + coors
385
+ )
386
+ else:
387
+ coors_out = coors
388
+
389
+ if exists(self.node_mlp):
390
+ if exists(mask):
391
+ m_ij_mask = rearrange(mask, "... -> ... ()")
392
+ m_ij = m_ij.masked_fill(~m_ij_mask, 0.0)
393
+
394
+ if self.m_pool_method == "mean":
395
+ if exists(mask):
396
+ # masked mean
397
+ mask_sum = m_ij_mask.sum(dim=-2)
398
+ m_i = safe_div(m_ij.sum(dim=-2), mask_sum)
399
+ else:
400
+ m_i = m_ij.mean(dim=-2)
401
+
402
+ elif self.m_pool_method == "sum":
403
+ m_i = m_ij.sum(dim=-2)
404
+
405
+ normed_feats = self.node_norm(feats)
406
+ node_mlp_input = torch.cat((normed_feats, m_i), dim=-1)
407
+ node_out = self.node_mlp(node_mlp_input) + feats
408
+ else:
409
+ node_out = feats
410
+
411
+ if self.return_edges:
412
+ if exists(num_nearest):
413
+ num_neighbours = num_nearest
414
+ else:
415
+ num_neighbours = n
416
+ edges_out = torch.zeros(
417
+ (b, n, n, m_ij.shape[-1])
418
+ ) # initialise full edge matrix
419
+ edges_out.scatter_(
420
+ 2, nbhd_indices.unsqueeze(-1).expand(-1, -1, -1, m_ij.shape[-1]), m_ij
421
+ )
422
+ # assert torch.stack([edges_out[0, i, nbhd_indices[0, i, idx]] == m_ij[0, i, idx] for idx in range(num_neighbours) for i in range(n)]).all() # check assignemnt has worked, comment out at runtime
423
+ return node_out, coors_out, edges_out
424
+
425
+ return node_out, coors_out
stcrpy/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .tcr_processing.TCRParser import TCRParser
2
+ from .tcr_processing.TCRIO import TCRIO
3
+ from .tcr_geometry.TCRDock import TCRDock
4
+ from .tcr_geometry.TCRGeom import TCRGeom
5
+ from .tcr_methods.tcr_methods import load_TCRs, fetch_TCR, yield_TCRs, load_TCR
File without changes