flash-abb 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flash_abb/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .pretrained import pretrained
2
+ from .pretrained_tap import pretrained_sss, pretrained_tap
@@ -0,0 +1,75 @@
1
+ import os, subprocess, json, argparse,requests
2
+ from yaml import load, Loader
3
+ import torch
4
+
5
+ list_of_models = {
6
+ "flash-abb":"flabb_weights.pt",
7
+ "flash-abb_masked":"flabb_masked_weights.pt",
8
+ }
9
+ flash_abb_models = ["flash-abb", "flash-abb_masked"]
10
+
11
+
12
+ def load_model(model_to_use="flash-abb", random_init=False, device='cpu'):
13
+
14
+ if model_to_use in flash_abb_models:
15
+ flabb, hparams = fetch_flash_abb(
16
+ model_to_use,
17
+ random_init=random_init,
18
+ device=device
19
+ )
20
+ else:
21
+ assert False, f"The selected model to use ({model_to_use}) does not exist.\
22
+ Please select a valid model."
23
+
24
+ return flabb, hparams
25
+
26
+
27
+ def fetch_flash_abb(model_to_use, random_init=False, device='cpu'):
28
+
29
+ from .model.flash_abb import FlashABB
30
+
31
+ local_model_folder = os.path.join(os.path.dirname(__file__), "weights")
32
+ file_model = list_of_models[model_to_use]
33
+
34
+ with open(os.path.join(local_model_folder, 'params.yaml'), 'r', encoding='utf-8') as f:
35
+ hparams = argparse.Namespace(**load(f, Loader=Loader)).model
36
+
37
+ flabb = FlashABB(hparams)
38
+ if not random_init:
39
+ ckpt = torch.load(
40
+ os.path.join(local_model_folder, file_model),
41
+ map_location=torch.device(device),
42
+ weights_only=False,
43
+ )
44
+ flabb.load_state_dict(ckpt)
45
+
46
+ return flabb, hparams
47
+
48
+
49
+ def fetch_sss(random_init=False, device='cpu'):
50
+ from .model.seq2struct2seq import BERTCoords
51
+
52
+ model = BERTCoords(device=device)
53
+ if not random_init:
54
+ weights_path = os.path.join(os.path.dirname(__file__), "weights", "sss_weights.pt")
55
+ ckpt = torch.load(weights_path, map_location=torch.device(device), weights_only=False)
56
+ model.load_state_dict(ckpt)
57
+ return model.to(device)
58
+
59
+
60
+ def fetch_tap(random_init=False, device='cpu'):
61
+ from .model.seq2struct2seq import BERTCoords
62
+ from .model.tap_head import TAPHead
63
+
64
+ encoder = BERTCoords(device=device)
65
+ head = TAPHead()
66
+ if not random_init:
67
+ weights_path = os.path.join(os.path.dirname(__file__), "weights", "tap_weights.pt")
68
+ ckpt = torch.load(weights_path, map_location=torch.device(device), weights_only=False)
69
+ # encoder_state keys have a 'model.' prefix from the training wrapper
70
+ encoder_state = {k.removeprefix('model.'): v for k, v in ckpt['encoder_state'].items()}
71
+ encoder.load_state_dict(encoder_state, strict=False)
72
+ head.load_state_dict(ckpt['head_state'], strict=False)
73
+ head.tgt_mean.copy_(ckpt['tgt_mean'])
74
+ head.tgt_std.copy_(ckpt['tgt_std'])
75
+ return encoder.to(device), head.to(device)
@@ -0,0 +1,45 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .load_model import load_model
5
+ from .model.flash_abb import featurize, FlashABBResult
6
+
7
+
8
+ class pretrained:
9
+
10
+ def __init__(self, model_to_use="flash-abb", random_init=False, device='cuda'):
11
+ super().__init__()
12
+
13
+ self.used_device = torch.device(device)
14
+
15
+ self.flabb, self.hparams = load_model(model_to_use, random_init=random_init)
16
+ self.flabb.to(self.used_device)
17
+ self.flabb.eval() # Default
18
+ self.device = torch.device(device)
19
+
20
+ def freeze(self):
21
+ self.flabb.eval()
22
+
23
+ def unfreeze(self):
24
+ self.flabb.train()
25
+
26
+ def from_features(self, features, batch_size=50):
27
+ pred = self.flabb.model(
28
+ {'single': features['single']},
29
+ features['aatype'],
30
+ features['res_idx'],
31
+ features['mask']
32
+ )
33
+ result = FlashABBResult(seqs, pred, features['mask'])
34
+ return result
35
+
36
+ def __call__(self, seqs, batch_size=50):
37
+ features = featurize(seqs, self.device)
38
+ pred = self.flabb.model(
39
+ {'single': features['single']},
40
+ features['aatype'],
41
+ features['res_idx'],
42
+ features['mask']
43
+ )
44
+ result = FlashABBResult(seqs, pred, features['mask'])
45
+ return result
@@ -0,0 +1,167 @@
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from .load_model import fetch_sss, fetch_tap
5
+ from .model.tokenizer import ABtokenizer
6
+ from .model.flag_calibrator import FlagCalibrator
7
+
8
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ _CALIBRATOR_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'flag_calibrators.pkl')
11
+
12
+
13
+ class SSSResult:
14
+ """Result from FlashABB-SSS: per-residue structure-aware embeddings."""
15
+
16
+ def __init__(self, embeddings: torch.Tensor, mask: torch.Tensor):
17
+ self._embeddings = embeddings
18
+ self._mask = mask
19
+
20
+ @property
21
+ def embeddings(self) -> torch.Tensor:
22
+ """(batch, seq_len, emb_size) per-residue embeddings."""
23
+ return self._embeddings
24
+
25
+ @property
26
+ def mask(self) -> torch.Tensor:
27
+ """(batch, seq_len) bool mask — True where residue is present."""
28
+ return self._mask
29
+
30
+
31
+ class TAPResult:
32
+ """Result from FlashTAP: four antibody developability scores and flag probabilities."""
33
+
34
+ TAP_COLS = ['PSH', 'PPC', 'PNC', 'SFvCSP']
35
+
36
+ def __init__(self, tensor: torch.Tensor, flag_probs_array: np.ndarray | None = None):
37
+ self._tensor = tensor
38
+ self._flag_probs_array = flag_probs_array # (batch, 4) or None
39
+
40
+ @property
41
+ def tensor(self) -> torch.Tensor:
42
+ """(batch, 4) raw score tensor."""
43
+ return self._tensor
44
+
45
+ @property
46
+ def scores(self) -> list[dict]:
47
+ """List of dicts (one per antibody) mapping property name → float."""
48
+ return [
49
+ {col: self._tensor[i, j].item() for j, col in enumerate(self.TAP_COLS)}
50
+ for i in range(self._tensor.shape[0])
51
+ ]
52
+
53
+ @property
54
+ def flag_probs(self) -> list[dict] | None:
55
+ """List of dicts (one per antibody) mapping property name → P(flag).
56
+
57
+ Returns None if no calibrator was loaded.
58
+ """
59
+ if self._flag_probs_array is None:
60
+ return None
61
+ return [
62
+ {col: float(self._flag_probs_array[i, j]) for j, col in enumerate(self.TAP_COLS)}
63
+ for i in range(self._flag_probs_array.shape[0])
64
+ ]
65
+
66
+ @property
67
+ def any_flag_prob(self) -> list[float] | None:
68
+ """P(any flag) for each antibody, assuming property independence.
69
+
70
+ Returns None if no calibrator was loaded.
71
+ """
72
+ if self._flag_probs_array is None:
73
+ return None
74
+ any_flag = 1 - np.prod(1 - self._flag_probs_array, axis=1)
75
+ return any_flag.tolist()
76
+
77
+
78
+ def _tokenize(seqs, alphabet: ABtokenizer, device):
79
+ tokens = alphabet(seqs, pad=True, w_extra_tkns=False)
80
+ return tokens.to(device)
81
+
82
+
83
+ def _emb_and_mask(model, seqs, tokens, alphabet, device):
84
+ """Run BERTCoords forward and return (embeddings, mask) with sep removed."""
85
+ pad_mask = tokens.eq(alphabet.pad_token).to(device)
86
+ emb = model(seqs, tokens, pad_mask, return_emb=True)
87
+
88
+ sep_mask = tokens != alphabet.sep_token
89
+ src_shape = list(tokens.shape)
90
+ src_shape[1] -= 1
91
+ mask = (~pad_mask)[sep_mask].view(src_shape)
92
+ return emb, mask
93
+
94
+
95
+ class pretrained_sss:
96
+ """FlashABB-SSS: structure-aware antibody sequence encoder.
97
+
98
+ Usage::
99
+
100
+ from flash_abb import pretrained_sss
101
+ sss = pretrained_sss()
102
+ result = sss(['EVQL...|DIQL...'])
103
+ print(result.embeddings.shape) # (1, seq_len, 128)
104
+ """
105
+
106
+ def __init__(self, random_init: bool = False, device=DEVICE):
107
+ self.device = device
108
+ self.sss = fetch_sss(random_init=random_init, device=str(device))
109
+ self.sss.eval()
110
+ self.sss.requires_grad_(False)
111
+ self.alphabet = self.sss.alphabet
112
+
113
+ def __call__(self, seqs, batch_size: int = 50) -> SSSResult:
114
+ all_emb, all_mask = [], []
115
+ for i in range(0, len(seqs), batch_size):
116
+ batch = seqs[i:i + batch_size]
117
+ tokens = _tokenize(batch, self.alphabet, self.device)
118
+ with torch.no_grad():
119
+ emb, mask = _emb_and_mask(self.sss, batch, tokens, self.alphabet, self.device)
120
+ all_emb.append(emb)
121
+ all_mask.append(mask)
122
+ return SSSResult(torch.cat(all_emb), torch.cat(all_mask))
123
+
124
+
125
+ class pretrained_tap:
126
+ """FlashTAP: predicts four TAP developability scores from antibody sequences.
127
+
128
+ Scores: PSH (patches of surface hydrophobicity), PPC (positive patches),
129
+ PNC (negative patches), SFvCSP (structural Fv charge symmetry parameter).
130
+
131
+ Usage::
132
+
133
+ from flash_abb import pretrained_tap
134
+ tap = pretrained_tap()
135
+ result = tap(['EVQL...|DIQL...'])
136
+ print(result.scores) # [{'PSH': ..., 'PPC': ..., 'PNC': ..., 'SFvCSP': ...}]
137
+ print(result.flag_probs) # [{'PSH': 0.12, 'PPC': 0.03, 'PNC': 0.05, 'SFvCSP': 0.41}]
138
+ print(result.any_flag_prob) # [0.52]
139
+ """
140
+
141
+ def __init__(self, random_init: bool = False, device=DEVICE):
142
+ self.device = device
143
+ self.encoder, self.head = fetch_tap(random_init=random_init, device=str(device))
144
+ self.encoder.eval()
145
+ self.encoder.requires_grad_(False)
146
+ self.head.eval()
147
+ self.head.requires_grad_(False)
148
+ self.alphabet = self.encoder.alphabet
149
+ self.calibrator = FlagCalibrator.load(_CALIBRATOR_PATH) if not random_init else None
150
+
151
+ def __call__(self, seqs, batch_size: int = 50) -> TAPResult:
152
+ all_scores = []
153
+ for i in range(0, len(seqs), batch_size):
154
+ batch = seqs[i:i + batch_size]
155
+ tokens = _tokenize(batch, self.alphabet, self.device)
156
+ with torch.no_grad():
157
+ emb, mask = _emb_and_mask(self.encoder, batch, tokens, self.alphabet, self.device)
158
+ scores = self.head(emb, mask)
159
+ all_scores.append(scores)
160
+ score_tensor = torch.cat(all_scores)
161
+
162
+ flag_probs_array = None
163
+ if self.calibrator is not None:
164
+ scores_np = score_tensor.cpu().numpy()
165
+ flag_probs_array = self.calibrator.predict_proba(scores_np)
166
+
167
+ return TAPResult(score_tensor, flag_probs_array)
@@ -0,0 +1,105 @@
1
+ Metadata-Version: 2.4
2
+ Name: flash-abb
3
+ Version: 0.0.1
4
+ Summary: Flash-ABB: modelling antibody structures at the speed of language
5
+ Home-page: https://github.com/oxpig/FlashABB
6
+ Author: Isaac Ellmen
7
+ Maintainer: Isaac Ellmen
8
+ Maintainer-email: isaac.ellmen@stats.ox.ac.uk
9
+ License: BSD 3-clause license
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: torch>2
12
+ Requires-Dist: requests
13
+ Requires-Dist: einops
14
+ Requires-Dist: rotary-embedding-torch
15
+ Requires-Dist: ml_collections
16
+ Requires-Dist: numpy
17
+ Requires-Dist: dm-tree
18
+ Requires-Dist: pyyaml
19
+ Requires-Dist: scipy
20
+ Dynamic: author
21
+ Dynamic: description
22
+ Dynamic: description-content-type
23
+ Dynamic: home-page
24
+ Dynamic: license
25
+ Dynamic: maintainer
26
+ Dynamic: maintainer-email
27
+ Dynamic: requires-dist
28
+ Dynamic: summary
29
+
30
+ # FlashABB: modelling antibody structures at the speed of language
31
+
32
+ ![Inference speed comparison](figures/speedup_multiplier.png)
33
+
34
+ Installation:
35
+
36
+ PyPi coming soon
37
+
38
+ ```bash
39
+ git clone git@github.com:oxpig/FlashABB.git
40
+ cd FlashABB
41
+ pip install .
42
+ ```
43
+
44
+ ## Structure prediction
45
+
46
+ The following is also in `example.py` and can be used to create the structures in `sample_preds`.
47
+
48
+ ```python
49
+ from flash_abb import pretrained
50
+ import torch
51
+
52
+ flabb = pretrained(device='cuda')
53
+
54
+ # Sequences in heavy|light format
55
+ seqs = [
56
+ 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS|DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK',
57
+ 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVS|DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK',
58
+ ]
59
+
60
+ with torch.no_grad():
61
+ result = flabb(seqs)
62
+
63
+ print(result.coords.shape) # (2, n_residues, 37, 3)
64
+ print(result.bb_coords.shape) # (2, n_residues, 4, 3)
65
+
66
+ result.to_pdbs(['ab1', 'ab2'], pdb_dir='sample_preds')
67
+ ```
68
+
69
+ ## Developability scoring (FlashTAP)
70
+
71
+ FlashTAP predicts four [TAP](https://doi.org/10.1038/s42003-023-05744-8) developability scores: PSH, PPC, PNC, and SFvCSP.
72
+
73
+ ```python
74
+ from flash_abb import pretrained_tap
75
+
76
+ tap = pretrained_tap(device='cuda')
77
+
78
+ seqs = [
79
+ 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS|DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK',
80
+ ]
81
+
82
+ result = tap(seqs)
83
+ print(result.scores) # [{'PSH': ..., 'PPC': ..., 'PNC': ..., 'SFvCSP': ...}]
84
+ print(result.tensor) # (1, 4) raw score tensor
85
+ print(result.flag_probs) # [{'PSH': 0.12, 'PPC': 0.03, 'PNC': 0.05, 'SFvCSP': 0.41}]
86
+ print(result.any_flag_prob) # [0.47]
87
+ ```
88
+
89
+ ## Structure-aware embeddings (FlashABB-SSS)
90
+
91
+ FlashABB-SSS (seq2struct2seq) produces per-residue embeddings that combine sequence and predicted 3D structure. These can be used as features for downstream tasks.
92
+
93
+ ```python
94
+ from flash_abb import pretrained_sss
95
+
96
+ sss = pretrained_sss(device='cuda')
97
+
98
+ seqs = [
99
+ 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS|DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK',
100
+ ]
101
+
102
+ result = sss(seqs)
103
+ print(result.embeddings.shape) # (1, n_residues, 128)
104
+ print(result.mask.shape) # (1, n_residues)
105
+ ```
@@ -0,0 +1,8 @@
1
+ flash_abb/__init__.py,sha256=6NnMNvsu4FTcWyBtklZvyStkn3HYxxueSt_7LS4faDI,94
2
+ flash_abb/load_model.py,sha256=RaY6J1hfs3Q9N1Iy6PN-klMg0yei_NSRwrTq3gO35I0,2626
3
+ flash_abb/pretrained.py,sha256=zPgDi2Kk1u8zpKk6ywJxL12TgWLbPJTbge--8Wg8iQw,1297
4
+ flash_abb/pretrained_tap.py,sha256=ZhXKDBPhP2IvLM8W2q6xZ-NKYds10E-BNt2P_vaH29c,5983
5
+ flash_abb-0.0.1.dist-info/METADATA,sha256=6RxPsHdWYW8TvUeAnN97tjlsyr5E6oG9o0Iw1Wg21TE,3441
6
+ flash_abb-0.0.1.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
7
+ flash_abb-0.0.1.dist-info/top_level.txt,sha256=cXv5m3fquiDHdQpJpR811w7tPuTRUn4WGipZG0cFxkw,10
8
+ flash_abb-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ flash_abb