RP3Net 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RP3Net/__init__.py +8 -0
- RP3Net/fm_cfg/esm2_650m/config.json +29 -0
- RP3Net/fm_cfg/esm2_650m/special_tokens_map.json +7 -0
- RP3Net/fm_cfg/esm2_650m/tokenizer_config.json +4 -0
- RP3Net/fm_cfg/esm2_650m/vocab.txt +33 -0
- RP3Net/model/__init__.py +1 -0
- RP3Net/model/layers.py +171 -0
- RP3Net/model/model.py +233 -0
- RP3Net/rp3_main.py +85 -0
- RP3Net/rp3_train.py +18 -0
- RP3Net/training/__init__.py +6 -0
- RP3Net/training/cli.py +166 -0
- RP3Net/training/data.py +300 -0
- RP3Net/training/data_emlc.py +94 -0
- RP3Net/training/lm.py +123 -0
- RP3Net/training/lm_emlc.py +400 -0
- RP3Net/training/metrics.py +357 -0
- RP3Net/util/__init__.py +3 -0
- RP3Net/util/fasta.py +26 -0
- RP3Net/util/torch.py +89 -0
- RP3Net/util/util.py +65 -0
- rp3net-0.0.1.dist-info/METADATA +77 -0
- rp3net-0.0.1.dist-info/RECORD +27 -0
- rp3net-0.0.1.dist-info/WHEEL +5 -0
- rp3net-0.0.1.dist-info/entry_points.txt +3 -0
- rp3net-0.0.1.dist-info/licenses/LICENSE +21 -0
- rp3net-0.0.1.dist-info/top_level.txt +1 -0
RP3Net/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .model import RP3Net, RP3_DEFAULT_CONFIG, RP3_CONFIG_B, load_model
|
|
2
|
+
from .rp3_main import rp3_main
|
|
3
|
+
import importlib
|
|
4
|
+
if importlib.util.find_spec('lightning') is not None:
|
|
5
|
+
from .rp3_train import rp3_train
|
|
6
|
+
else:
|
|
7
|
+
def rp3_train():
|
|
8
|
+
raise ImportError("Please install 'RP3Net[training]' to enable training")
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
{
|
|
2
|
+
"architectures": [
|
|
3
|
+
"EsmForMaskedLM"
|
|
4
|
+
],
|
|
5
|
+
"attention_probs_dropout_prob": 0.0,
|
|
6
|
+
"classifier_dropout": null,
|
|
7
|
+
"emb_layer_norm_before": false,
|
|
8
|
+
"esmfold_config": null,
|
|
9
|
+
"hidden_act": "gelu",
|
|
10
|
+
"hidden_dropout_prob": 0.0,
|
|
11
|
+
"hidden_size": 1280,
|
|
12
|
+
"initializer_range": 0.02,
|
|
13
|
+
"intermediate_size": 5120,
|
|
14
|
+
"is_folding_model": false,
|
|
15
|
+
"layer_norm_eps": 1e-05,
|
|
16
|
+
"mask_token_id": 32,
|
|
17
|
+
"max_position_embeddings": 1026,
|
|
18
|
+
"model_type": "esm",
|
|
19
|
+
"num_attention_heads": 20,
|
|
20
|
+
"num_hidden_layers": 33,
|
|
21
|
+
"pad_token_id": 1,
|
|
22
|
+
"position_embedding_type": "rotary",
|
|
23
|
+
"token_dropout": true,
|
|
24
|
+
"torch_dtype": "float32",
|
|
25
|
+
"transformers_version": "4.25.0.dev0",
|
|
26
|
+
"use_cache": true,
|
|
27
|
+
"vocab_list": null,
|
|
28
|
+
"vocab_size": 33
|
|
29
|
+
}
|
RP3Net/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model import RP3Net, load_model, Mode_Training, RP3_DEFAULT_CONFIG, RP3_CONFIG_B
|
RP3Net/model/layers.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.nn.attention as attn
|
|
5
|
+
from torch.nn import init
|
|
6
|
+
import math
|
|
7
|
+
# from lightning.pytorch.cli import instantiate_class
|
|
8
|
+
from typing import Optional, Mapping, Sequence, Union
|
|
9
|
+
# from .util import get_rank, safe_gather
|
|
10
|
+
|
|
11
|
+
TNonLin = Optional[Union[str,Mapping]]
|
|
12
|
+
|
|
13
|
+
def apply_p_drop(m:nn.Module, p:float):
|
|
14
|
+
def set_p_drop(m:nn.Module):
|
|
15
|
+
if isinstance(m, nn.Dropout):
|
|
16
|
+
m.p = p
|
|
17
|
+
m.apply(set_p_drop)
|
|
18
|
+
|
|
19
|
+
class StackedLinear(torch.nn.Module):
|
|
20
|
+
def __init__(self, d_in:int, d_out:int, n_stack:int, bias:bool=True,
|
|
21
|
+
device=None, dtype=None, transpose=False) -> None:
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.d_in = d_in
|
|
24
|
+
self.d_out = d_out
|
|
25
|
+
self.n_stack = n_stack
|
|
26
|
+
self.transpose = transpose
|
|
27
|
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
28
|
+
self.weight = nn.Parameter(torch.empty((n_stack, d_in, d_out), **factory_kwargs))
|
|
29
|
+
if bias:
|
|
30
|
+
self.bias = nn.Parameter(torch.empty(n_stack, 1, d_out, **factory_kwargs))
|
|
31
|
+
else:
|
|
32
|
+
self.register_parameter('bias', None)
|
|
33
|
+
self.reset_parameters()
|
|
34
|
+
|
|
35
|
+
def reset_parameters(self) -> None:
|
|
36
|
+
# Lifted from torch.nn.Linear
|
|
37
|
+
# NB: this is important; you could get really weird results (loss off the scale) if you get this wrong
|
|
38
|
+
for i in range(self.n_stack):
|
|
39
|
+
init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
|
|
40
|
+
if self.bias is not None:
|
|
41
|
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
42
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
43
|
+
for i in range(self.n_stack):
|
|
44
|
+
init.uniform_(self.bias[i], -bound, bound)
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
out = x @ self.weight
|
|
48
|
+
if self.bias is not None:
|
|
49
|
+
out += self.bias
|
|
50
|
+
if self.transpose:
|
|
51
|
+
out = out.transpose(0,1)
|
|
52
|
+
return out
|
|
53
|
+
|
|
54
|
+
def extra_repr(self) -> str:
|
|
55
|
+
return 'd_in={}, d_out={}, n_stack={}, bias={}'.format(
|
|
56
|
+
self.d_in, self.d_out, self.n_stack, self.bias is not None
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ClsHeadBlock(nn.Module):
|
|
61
|
+
def __init__(self, in_dim:int, out_dim:int, n_stack:Optional[int], layer_norm: bool, p_drop:Optional[float],
|
|
62
|
+
nonlinearity:TNonLin, bias:bool) -> None:
|
|
63
|
+
super().__init__()
|
|
64
|
+
layer = StackedLinear(in_dim, out_dim, n_stack, bias=bias) if n_stack is not None else nn.Linear(in_dim, out_dim, bias=bias)
|
|
65
|
+
self.layers = nn.Sequential(layer)
|
|
66
|
+
if layer_norm:
|
|
67
|
+
self.layers.append(nn.LayerNorm(out_dim))
|
|
68
|
+
if p_drop is not None and p_drop != 0:
|
|
69
|
+
self.layers.append(nn.Dropout(p_drop))
|
|
70
|
+
if nonlinearity is not None:
|
|
71
|
+
nonlin = getattr(nn, nonlinearity)()
|
|
72
|
+
self.layers.append(nonlin)
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
return self.layers(x)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class StackedClsHead(nn.Module):
|
|
79
|
+
|
|
80
|
+
def __init__(self, embedding_dim:int, n_stack:int, layers:Optional[Mapping[str, int]]=None,
|
|
81
|
+
layer_norm:bool=False, p_drop:Optional[float]=None,
|
|
82
|
+
nonlinearity:TNonLin=None, bias:bool=False, end_bias:bool=False) -> None:
|
|
83
|
+
super().__init__()
|
|
84
|
+
layer_dims = []
|
|
85
|
+
if layers:
|
|
86
|
+
layer_dims = [round(layers['d'])] * round(layers['n'])
|
|
87
|
+
layer_dims = [embedding_dim] + layer_dims
|
|
88
|
+
self.layers = nn.Sequential(*[ClsHeadBlock(layer_dims[i], layer_dims[i+1], n_stack,
|
|
89
|
+
layer_norm, p_drop, nonlinearity, bias)
|
|
90
|
+
for i in range(len(layer_dims) - 1)])
|
|
91
|
+
self.layers.append(StackedLinear(layer_dims[-1], 2, n_stack, bias=end_bias, transpose=True)) # todo p_drop
|
|
92
|
+
|
|
93
|
+
def forward(self, x):
|
|
94
|
+
return self.layers(x)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ClsHead(nn.Module):
|
|
98
|
+
|
|
99
|
+
def __init__(self, embedding_dim:int, layers:Optional[Mapping[str, int]]=None,
|
|
100
|
+
layer_norm:bool=False, p_drop:Optional[float]=None,
|
|
101
|
+
nonlinearity:TNonLin=None, bias:bool=False, end_bias:bool=False,
|
|
102
|
+
n_stack:Optional[int]=None, out_dim:int=2) -> None:
|
|
103
|
+
super().__init__()
|
|
104
|
+
layer_dims = []
|
|
105
|
+
if layers:
|
|
106
|
+
layer_dims = [round(layers['d'])] * round(layers['n'])
|
|
107
|
+
layer_dims = [embedding_dim] + layer_dims
|
|
108
|
+
self.layers = nn.Sequential(*[ClsHeadBlock(layer_dims[i], layer_dims[i+1], n_stack,
|
|
109
|
+
layer_norm, p_drop, nonlinearity, bias)
|
|
110
|
+
for i in range(len(layer_dims) - 1)])
|
|
111
|
+
self.layers.append(ClsHeadBlock(layer_dims[-1], out_dim, n_stack=None,
|
|
112
|
+
layer_norm=layer_norm, p_drop=p_drop, nonlinearity=None, bias=end_bias))
|
|
113
|
+
|
|
114
|
+
def forward(self, x):
|
|
115
|
+
return self.layers(x)
|
|
116
|
+
|
|
117
|
+
class SetTransformerPooling(nn.Module):
|
|
118
|
+
def __init__(self, d:int, num_heads:int, seq_dim:int|None=None, num_seeds:int=1, layer_norm:bool=False, p_drop:float=.0, keep_dim:bool=False,
|
|
119
|
+
can_use_efficient:bool=True) -> None:
|
|
120
|
+
"""
|
|
121
|
+
[Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. Lee et al, ICLM 2019](https://arxiv.org/pdf/1810.00825v3.pdf)
|
|
122
|
+
Section 3.2
|
|
123
|
+
"""
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.seed = nn.Parameter(torch.empty(1, num_seeds, d))
|
|
126
|
+
nn.init.xavier_uniform_(self.seed)
|
|
127
|
+
if seq_dim is None:
|
|
128
|
+
seq_dim = d
|
|
129
|
+
self.mha = nn.MultiheadAttention(d, num_heads, dropout=p_drop, batch_first=True, kdim=seq_dim, vdim=seq_dim)
|
|
130
|
+
self.layer_norm = nn.LayerNorm(d) if layer_norm else None
|
|
131
|
+
self.keep_dim = keep_dim
|
|
132
|
+
self.attn_flags = [attn.SDPBackend.FLASH_ATTENTION, attn.SDPBackend.MATH, attn.SDPBackend.CUDNN_ATTENTION]
|
|
133
|
+
if can_use_efficient:
|
|
134
|
+
self.attn_flags.append(attn.SDPBackend.EFFICIENT_ATTENTION)
|
|
135
|
+
|
|
136
|
+
def forward(self, x:torch.Tensor, mask:Optional[torch.Tensor]=None, **kwargs) -> torch.Tensor:
|
|
137
|
+
if mask is not None:
|
|
138
|
+
mask = ~(mask.to(dtype=bool))
|
|
139
|
+
seed = self.seed.expand(x.shape[0], -1, -1)
|
|
140
|
+
# Required for EMLC to work.
|
|
141
|
+
# Forward mode differentiation is not implemented for memory efficient attention,
|
|
142
|
+
# so need to disable optimization here. NB: this might change in future version of PyTorch
|
|
143
|
+
with attn.sdpa_kernel(self.attn_flags):
|
|
144
|
+
out, _ = self.mha(seed, x, x, key_padding_mask=mask, need_weights=False)
|
|
145
|
+
if self.layer_norm is not None:
|
|
146
|
+
out = self.layer_norm(out)
|
|
147
|
+
if not self.keep_dim:
|
|
148
|
+
out = out.flatten(start_dim=1)
|
|
149
|
+
return out
|
|
150
|
+
|
|
151
|
+
class MeanPooling(nn.Module):
|
|
152
|
+
def __init__(self):
|
|
153
|
+
super().__init__()
|
|
154
|
+
|
|
155
|
+
def forward(self, x:torch.Tensor, *, mask:torch.Tensor, **kwargs) -> torch.Tensor:
|
|
156
|
+
mask = mask.to(dtype=torch.bool)
|
|
157
|
+
mask = mask.unsqueeze(-1)
|
|
158
|
+
batch_seq_len = mask.sum(dim=1)
|
|
159
|
+
x_sum = x.masked_fill(~mask, 0).sum(dim=1)
|
|
160
|
+
return x_sum / batch_seq_len
|
|
161
|
+
|
|
162
|
+
class MaxPooling(nn.Module):
|
|
163
|
+
def __init__(self):
|
|
164
|
+
super().__init__()
|
|
165
|
+
|
|
166
|
+
def forward(self, x:torch.Tensor, *, mask:torch.Tensor, **kwargs) -> torch.Tensor:
|
|
167
|
+
mask = mask.to(dtype=torch.bool)
|
|
168
|
+
mask = mask.unsqueeze(-1)
|
|
169
|
+
x = x.masked_fill(~mask, float('-inf'))
|
|
170
|
+
ret, _ = x.max(dim=1)
|
|
171
|
+
return ret
|
RP3Net/model/model.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import re
|
|
3
|
+
import os
|
|
4
|
+
import enum
|
|
5
|
+
import typing
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import ml_collections as mlc
|
|
9
|
+
import transformers as hub
|
|
10
|
+
import pathlib
|
|
11
|
+
import peft
|
|
12
|
+
|
|
13
|
+
from . import layers
|
|
14
|
+
from .. import util
|
|
15
|
+
|
|
16
|
+
log = util.get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
def resolve(filename) -> pathlib.Path:
|
|
19
|
+
if type(filename) == str:
|
|
20
|
+
filename = pathlib.Path(filename)
|
|
21
|
+
if str(filename).startswith('~'):
|
|
22
|
+
filename = filename.expanduser()
|
|
23
|
+
return filename.resolve()
|
|
24
|
+
|
|
25
|
+
class Mode(enum.Flag):
|
|
26
|
+
Inference = enum.auto()
|
|
27
|
+
Training_A = enum.auto()
|
|
28
|
+
Training_B = enum.auto()
|
|
29
|
+
Training_C = enum.auto()
|
|
30
|
+
Training_D = enum.auto()
|
|
31
|
+
Mode_Training = Mode.Training_A | Mode.Training_B | Mode.Training_C | Mode.Training_D
|
|
32
|
+
Mode_Training_FM = Mode.Training_C | Mode.Training_D
|
|
33
|
+
Mode_FM = Mode_Training_FM | Mode.Inference
|
|
34
|
+
Mode_Training_Aggregation = Mode.Training_B | Mode.Training_C | Mode.Training_D
|
|
35
|
+
Mode_Aggregation = Mode_Training_Aggregation | Mode.Inference
|
|
36
|
+
|
|
37
|
+
class RP3Net(nn.Module):
|
|
38
|
+
def __init__(self, cfg:mlc.ConfigDict) -> None:
|
|
39
|
+
super(RP3Net, self).__init__()
|
|
40
|
+
self.fm = None
|
|
41
|
+
self.mode = Mode[cfg.get('mode', 'Inference')]
|
|
42
|
+
if self.mode in Mode_FM:
|
|
43
|
+
self._init_fm(cfg, self.mode)
|
|
44
|
+
assert self.fm is not None, "Model must be initialized"
|
|
45
|
+
if self.mode in Mode_Aggregation:
|
|
46
|
+
if cfg.aggregation == 'mean':
|
|
47
|
+
self.pooling = layers.MeanPooling()
|
|
48
|
+
elif cfg.aggregation == 'max':
|
|
49
|
+
self.pooling = layers.MaxPooling()
|
|
50
|
+
elif cfg.aggregation == 'stp':
|
|
51
|
+
self.pooling = layers.SetTransformerPooling(**cfg.stp)
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Aggregation type {cfg.aggregation} not supported")
|
|
54
|
+
self.cls_head = layers.ClsHead(**cfg.classification_head)
|
|
55
|
+
|
|
56
|
+
def forward(self, batch, return_repr=False):
|
|
57
|
+
if self.mode == Mode.Training_A:
|
|
58
|
+
logits = self.cls_head(batch['embeddings'])
|
|
59
|
+
elif self.mode == Mode.Training_B:
|
|
60
|
+
global_repr = self.pooling(batch['embeddings'], mask=batch['attention_mask'])
|
|
61
|
+
logits = self.cls_head(global_repr)
|
|
62
|
+
else: # Inference, Training_CD
|
|
63
|
+
seq_repr = self.sequence_representation(batch['seq'])
|
|
64
|
+
mask = self.attention_mask(batch['seq'])
|
|
65
|
+
global_repr = self.pooling(seq_repr, mask=mask)
|
|
66
|
+
logits = self.cls_head(global_repr)
|
|
67
|
+
if return_repr:
|
|
68
|
+
return logits, global_repr
|
|
69
|
+
return logits
|
|
70
|
+
|
|
71
|
+
# def train(self, train_mode:bool=True):
|
|
72
|
+
# if not train_mode or self.mode == Mode.Training_CD:
|
|
73
|
+
# super().train(train_mode)
|
|
74
|
+
# else:
|
|
75
|
+
# if self.mode == Mode.Inference:
|
|
76
|
+
# raise ValueError("Model is in inference mode")
|
|
77
|
+
# elif self.mode == Mode.Training_A:
|
|
78
|
+
# self.fm.train(False)
|
|
79
|
+
# self.pooling.train(False)
|
|
80
|
+
# self.cls_head.train(True)
|
|
81
|
+
# elif self.mode == Mode.Training_B:
|
|
82
|
+
# self.fm.train(False)
|
|
83
|
+
# self.pooling.train(True)
|
|
84
|
+
# self.cls_head.train(True)
|
|
85
|
+
|
|
86
|
+
@abc.abstractmethod
|
|
87
|
+
def _init_fm(self, cfg:mlc.ConfigDict, mode:Mode):
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@abc.abstractmethod
|
|
91
|
+
def tokenize_sequences(self, sequences:typing.Sequence[str]):
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@abc.abstractmethod
|
|
95
|
+
def sequence_representation(self, batch):
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
@abc.abstractmethod
|
|
99
|
+
def attention_mask(self, batch):
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
@torch.no_grad()
|
|
103
|
+
def predict(self, sequences:typing.Sequence[str]|typing.Mapping[str,str], device=None):
|
|
104
|
+
is_mapping=False
|
|
105
|
+
if isinstance(sequences, typing.Mapping):
|
|
106
|
+
is_mapping = True
|
|
107
|
+
keys = list(sequences.keys())
|
|
108
|
+
sequences = [sequences[k] for k in keys]
|
|
109
|
+
seq_batch = self.tokenize_sequences(sequences)
|
|
110
|
+
if device:
|
|
111
|
+
seq_batch = seq_batch.to(device)
|
|
112
|
+
batch = {'seq': seq_batch}
|
|
113
|
+
logits = self(batch)
|
|
114
|
+
logits_norm = torch.softmax(logits, dim=-1)
|
|
115
|
+
logits_norm = logits_norm[:,1].cpu()
|
|
116
|
+
if is_mapping:
|
|
117
|
+
return {k:logits_norm[i].item() for i,k in enumerate(keys)}
|
|
118
|
+
else:
|
|
119
|
+
return logits_norm
|
|
120
|
+
|
|
121
|
+
class RP3Esm2(RP3Net):
|
|
122
|
+
def _init_esm2(self, cfg:mlc.ConfigDict, mode:Mode, cfg_path:str|os.PathLike) -> None:
|
|
123
|
+
esm_cfg = hub.EsmConfig.from_pretrained(cfg_path, local_files_only=True)
|
|
124
|
+
if (checkpoint_file := cfg.get('fm.cp')):
|
|
125
|
+
log.info(f"Loading pre-trained FM from checkpoint {checkpoint_file}")
|
|
126
|
+
state_dict = torch.load(util.resolve(checkpoint_file), map_location='cpu', weights_only=True)
|
|
127
|
+
self.fm = hub.EsmModel.from_pretrained(
|
|
128
|
+
None,
|
|
129
|
+
config=esm_cfg,
|
|
130
|
+
state_dict=state_dict,
|
|
131
|
+
local_files_only=True,
|
|
132
|
+
add_pooling_layer=False
|
|
133
|
+
)
|
|
134
|
+
else:
|
|
135
|
+
log.info(f"Loading random model")
|
|
136
|
+
self.fm = hub.EsmModel(esm_cfg, add_pooling_layer=False)
|
|
137
|
+
self.tokenizer = hub.EsmTokenizer.from_pretrained(cfg_path, do_lower_case=False)
|
|
138
|
+
self.re_aa_x = re.compile(r"[UZOB]")
|
|
139
|
+
lora_config=cfg.get('fm.lora')
|
|
140
|
+
if lora_config:
|
|
141
|
+
lora_config = peft.LoraConfig(**lora_config, inference_mode=(mode == Mode.Inference))
|
|
142
|
+
self.fm = peft.get_peft_model(self.fm, lora_config)
|
|
143
|
+
|
|
144
|
+
def tokenize_sequences(self, seqs:typing.Sequence[str]):
|
|
145
|
+
seqs = [self.re_aa_x.sub('X', s) for s in seqs]
|
|
146
|
+
return self.tokenizer(seqs, padding=True, return_tensors='pt')
|
|
147
|
+
|
|
148
|
+
def sequence_representation(self, batch):
|
|
149
|
+
return self.fm(**batch).last_hidden_state
|
|
150
|
+
|
|
151
|
+
def attention_mask(self, batch):
|
|
152
|
+
return batch['attention_mask']
|
|
153
|
+
|
|
154
|
+
def train(self, train_mode:bool=True):
|
|
155
|
+
super().train(train_mode)
|
|
156
|
+
if train_mode and self.mode in Mode_Training_FM:
|
|
157
|
+
self.fm.embeddings.position_embeddings.requires_grad_(False)
|
|
158
|
+
self.fm.contact_head.requires_grad_(False)
|
|
159
|
+
return self
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class RP3Esm2_650m(RP3Esm2):
|
|
163
|
+
def _init_fm(self, cfg:mlc.ConfigDict, mode:Mode):
|
|
164
|
+
cfg_path = pathlib.Path(__file__).resolve().parent.parent /'fm_cfg'/'esm2_650m'
|
|
165
|
+
self._init_esm2(cfg, mode, cfg_path)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
RP3_DEFAULT_CONFIG = mlc.FrozenConfigDict({
|
|
169
|
+
'fm': {
|
|
170
|
+
'type': 'esm2_650m',
|
|
171
|
+
'lora': {
|
|
172
|
+
'r': 8,
|
|
173
|
+
'lora_alpha': 1.0,
|
|
174
|
+
'target_modules': ['query', 'key', 'value'],
|
|
175
|
+
'lora_dropout': 0.1,
|
|
176
|
+
'bias': 'lora_only'
|
|
177
|
+
},
|
|
178
|
+
},
|
|
179
|
+
'aggregation': 'stp',
|
|
180
|
+
'stp': {
|
|
181
|
+
'seq_dim': 1280,
|
|
182
|
+
'd': 256,
|
|
183
|
+
'num_heads': 8,
|
|
184
|
+
'layer_norm': True,
|
|
185
|
+
},
|
|
186
|
+
'classification_head': {
|
|
187
|
+
'embedding_dim': 256,
|
|
188
|
+
'bias': False,
|
|
189
|
+
'end_bias': True,
|
|
190
|
+
'layer_norm': False,
|
|
191
|
+
'layers': {
|
|
192
|
+
'd': 256,
|
|
193
|
+
'n': 1
|
|
194
|
+
},
|
|
195
|
+
'nonlinearity': 'SiLU',
|
|
196
|
+
}
|
|
197
|
+
})
|
|
198
|
+
|
|
199
|
+
RP3_CONFIG_B = mlc.FrozenConfigDict({
|
|
200
|
+
'fm':{'type': 'esm2_650m'},
|
|
201
|
+
'aggregation': 'stp',
|
|
202
|
+
'stp': {
|
|
203
|
+
'seq_dim': 1280,
|
|
204
|
+
'd': 128,
|
|
205
|
+
'num_heads': 8,
|
|
206
|
+
'layer_norm': True,
|
|
207
|
+
},
|
|
208
|
+
'classification_head': {
|
|
209
|
+
'embedding_dim': 128,
|
|
210
|
+
'bias': False,
|
|
211
|
+
'end_bias': True,
|
|
212
|
+
'layer_norm': False,
|
|
213
|
+
'layers': {
|
|
214
|
+
'd': 512,
|
|
215
|
+
'n': 1
|
|
216
|
+
},
|
|
217
|
+
'nonlinearity': 'SiLU',
|
|
218
|
+
}
|
|
219
|
+
})
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def load_model(cfg:mlc.ConfigDict, cp_path:str|os.PathLike|None=None) -> RP3Net:
|
|
223
|
+
model_type = cfg.fm.type
|
|
224
|
+
mode = Mode[cfg.get('mode', 'Inference')]
|
|
225
|
+
if model_type == 'esm2_650m':
|
|
226
|
+
model = RP3Esm2_650m(cfg)
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError(f"Model {model_type} not supported")
|
|
229
|
+
if cp_path:
|
|
230
|
+
cp = torch.load(cp_path, map_location='cpu', weights_only=True)
|
|
231
|
+
model.load_state_dict(cp, strict=True)
|
|
232
|
+
return model
|
|
233
|
+
|
RP3Net/rp3_main.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import yaml
|
|
4
|
+
import argparse
|
|
5
|
+
import logging
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import ml_collections as mlc
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
import RP3Net.util as util
|
|
12
|
+
import RP3Net.model as model
|
|
13
|
+
|
|
14
|
+
log = util.get_logger(__file__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_args(_args: Optional[list] = None):
|
|
18
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="""
|
|
19
|
+
Predict protein expression in E.coli from sequences. Takes in a fasta file with protein sequences and ids.
|
|
20
|
+
Outputs a .csv[.gz] file with "id,score" columns.
|
|
21
|
+
The score is the predicted probability of expression.
|
|
22
|
+
""")
|
|
23
|
+
parser.add_argument("--log_file", help="Log file. Log output to console if set to None.")
|
|
24
|
+
parser.add_argument("--log_level", default="info", help="Log level of root logger. Appender levels are appropriately hard coded.")
|
|
25
|
+
parser.add_argument("-c", "--config", help="Model configuration in yaml, or registry entry.", default="RP3_DEFAULT_CONFIG")
|
|
26
|
+
parser.add_argument("-p", "--checkpoint", required=True, help="Model checkpoint.")
|
|
27
|
+
parser.add_argument("-f", "--fasta", required=True, help="Fasta file, possibly gzipped.")
|
|
28
|
+
parser.add_argument("-d", "--device", default='cpu', help="Device to use: cpu, cuda[:n], mps, ...")
|
|
29
|
+
parser.add_argument("-b", "--batch_size", default=8, type=int, help="Batch size. Memory consumption depends on construct sequence length and batch size. If the model runs out of SLURM/GPU memory, try reducing this parameter to make it fit.")
|
|
30
|
+
parser.add_argument("-o", "--out_file", required=True, help="Output .csv[.gz] file.")
|
|
31
|
+
parser.add_argument("--progress", action=argparse.BooleanOptionalAction, default=True, help="Show progress bar.")
|
|
32
|
+
return parser.parse_args(_args)
|
|
33
|
+
|
|
34
|
+
MODEL_REGISTRY = {
|
|
35
|
+
'RP3_DEFAULT_CONFIG': model.RP3_DEFAULT_CONFIG,
|
|
36
|
+
'RP3_CONFIG_B': model.RP3_CONFIG_B,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
def rp3_main():
|
|
40
|
+
args = setup_args()
|
|
41
|
+
util.setup_logging(save_path=args.log_file, level=args.log_level, log_console=args.log_file is None)
|
|
42
|
+
|
|
43
|
+
seq_map = util.read_fasta(args.fasta)
|
|
44
|
+
seq_lens = np.array(list(map(len, seq_map.values())))
|
|
45
|
+
log.info(f"Read {len(seq_map)} sequences from {args.fasta}. Sequence lengths (mean/std/median/min/max): "
|
|
46
|
+
f"{seq_lens.mean():.2f}/{seq_lens.std():.2f}/{np.median(seq_lens):.0f}/{seq_lens.min()}/{seq_lens.max()}")
|
|
47
|
+
|
|
48
|
+
config_path = util.resolve(args.config)
|
|
49
|
+
if config_path.exists():
|
|
50
|
+
with open(config_path, "r") as f:
|
|
51
|
+
config = mlc.FrozenConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
|
52
|
+
elif args.config in MODEL_REGISTRY:
|
|
53
|
+
config = MODEL_REGISTRY[args.config]
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Config file {args.config} not found.")
|
|
56
|
+
log.info(f"Loading model {args.config} from checkpoint {args.checkpoint}")
|
|
57
|
+
m = model.load_model(config, args.checkpoint)
|
|
58
|
+
m = m.to(device=args.device)
|
|
59
|
+
m = m.eval()
|
|
60
|
+
log.info(f"Loaded model {m}")
|
|
61
|
+
|
|
62
|
+
seq_keys = list(seq_map.keys())
|
|
63
|
+
def batches():
|
|
64
|
+
if args.progress:
|
|
65
|
+
tqdm_desc = args.fasta.replace('.fasta', '').replace('.gz', '')[-20:]
|
|
66
|
+
r = tqdm(range(0, len(seq_keys), args.batch_size), desc=tqdm_desc)
|
|
67
|
+
else:
|
|
68
|
+
r = range(0, len(seq_keys), args.batch_size)
|
|
69
|
+
for i in r:
|
|
70
|
+
if not args.progress:
|
|
71
|
+
log.info(f"Processing batch {i // args.batch_size + 1} of {len(seq_keys) // args.batch_size + 1}.")
|
|
72
|
+
yield {k: seq_map[k] for k in seq_keys[i:i + args.batch_size]}
|
|
73
|
+
|
|
74
|
+
ret = {}
|
|
75
|
+
for b in batches():
|
|
76
|
+
ret.update(m.predict(b, device=args.device))
|
|
77
|
+
keys, scores = zip(*ret.items())
|
|
78
|
+
out_df = pd.DataFrame({'id': keys, 'score': scores})
|
|
79
|
+
log.info(f"Writing {out_df.shape[0]} rows to {args.out_file}")
|
|
80
|
+
out_df.to_csv(args.out_file, index=False)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
rp3_main()
|
|
85
|
+
|
RP3Net/rp3_train.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
from RP3Net.training.cli import RP3Cli
|
|
3
|
+
from RP3Net.util import util
|
|
4
|
+
|
|
5
|
+
log = util.get_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def rp3_train():
|
|
9
|
+
try:
|
|
10
|
+
assert importlib.util.find_spec('lightning') is not None, "Please install 'RP3Net[training]' to enable training"
|
|
11
|
+
cli = RP3Cli()
|
|
12
|
+
except Exception as e:
|
|
13
|
+
log.error("Top level catch", exc_info=e)
|
|
14
|
+
raise e
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if __name__ == "__main__":
|
|
18
|
+
rp3_train()
|