RP3Net 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
RP3Net/training/cli.py ADDED
@@ -0,0 +1,166 @@
1
+ import re
2
+ import os
3
+ import socket
4
+ import sys
5
+ import logging
6
+ import lightning.pytorch as L
7
+ import lightning.pytorch.utilities as L_util
8
+ import lightning.pytorch.cli as L_cli
9
+ import lightning.pytorch.callbacks as L_cb
10
+ import lightning.pytorch.loggers as L_log
11
+ import wandb
12
+
13
+ from . import lm
14
+ from .. import util
15
+
16
+ log = util.get_logger(__name__)
17
+
18
+ def setup_logging_torch(args):
19
+ log_level = args.log_level
20
+ os.environ["PP_LOG_LEVEL"] = log_level
21
+ logfile = args.logfile
22
+ if logfile is not None:
23
+ logfile_base = util.resolve(logfile)
24
+ logfile_base = str(logfile_base.parent/logfile_base.stem)
25
+ os.environ["PP_LOGFILE_BASE"] = logfile_base
26
+ # /homes/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py
27
+ # SLURMEnvironment.world_size() and SLURMEnvironment.global_rank()
28
+ if int(os.environ.get("SLURM_NTASKS", "0")) > 1:
29
+ logfile = logfile.replace(".log", f"_{os.environ.get('SLURM_PROCID', '0')}.log")
30
+ util.setup_logging(logfile, log_level, log_console=logfile is None)
31
+ ll = logging.getLogger("lightning")
32
+ ll.propagate = True
33
+ ll.handlers.clear()
34
+ ll = logging.getLogger("lightning.pytorch")
35
+ ll.handlers.clear()
36
+ ll.propagate = True
37
+ log.info(f"Host: {socket.gethostname()}; PID: {os.getpid()}; Command line: {' '.join(sys.argv)}")
38
+
39
+ class RP3Cli(L_cli.LightningCLI):
40
+ def __init__(self, *args, **kwargs):
41
+ self.wandb_logger = None
42
+ super().__init__(*args, **{**kwargs, 'save_config_kwargs': {"overwrite": True}})
43
+
44
+ @staticmethod
45
+ @L_util.rank_zero_only
46
+ def wandb_init(wandb_project, wandb_run_name, wandb_run_id):
47
+ return wandb.init(id=wandb_run_id, project=wandb_project, name=wandb_run_name, resume='allow')
48
+
49
+ def wandb_logger_init(self, config):
50
+ if 'wandb' not in config or config.wandb is None or \
51
+ 'project' not in config.wandb or ('run' not in config.wandb and 'run_id' not in config.wandb) or \
52
+ config.wandb.project is None or (config.wandb.run is None and config.wandb.run_id is None) or \
53
+ ('disable' in config.wandb and config.wandb.disable):
54
+ log.info("No wandb logging")
55
+ return
56
+ else:
57
+ log.info("Configure wandb logging")
58
+ run_id = config.wandb.run_id if 'run_id' in config.wandb and config.wandb.run_id is not None \
59
+ else config.wandb.run
60
+ run = self.wandb_init(wandb_project=config.wandb.project, wandb_run_name=config.wandb.run, wandb_run_id=run_id)
61
+ if run is not None:
62
+ log.info(f"Wandb run: {run.name}({run.id})")
63
+ logger = L_log.WandbLogger(project=config.wandb.project, name=config.wandb.run, id=run_id)
64
+ return logger
65
+
66
+ def add_arguments_to_parser(self, parser: L_cli.LightningArgumentParser) -> None:
67
+ parser.add_argument("--logfile", help="Log file. Log output to console if set to None.")
68
+ parser.add_argument("--log_level", default="info",
69
+ help="Log level of root logger. Appender levels are appropriately hard coded.")
70
+ parser.add_argument("--track_metric_checkpoints", choices=["last", "all", "best"],
71
+ help="""
72
+ Track checkpoints for training and validation metrics from the module.
73
+ If not provided, no checkpoints will be recorded at all.
74
+ If any value is provided, only the best checkpoint will be recorded for all the metrics specified by `model.metrics_for_checkpointing()`.
75
+ The value of this argument affects how checkpoints will be saved for 'train_loss'.
76
+ """)
77
+ parser.add_argument("--wandb.project", help="Wandb project name", default=None)
78
+ parser.add_argument("--wandb.run", help="Wandb run name", default=None)
79
+ parser.add_argument("--wandb.run_id", help="Wandb run id, same as name by default", default=None)
80
+ parser.add_argument("--wandb.disable", help="Set to true to turn off wandb logging, without removing the rest of wandb settings", action='store_true')
81
+ parser.add_argument("--test_after_fit_metric", help="Metric to use for test_after_fit. If not set, then do not run test_after_fit", default=None)
82
+ parser.add_argument("--emlc_k", help="Number of student gradinent steps to perform per teacher step for EMLC", default=1, type=int)
83
+
84
+ def before_instantiate_classes(self) -> None:
85
+ config = self.config.get(str(self.subcommand), self.config)
86
+ setup_logging_torch(config)
87
+ self.wandb_logger = self.wandb_logger_init(config)
88
+
89
+
90
+ def instantiate_trainer(self, **kwargs) -> L.Trainer:
91
+ log.info("Instantiating trainer")
92
+ config = self.config.get(str(self.subcommand), self.config)
93
+ metric_checkpoints = self._get(self.config_init, 'track_metric_checkpoints')
94
+ # metric_checkpoints = bool(metric_checkpoints)
95
+ if metric_checkpoints is not None:
96
+ self.init_metric_checkpoints(metric_checkpoints)
97
+ self.add_loggers(config)
98
+ return super().instantiate_trainer(**kwargs)
99
+
100
+ def add_loggers(self, config):
101
+ configured_loggers = self._get(self.config_init, 'trainer.logger', default=[])
102
+ if configured_loggers == True or configured_loggers is None:
103
+ configured_loggers = []
104
+ elif isinstance(configured_loggers, L_log.Logger):
105
+ configured_loggers = [configured_loggers]
106
+ elif configured_loggers == False:
107
+ return
108
+ add_csv_logger = True
109
+ for logger in configured_loggers:
110
+ if isinstance(logger, L_log.CSVLogger):
111
+ add_csv_logger = False
112
+ break
113
+ if add_csv_logger:
114
+ csv_logger = L_log.CSVLogger(config['trainer']['default_root_dir'])
115
+ configured_loggers.append(csv_logger)
116
+ if self.wandb_logger is not None:
117
+ configured_loggers.append(self.wandb_logger)
118
+ config_init = self.config_init.get(str(self.subcommand), self.config_init)
119
+ config_init['trainer']['logger'] = configured_loggers
120
+
121
+ def init_metric_checkpoints(self, checkpoint_save_flag):
122
+ model: lm.RP3LM = self.model
123
+ metrics = model.metrics.metrics_for_checkpointing()
124
+ default_root_dir = util.resolve(self._get(self.config, 'trainer.default_root_dir'))
125
+ if metrics is None:
126
+ return
127
+ trainer_config = self._get(self.config_init, 'trainer', default={})
128
+ if 'callbacks' not in trainer_config or trainer_config['callbacks'] is None:
129
+ callbacks = []
130
+ trainer_config['callbacks'] = callbacks
131
+ else:
132
+ callbacks = trainer_config['callbacks']
133
+ if isinstance(callbacks, L.Callback):
134
+ callbacks = [callbacks]
135
+ trainer_config['callbacks'] = callbacks
136
+ for i, (key, metric) in enumerate(metrics.items()):
137
+ mode = 'max' if metric.higher_is_better else 'min'
138
+ callbacks.append(L_cb.ModelCheckpoint(
139
+ dirpath=default_root_dir, monitor=key, mode=mode,
140
+ filename='{epoch}_{'+key+':.2f}',
141
+ save_on_train_epoch_end=False, save_weights_only=False
142
+ ))
143
+ assert checkpoint_save_flag in ['last', 'all', 'best']
144
+ if checkpoint_save_flag == 'best':
145
+ save_top_k = 1
146
+ elif checkpoint_save_flag == 'all':
147
+ save_top_k = -1
148
+ else:
149
+ save_top_k = 0
150
+ callbacks.append(L_cb.ModelCheckpoint(dirpath=default_root_dir, monitor='train_loss', filename='{epoch}_{train_loss:.2f}', mode='min',
151
+ save_top_k=save_top_k, save_last=True))
152
+
153
+ def test_after_fit(self, metric):
154
+ model: lm.RP3LM = self.model
155
+ dm = self.datamodule
156
+ dir = util.resolve(self.trainer.default_root_dir)
157
+ filename_pattern = re.compile(r'^epoch=\d+_' + metric + r'=\d+(\.\d+)?\.ckpt$')
158
+ cp_file = util.find_checkpoint_file(dir, filename_pattern)
159
+ log.info(f"Loading checkpoint {cp_file}")
160
+ self.trainer.test(model, dm, ckpt_path=str(cp_file))
161
+
162
+ def after_fit(self):
163
+ metric = self._get(self.config, 'test_after_fit_metric')
164
+ if metric is not None:
165
+ log.info(f"Running test_after_fit on {metric}")
166
+ self.test_after_fit(metric)
@@ -0,0 +1,300 @@
1
+ import os
2
+ import typing
3
+ import functools
4
+ import zipfile
5
+
6
+ import torch
7
+ import torch.utils.data as torch_data
8
+ import polars as pl
9
+ import numpy as np
10
+ import lightning as L
11
+ import ml_collections as mlc
12
+
13
+ from .lm import RP3LM
14
+ from .. import util
15
+ from .. import model
16
+
17
+ log = util.get_logger(__name__)
18
+
19
+ FULL_DF_DTYPE_PL = {'created_at': pl.Datetime(), 'source': pl.Categorical(), 'sub_source': pl.Categorical(),
20
+ 'no_tags_cluster_40_id': pl.String(), 'with_tags_cluster_90_id': pl.String(),
21
+ 'has_dna': pl.Boolean(),
22
+ 'experiment_id': pl.String(),
23
+ 'yield_binary': pl.Boolean(), 'yield_cat': pl.Int64(),
24
+ 'host': pl.Categorical(), 'exp_outcome': pl.Categorical(),
25
+ 'id': pl.String(), 'fasta_id': pl.String(), 'dna_fasta_id': pl.String(), 'fasta_id_no_tags': pl.String(),
26
+ 'ds_type': pl.Categorical(),
27
+ 'n_tags_end': pl.Int64(), 'c_tags_start': pl.Int64(), 'n_fragments': pl.Int64(), 'unique_target_count': pl.Int64(),
28
+ 'fasta_id_no_tags': pl.String(),
29
+ 'uniprot_id': pl.String(), 'gene_id': pl.String(), 'taxon_id': pl.Int64(),
30
+ }
31
+
32
+ def read_full_df_pl(path:str|os.PathLike, **kwargs) -> pl.DataFrame:
33
+ _schema = FULL_DF_DTYPE_PL if 'schema_overrides' not in kwargs else kwargs['schema_overrides']
34
+ kwargs.pop('schema_overrides', None)
35
+ df = pl.read_csv(path, schema_overrides=_schema, **kwargs)
36
+ return df
37
+
38
+ def load_global_embeddings_file(embeddings_file:os.PathLike) -> typing.Mapping:
39
+ log.info(f"Loading global embeddings from {embeddings_file}")
40
+ embeddings_data = torch.load(embeddings_file)
41
+ ids = embeddings_data['ids']
42
+ embeddings = embeddings_data['embeddings']
43
+ return {id: embeddings[i] for i, id in enumerate(ids)}
44
+
45
+ class RP3GlobalEmbeddingsDataSet(torch_data.Dataset):
46
+ def __init__(self, df: pl.DataFrame, prefix:str, embeddings:typing.Mapping[str, torch.Tensor]) -> None:
47
+ super().__init__()
48
+ self.df = df
49
+ self.prefix=prefix
50
+ self.embeddings = embeddings
51
+
52
+ def __len__(self):
53
+ return self.df.shape[0]
54
+
55
+ def __getitem__(self, idx):
56
+ try:
57
+ row = self.df.row(idx, named=True)
58
+ ret = {
59
+ 'idx': row['src_idx'],
60
+ 'yield_binary': int(row['yield_binary']),
61
+ 'source': row['source'],
62
+ 'embeddings': self.embeddings[row['id']]
63
+ }
64
+ log.debug(f"{self.prefix}: (torch={idx}, csv={ret['idx']}){row['id']}: {ret['yield_binary']}")
65
+ return ret
66
+ except Exception as e:
67
+ log.error(f"Top level catch in {self.prefix} __getitem__", exc_info=e)
68
+ raise e
69
+
70
+
71
+ class RP3SequenceEmbeddingsDataSet(torch_data.Dataset):
72
+ def __init__(self, df: pl.DataFrame, prefix:str, embeddings:zipfile.ZipFile) -> None:
73
+ super().__init__()
74
+ self.df = df
75
+ self.prefix=prefix
76
+ self.embeddings:zipfile.ZipFile = embeddings
77
+
78
+ def __len__(self):
79
+ return self.df.shape[0]
80
+
81
+ def __getitem__(self, idx):
82
+ try:
83
+ row = self.df.row(idx, named=True)
84
+ with self.embeddings.open(row['id'], 'r') as f:
85
+ emb = torch.load(f, weights_only=True)
86
+ ret = {
87
+ 'idx': row['src_idx'],
88
+ 'yield_binary': int(row['yield_binary']),
89
+ 'source': row['source'],
90
+ 'embeddings': emb
91
+ }
92
+ log.debug(f"{self.prefix}: (torch={idx}, csv={ret['idx']}){row['id']}: {ret['yield_binary']}")
93
+ return ret
94
+ except Exception as e:
95
+ log.error(f"Top level catch in {self.prefix} __getitem__", exc_info=e)
96
+ raise e
97
+
98
+ class RP3SequenceDataSet(torch_data.Dataset):
99
+ def __init__(self, df: pl.DataFrame, prefix, rng:np.random.Generator, max_seq_len:int=0):
100
+ super().__init__()
101
+ self.rng = rng
102
+ self.df = df
103
+ self.prefix=prefix
104
+ self.max_seq_len = max_seq_len
105
+
106
+ def seq_chunk(self, seq:str):
107
+ if self.max_seq_len == 0 or len(seq) <= self.max_seq_len:
108
+ log.debug(f"Not changing sequence of length {len(seq)}; max_seq_len={self.max_seq_len}")
109
+ return seq
110
+ start_idx = self.rng.integers(len(seq) - self.max_seq_len + 1)
111
+ end_idx = start_idx + self.max_seq_len
112
+ log.debug(f"Returning the {start_idx}:{end_idx} chunk from sequence of length {len(seq)}; max_seq_len={self.max_seq_len}")
113
+ return seq[start_idx:end_idx]
114
+
115
+ def __len__(self):
116
+ return self.df.shape[0]
117
+
118
+ def __getitem__(self, idx):
119
+ try:
120
+ row = self.df.row(idx, named=True)
121
+ ret = {
122
+ 'idx': row['src_idx'],
123
+ 'source': row['source'],
124
+ 'seq': self.seq_chunk(row['seq']),
125
+ 'yield_binary': int(row['yield_binary']),
126
+ }
127
+ log.debug(f"{self.prefix}: (torch={idx}, csv={ret['idx']}){row['id']}: {ret['yield_binary']}")
128
+ return ret
129
+ except Exception as e:
130
+ log.error(f"Top level catch in {self.prefix} __getitem__", exc_info=e)
131
+ raise e
132
+
133
+
134
+ class RP3LDM(L.LightningDataModule):
135
+ def __init__(self, hypers) -> None:
136
+ super().__init__()
137
+ log.debug("DataModule init")
138
+ self.save_hyperparameters({'data': hypers})
139
+ self.hypers = mlc.ConfigDict(self.hparams.data)
140
+ self.sources_map = {s: i for i, s in enumerate(self.hypers.sources)}
141
+ self.rng = np.random.default_rng(self.hypers.get('seed', None))
142
+ self.validation_slice = self.hypers.get('validation_slice', 'VALIDATION')
143
+
144
+ def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
145
+ raise NotImplemented()
146
+
147
+ def load_df(self) -> pl.DataFrame:
148
+ data_path = self.hypers.ds_path
149
+ log.info(f"Loading data from {data_path}; validation slice: {self.validation_slice}")
150
+ df = read_full_df_pl(data_path).with_row_index('src_idx')
151
+ df_sources = set(*df.select(pl.col('source').cast(pl.String).unique()))
152
+ for s in self.sources_map:
153
+ assert s in df_sources
154
+ df = (df
155
+ .filter(pl.col('source').is_in(self.hypers.sources))
156
+ .with_columns(pl.col('source').cast(pl.String).replace_strict(self.sources_map))
157
+ )
158
+ return df
159
+
160
+ def setup(self, stage: str) -> None:
161
+ log.debug("RP3LDM setup")
162
+ if self.trainer is not None:
163
+ assert self.trainer.model.sources == self.hypers.sources
164
+ df = self.load_df()
165
+ self.df_train = df.filter(
166
+ pl.col('ds_type').is_not_null() &
167
+ pl.col('ds_type').is_in(['TEST', self.validation_slice]).not_()
168
+ )
169
+ assert self.df_train.shape[0] > 0, f"No training data for slice {self.validation_slice}"
170
+ self.df_val = df.filter(ds_type=self.validation_slice)
171
+ assert self.df_val.shape[0] > 0, f"No validation data for slice {self.validation_slice}"
172
+ self.df_test = df.filter(ds_type='TEST')
173
+ assert self.df_test.shape[0] > 0, f"No test data for slice {self.validation_slice}"
174
+ self.create_torch_datasets()
175
+
176
+ def create_torch_datasets(self):
177
+ self.train_ds = self.torch_dataset(self.df_train, 'train')
178
+ self.val_ds = self.torch_dataset(self.df_val, "val")
179
+ df_val_train = self.df_train.sample(len(self.val_ds), with_replacement=False, seed=self.hypers.test_val_seed)
180
+ self.val_train_ds = self.torch_dataset(df_val_train, "val-training")
181
+ self.test_ds = self.torch_dataset(self.df_test, "test")
182
+
183
+ def get_collate_fn(self):
184
+ return None
185
+
186
+ def get_batch_size(self, key:str) -> int:
187
+ return int(self.hypers.get(f'{key}_batch_size', self.hypers.get('batch_size', -1)))
188
+
189
+ def train_dataloader(self):
190
+ batch_size = self.get_batch_size('training')
191
+ return torch_data.DataLoader(self.train_ds, batch_size=batch_size, collate_fn=self.get_collate_fn(), shuffle=True,
192
+ num_workers=0, pin_memory=True)
193
+
194
+ def _build_val_test_loader(self, ds):
195
+ batch_size = self.get_batch_size('val_test')
196
+ sampler = None
197
+ drop_last = False
198
+ if self.hypers.get('use_distributed_sampler', False) and util.is_distr_env():
199
+ sampler = torch_data.DistributedSampler(ds, drop_last=True, shuffle=False)
200
+ drop_last = True
201
+ dl = torch_data.DataLoader(ds, batch_size=batch_size, collate_fn=self.get_collate_fn(), sampler=sampler,
202
+ num_workers=0, pin_memory=True, drop_last=drop_last)
203
+ return dl
204
+
205
+ def val_dataloader(self):
206
+ train_dl = self._build_val_test_loader(self.val_train_ds)
207
+ val_dl = self._build_val_test_loader(self.val_ds)
208
+ return [train_dl, val_dl]
209
+
210
+ def test_dataloader(self):
211
+ return self._build_val_test_loader(self.test_ds)
212
+
213
+ class RP3GlobalEmbeddingsLDM(RP3LDM):
214
+ def __init__(self, hypers) -> None:
215
+ super().__init__(hypers)
216
+ self.embeddings = None
217
+
218
+ def load_df(self):
219
+ df = super().load_df()
220
+ embeddings_file = self.hypers.embeddings_file
221
+ if embeddings_file == 'onehot' or embeddings_file.startswith('random_'):
222
+ seqs = util.read_fasta(self.hypers.fasta_path)
223
+ aa_to_int = {aa: i for i, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')}
224
+ if embeddings_file.startswith('random_'):
225
+ emb = torch.nn.Embedding(20, int(embeddings_file[7:])).to('cpu').requires_grad_(False)
226
+ self.embeddings = dict()
227
+ for row in df.select('id', self.hypers.fasta_key).iter_rows():
228
+ seq = seqs[row[1]]
229
+ seq_tz = torch.tensor([aa_to_int[aa] for aa in seq], dtype=torch.int64)
230
+ if embeddings_file == 'onehot':
231
+ seq_enc = torch.nn.functional.one_hot(seq_tz, num_classes=len(aa_to_int)).to(dtype=torch.float32)
232
+ else:
233
+ seq_enc = emb(seq_tz)
234
+ self.embeddings[row[0]] = seq_enc.mean(0)
235
+ else:
236
+ embeddings_file = util.resolve(self.hypers.embeddings_file)
237
+ self.embeddings = load_global_embeddings_file(embeddings_file)
238
+ return df
239
+
240
+ def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
241
+ return RP3GlobalEmbeddingsDataSet(df, prefix, self.embeddings)
242
+
243
+ class RP3SequenceEmbeddingsLDM(RP3LDM):
244
+ def __init__(self, hypers) -> None:
245
+ super().__init__(hypers)
246
+ self.embeddings_file = None
247
+
248
+ def load_df(self):
249
+ df = super().load_df()
250
+ self.embeddings_file = zipfile.ZipFile(util.resolve(self.hypers.embeddings_file), 'r')
251
+ return df
252
+
253
+ @staticmethod
254
+ def collate(batch):
255
+ embeddings = [b.pop('embeddings') for b in batch]
256
+ ret = torch_data.default_collate(batch)
257
+ emb_len = torch.tensor([e.shape[0] for e in embeddings])
258
+ max_len = emb_len.max()
259
+ emb_padded = torch.stack([torch.nn.functional.pad(e, (0,0,0, max_len - e.shape[0]), value=0) for e in embeddings])
260
+ attn_mask = torch.zeros((emb_len.shape[0], max_len), dtype=torch.int32)
261
+ for i, l in enumerate(emb_len):
262
+ attn_mask[i, :l] = 1
263
+ ret['embeddings'] = emb_padded
264
+ ret['attention_mask'] = attn_mask
265
+ return ret
266
+
267
+ def get_collate_fn(self):
268
+ return RP3SequenceEmbeddingsLDM.collate
269
+
270
+ def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
271
+ return RP3SequenceEmbeddingsDataSet(df, prefix, self.embeddings_file)
272
+
273
+
274
+ class RP3SequenceLDM(RP3LDM):
275
+
276
+ def __init__(self, hypers) -> None:
277
+ super().__init__(hypers)
278
+
279
+ def load_df(self):
280
+ df = super().load_df()
281
+ log.info(f"Reading sequences from {self.hypers.fasta_path}")
282
+ fasta_map = util.read_fasta(self.hypers.fasta_path)
283
+ fasta_id_col = self.hypers.fasta_id_col
284
+ df = df.with_columns(seq=pl.col(fasta_id_col).replace_strict(fasta_map))
285
+ return df
286
+
287
+ def torch_dataset(self, df:pl.DataFrame, prefix:str) -> torch_data.Dataset:
288
+ return RP3SequenceDataSet(df, prefix, self.rng, self.hypers.get('max_seq_len', 0))
289
+
290
+ @staticmethod
291
+ def collate(tokenizer:model.RP3Net, batch):
292
+ seqs = [b.pop('seq') for b in batch]
293
+ ret = torch_data.default_collate(batch)
294
+ ret['seq'] = tokenizer.tokenize_sequences(seqs)
295
+ return ret
296
+
297
+ def get_collate_fn(self):
298
+ lm: RP3LM = self.trainer.lightning_module
299
+ return functools.partial(RP3SequenceLDM.collate, lm.model)
300
+
@@ -0,0 +1,94 @@
1
+ import logging
2
+ import numpy as np
3
+ import polars as pl
4
+ import torch.utils.data as torch_data
5
+
6
+ from . import data
7
+ from .. import util
8
+
9
+ log = util.get_logger(__name__)
10
+
11
+ class EmlcBatchSampler(torch_data.Sampler):
12
+ def __init__(self, *, df:pl.DataFrame, rng:np.random.Generator, clean_sources:list[int], noisy_sources:list[int],
13
+ batch_size_clean:int, emlc_k:int=1, world_size:int=1, global_rank:int=0):
14
+ self.df = df.with_row_index('_row_idx')
15
+ self.rng = rng
16
+ self.emlc_k = emlc_k
17
+ self.clean_sources = clean_sources
18
+ self.noisy_sources = noisy_sources
19
+ self.batch_size_clean = batch_size_clean
20
+ self.batch_size_noisy = batch_size_clean * emlc_k
21
+ self.batch_count = min(df.select(pl.col('source').is_in(clean_sources)).sum()[0,0] // (self.batch_size_clean * world_size),
22
+ df.select(pl.col.source.is_in(noisy_sources)).sum()[0,0] // (self.batch_size_noisy * world_size))
23
+ self.global_rank = global_rank
24
+ self.world_size = world_size
25
+
26
+ def __iter__(self):
27
+ ix_clean = self.df.filter(pl.col.source.is_in(self.clean_sources)).select('_row_idx').to_numpy().flatten()
28
+ self.rng.shuffle(ix_clean)
29
+ ix_noisy = self.df.filter(pl.col.source.is_in(self.noisy_sources)).select('_row_idx').to_numpy().flatten()
30
+ self.rng.shuffle(ix_noisy)
31
+ if log.isEnabledFor(logging.DEBUG):
32
+ log.debug(f"Clean index:\n{ix_clean[:100]}")
33
+ log.debug(f"Noisy index:\n{ix_noisy[:100]}")
34
+ ix_clean = ix_clean[self.global_rank::self.world_size]
35
+ ix_noisy = ix_noisy[self.global_rank::self.world_size]
36
+ log.debug(f"{self.batch_count} batches per worker {self.global_rank}/{self.world_size}")
37
+ for i in range(self.batch_count):
38
+ out = np.concatenate([ix_clean[i*self.batch_size_clean:(i+1)*self.batch_size_clean],
39
+ ix_noisy[i*self.batch_size_noisy:(i+1)*self.batch_size_noisy]])
40
+ log.debug(f"Batch {i}/{self.batch_count}: {out}")
41
+ yield out
42
+
43
+ def __len__(self):
44
+ return self.batch_count
45
+
46
+
47
+ class EmlcLDM(data.RP3SequenceLDM):
48
+ def __init__(self, hypers) -> None:
49
+ super().__init__(hypers)
50
+ log.info("EmlcLDM init")
51
+ clean_sources = self.hypers.clean_sources
52
+ self.emlc_k = int(self.hypers.emlc_k)
53
+ self.clean_sources = [self.sources_map[s] for s in clean_sources]
54
+
55
+ def create_torch_datasets(self):
56
+ self.train_ds = self.torch_dataset(self.df_train, 'train')
57
+ self.val_ds = self.torch_dataset(self.df_val, "val")
58
+ df_val_train = (self.df_train
59
+ .filter(pl.col.source.is_in(self.clean_sources))
60
+ .sample(len(self.val_ds), with_replacement=False, seed=self.hypers.test_val_seed)
61
+ )
62
+ self.val_train_ds = self.torch_dataset(df_val_train, "val-training")
63
+ self.test_ds = self.torch_dataset(self.df_test, "test")
64
+
65
+ def setup(self, stage: str) -> None:
66
+ log.debug("EmlcLDM setup")
67
+ if self.trainer is not None:
68
+ assert self.trainer.model.sources == self.hypers.sources
69
+ assert self.trainer.model.emlc_k == self.emlc_k
70
+ df = self.load_df()
71
+ self.df_train = df.filter(
72
+ pl.col('ds_type').is_not_null() &
73
+ pl.col('ds_type').is_in(['TEST', self.validation_slice]).not_()
74
+ )
75
+ assert self.df_train.shape[0] > 0, f"No training data for slice {self.validation_slice}"
76
+ self.df_val = df.filter(pl.col('ds_type') == self.validation_slice, pl.col.source.is_in(self.clean_sources))
77
+ assert self.df_val.shape[0] > 0, f"No validation data for slice {self.validation_slice}"
78
+ self.df_test = df.filter(pl.col('ds_type') == 'TEST', pl.col.source.is_in(self.clean_sources))
79
+ assert self.df_test.shape[0] > 0, f"No test data for slice {self.validation_slice}"
80
+ self.create_torch_datasets()
81
+
82
+
83
+ def train_dataloader(self):
84
+ log.debug("EmlcCVLDM train_dataloader")
85
+ batch_size = self.get_batch_size('training')
86
+ noisy_sources = [self.sources_map[s] for s in self.sources_map if self.sources_map[s] not in self.clean_sources]
87
+ world_size = self.trainer.world_size
88
+ global_rank = self.trainer.global_rank
89
+ batch_sampler = EmlcBatchSampler(df=self.df_train, rng=self.rng,
90
+ clean_sources=self.clean_sources, noisy_sources=noisy_sources,
91
+ batch_size_clean=batch_size, emlc_k=self.emlc_k,
92
+ world_size=world_size, global_rank=global_rank)
93
+ return torch_data.DataLoader(self.train_ds, batch_sampler=batch_sampler, num_workers=0, pin_memory=True, collate_fn=self.get_collate_fn())
94
+
RP3Net/training/lm.py ADDED
@@ -0,0 +1,123 @@
1
+ import lightning as L
2
+ import lightning.pytorch.utilities as L_util
3
+
4
+ import torch.nn as nn
5
+ import torch
6
+ import ml_collections as mlc
7
+ import pandas as pd
8
+ import os
9
+
10
+ from . import metrics
11
+ from .. import util
12
+ from .. import model
13
+
14
+
15
+ log = util.get_logger(__name__)
16
+
17
+ class RP3LM(L.LightningModule):
18
+ def __init__(self, hypers) -> None:
19
+ super().__init__()
20
+ log.debug("Lightning module init")
21
+ self._hypers_prefix = 'model'
22
+ self.save_hyperparameters({'model': hypers})
23
+ self.hypers = mlc.ConfigDict(self.hparams.model)
24
+ self.sources = self.hypers.sources
25
+ self.sources_map = {s:i for i, s in enumerate(self.sources)}
26
+ log.info(f"Sources: {self.sources}")
27
+ self.metrics = metrics.ClassificationMetricContainer.create_classification_metrics(self.sources, 2)
28
+ self.loss = nn.CrossEntropyLoss()
29
+ log.info(f"Loss: {self.loss}")
30
+ self.model: model.RP3Net = model.load_model(self.hypers.model)
31
+ log.info(f"Model: {self.model}")
32
+
33
+ def setup(self, stage):
34
+ if stage == 'fit':
35
+ assert self.model.mode in model.Mode_Training, "Model must be in training mode"
36
+
37
+ def force_train_on_fit_start(self):
38
+ """
39
+ Need this, because loading a pre-trained HF model calls .eval() under the hood,
40
+ and PL preserves the state of training flags on modules when switching back from eval to train.
41
+ """
42
+ self.model.train()
43
+
44
+ def on_fit_start(self) -> None:
45
+ self.force_train_on_fit_start()
46
+
47
+ def forward(self, batch):
48
+ return self.model(batch)
49
+
50
+ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
51
+ logits = self.model(batch)
52
+ return torch.argmax(logits, dim=1)
53
+
54
+ def training_step(self, batch, batch_idx):
55
+ log.debug(f"Training batch ids: {batch['idx']}")
56
+ logits = self(batch)
57
+ loss = self.loss(logits, batch['yield_binary'])
58
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch['yield_binary'].shape[0])
59
+ return loss
60
+
61
+ def validation_step(self, batch, batch_idx, dataloader_idx):
62
+ ids = batch['idx']
63
+ log.debug(f"Validation batch ids for dataloader index {dataloader_idx}: {ids}")
64
+ logits = self(batch)
65
+ if dataloader_idx == 0:
66
+ self.metrics.update_train(logits, batch)
67
+ elif dataloader_idx == 1:
68
+ self.metrics.update_val(logits, batch)
69
+ else:
70
+ raise RuntimeError(f"Unknown dataloader index: {dataloader_idx}")
71
+
72
+ def test_step(self, batch, batch_idx):
73
+ log.debug(f"Test batch index: {batch['idx']}")
74
+ logits = self(batch)
75
+ self.metrics.update_test(logits, batch)
76
+
77
+ @L_util.rank_zero_only
78
+ def write_results_df(self, filename:os.PathLike, ids:torch.Tensor, logits:torch.Tensor):
79
+ proba = torch.softmax(logits, axis=1).cpu().numpy()
80
+ y_hat = proba.argmax(axis=1)
81
+ df = pd.DataFrame({'id': ids.to(dtype=torch.int32, device='cpu').numpy(), 'y_hat': y_hat})
82
+ df_logits = pd.DataFrame(logits.cpu().numpy(), columns=[f'logit_{i}' for i in range(logits.shape[1])])
83
+ df_proba = pd.DataFrame(proba, columns=[f'prob_{i}' for i in range(proba.shape[1])])
84
+ df = pd.concat([df, df_logits, df_proba], axis=1)
85
+ df.to_csv(filename, index=False)
86
+
87
+ def on_validation_epoch_end(self) -> None:
88
+ log.info(f"Validation epoch {self.current_epoch} end.")
89
+ train_log_dict = self.metrics.compute_train_dict()
90
+ if not self.trainer.sanity_checking:
91
+ self.log_dict(train_log_dict, on_epoch=True, add_dataloader_idx=False, sync_dist=True)
92
+
93
+ val_log_dict = self.metrics.compute_val_dict()
94
+ if not self.trainer.sanity_checking:
95
+ self.log_dict(val_log_dict, on_epoch=True, add_dataloader_idx=False, sync_dist=True)
96
+
97
+ train_df_file = util.resolve(self.trainer.default_root_dir) / f"train_df_{self.current_epoch}.csv.gz"
98
+ train_ids, train_logits = self.metrics.train_curve()
99
+ if isinstance(train_logits, torch.Tensor) and train_logits.shape[0] > 0 and not self.trainer.sanity_checking:
100
+ log.info(f"Writing training results for epoch {self.current_epoch} to {train_df_file}")
101
+ self.write_results_df(train_df_file, train_ids, train_logits)
102
+
103
+ val_df_file = util.resolve(self.trainer.default_root_dir) / f"val_df_{self.current_epoch}.csv.gz"
104
+ val_ids, val_logits = self.metrics.val_curve()
105
+ if isinstance(val_logits, torch.Tensor) and val_logits.shape[0] > 0 and not self.trainer.sanity_checking:
106
+ log.info(f"Writing validation results for epoch {self.current_epoch} to {val_df_file}")
107
+ self.write_results_df(val_df_file, val_ids, val_logits)
108
+ self.metrics.reset()
109
+
110
+
111
+ def on_test_epoch_end(self) -> None:
112
+ test_log_dict = self.metrics.compute_test_dict()
113
+ self.log_dict(test_log_dict, on_epoch=True, add_dataloader_idx=False)
114
+ test_df_file = util.resolve(self.trainer.default_root_dir) / f"test_df.csv.gz"
115
+ test_ids, test_logits = self.metrics.test_curve()
116
+ if isinstance(test_logits, torch.Tensor) and test_logits.shape[0] > 0:
117
+ log.info(f"Writing test results to {test_df_file}")
118
+ self.write_results_df(test_df_file, test_ids, test_logits)
119
+ self.metrics.reset()
120
+
121
+
122
+
123
+