RP3Net 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
RP3Net/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ from .model import RP3Net, RP3_DEFAULT_CONFIG, RP3_CONFIG_B, load_model
2
+ from .rp3_main import rp3_main
3
+ import importlib
4
+ if importlib.util.find_spec('lightning') is not None:
5
+ from .rp3_train import rp3_train
6
+ else:
7
+ def rp3_train():
8
+ raise ImportError("Please install 'RP3Net[training]' to enable training")
@@ -0,0 +1,29 @@
1
+ {
2
+ "architectures": [
3
+ "EsmForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "classifier_dropout": null,
7
+ "emb_layer_norm_before": false,
8
+ "esmfold_config": null,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.0,
11
+ "hidden_size": 1280,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "is_folding_model": false,
15
+ "layer_norm_eps": 1e-05,
16
+ "mask_token_id": 32,
17
+ "max_position_embeddings": 1026,
18
+ "model_type": "esm",
19
+ "num_attention_heads": 20,
20
+ "num_hidden_layers": 33,
21
+ "pad_token_id": 1,
22
+ "position_embedding_type": "rotary",
23
+ "token_dropout": true,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.25.0.dev0",
26
+ "use_cache": true,
27
+ "vocab_list": null,
28
+ "vocab_size": 33
29
+ }
@@ -0,0 +1,7 @@
1
+ {
2
+ "cls_token": "<cls>",
3
+ "eos_token": "<eos>",
4
+ "mask_token": "<mask>",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>"
7
+ }
@@ -0,0 +1,4 @@
1
+ {
2
+ "model_max_length": 1000000000000000019884624838656,
3
+ "tokenizer_class": "EsmTokenizer"
4
+ }
@@ -0,0 +1,33 @@
1
+ <cls>
2
+ <pad>
3
+ <eos>
4
+ <unk>
5
+ L
6
+ A
7
+ G
8
+ V
9
+ S
10
+ E
11
+ R
12
+ T
13
+ I
14
+ D
15
+ P
16
+ K
17
+ Q
18
+ N
19
+ F
20
+ Y
21
+ M
22
+ H
23
+ W
24
+ C
25
+ X
26
+ B
27
+ U
28
+ Z
29
+ O
30
+ .
31
+ -
32
+ <null_1>
33
+ <mask>
@@ -0,0 +1 @@
1
+ from .model import RP3Net, load_model, Mode_Training, RP3_DEFAULT_CONFIG, RP3_CONFIG_B
RP3Net/model/layers.py ADDED
@@ -0,0 +1,171 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.attention as attn
5
+ from torch.nn import init
6
+ import math
7
+ # from lightning.pytorch.cli import instantiate_class
8
+ from typing import Optional, Mapping, Sequence, Union
9
+ # from .util import get_rank, safe_gather
10
+
11
+ TNonLin = Optional[Union[str,Mapping]]
12
+
13
+ def apply_p_drop(m:nn.Module, p:float):
14
+ def set_p_drop(m:nn.Module):
15
+ if isinstance(m, nn.Dropout):
16
+ m.p = p
17
+ m.apply(set_p_drop)
18
+
19
+ class StackedLinear(torch.nn.Module):
20
+ def __init__(self, d_in:int, d_out:int, n_stack:int, bias:bool=True,
21
+ device=None, dtype=None, transpose=False) -> None:
22
+ super().__init__()
23
+ self.d_in = d_in
24
+ self.d_out = d_out
25
+ self.n_stack = n_stack
26
+ self.transpose = transpose
27
+ factory_kwargs = {'device': device, 'dtype': dtype}
28
+ self.weight = nn.Parameter(torch.empty((n_stack, d_in, d_out), **factory_kwargs))
29
+ if bias:
30
+ self.bias = nn.Parameter(torch.empty(n_stack, 1, d_out, **factory_kwargs))
31
+ else:
32
+ self.register_parameter('bias', None)
33
+ self.reset_parameters()
34
+
35
+ def reset_parameters(self) -> None:
36
+ # Lifted from torch.nn.Linear
37
+ # NB: this is important; you could get really weird results (loss off the scale) if you get this wrong
38
+ for i in range(self.n_stack):
39
+ init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
40
+ if self.bias is not None:
41
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
42
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
43
+ for i in range(self.n_stack):
44
+ init.uniform_(self.bias[i], -bound, bound)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ out = x @ self.weight
48
+ if self.bias is not None:
49
+ out += self.bias
50
+ if self.transpose:
51
+ out = out.transpose(0,1)
52
+ return out
53
+
54
+ def extra_repr(self) -> str:
55
+ return 'd_in={}, d_out={}, n_stack={}, bias={}'.format(
56
+ self.d_in, self.d_out, self.n_stack, self.bias is not None
57
+ )
58
+
59
+
60
+ class ClsHeadBlock(nn.Module):
61
+ def __init__(self, in_dim:int, out_dim:int, n_stack:Optional[int], layer_norm: bool, p_drop:Optional[float],
62
+ nonlinearity:TNonLin, bias:bool) -> None:
63
+ super().__init__()
64
+ layer = StackedLinear(in_dim, out_dim, n_stack, bias=bias) if n_stack is not None else nn.Linear(in_dim, out_dim, bias=bias)
65
+ self.layers = nn.Sequential(layer)
66
+ if layer_norm:
67
+ self.layers.append(nn.LayerNorm(out_dim))
68
+ if p_drop is not None and p_drop != 0:
69
+ self.layers.append(nn.Dropout(p_drop))
70
+ if nonlinearity is not None:
71
+ nonlin = getattr(nn, nonlinearity)()
72
+ self.layers.append(nonlin)
73
+
74
+ def forward(self, x):
75
+ return self.layers(x)
76
+
77
+
78
+ class StackedClsHead(nn.Module):
79
+
80
+ def __init__(self, embedding_dim:int, n_stack:int, layers:Optional[Mapping[str, int]]=None,
81
+ layer_norm:bool=False, p_drop:Optional[float]=None,
82
+ nonlinearity:TNonLin=None, bias:bool=False, end_bias:bool=False) -> None:
83
+ super().__init__()
84
+ layer_dims = []
85
+ if layers:
86
+ layer_dims = [round(layers['d'])] * round(layers['n'])
87
+ layer_dims = [embedding_dim] + layer_dims
88
+ self.layers = nn.Sequential(*[ClsHeadBlock(layer_dims[i], layer_dims[i+1], n_stack,
89
+ layer_norm, p_drop, nonlinearity, bias)
90
+ for i in range(len(layer_dims) - 1)])
91
+ self.layers.append(StackedLinear(layer_dims[-1], 2, n_stack, bias=end_bias, transpose=True)) # todo p_drop
92
+
93
+ def forward(self, x):
94
+ return self.layers(x)
95
+
96
+
97
+ class ClsHead(nn.Module):
98
+
99
+ def __init__(self, embedding_dim:int, layers:Optional[Mapping[str, int]]=None,
100
+ layer_norm:bool=False, p_drop:Optional[float]=None,
101
+ nonlinearity:TNonLin=None, bias:bool=False, end_bias:bool=False,
102
+ n_stack:Optional[int]=None, out_dim:int=2) -> None:
103
+ super().__init__()
104
+ layer_dims = []
105
+ if layers:
106
+ layer_dims = [round(layers['d'])] * round(layers['n'])
107
+ layer_dims = [embedding_dim] + layer_dims
108
+ self.layers = nn.Sequential(*[ClsHeadBlock(layer_dims[i], layer_dims[i+1], n_stack,
109
+ layer_norm, p_drop, nonlinearity, bias)
110
+ for i in range(len(layer_dims) - 1)])
111
+ self.layers.append(ClsHeadBlock(layer_dims[-1], out_dim, n_stack=None,
112
+ layer_norm=layer_norm, p_drop=p_drop, nonlinearity=None, bias=end_bias))
113
+
114
+ def forward(self, x):
115
+ return self.layers(x)
116
+
117
+ class SetTransformerPooling(nn.Module):
118
+ def __init__(self, d:int, num_heads:int, seq_dim:int|None=None, num_seeds:int=1, layer_norm:bool=False, p_drop:float=.0, keep_dim:bool=False,
119
+ can_use_efficient:bool=True) -> None:
120
+ """
121
+ [Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. Lee et al, ICLM 2019](https://arxiv.org/pdf/1810.00825v3.pdf)
122
+ Section 3.2
123
+ """
124
+ super().__init__()
125
+ self.seed = nn.Parameter(torch.empty(1, num_seeds, d))
126
+ nn.init.xavier_uniform_(self.seed)
127
+ if seq_dim is None:
128
+ seq_dim = d
129
+ self.mha = nn.MultiheadAttention(d, num_heads, dropout=p_drop, batch_first=True, kdim=seq_dim, vdim=seq_dim)
130
+ self.layer_norm = nn.LayerNorm(d) if layer_norm else None
131
+ self.keep_dim = keep_dim
132
+ self.attn_flags = [attn.SDPBackend.FLASH_ATTENTION, attn.SDPBackend.MATH, attn.SDPBackend.CUDNN_ATTENTION]
133
+ if can_use_efficient:
134
+ self.attn_flags.append(attn.SDPBackend.EFFICIENT_ATTENTION)
135
+
136
+ def forward(self, x:torch.Tensor, mask:Optional[torch.Tensor]=None, **kwargs) -> torch.Tensor:
137
+ if mask is not None:
138
+ mask = ~(mask.to(dtype=bool))
139
+ seed = self.seed.expand(x.shape[0], -1, -1)
140
+ # Required for EMLC to work.
141
+ # Forward mode differentiation is not implemented for memory efficient attention,
142
+ # so need to disable optimization here. NB: this might change in future version of PyTorch
143
+ with attn.sdpa_kernel(self.attn_flags):
144
+ out, _ = self.mha(seed, x, x, key_padding_mask=mask, need_weights=False)
145
+ if self.layer_norm is not None:
146
+ out = self.layer_norm(out)
147
+ if not self.keep_dim:
148
+ out = out.flatten(start_dim=1)
149
+ return out
150
+
151
+ class MeanPooling(nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+
155
+ def forward(self, x:torch.Tensor, *, mask:torch.Tensor, **kwargs) -> torch.Tensor:
156
+ mask = mask.to(dtype=torch.bool)
157
+ mask = mask.unsqueeze(-1)
158
+ batch_seq_len = mask.sum(dim=1)
159
+ x_sum = x.masked_fill(~mask, 0).sum(dim=1)
160
+ return x_sum / batch_seq_len
161
+
162
+ class MaxPooling(nn.Module):
163
+ def __init__(self):
164
+ super().__init__()
165
+
166
+ def forward(self, x:torch.Tensor, *, mask:torch.Tensor, **kwargs) -> torch.Tensor:
167
+ mask = mask.to(dtype=torch.bool)
168
+ mask = mask.unsqueeze(-1)
169
+ x = x.masked_fill(~mask, float('-inf'))
170
+ ret, _ = x.max(dim=1)
171
+ return ret
RP3Net/model/model.py ADDED
@@ -0,0 +1,233 @@
1
+ import abc
2
+ import re
3
+ import os
4
+ import enum
5
+ import typing
6
+ import torch
7
+ import torch.nn as nn
8
+ import ml_collections as mlc
9
+ import transformers as hub
10
+ import pathlib
11
+ import peft
12
+
13
+ from . import layers
14
+ from .. import util
15
+
16
+ log = util.get_logger(__name__)
17
+
18
+ def resolve(filename) -> pathlib.Path:
19
+ if type(filename) == str:
20
+ filename = pathlib.Path(filename)
21
+ if str(filename).startswith('~'):
22
+ filename = filename.expanduser()
23
+ return filename.resolve()
24
+
25
+ class Mode(enum.Flag):
26
+ Inference = enum.auto()
27
+ Training_A = enum.auto()
28
+ Training_B = enum.auto()
29
+ Training_C = enum.auto()
30
+ Training_D = enum.auto()
31
+ Mode_Training = Mode.Training_A | Mode.Training_B | Mode.Training_C | Mode.Training_D
32
+ Mode_Training_FM = Mode.Training_C | Mode.Training_D
33
+ Mode_FM = Mode_Training_FM | Mode.Inference
34
+ Mode_Training_Aggregation = Mode.Training_B | Mode.Training_C | Mode.Training_D
35
+ Mode_Aggregation = Mode_Training_Aggregation | Mode.Inference
36
+
37
+ class RP3Net(nn.Module):
38
+ def __init__(self, cfg:mlc.ConfigDict) -> None:
39
+ super(RP3Net, self).__init__()
40
+ self.fm = None
41
+ self.mode = Mode[cfg.get('mode', 'Inference')]
42
+ if self.mode in Mode_FM:
43
+ self._init_fm(cfg, self.mode)
44
+ assert self.fm is not None, "Model must be initialized"
45
+ if self.mode in Mode_Aggregation:
46
+ if cfg.aggregation == 'mean':
47
+ self.pooling = layers.MeanPooling()
48
+ elif cfg.aggregation == 'max':
49
+ self.pooling = layers.MaxPooling()
50
+ elif cfg.aggregation == 'stp':
51
+ self.pooling = layers.SetTransformerPooling(**cfg.stp)
52
+ else:
53
+ raise ValueError(f"Aggregation type {cfg.aggregation} not supported")
54
+ self.cls_head = layers.ClsHead(**cfg.classification_head)
55
+
56
+ def forward(self, batch, return_repr=False):
57
+ if self.mode == Mode.Training_A:
58
+ logits = self.cls_head(batch['embeddings'])
59
+ elif self.mode == Mode.Training_B:
60
+ global_repr = self.pooling(batch['embeddings'], mask=batch['attention_mask'])
61
+ logits = self.cls_head(global_repr)
62
+ else: # Inference, Training_CD
63
+ seq_repr = self.sequence_representation(batch['seq'])
64
+ mask = self.attention_mask(batch['seq'])
65
+ global_repr = self.pooling(seq_repr, mask=mask)
66
+ logits = self.cls_head(global_repr)
67
+ if return_repr:
68
+ return logits, global_repr
69
+ return logits
70
+
71
+ # def train(self, train_mode:bool=True):
72
+ # if not train_mode or self.mode == Mode.Training_CD:
73
+ # super().train(train_mode)
74
+ # else:
75
+ # if self.mode == Mode.Inference:
76
+ # raise ValueError("Model is in inference mode")
77
+ # elif self.mode == Mode.Training_A:
78
+ # self.fm.train(False)
79
+ # self.pooling.train(False)
80
+ # self.cls_head.train(True)
81
+ # elif self.mode == Mode.Training_B:
82
+ # self.fm.train(False)
83
+ # self.pooling.train(True)
84
+ # self.cls_head.train(True)
85
+
86
+ @abc.abstractmethod
87
+ def _init_fm(self, cfg:mlc.ConfigDict, mode:Mode):
88
+ pass
89
+
90
+ @abc.abstractmethod
91
+ def tokenize_sequences(self, sequences:typing.Sequence[str]):
92
+ pass
93
+
94
+ @abc.abstractmethod
95
+ def sequence_representation(self, batch):
96
+ pass
97
+
98
+ @abc.abstractmethod
99
+ def attention_mask(self, batch):
100
+ pass
101
+
102
+ @torch.no_grad()
103
+ def predict(self, sequences:typing.Sequence[str]|typing.Mapping[str,str], device=None):
104
+ is_mapping=False
105
+ if isinstance(sequences, typing.Mapping):
106
+ is_mapping = True
107
+ keys = list(sequences.keys())
108
+ sequences = [sequences[k] for k in keys]
109
+ seq_batch = self.tokenize_sequences(sequences)
110
+ if device:
111
+ seq_batch = seq_batch.to(device)
112
+ batch = {'seq': seq_batch}
113
+ logits = self(batch)
114
+ logits_norm = torch.softmax(logits, dim=-1)
115
+ logits_norm = logits_norm[:,1].cpu()
116
+ if is_mapping:
117
+ return {k:logits_norm[i].item() for i,k in enumerate(keys)}
118
+ else:
119
+ return logits_norm
120
+
121
+ class RP3Esm2(RP3Net):
122
+ def _init_esm2(self, cfg:mlc.ConfigDict, mode:Mode, cfg_path:str|os.PathLike) -> None:
123
+ esm_cfg = hub.EsmConfig.from_pretrained(cfg_path, local_files_only=True)
124
+ if (checkpoint_file := cfg.get('fm.cp')):
125
+ log.info(f"Loading pre-trained FM from checkpoint {checkpoint_file}")
126
+ state_dict = torch.load(util.resolve(checkpoint_file), map_location='cpu', weights_only=True)
127
+ self.fm = hub.EsmModel.from_pretrained(
128
+ None,
129
+ config=esm_cfg,
130
+ state_dict=state_dict,
131
+ local_files_only=True,
132
+ add_pooling_layer=False
133
+ )
134
+ else:
135
+ log.info(f"Loading random model")
136
+ self.fm = hub.EsmModel(esm_cfg, add_pooling_layer=False)
137
+ self.tokenizer = hub.EsmTokenizer.from_pretrained(cfg_path, do_lower_case=False)
138
+ self.re_aa_x = re.compile(r"[UZOB]")
139
+ lora_config=cfg.get('fm.lora')
140
+ if lora_config:
141
+ lora_config = peft.LoraConfig(**lora_config, inference_mode=(mode == Mode.Inference))
142
+ self.fm = peft.get_peft_model(self.fm, lora_config)
143
+
144
+ def tokenize_sequences(self, seqs:typing.Sequence[str]):
145
+ seqs = [self.re_aa_x.sub('X', s) for s in seqs]
146
+ return self.tokenizer(seqs, padding=True, return_tensors='pt')
147
+
148
+ def sequence_representation(self, batch):
149
+ return self.fm(**batch).last_hidden_state
150
+
151
+ def attention_mask(self, batch):
152
+ return batch['attention_mask']
153
+
154
+ def train(self, train_mode:bool=True):
155
+ super().train(train_mode)
156
+ if train_mode and self.mode in Mode_Training_FM:
157
+ self.fm.embeddings.position_embeddings.requires_grad_(False)
158
+ self.fm.contact_head.requires_grad_(False)
159
+ return self
160
+
161
+
162
+ class RP3Esm2_650m(RP3Esm2):
163
+ def _init_fm(self, cfg:mlc.ConfigDict, mode:Mode):
164
+ cfg_path = pathlib.Path(__file__).resolve().parent.parent /'fm_cfg'/'esm2_650m'
165
+ self._init_esm2(cfg, mode, cfg_path)
166
+
167
+
168
+ RP3_DEFAULT_CONFIG = mlc.FrozenConfigDict({
169
+ 'fm': {
170
+ 'type': 'esm2_650m',
171
+ 'lora': {
172
+ 'r': 8,
173
+ 'lora_alpha': 1.0,
174
+ 'target_modules': ['query', 'key', 'value'],
175
+ 'lora_dropout': 0.1,
176
+ 'bias': 'lora_only'
177
+ },
178
+ },
179
+ 'aggregation': 'stp',
180
+ 'stp': {
181
+ 'seq_dim': 1280,
182
+ 'd': 256,
183
+ 'num_heads': 8,
184
+ 'layer_norm': True,
185
+ },
186
+ 'classification_head': {
187
+ 'embedding_dim': 256,
188
+ 'bias': False,
189
+ 'end_bias': True,
190
+ 'layer_norm': False,
191
+ 'layers': {
192
+ 'd': 256,
193
+ 'n': 1
194
+ },
195
+ 'nonlinearity': 'SiLU',
196
+ }
197
+ })
198
+
199
+ RP3_CONFIG_B = mlc.FrozenConfigDict({
200
+ 'fm':{'type': 'esm2_650m'},
201
+ 'aggregation': 'stp',
202
+ 'stp': {
203
+ 'seq_dim': 1280,
204
+ 'd': 128,
205
+ 'num_heads': 8,
206
+ 'layer_norm': True,
207
+ },
208
+ 'classification_head': {
209
+ 'embedding_dim': 128,
210
+ 'bias': False,
211
+ 'end_bias': True,
212
+ 'layer_norm': False,
213
+ 'layers': {
214
+ 'd': 512,
215
+ 'n': 1
216
+ },
217
+ 'nonlinearity': 'SiLU',
218
+ }
219
+ })
220
+
221
+
222
+ def load_model(cfg:mlc.ConfigDict, cp_path:str|os.PathLike|None=None) -> RP3Net:
223
+ model_type = cfg.fm.type
224
+ mode = Mode[cfg.get('mode', 'Inference')]
225
+ if model_type == 'esm2_650m':
226
+ model = RP3Esm2_650m(cfg)
227
+ else:
228
+ raise ValueError(f"Model {model_type} not supported")
229
+ if cp_path:
230
+ cp = torch.load(cp_path, map_location='cpu', weights_only=True)
231
+ model.load_state_dict(cp, strict=True)
232
+ return model
233
+
RP3Net/rp3_main.py ADDED
@@ -0,0 +1,85 @@
1
+
2
+ from typing import Optional
3
+ import yaml
4
+ import argparse
5
+ import logging
6
+ from tqdm import tqdm
7
+ import ml_collections as mlc
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ import RP3Net.util as util
12
+ import RP3Net.model as model
13
+
14
+ log = util.get_logger(__file__)
15
+
16
+
17
+ def setup_args(_args: Optional[list] = None):
18
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="""
19
+ Predict protein expression in E.coli from sequences. Takes in a fasta file with protein sequences and ids.
20
+ Outputs a .csv[.gz] file with "id,score" columns.
21
+ The score is the predicted probability of expression.
22
+ """)
23
+ parser.add_argument("--log_file", help="Log file. Log output to console if set to None.")
24
+ parser.add_argument("--log_level", default="info", help="Log level of root logger. Appender levels are appropriately hard coded.")
25
+ parser.add_argument("-c", "--config", help="Model configuration in yaml, or registry entry.", default="RP3_DEFAULT_CONFIG")
26
+ parser.add_argument("-p", "--checkpoint", required=True, help="Model checkpoint.")
27
+ parser.add_argument("-f", "--fasta", required=True, help="Fasta file, possibly gzipped.")
28
+ parser.add_argument("-d", "--device", default='cpu', help="Device to use: cpu, cuda[:n], mps, ...")
29
+ parser.add_argument("-b", "--batch_size", default=8, type=int, help="Batch size. Memory consumption depends on construct sequence length and batch size. If the model runs out of SLURM/GPU memory, try reducing this parameter to make it fit.")
30
+ parser.add_argument("-o", "--out_file", required=True, help="Output .csv[.gz] file.")
31
+ parser.add_argument("--progress", action=argparse.BooleanOptionalAction, default=True, help="Show progress bar.")
32
+ return parser.parse_args(_args)
33
+
34
+ MODEL_REGISTRY = {
35
+ 'RP3_DEFAULT_CONFIG': model.RP3_DEFAULT_CONFIG,
36
+ 'RP3_CONFIG_B': model.RP3_CONFIG_B,
37
+ }
38
+
39
+ def rp3_main():
40
+ args = setup_args()
41
+ util.setup_logging(save_path=args.log_file, level=args.log_level, log_console=args.log_file is None)
42
+
43
+ seq_map = util.read_fasta(args.fasta)
44
+ seq_lens = np.array(list(map(len, seq_map.values())))
45
+ log.info(f"Read {len(seq_map)} sequences from {args.fasta}. Sequence lengths (mean/std/median/min/max): "
46
+ f"{seq_lens.mean():.2f}/{seq_lens.std():.2f}/{np.median(seq_lens):.0f}/{seq_lens.min()}/{seq_lens.max()}")
47
+
48
+ config_path = util.resolve(args.config)
49
+ if config_path.exists():
50
+ with open(config_path, "r") as f:
51
+ config = mlc.FrozenConfigDict(yaml.load(f, Loader=yaml.FullLoader))
52
+ elif args.config in MODEL_REGISTRY:
53
+ config = MODEL_REGISTRY[args.config]
54
+ else:
55
+ raise ValueError(f"Config file {args.config} not found.")
56
+ log.info(f"Loading model {args.config} from checkpoint {args.checkpoint}")
57
+ m = model.load_model(config, args.checkpoint)
58
+ m = m.to(device=args.device)
59
+ m = m.eval()
60
+ log.info(f"Loaded model {m}")
61
+
62
+ seq_keys = list(seq_map.keys())
63
+ def batches():
64
+ if args.progress:
65
+ tqdm_desc = args.fasta.replace('.fasta', '').replace('.gz', '')[-20:]
66
+ r = tqdm(range(0, len(seq_keys), args.batch_size), desc=tqdm_desc)
67
+ else:
68
+ r = range(0, len(seq_keys), args.batch_size)
69
+ for i in r:
70
+ if not args.progress:
71
+ log.info(f"Processing batch {i // args.batch_size + 1} of {len(seq_keys) // args.batch_size + 1}.")
72
+ yield {k: seq_map[k] for k in seq_keys[i:i + args.batch_size]}
73
+
74
+ ret = {}
75
+ for b in batches():
76
+ ret.update(m.predict(b, device=args.device))
77
+ keys, scores = zip(*ret.items())
78
+ out_df = pd.DataFrame({'id': keys, 'score': scores})
79
+ log.info(f"Writing {out_df.shape[0]} rows to {args.out_file}")
80
+ out_df.to_csv(args.out_file, index=False)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ rp3_main()
85
+
RP3Net/rp3_train.py ADDED
@@ -0,0 +1,18 @@
1
+ import importlib
2
+ from RP3Net.training.cli import RP3Cli
3
+ from RP3Net.util import util
4
+
5
+ log = util.get_logger(__name__)
6
+
7
+
8
+ def rp3_train():
9
+ try:
10
+ assert importlib.util.find_spec('lightning') is not None, "Please install 'RP3Net[training]' to enable training"
11
+ cli = RP3Cli()
12
+ except Exception as e:
13
+ log.error("Top level catch", exc_info=e)
14
+ raise e
15
+
16
+
17
+ if __name__ == "__main__":
18
+ rp3_train()
@@ -0,0 +1,6 @@
1
+ from . import lm
2
+ from . import lm_emlc
3
+ from . import cli
4
+ from . import data
5
+ from . import data_emlc
6
+ from . import metrics