RP3Net 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RP3Net/__init__.py +8 -0
- RP3Net/fm_cfg/esm2_650m/config.json +29 -0
- RP3Net/fm_cfg/esm2_650m/special_tokens_map.json +7 -0
- RP3Net/fm_cfg/esm2_650m/tokenizer_config.json +4 -0
- RP3Net/fm_cfg/esm2_650m/vocab.txt +33 -0
- RP3Net/model/__init__.py +1 -0
- RP3Net/model/layers.py +171 -0
- RP3Net/model/model.py +233 -0
- RP3Net/rp3_main.py +85 -0
- RP3Net/rp3_train.py +18 -0
- RP3Net/training/__init__.py +6 -0
- RP3Net/training/cli.py +166 -0
- RP3Net/training/data.py +300 -0
- RP3Net/training/data_emlc.py +94 -0
- RP3Net/training/lm.py +123 -0
- RP3Net/training/lm_emlc.py +400 -0
- RP3Net/training/metrics.py +357 -0
- RP3Net/util/__init__.py +3 -0
- RP3Net/util/fasta.py +26 -0
- RP3Net/util/torch.py +89 -0
- RP3Net/util/util.py +65 -0
- rp3net-0.0.1.dist-info/METADATA +77 -0
- rp3net-0.0.1.dist-info/RECORD +27 -0
- rp3net-0.0.1.dist-info/WHEEL +5 -0
- rp3net-0.0.1.dist-info/entry_points.txt +3 -0
- rp3net-0.0.1.dist-info/licenses/LICENSE +21 -0
- rp3net-0.0.1.dist-info/top_level.txt +1 -0
RP3Net/training/cli.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import os
|
|
3
|
+
import socket
|
|
4
|
+
import sys
|
|
5
|
+
import logging
|
|
6
|
+
import lightning.pytorch as L
|
|
7
|
+
import lightning.pytorch.utilities as L_util
|
|
8
|
+
import lightning.pytorch.cli as L_cli
|
|
9
|
+
import lightning.pytorch.callbacks as L_cb
|
|
10
|
+
import lightning.pytorch.loggers as L_log
|
|
11
|
+
import wandb
|
|
12
|
+
|
|
13
|
+
from . import lm
|
|
14
|
+
from .. import util
|
|
15
|
+
|
|
16
|
+
log = util.get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
def setup_logging_torch(args):
|
|
19
|
+
log_level = args.log_level
|
|
20
|
+
os.environ["PP_LOG_LEVEL"] = log_level
|
|
21
|
+
logfile = args.logfile
|
|
22
|
+
if logfile is not None:
|
|
23
|
+
logfile_base = util.resolve(logfile)
|
|
24
|
+
logfile_base = str(logfile_base.parent/logfile_base.stem)
|
|
25
|
+
os.environ["PP_LOGFILE_BASE"] = logfile_base
|
|
26
|
+
# /homes/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py
|
|
27
|
+
# SLURMEnvironment.world_size() and SLURMEnvironment.global_rank()
|
|
28
|
+
if int(os.environ.get("SLURM_NTASKS", "0")) > 1:
|
|
29
|
+
logfile = logfile.replace(".log", f"_{os.environ.get('SLURM_PROCID', '0')}.log")
|
|
30
|
+
util.setup_logging(logfile, log_level, log_console=logfile is None)
|
|
31
|
+
ll = logging.getLogger("lightning")
|
|
32
|
+
ll.propagate = True
|
|
33
|
+
ll.handlers.clear()
|
|
34
|
+
ll = logging.getLogger("lightning.pytorch")
|
|
35
|
+
ll.handlers.clear()
|
|
36
|
+
ll.propagate = True
|
|
37
|
+
log.info(f"Host: {socket.gethostname()}; PID: {os.getpid()}; Command line: {' '.join(sys.argv)}")
|
|
38
|
+
|
|
39
|
+
class RP3Cli(L_cli.LightningCLI):
|
|
40
|
+
def __init__(self, *args, **kwargs):
|
|
41
|
+
self.wandb_logger = None
|
|
42
|
+
super().__init__(*args, **{**kwargs, 'save_config_kwargs': {"overwrite": True}})
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
@L_util.rank_zero_only
|
|
46
|
+
def wandb_init(wandb_project, wandb_run_name, wandb_run_id):
|
|
47
|
+
return wandb.init(id=wandb_run_id, project=wandb_project, name=wandb_run_name, resume='allow')
|
|
48
|
+
|
|
49
|
+
def wandb_logger_init(self, config):
|
|
50
|
+
if 'wandb' not in config or config.wandb is None or \
|
|
51
|
+
'project' not in config.wandb or ('run' not in config.wandb and 'run_id' not in config.wandb) or \
|
|
52
|
+
config.wandb.project is None or (config.wandb.run is None and config.wandb.run_id is None) or \
|
|
53
|
+
('disable' in config.wandb and config.wandb.disable):
|
|
54
|
+
log.info("No wandb logging")
|
|
55
|
+
return
|
|
56
|
+
else:
|
|
57
|
+
log.info("Configure wandb logging")
|
|
58
|
+
run_id = config.wandb.run_id if 'run_id' in config.wandb and config.wandb.run_id is not None \
|
|
59
|
+
else config.wandb.run
|
|
60
|
+
run = self.wandb_init(wandb_project=config.wandb.project, wandb_run_name=config.wandb.run, wandb_run_id=run_id)
|
|
61
|
+
if run is not None:
|
|
62
|
+
log.info(f"Wandb run: {run.name}({run.id})")
|
|
63
|
+
logger = L_log.WandbLogger(project=config.wandb.project, name=config.wandb.run, id=run_id)
|
|
64
|
+
return logger
|
|
65
|
+
|
|
66
|
+
def add_arguments_to_parser(self, parser: L_cli.LightningArgumentParser) -> None:
|
|
67
|
+
parser.add_argument("--logfile", help="Log file. Log output to console if set to None.")
|
|
68
|
+
parser.add_argument("--log_level", default="info",
|
|
69
|
+
help="Log level of root logger. Appender levels are appropriately hard coded.")
|
|
70
|
+
parser.add_argument("--track_metric_checkpoints", choices=["last", "all", "best"],
|
|
71
|
+
help="""
|
|
72
|
+
Track checkpoints for training and validation metrics from the module.
|
|
73
|
+
If not provided, no checkpoints will be recorded at all.
|
|
74
|
+
If any value is provided, only the best checkpoint will be recorded for all the metrics specified by `model.metrics_for_checkpointing()`.
|
|
75
|
+
The value of this argument affects how checkpoints will be saved for 'train_loss'.
|
|
76
|
+
""")
|
|
77
|
+
parser.add_argument("--wandb.project", help="Wandb project name", default=None)
|
|
78
|
+
parser.add_argument("--wandb.run", help="Wandb run name", default=None)
|
|
79
|
+
parser.add_argument("--wandb.run_id", help="Wandb run id, same as name by default", default=None)
|
|
80
|
+
parser.add_argument("--wandb.disable", help="Set to true to turn off wandb logging, without removing the rest of wandb settings", action='store_true')
|
|
81
|
+
parser.add_argument("--test_after_fit_metric", help="Metric to use for test_after_fit. If not set, then do not run test_after_fit", default=None)
|
|
82
|
+
parser.add_argument("--emlc_k", help="Number of student gradinent steps to perform per teacher step for EMLC", default=1, type=int)
|
|
83
|
+
|
|
84
|
+
def before_instantiate_classes(self) -> None:
|
|
85
|
+
config = self.config.get(str(self.subcommand), self.config)
|
|
86
|
+
setup_logging_torch(config)
|
|
87
|
+
self.wandb_logger = self.wandb_logger_init(config)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def instantiate_trainer(self, **kwargs) -> L.Trainer:
|
|
91
|
+
log.info("Instantiating trainer")
|
|
92
|
+
config = self.config.get(str(self.subcommand), self.config)
|
|
93
|
+
metric_checkpoints = self._get(self.config_init, 'track_metric_checkpoints')
|
|
94
|
+
# metric_checkpoints = bool(metric_checkpoints)
|
|
95
|
+
if metric_checkpoints is not None:
|
|
96
|
+
self.init_metric_checkpoints(metric_checkpoints)
|
|
97
|
+
self.add_loggers(config)
|
|
98
|
+
return super().instantiate_trainer(**kwargs)
|
|
99
|
+
|
|
100
|
+
def add_loggers(self, config):
|
|
101
|
+
configured_loggers = self._get(self.config_init, 'trainer.logger', default=[])
|
|
102
|
+
if configured_loggers == True or configured_loggers is None:
|
|
103
|
+
configured_loggers = []
|
|
104
|
+
elif isinstance(configured_loggers, L_log.Logger):
|
|
105
|
+
configured_loggers = [configured_loggers]
|
|
106
|
+
elif configured_loggers == False:
|
|
107
|
+
return
|
|
108
|
+
add_csv_logger = True
|
|
109
|
+
for logger in configured_loggers:
|
|
110
|
+
if isinstance(logger, L_log.CSVLogger):
|
|
111
|
+
add_csv_logger = False
|
|
112
|
+
break
|
|
113
|
+
if add_csv_logger:
|
|
114
|
+
csv_logger = L_log.CSVLogger(config['trainer']['default_root_dir'])
|
|
115
|
+
configured_loggers.append(csv_logger)
|
|
116
|
+
if self.wandb_logger is not None:
|
|
117
|
+
configured_loggers.append(self.wandb_logger)
|
|
118
|
+
config_init = self.config_init.get(str(self.subcommand), self.config_init)
|
|
119
|
+
config_init['trainer']['logger'] = configured_loggers
|
|
120
|
+
|
|
121
|
+
def init_metric_checkpoints(self, checkpoint_save_flag):
|
|
122
|
+
model: lm.RP3LM = self.model
|
|
123
|
+
metrics = model.metrics.metrics_for_checkpointing()
|
|
124
|
+
default_root_dir = util.resolve(self._get(self.config, 'trainer.default_root_dir'))
|
|
125
|
+
if metrics is None:
|
|
126
|
+
return
|
|
127
|
+
trainer_config = self._get(self.config_init, 'trainer', default={})
|
|
128
|
+
if 'callbacks' not in trainer_config or trainer_config['callbacks'] is None:
|
|
129
|
+
callbacks = []
|
|
130
|
+
trainer_config['callbacks'] = callbacks
|
|
131
|
+
else:
|
|
132
|
+
callbacks = trainer_config['callbacks']
|
|
133
|
+
if isinstance(callbacks, L.Callback):
|
|
134
|
+
callbacks = [callbacks]
|
|
135
|
+
trainer_config['callbacks'] = callbacks
|
|
136
|
+
for i, (key, metric) in enumerate(metrics.items()):
|
|
137
|
+
mode = 'max' if metric.higher_is_better else 'min'
|
|
138
|
+
callbacks.append(L_cb.ModelCheckpoint(
|
|
139
|
+
dirpath=default_root_dir, monitor=key, mode=mode,
|
|
140
|
+
filename='{epoch}_{'+key+':.2f}',
|
|
141
|
+
save_on_train_epoch_end=False, save_weights_only=False
|
|
142
|
+
))
|
|
143
|
+
assert checkpoint_save_flag in ['last', 'all', 'best']
|
|
144
|
+
if checkpoint_save_flag == 'best':
|
|
145
|
+
save_top_k = 1
|
|
146
|
+
elif checkpoint_save_flag == 'all':
|
|
147
|
+
save_top_k = -1
|
|
148
|
+
else:
|
|
149
|
+
save_top_k = 0
|
|
150
|
+
callbacks.append(L_cb.ModelCheckpoint(dirpath=default_root_dir, monitor='train_loss', filename='{epoch}_{train_loss:.2f}', mode='min',
|
|
151
|
+
save_top_k=save_top_k, save_last=True))
|
|
152
|
+
|
|
153
|
+
def test_after_fit(self, metric):
|
|
154
|
+
model: lm.RP3LM = self.model
|
|
155
|
+
dm = self.datamodule
|
|
156
|
+
dir = util.resolve(self.trainer.default_root_dir)
|
|
157
|
+
filename_pattern = re.compile(r'^epoch=\d+_' + metric + r'=\d+(\.\d+)?\.ckpt$')
|
|
158
|
+
cp_file = util.find_checkpoint_file(dir, filename_pattern)
|
|
159
|
+
log.info(f"Loading checkpoint {cp_file}")
|
|
160
|
+
self.trainer.test(model, dm, ckpt_path=str(cp_file))
|
|
161
|
+
|
|
162
|
+
def after_fit(self):
|
|
163
|
+
metric = self._get(self.config, 'test_after_fit_metric')
|
|
164
|
+
if metric is not None:
|
|
165
|
+
log.info(f"Running test_after_fit on {metric}")
|
|
166
|
+
self.test_after_fit(metric)
|
RP3Net/training/data.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import typing
|
|
3
|
+
import functools
|
|
4
|
+
import zipfile
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.utils.data as torch_data
|
|
8
|
+
import polars as pl
|
|
9
|
+
import numpy as np
|
|
10
|
+
import lightning as L
|
|
11
|
+
import ml_collections as mlc
|
|
12
|
+
|
|
13
|
+
from .lm import RP3LM
|
|
14
|
+
from .. import util
|
|
15
|
+
from .. import model
|
|
16
|
+
|
|
17
|
+
log = util.get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
FULL_DF_DTYPE_PL = {'created_at': pl.Datetime(), 'source': pl.Categorical(), 'sub_source': pl.Categorical(),
|
|
20
|
+
'no_tags_cluster_40_id': pl.String(), 'with_tags_cluster_90_id': pl.String(),
|
|
21
|
+
'has_dna': pl.Boolean(),
|
|
22
|
+
'experiment_id': pl.String(),
|
|
23
|
+
'yield_binary': pl.Boolean(), 'yield_cat': pl.Int64(),
|
|
24
|
+
'host': pl.Categorical(), 'exp_outcome': pl.Categorical(),
|
|
25
|
+
'id': pl.String(), 'fasta_id': pl.String(), 'dna_fasta_id': pl.String(), 'fasta_id_no_tags': pl.String(),
|
|
26
|
+
'ds_type': pl.Categorical(),
|
|
27
|
+
'n_tags_end': pl.Int64(), 'c_tags_start': pl.Int64(), 'n_fragments': pl.Int64(), 'unique_target_count': pl.Int64(),
|
|
28
|
+
'fasta_id_no_tags': pl.String(),
|
|
29
|
+
'uniprot_id': pl.String(), 'gene_id': pl.String(), 'taxon_id': pl.Int64(),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def read_full_df_pl(path:str|os.PathLike, **kwargs) -> pl.DataFrame:
|
|
33
|
+
_schema = FULL_DF_DTYPE_PL if 'schema_overrides' not in kwargs else kwargs['schema_overrides']
|
|
34
|
+
kwargs.pop('schema_overrides', None)
|
|
35
|
+
df = pl.read_csv(path, schema_overrides=_schema, **kwargs)
|
|
36
|
+
return df
|
|
37
|
+
|
|
38
|
+
def load_global_embeddings_file(embeddings_file:os.PathLike) -> typing.Mapping:
|
|
39
|
+
log.info(f"Loading global embeddings from {embeddings_file}")
|
|
40
|
+
embeddings_data = torch.load(embeddings_file)
|
|
41
|
+
ids = embeddings_data['ids']
|
|
42
|
+
embeddings = embeddings_data['embeddings']
|
|
43
|
+
return {id: embeddings[i] for i, id in enumerate(ids)}
|
|
44
|
+
|
|
45
|
+
class RP3GlobalEmbeddingsDataSet(torch_data.Dataset):
|
|
46
|
+
def __init__(self, df: pl.DataFrame, prefix:str, embeddings:typing.Mapping[str, torch.Tensor]) -> None:
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.df = df
|
|
49
|
+
self.prefix=prefix
|
|
50
|
+
self.embeddings = embeddings
|
|
51
|
+
|
|
52
|
+
def __len__(self):
|
|
53
|
+
return self.df.shape[0]
|
|
54
|
+
|
|
55
|
+
def __getitem__(self, idx):
|
|
56
|
+
try:
|
|
57
|
+
row = self.df.row(idx, named=True)
|
|
58
|
+
ret = {
|
|
59
|
+
'idx': row['src_idx'],
|
|
60
|
+
'yield_binary': int(row['yield_binary']),
|
|
61
|
+
'source': row['source'],
|
|
62
|
+
'embeddings': self.embeddings[row['id']]
|
|
63
|
+
}
|
|
64
|
+
log.debug(f"{self.prefix}: (torch={idx}, csv={ret['idx']}){row['id']}: {ret['yield_binary']}")
|
|
65
|
+
return ret
|
|
66
|
+
except Exception as e:
|
|
67
|
+
log.error(f"Top level catch in {self.prefix} __getitem__", exc_info=e)
|
|
68
|
+
raise e
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class RP3SequenceEmbeddingsDataSet(torch_data.Dataset):
|
|
72
|
+
def __init__(self, df: pl.DataFrame, prefix:str, embeddings:zipfile.ZipFile) -> None:
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.df = df
|
|
75
|
+
self.prefix=prefix
|
|
76
|
+
self.embeddings:zipfile.ZipFile = embeddings
|
|
77
|
+
|
|
78
|
+
def __len__(self):
|
|
79
|
+
return self.df.shape[0]
|
|
80
|
+
|
|
81
|
+
def __getitem__(self, idx):
|
|
82
|
+
try:
|
|
83
|
+
row = self.df.row(idx, named=True)
|
|
84
|
+
with self.embeddings.open(row['id'], 'r') as f:
|
|
85
|
+
emb = torch.load(f, weights_only=True)
|
|
86
|
+
ret = {
|
|
87
|
+
'idx': row['src_idx'],
|
|
88
|
+
'yield_binary': int(row['yield_binary']),
|
|
89
|
+
'source': row['source'],
|
|
90
|
+
'embeddings': emb
|
|
91
|
+
}
|
|
92
|
+
log.debug(f"{self.prefix}: (torch={idx}, csv={ret['idx']}){row['id']}: {ret['yield_binary']}")
|
|
93
|
+
return ret
|
|
94
|
+
except Exception as e:
|
|
95
|
+
log.error(f"Top level catch in {self.prefix} __getitem__", exc_info=e)
|
|
96
|
+
raise e
|
|
97
|
+
|
|
98
|
+
class RP3SequenceDataSet(torch_data.Dataset):
|
|
99
|
+
def __init__(self, df: pl.DataFrame, prefix, rng:np.random.Generator, max_seq_len:int=0):
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.rng = rng
|
|
102
|
+
self.df = df
|
|
103
|
+
self.prefix=prefix
|
|
104
|
+
self.max_seq_len = max_seq_len
|
|
105
|
+
|
|
106
|
+
def seq_chunk(self, seq:str):
|
|
107
|
+
if self.max_seq_len == 0 or len(seq) <= self.max_seq_len:
|
|
108
|
+
log.debug(f"Not changing sequence of length {len(seq)}; max_seq_len={self.max_seq_len}")
|
|
109
|
+
return seq
|
|
110
|
+
start_idx = self.rng.integers(len(seq) - self.max_seq_len + 1)
|
|
111
|
+
end_idx = start_idx + self.max_seq_len
|
|
112
|
+
log.debug(f"Returning the {start_idx}:{end_idx} chunk from sequence of length {len(seq)}; max_seq_len={self.max_seq_len}")
|
|
113
|
+
return seq[start_idx:end_idx]
|
|
114
|
+
|
|
115
|
+
def __len__(self):
|
|
116
|
+
return self.df.shape[0]
|
|
117
|
+
|
|
118
|
+
def __getitem__(self, idx):
|
|
119
|
+
try:
|
|
120
|
+
row = self.df.row(idx, named=True)
|
|
121
|
+
ret = {
|
|
122
|
+
'idx': row['src_idx'],
|
|
123
|
+
'source': row['source'],
|
|
124
|
+
'seq': self.seq_chunk(row['seq']),
|
|
125
|
+
'yield_binary': int(row['yield_binary']),
|
|
126
|
+
}
|
|
127
|
+
log.debug(f"{self.prefix}: (torch={idx}, csv={ret['idx']}){row['id']}: {ret['yield_binary']}")
|
|
128
|
+
return ret
|
|
129
|
+
except Exception as e:
|
|
130
|
+
log.error(f"Top level catch in {self.prefix} __getitem__", exc_info=e)
|
|
131
|
+
raise e
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class RP3LDM(L.LightningDataModule):
|
|
135
|
+
def __init__(self, hypers) -> None:
|
|
136
|
+
super().__init__()
|
|
137
|
+
log.debug("DataModule init")
|
|
138
|
+
self.save_hyperparameters({'data': hypers})
|
|
139
|
+
self.hypers = mlc.ConfigDict(self.hparams.data)
|
|
140
|
+
self.sources_map = {s: i for i, s in enumerate(self.hypers.sources)}
|
|
141
|
+
self.rng = np.random.default_rng(self.hypers.get('seed', None))
|
|
142
|
+
self.validation_slice = self.hypers.get('validation_slice', 'VALIDATION')
|
|
143
|
+
|
|
144
|
+
def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
|
|
145
|
+
raise NotImplemented()
|
|
146
|
+
|
|
147
|
+
def load_df(self) -> pl.DataFrame:
|
|
148
|
+
data_path = self.hypers.ds_path
|
|
149
|
+
log.info(f"Loading data from {data_path}; validation slice: {self.validation_slice}")
|
|
150
|
+
df = read_full_df_pl(data_path).with_row_index('src_idx')
|
|
151
|
+
df_sources = set(*df.select(pl.col('source').cast(pl.String).unique()))
|
|
152
|
+
for s in self.sources_map:
|
|
153
|
+
assert s in df_sources
|
|
154
|
+
df = (df
|
|
155
|
+
.filter(pl.col('source').is_in(self.hypers.sources))
|
|
156
|
+
.with_columns(pl.col('source').cast(pl.String).replace_strict(self.sources_map))
|
|
157
|
+
)
|
|
158
|
+
return df
|
|
159
|
+
|
|
160
|
+
def setup(self, stage: str) -> None:
|
|
161
|
+
log.debug("RP3LDM setup")
|
|
162
|
+
if self.trainer is not None:
|
|
163
|
+
assert self.trainer.model.sources == self.hypers.sources
|
|
164
|
+
df = self.load_df()
|
|
165
|
+
self.df_train = df.filter(
|
|
166
|
+
pl.col('ds_type').is_not_null() &
|
|
167
|
+
pl.col('ds_type').is_in(['TEST', self.validation_slice]).not_()
|
|
168
|
+
)
|
|
169
|
+
assert self.df_train.shape[0] > 0, f"No training data for slice {self.validation_slice}"
|
|
170
|
+
self.df_val = df.filter(ds_type=self.validation_slice)
|
|
171
|
+
assert self.df_val.shape[0] > 0, f"No validation data for slice {self.validation_slice}"
|
|
172
|
+
self.df_test = df.filter(ds_type='TEST')
|
|
173
|
+
assert self.df_test.shape[0] > 0, f"No test data for slice {self.validation_slice}"
|
|
174
|
+
self.create_torch_datasets()
|
|
175
|
+
|
|
176
|
+
def create_torch_datasets(self):
|
|
177
|
+
self.train_ds = self.torch_dataset(self.df_train, 'train')
|
|
178
|
+
self.val_ds = self.torch_dataset(self.df_val, "val")
|
|
179
|
+
df_val_train = self.df_train.sample(len(self.val_ds), with_replacement=False, seed=self.hypers.test_val_seed)
|
|
180
|
+
self.val_train_ds = self.torch_dataset(df_val_train, "val-training")
|
|
181
|
+
self.test_ds = self.torch_dataset(self.df_test, "test")
|
|
182
|
+
|
|
183
|
+
def get_collate_fn(self):
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
def get_batch_size(self, key:str) -> int:
|
|
187
|
+
return int(self.hypers.get(f'{key}_batch_size', self.hypers.get('batch_size', -1)))
|
|
188
|
+
|
|
189
|
+
def train_dataloader(self):
|
|
190
|
+
batch_size = self.get_batch_size('training')
|
|
191
|
+
return torch_data.DataLoader(self.train_ds, batch_size=batch_size, collate_fn=self.get_collate_fn(), shuffle=True,
|
|
192
|
+
num_workers=0, pin_memory=True)
|
|
193
|
+
|
|
194
|
+
def _build_val_test_loader(self, ds):
|
|
195
|
+
batch_size = self.get_batch_size('val_test')
|
|
196
|
+
sampler = None
|
|
197
|
+
drop_last = False
|
|
198
|
+
if self.hypers.get('use_distributed_sampler', False) and util.is_distr_env():
|
|
199
|
+
sampler = torch_data.DistributedSampler(ds, drop_last=True, shuffle=False)
|
|
200
|
+
drop_last = True
|
|
201
|
+
dl = torch_data.DataLoader(ds, batch_size=batch_size, collate_fn=self.get_collate_fn(), sampler=sampler,
|
|
202
|
+
num_workers=0, pin_memory=True, drop_last=drop_last)
|
|
203
|
+
return dl
|
|
204
|
+
|
|
205
|
+
def val_dataloader(self):
|
|
206
|
+
train_dl = self._build_val_test_loader(self.val_train_ds)
|
|
207
|
+
val_dl = self._build_val_test_loader(self.val_ds)
|
|
208
|
+
return [train_dl, val_dl]
|
|
209
|
+
|
|
210
|
+
def test_dataloader(self):
|
|
211
|
+
return self._build_val_test_loader(self.test_ds)
|
|
212
|
+
|
|
213
|
+
class RP3GlobalEmbeddingsLDM(RP3LDM):
|
|
214
|
+
def __init__(self, hypers) -> None:
|
|
215
|
+
super().__init__(hypers)
|
|
216
|
+
self.embeddings = None
|
|
217
|
+
|
|
218
|
+
def load_df(self):
|
|
219
|
+
df = super().load_df()
|
|
220
|
+
embeddings_file = self.hypers.embeddings_file
|
|
221
|
+
if embeddings_file == 'onehot' or embeddings_file.startswith('random_'):
|
|
222
|
+
seqs = util.read_fasta(self.hypers.fasta_path)
|
|
223
|
+
aa_to_int = {aa: i for i, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')}
|
|
224
|
+
if embeddings_file.startswith('random_'):
|
|
225
|
+
emb = torch.nn.Embedding(20, int(embeddings_file[7:])).to('cpu').requires_grad_(False)
|
|
226
|
+
self.embeddings = dict()
|
|
227
|
+
for row in df.select('id', self.hypers.fasta_key).iter_rows():
|
|
228
|
+
seq = seqs[row[1]]
|
|
229
|
+
seq_tz = torch.tensor([aa_to_int[aa] for aa in seq], dtype=torch.int64)
|
|
230
|
+
if embeddings_file == 'onehot':
|
|
231
|
+
seq_enc = torch.nn.functional.one_hot(seq_tz, num_classes=len(aa_to_int)).to(dtype=torch.float32)
|
|
232
|
+
else:
|
|
233
|
+
seq_enc = emb(seq_tz)
|
|
234
|
+
self.embeddings[row[0]] = seq_enc.mean(0)
|
|
235
|
+
else:
|
|
236
|
+
embeddings_file = util.resolve(self.hypers.embeddings_file)
|
|
237
|
+
self.embeddings = load_global_embeddings_file(embeddings_file)
|
|
238
|
+
return df
|
|
239
|
+
|
|
240
|
+
def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
|
|
241
|
+
return RP3GlobalEmbeddingsDataSet(df, prefix, self.embeddings)
|
|
242
|
+
|
|
243
|
+
class RP3SequenceEmbeddingsLDM(RP3LDM):
|
|
244
|
+
def __init__(self, hypers) -> None:
|
|
245
|
+
super().__init__(hypers)
|
|
246
|
+
self.embeddings_file = None
|
|
247
|
+
|
|
248
|
+
def load_df(self):
|
|
249
|
+
df = super().load_df()
|
|
250
|
+
self.embeddings_file = zipfile.ZipFile(util.resolve(self.hypers.embeddings_file), 'r')
|
|
251
|
+
return df
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def collate(batch):
|
|
255
|
+
embeddings = [b.pop('embeddings') for b in batch]
|
|
256
|
+
ret = torch_data.default_collate(batch)
|
|
257
|
+
emb_len = torch.tensor([e.shape[0] for e in embeddings])
|
|
258
|
+
max_len = emb_len.max()
|
|
259
|
+
emb_padded = torch.stack([torch.nn.functional.pad(e, (0,0,0, max_len - e.shape[0]), value=0) for e in embeddings])
|
|
260
|
+
attn_mask = torch.zeros((emb_len.shape[0], max_len), dtype=torch.int32)
|
|
261
|
+
for i, l in enumerate(emb_len):
|
|
262
|
+
attn_mask[i, :l] = 1
|
|
263
|
+
ret['embeddings'] = emb_padded
|
|
264
|
+
ret['attention_mask'] = attn_mask
|
|
265
|
+
return ret
|
|
266
|
+
|
|
267
|
+
def get_collate_fn(self):
|
|
268
|
+
return RP3SequenceEmbeddingsLDM.collate
|
|
269
|
+
|
|
270
|
+
def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
|
|
271
|
+
return RP3SequenceEmbeddingsDataSet(df, prefix, self.embeddings_file)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class RP3SequenceLDM(RP3LDM):
|
|
275
|
+
|
|
276
|
+
def __init__(self, hypers) -> None:
|
|
277
|
+
super().__init__(hypers)
|
|
278
|
+
|
|
279
|
+
def load_df(self):
|
|
280
|
+
df = super().load_df()
|
|
281
|
+
log.info(f"Reading sequences from {self.hypers.fasta_path}")
|
|
282
|
+
fasta_map = util.read_fasta(self.hypers.fasta_path)
|
|
283
|
+
fasta_id_col = self.hypers.fasta_id_col
|
|
284
|
+
df = df.with_columns(seq=pl.col(fasta_id_col).replace_strict(fasta_map))
|
|
285
|
+
return df
|
|
286
|
+
|
|
287
|
+
def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
|
|
288
|
+
return RP3SequenceDataSet(df, prefix, self.rng, self.hypers.get('max_seq_len', 0))
|
|
289
|
+
|
|
290
|
+
@staticmethod
|
|
291
|
+
def collate(tokenizer:model.RP3Net, batch):
|
|
292
|
+
seqs = [b.pop('seq') for b in batch]
|
|
293
|
+
ret = torch_data.default_collate(batch)
|
|
294
|
+
ret['seq'] = tokenizer.tokenize_sequences(seqs)
|
|
295
|
+
return ret
|
|
296
|
+
|
|
297
|
+
def get_collate_fn(self):
|
|
298
|
+
lm: RP3LM = self.trainer.lightning_module
|
|
299
|
+
return functools.partial(RP3SequenceLDM.collate, lm.model)
|
|
300
|
+
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import numpy as np
|
|
3
|
+
import polars as pl
|
|
4
|
+
import torch.utils.data as torch_data
|
|
5
|
+
|
|
6
|
+
from . import data
|
|
7
|
+
from .. import util
|
|
8
|
+
|
|
9
|
+
log = util.get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
class EmlcBatchSampler(torch_data.Sampler):
|
|
12
|
+
def __init__(self, *, df:pl.DataFrame, rng:np.random.Generator, clean_sources:list[int], noisy_sources:list[int],
|
|
13
|
+
batch_size_clean:int, emlc_k:int=1, world_size:int=1, global_rank:int=0):
|
|
14
|
+
self.df = df.with_row_index('_row_idx')
|
|
15
|
+
self.rng = rng
|
|
16
|
+
self.emlc_k = emlc_k
|
|
17
|
+
self.clean_sources = clean_sources
|
|
18
|
+
self.noisy_sources = noisy_sources
|
|
19
|
+
self.batch_size_clean = batch_size_clean
|
|
20
|
+
self.batch_size_noisy = batch_size_clean * emlc_k
|
|
21
|
+
self.batch_count = min(df.select(pl.col('source').is_in(clean_sources)).sum()[0,0] // (self.batch_size_clean * world_size),
|
|
22
|
+
df.select(pl.col.source.is_in(noisy_sources)).sum()[0,0] // (self.batch_size_noisy * world_size))
|
|
23
|
+
self.global_rank = global_rank
|
|
24
|
+
self.world_size = world_size
|
|
25
|
+
|
|
26
|
+
def __iter__(self):
|
|
27
|
+
ix_clean = self.df.filter(pl.col.source.is_in(self.clean_sources)).select('_row_idx').to_numpy().flatten()
|
|
28
|
+
self.rng.shuffle(ix_clean)
|
|
29
|
+
ix_noisy = self.df.filter(pl.col.source.is_in(self.noisy_sources)).select('_row_idx').to_numpy().flatten()
|
|
30
|
+
self.rng.shuffle(ix_noisy)
|
|
31
|
+
if log.isEnabledFor(logging.DEBUG):
|
|
32
|
+
log.debug(f"Clean index:\n{ix_clean[:100]}")
|
|
33
|
+
log.debug(f"Noisy index:\n{ix_noisy[:100]}")
|
|
34
|
+
ix_clean = ix_clean[self.global_rank::self.world_size]
|
|
35
|
+
ix_noisy = ix_noisy[self.global_rank::self.world_size]
|
|
36
|
+
log.debug(f"{self.batch_count} batches per worker {self.global_rank}/{self.world_size}")
|
|
37
|
+
for i in range(self.batch_count):
|
|
38
|
+
out = np.concatenate([ix_clean[i*self.batch_size_clean:(i+1)*self.batch_size_clean],
|
|
39
|
+
ix_noisy[i*self.batch_size_noisy:(i+1)*self.batch_size_noisy]])
|
|
40
|
+
log.debug(f"Batch {i}/{self.batch_count}: {out}")
|
|
41
|
+
yield out
|
|
42
|
+
|
|
43
|
+
def __len__(self):
|
|
44
|
+
return self.batch_count
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EmlcLDM(data.RP3SequenceLDM):
|
|
48
|
+
def __init__(self, hypers) -> None:
|
|
49
|
+
super().__init__(hypers)
|
|
50
|
+
log.info("EmlcLDM init")
|
|
51
|
+
clean_sources = self.hypers.clean_sources
|
|
52
|
+
self.emlc_k = int(self.hypers.emlc_k)
|
|
53
|
+
self.clean_sources = [self.sources_map[s] for s in clean_sources]
|
|
54
|
+
|
|
55
|
+
def create_torch_datasets(self):
|
|
56
|
+
self.train_ds = self.torch_dataset(self.df_train, 'train')
|
|
57
|
+
self.val_ds = self.torch_dataset(self.df_val, "val")
|
|
58
|
+
df_val_train = (self.df_train
|
|
59
|
+
.filter(pl.col.source.is_in(self.clean_sources))
|
|
60
|
+
.sample(len(self.val_ds), with_replacement=False, seed=self.hypers.test_val_seed)
|
|
61
|
+
)
|
|
62
|
+
self.val_train_ds = self.torch_dataset(df_val_train, "val-training")
|
|
63
|
+
self.test_ds = self.torch_dataset(self.df_test, "test")
|
|
64
|
+
|
|
65
|
+
def setup(self, stage: str) -> None:
|
|
66
|
+
log.debug("EmlcLDM setup")
|
|
67
|
+
if self.trainer is not None:
|
|
68
|
+
assert self.trainer.model.sources == self.hypers.sources
|
|
69
|
+
assert self.trainer.model.emlc_k == self.emlc_k
|
|
70
|
+
df = self.load_df()
|
|
71
|
+
self.df_train = df.filter(
|
|
72
|
+
pl.col('ds_type').is_not_null() &
|
|
73
|
+
pl.col('ds_type').is_in(['TEST', self.validation_slice]).not_()
|
|
74
|
+
)
|
|
75
|
+
assert self.df_train.shape[0] > 0, f"No training data for slice {self.validation_slice}"
|
|
76
|
+
self.df_val = df.filter(pl.col('ds_type') == self.validation_slice, pl.col.source.is_in(self.clean_sources))
|
|
77
|
+
assert self.df_val.shape[0] > 0, f"No validation data for slice {self.validation_slice}"
|
|
78
|
+
self.df_test = df.filter(pl.col('ds_type') == 'TEST', pl.col.source.is_in(self.clean_sources))
|
|
79
|
+
assert self.df_test.shape[0] > 0, f"No test data for slice {self.validation_slice}"
|
|
80
|
+
self.create_torch_datasets()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def train_dataloader(self):
|
|
84
|
+
log.debug("EmlcCVLDM train_dataloader")
|
|
85
|
+
batch_size = self.get_batch_size('training')
|
|
86
|
+
noisy_sources = [self.sources_map[s] for s in self.sources_map if self.sources_map[s] not in self.clean_sources]
|
|
87
|
+
world_size = self.trainer.world_size
|
|
88
|
+
global_rank = self.trainer.global_rank
|
|
89
|
+
batch_sampler = EmlcBatchSampler(df=self.df_train, rng=self.rng,
|
|
90
|
+
clean_sources=self.clean_sources, noisy_sources=noisy_sources,
|
|
91
|
+
batch_size_clean=batch_size, emlc_k=self.emlc_k,
|
|
92
|
+
world_size=world_size, global_rank=global_rank)
|
|
93
|
+
return torch_data.DataLoader(self.train_ds, batch_sampler=batch_sampler, num_workers=0, pin_memory=True, collate_fn=self.get_collate_fn())
|
|
94
|
+
|
RP3Net/training/lm.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import lightning as L
|
|
2
|
+
import lightning.pytorch.utilities as L_util
|
|
3
|
+
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch
|
|
6
|
+
import ml_collections as mlc
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
from . import metrics
|
|
11
|
+
from .. import util
|
|
12
|
+
from .. import model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
log = util.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
class RP3LM(L.LightningModule):
|
|
18
|
+
def __init__(self, hypers) -> None:
|
|
19
|
+
super().__init__()
|
|
20
|
+
log.debug("Lightning module init")
|
|
21
|
+
self._hypers_prefix = 'model'
|
|
22
|
+
self.save_hyperparameters({'model': hypers})
|
|
23
|
+
self.hypers = mlc.ConfigDict(self.hparams.model)
|
|
24
|
+
self.sources = self.hypers.sources
|
|
25
|
+
self.sources_map = {s:i for i, s in enumerate(self.sources)}
|
|
26
|
+
log.info(f"Sources: {self.sources}")
|
|
27
|
+
self.metrics = metrics.ClassificationMetricContainer.create_classification_metrics(self.sources, 2)
|
|
28
|
+
self.loss = nn.CrossEntropyLoss()
|
|
29
|
+
log.info(f"Loss: {self.loss}")
|
|
30
|
+
self.model: model.RP3Net = model.load_model(self.hypers.model)
|
|
31
|
+
log.info(f"Model: {self.model}")
|
|
32
|
+
|
|
33
|
+
def setup(self, stage):
|
|
34
|
+
if stage == 'fit':
|
|
35
|
+
assert self.model.mode in model.Mode_Training, "Model must be in training mode"
|
|
36
|
+
|
|
37
|
+
def force_train_on_fit_start(self):
|
|
38
|
+
"""
|
|
39
|
+
Need this, because loading a pre-trained HF model calls .eval() under the hood,
|
|
40
|
+
and PL preserves the state of training flags on modules when switching back from eval to train.
|
|
41
|
+
"""
|
|
42
|
+
self.model.train()
|
|
43
|
+
|
|
44
|
+
def on_fit_start(self) -> None:
|
|
45
|
+
self.force_train_on_fit_start()
|
|
46
|
+
|
|
47
|
+
def forward(self, batch):
|
|
48
|
+
return self.model(batch)
|
|
49
|
+
|
|
50
|
+
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
|
51
|
+
logits = self.model(batch)
|
|
52
|
+
return torch.argmax(logits, dim=1)
|
|
53
|
+
|
|
54
|
+
def training_step(self, batch, batch_idx):
|
|
55
|
+
log.debug(f"Training batch ids: {batch['idx']}")
|
|
56
|
+
logits = self(batch)
|
|
57
|
+
loss = self.loss(logits, batch['yield_binary'])
|
|
58
|
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch['yield_binary'].shape[0])
|
|
59
|
+
return loss
|
|
60
|
+
|
|
61
|
+
def validation_step(self, batch, batch_idx, dataloader_idx):
|
|
62
|
+
ids = batch['idx']
|
|
63
|
+
log.debug(f"Validation batch ids for dataloader index {dataloader_idx}: {ids}")
|
|
64
|
+
logits = self(batch)
|
|
65
|
+
if dataloader_idx == 0:
|
|
66
|
+
self.metrics.update_train(logits, batch)
|
|
67
|
+
elif dataloader_idx == 1:
|
|
68
|
+
self.metrics.update_val(logits, batch)
|
|
69
|
+
else:
|
|
70
|
+
raise RuntimeError(f"Unknown dataloader index: {dataloader_idx}")
|
|
71
|
+
|
|
72
|
+
def test_step(self, batch, batch_idx):
|
|
73
|
+
log.debug(f"Test batch index: {batch['idx']}")
|
|
74
|
+
logits = self(batch)
|
|
75
|
+
self.metrics.update_test(logits, batch)
|
|
76
|
+
|
|
77
|
+
@L_util.rank_zero_only
|
|
78
|
+
def write_results_df(self, filename:os.PathLike, ids:torch.Tensor, logits:torch.Tensor):
|
|
79
|
+
proba = torch.softmax(logits, axis=1).cpu().numpy()
|
|
80
|
+
y_hat = proba.argmax(axis=1)
|
|
81
|
+
df = pd.DataFrame({'id': ids.to(dtype=torch.int32, device='cpu').numpy(), 'y_hat': y_hat})
|
|
82
|
+
df_logits = pd.DataFrame(logits.cpu().numpy(), columns=[f'logit_{i}' for i in range(logits.shape[1])])
|
|
83
|
+
df_proba = pd.DataFrame(proba, columns=[f'prob_{i}' for i in range(proba.shape[1])])
|
|
84
|
+
df = pd.concat([df, df_logits, df_proba], axis=1)
|
|
85
|
+
df.to_csv(filename, index=False)
|
|
86
|
+
|
|
87
|
+
def on_validation_epoch_end(self) -> None:
|
|
88
|
+
log.info(f"Validation epoch {self.current_epoch} end.")
|
|
89
|
+
train_log_dict = self.metrics.compute_train_dict()
|
|
90
|
+
if not self.trainer.sanity_checking:
|
|
91
|
+
self.log_dict(train_log_dict, on_epoch=True, add_dataloader_idx=False, sync_dist=True)
|
|
92
|
+
|
|
93
|
+
val_log_dict = self.metrics.compute_val_dict()
|
|
94
|
+
if not self.trainer.sanity_checking:
|
|
95
|
+
self.log_dict(val_log_dict, on_epoch=True, add_dataloader_idx=False, sync_dist=True)
|
|
96
|
+
|
|
97
|
+
train_df_file = util.resolve(self.trainer.default_root_dir) / f"train_df_{self.current_epoch}.csv.gz"
|
|
98
|
+
train_ids, train_logits = self.metrics.train_curve()
|
|
99
|
+
if isinstance(train_logits, torch.Tensor) and train_logits.shape[0] > 0 and not self.trainer.sanity_checking:
|
|
100
|
+
log.info(f"Writing training results for epoch {self.current_epoch} to {train_df_file}")
|
|
101
|
+
self.write_results_df(train_df_file, train_ids, train_logits)
|
|
102
|
+
|
|
103
|
+
val_df_file = util.resolve(self.trainer.default_root_dir) / f"val_df_{self.current_epoch}.csv.gz"
|
|
104
|
+
val_ids, val_logits = self.metrics.val_curve()
|
|
105
|
+
if isinstance(val_logits, torch.Tensor) and val_logits.shape[0] > 0 and not self.trainer.sanity_checking:
|
|
106
|
+
log.info(f"Writing validation results for epoch {self.current_epoch} to {val_df_file}")
|
|
107
|
+
self.write_results_df(val_df_file, val_ids, val_logits)
|
|
108
|
+
self.metrics.reset()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def on_test_epoch_end(self) -> None:
|
|
112
|
+
test_log_dict = self.metrics.compute_test_dict()
|
|
113
|
+
self.log_dict(test_log_dict, on_epoch=True, add_dataloader_idx=False)
|
|
114
|
+
test_df_file = util.resolve(self.trainer.default_root_dir) / f"test_df.csv.gz"
|
|
115
|
+
test_ids, test_logits = self.metrics.test_curve()
|
|
116
|
+
if isinstance(test_logits, torch.Tensor) and test_logits.shape[0] > 0:
|
|
117
|
+
log.info(f"Writing test results to {test_df_file}")
|
|
118
|
+
self.write_results_df(test_df_file, test_ids, test_logits)
|
|
119
|
+
self.metrics.reset()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
|