genhpf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genhpf might be problematic. Click here for more details.

Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +233 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +174 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +584 -0
  53. genhpf/scripts/test.py +261 -0
  54. genhpf/scripts/train.py +350 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.0.dist-info/LICENSE +21 -0
  63. genhpf-1.0.0.dist-info/METADATA +197 -0
  64. genhpf-1.0.0.dist-info/RECORD +67 -0
  65. genhpf-1.0.0.dist-info/WHEEL +5 -0
  66. genhpf-1.0.0.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.0.dist-info/top_level.txt +1 -0
genhpf/scripts/test.py ADDED
@@ -0,0 +1,261 @@
1
+ import logging
2
+ import os
3
+ import pprint
4
+ import sys
5
+ from itertools import chain
6
+
7
+ import polars as pl
8
+ import torch.distributed
9
+ import torch.utils.data
10
+
11
+ logging.basicConfig(
12
+ format="%(asctime)s | %(levelname)s %(name)s %(message)s)))",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
15
+ stream=sys.stdout,
16
+ )
17
+ logger = logging.getLogger("genhpf.test")
18
+
19
+ import hydra
20
+ import torch
21
+ from hydra.core.hydra_config import HydraConfig
22
+ from omegaconf import OmegaConf, open_dict
23
+
24
+ from genhpf import criterions, models
25
+ from genhpf.configs import Config
26
+ from genhpf.configs.initialize import add_defaults, hydra_init
27
+ from genhpf.datasets import load_dataset
28
+ from genhpf.loggings import metrics, progress_bar
29
+ from genhpf.utils import distributed_utils, utils
30
+
31
+
32
+ def main(cfg: Config) -> None:
33
+ assert (
34
+ cfg.checkpoint.load_checkpoint is not None
35
+ ), "Please specify the checkpoint to load with `checkpoint.load_checkpoint`"
36
+
37
+ assert cfg.dataset.batch_size is not None, "batch_size must be specified"
38
+ metrics.reset()
39
+
40
+ use_cuda = torch.cuda.is_available()
41
+
42
+ if use_cuda:
43
+ torch.cuda.set_device(cfg.distributed_training.device_id)
44
+
45
+ if cfg.distributed_training.distributed_world_size > 1:
46
+ data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
47
+ else:
48
+ data_parallel_world_size = 1
49
+
50
+ # print args
51
+ logger.info(pprint.pformat(dict(cfg)))
52
+
53
+ # load model
54
+ model = models.build_model(cfg.model)
55
+ logger.info(f"loading model from {cfg.checkpoint.load_checkpoint}")
56
+ model_state_dict = torch.load(cfg.checkpoint.load_checkpoint, map_location="cpu")["model"]
57
+ model.load_state_dict(model_state_dict, strict=True)
58
+ logger.info(f"loaded model from {cfg.checkpoint.load_checkpoint}")
59
+
60
+ logger.info(model)
61
+ logger.info(f"model: {model.__class__.__name__}")
62
+ logger.info(
63
+ "num. shared model params: {:,} (num. trained: {:,})".format(
64
+ sum(p.numel() for p in model.parameters()),
65
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
66
+ )
67
+ )
68
+
69
+ # Move model to GPU
70
+ model.eval()
71
+ if use_cuda:
72
+ model.cuda()
73
+
74
+ # build criterion
75
+ criterion = criterions.build_criterion(cfg.criterion)
76
+ criterion.eval()
77
+
78
+ def _fp_convert_sample(sample):
79
+ def apply_float(t):
80
+ if t.dtype in [torch.float64, torch.float32, torch.int16]:
81
+ return t.to(dtype=torch.float)
82
+ return t
83
+
84
+ sample = utils.apply_to_sample(apply_float, sample)
85
+
86
+ return sample
87
+
88
+ assert cfg.dataset.test_subset is not None, "Please specify the test subset with `dataset.test_subset`"
89
+ test_subsets = cfg.dataset.test_subset.split(",")
90
+
91
+ if cfg.dataset.combine_test_subsets:
92
+ datasets = [("combined-test", load_dataset(cfg.dataset.data, test_subsets, cfg))]
93
+ else:
94
+ datasets = [
95
+ (subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in test_subsets
96
+ ]
97
+
98
+ output_meds_predictions = False
99
+ if cfg.dataset.data_format == "meds" and cfg.meds.output_predictions:
100
+ if len(test_subsets) > 1:
101
+ raise NotImplementedError(
102
+ "MEDS dataset does not currently support multiple test subsets when `output_predictions` "
103
+ "is enabled. Please specify only one test subset."
104
+ )
105
+ if cfg.dataset.combine_test_subsets:
106
+ raise NotImplementedError(
107
+ "MEDS dataset does not currently support `combine_test_subsets` when `output_predictions` "
108
+ "is enabled. Please set `dataset.combine_test_subsets` to False."
109
+ )
110
+ if len(cfg.criterion.task_names) > 1:
111
+ raise NotImplementedError(
112
+ "MEDS dataset does not currently support multiple tasks when `output_predictions` "
113
+ "is enabled. Please specify only one task."
114
+ )
115
+ if cfg.criterion.num_labels[0] > 1:
116
+ raise NotImplementedError(
117
+ "MEDS dataset currently only supports binary classification when `output_predictions` "
118
+ "is enabled. Please specify only one label by setting `criterion.num_labels` to 1."
119
+ )
120
+
121
+ assert (
122
+ cfg.meds.labels_dir is not None and cfg.meds.output_dir is not None
123
+ ), "Please specify labels_dir and output_dir in the MEDS config to output predictions."
124
+ assert data_parallel_world_size == 1, (
125
+ "MEDS dataset does not currently support distributed testing when `output_predictions` "
126
+ "is enabled. Please set `distributed_training.distributed_world_size` to 1."
127
+ )
128
+
129
+ output_meds_predictions = True
130
+
131
+ labels = pl.read_parquet(os.path.join(cfg.meds.labels_dir, f"{test_subsets[0]}/*.parquet"))
132
+ labels = labels.sort(by=["subject_id", "prediction_time"])
133
+ labels = labels.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
134
+ labels = labels.with_columns(
135
+ pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str).alias("subject_id")
136
+ )
137
+ labels = labels.drop("suffix")
138
+ labels = labels.select(["subject_id", "prediction_time", "boolean_value"])
139
+
140
+ meds_pred_output = {
141
+ "subject_id": [],
142
+ "predicted_boolean_value": [],
143
+ "predicted_boolean_probability": [],
144
+ }
145
+
146
+ for subset, dataset in datasets:
147
+ logger.info(f"begin validation on '{subset}' subset")
148
+
149
+ # initialize data iterator
150
+ batch_sampler = (
151
+ torch.utils.data.DistributedSampler(dataset, shuffle=False)
152
+ if torch.distributed.is_initialized()
153
+ else None
154
+ )
155
+ batch_iterator = torch.utils.data.DataLoader(
156
+ dataset,
157
+ batch_size=cfg.dataset.batch_size,
158
+ shuffle=False,
159
+ num_workers=cfg.dataset.num_workers,
160
+ collate_fn=dataset.collator,
161
+ sampler=batch_sampler,
162
+ )
163
+
164
+ progress = progress_bar.progress_bar(
165
+ batch_iterator,
166
+ log_format=cfg.common.log_format,
167
+ log_interval=cfg.common.log_interval,
168
+ log_file=cfg.common.log_file,
169
+ epoch=0,
170
+ default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
171
+ wandb_project=(
172
+ cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
173
+ ),
174
+ wandb_entity=(
175
+ cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
176
+ ),
177
+ wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
178
+ )
179
+
180
+ log_outputs = []
181
+ for i, sample in enumerate(progress):
182
+ with torch.no_grad():
183
+ sample = utils.prepare_sample(sample)
184
+ sample = _fp_convert_sample(sample)
185
+ _loss, _sample_size, log_output, net_output = criterion(model, sample, return_net_output=True)
186
+ log_outputs.append(log_output)
187
+ if output_meds_predictions:
188
+ meds_pred_output["subject_id"].extend(sample["id"])
189
+ logits = model.get_logits(sample, net_output)
190
+ probs = torch.sigmoid(logits).view(-1).cpu()
191
+ meds_pred_output["predicted_boolean_probability"].extend(probs.tolist())
192
+ meds_pred_output["predicted_boolean_value"].extend(
193
+ (probs >= cfg.criterion.threshold).int().tolist()
194
+ )
195
+
196
+ if output_meds_predictions:
197
+ meds_pred_output = pl.DataFrame(meds_pred_output)
198
+ meds_pred_output = meds_pred_output.join(labels, on="subject_id", how="left")
199
+ meds_pred_output = meds_pred_output.select(
200
+ [
201
+ pl.col("subject_id"),
202
+ pl.col("prediction_time"),
203
+ pl.col("boolean_value"),
204
+ pl.col("predicted_boolean_value"),
205
+ pl.col("predicted_boolean_probability"),
206
+ ]
207
+ )
208
+ meds_pred_output = meds_pred_output.with_columns(
209
+ pl.col("subject_id").map_elements(lambda x: x.split("_")[0], return_dtype=pl.String).cast(int)
210
+ )
211
+ if not os.path.exists(cfg.meds.output_dir):
212
+ os.makedirs(cfg.meds.output_dir)
213
+ meds_pred_output.write_parquet(os.path.join(cfg.meds.output_dir, f"{subset}.parquet"))
214
+
215
+ if data_parallel_world_size > 1:
216
+ log_outputs = distributed_utils.all_gather_list(
217
+ log_outputs,
218
+ max_size=cfg.common.all_gather_list_size,
219
+ group=distributed_utils.get_data_parallel_group(),
220
+ )
221
+ log_outputs = list(chain.from_iterable(log_outputs))
222
+
223
+ with metrics.aggregate(new_root=True) as agg:
224
+ criterion.__class__.reduce_metrics(log_outputs)
225
+ del log_outputs
226
+ log_outputs = agg.get_smoothed_values()
227
+
228
+ if hasattr(criterion, "post_validate"):
229
+ stats = criterion.post_validate(stats=log_outputs, agg=agg)
230
+
231
+ progress.print(stats, tag=subset, step=None)
232
+
233
+
234
+ @hydra.main(config_path=os.path.join("..", "configs"), config_name="config")
235
+ def hydra_main(cfg: Config) -> None:
236
+ add_defaults(cfg)
237
+
238
+ with open_dict(cfg):
239
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
240
+ cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
241
+
242
+ cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
243
+ OmegaConf.set_struct(cfg, True)
244
+
245
+ distributed_utils.call_main(cfg, main)
246
+
247
+
248
+ def cli_main():
249
+ try:
250
+ from hydra._internal.utils import get_args
251
+
252
+ cfg_name = get_args().config_name or "config"
253
+ except Exception:
254
+ logger.warning("Failed to get config name from hydra args")
255
+ cfg_name = "config"
256
+ hydra_init(cfg_name)
257
+ hydra_main()
258
+
259
+
260
+ if __name__ == "__main__":
261
+ cli_main()
@@ -0,0 +1,350 @@
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import pprint
5
+ import random
6
+ import sys
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import torch.distributed
10
+
11
+ logging.basicConfig(
12
+ format="%(asctime)s | %(levelname)s %(name)s %(message)s)))",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
15
+ stream=sys.stdout,
16
+ )
17
+ logger = logging.getLogger("genhpf.train")
18
+
19
+ import hydra
20
+ import numpy as np
21
+ import torch
22
+ from hydra.core.hydra_config import HydraConfig
23
+ from omegaconf import OmegaConf, open_dict
24
+
25
+ from genhpf import criterions, models
26
+ from genhpf.configs import Config
27
+ from genhpf.configs.initialize import add_defaults, hydra_init
28
+ from genhpf.datasets import load_dataset
29
+ from genhpf.loggings import meters, metrics, progress_bar
30
+ from genhpf.trainer import Trainer
31
+ from genhpf.utils import checkpoint_utils, distributed_utils, utils
32
+
33
+
34
+ def main(cfg: Config) -> None:
35
+ if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
36
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
37
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
38
+
39
+ assert cfg.dataset.batch_size is not None, "batch_size must be specified"
40
+ metrics.reset()
41
+
42
+ np.random.seed(cfg.common.seed)
43
+ random.seed(cfg.common.seed)
44
+ utils.set_torch_seed(cfg.common.seed)
45
+
46
+ if distributed_utils.is_master(cfg.distributed_training):
47
+ checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
48
+
49
+ # print args
50
+ logger.info(pprint.pformat(dict(cfg)))
51
+
52
+ model = models.build_model(cfg.model)
53
+ if cfg.checkpoint.load_checkpoint is not None:
54
+ state_dict = torch.load(cfg.checkpoint.load_checkpoint, map_location="cpu")["model"]
55
+ model.load_state_dict(state_dict, strict=True)
56
+ logger.info(f"loaded model from {cfg.checkpoint.load_checkpoint}")
57
+ criterion = criterions.build_criterion(cfg.criterion)
58
+
59
+ logger.info(model)
60
+ logger.info(f"model: {model.__class__.__name__}")
61
+ logger.info(f"criterion: {criterion.__class__.__name__}")
62
+ logger.info(
63
+ "num. shared model params: {:,} (num. trained: {:,})".format(
64
+ sum(p.numel() for p in model.parameters()),
65
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
66
+ )
67
+ )
68
+
69
+ datasets = {}
70
+ train_subsets = cfg.dataset.train_subset.split(",")
71
+ if len(train_subsets) > 1:
72
+ assert (
73
+ cfg.dataset.combine_train_subsets
74
+ ), "train_subset contains multiple datasets, but combine_train_subsets is not set"
75
+ datasets["train"] = [("combined-train", load_dataset(cfg.dataset.data, train_subsets, cfg))]
76
+ else:
77
+ datasets["train"] = [(train_subsets[0].strip(), load_dataset(cfg.dataset.data, train_subsets, cfg))]
78
+
79
+ if not cfg.dataset.disable_validation and cfg.dataset.valid_subset is not None:
80
+ valid_subsets = cfg.dataset.valid_subset.split(",")
81
+ if cfg.dataset.combine_valid_subsets:
82
+ datasets["valid"] = [("combined-valid", load_dataset(cfg.dataset.data, valid_subsets, cfg))]
83
+ else:
84
+ datasets["valid"] = [
85
+ (subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in valid_subsets
86
+ ]
87
+ if cfg.dataset.test_subset is not None:
88
+ test_subsets = cfg.dataset.test_subset.split(",")
89
+ if cfg.dataset.combine_test_subsets:
90
+ datasets["test"] = [("combined-test", load_dataset(cfg.dataset.data, test_subsets, cfg))]
91
+ else:
92
+ datasets["test"] = [
93
+ (subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in test_subsets
94
+ ]
95
+
96
+ trainer = Trainer(cfg, model, criterion)
97
+
98
+ logger.info(f"training on {cfg.distributed_training.distributed_world_size} devices (GPUs)")
99
+ logger.info(f"batch size per device = {cfg.dataset.batch_size}")
100
+
101
+ max_epoch = cfg.optimization.max_epoch
102
+
103
+ train_meter = meters.StopwatchMeter()
104
+ train_meter.start()
105
+ for i in range(1, max_epoch + 1):
106
+ # train for one epoch
107
+ valid_losses, should_stop = train(cfg, trainer, datasets, i)
108
+ if should_stop:
109
+ break
110
+ train_meter.stop()
111
+ logger.info(f"done training in {train_meter.sum:.1f} seconds")
112
+
113
+
114
+ def should_stop_early(cfg: Config, valid_loss: float) -> bool:
115
+ # skip check if no validation was done in the current epoch
116
+ if valid_loss is None:
117
+ return False
118
+ if cfg.checkpoint.patience <= 0:
119
+ return False
120
+
121
+ def is_better(a, b):
122
+ return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
123
+
124
+ prev_best = getattr(should_stop_early, "best", None)
125
+ if prev_best is None or is_better(valid_loss, prev_best):
126
+ should_stop_early.best = valid_loss
127
+ should_stop_early.num_runs = 0
128
+ return False
129
+ else:
130
+ should_stop_early.num_runs += 1
131
+ if should_stop_early.num_runs >= cfg.checkpoint.patience:
132
+ logger.info(
133
+ f"early stop since valid performance hasn't improved for " f"{cfg.checkpoint.patience} runs"
134
+ )
135
+ return True
136
+ else:
137
+ return False
138
+
139
+
140
+ @metrics.aggregate("train")
141
+ def train(
142
+ cfg: Config,
143
+ trainer: Trainer,
144
+ datasets,
145
+ epoch: int,
146
+ ) -> Tuple[List[Optional[float]], bool]:
147
+ """Train the model for one epoch and return validation losses."""
148
+ # initialize data iterator
149
+ data_loader, batch_sampler = trainer.get_train_iterator(datasets["train"][0][1])
150
+ if batch_sampler is not None:
151
+ batch_sampler.set_epoch(epoch)
152
+
153
+ itr = iter(data_loader)
154
+ progress = progress_bar.progress_bar(
155
+ itr,
156
+ log_format=cfg.common.log_format,
157
+ log_file=cfg.common.log_file,
158
+ log_interval=cfg.common.log_interval,
159
+ epoch=epoch,
160
+ default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
161
+ wandb_project=(
162
+ cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
163
+ ),
164
+ wandb_entity=(
165
+ cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
166
+ ),
167
+ wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
168
+ )
169
+ progress.update_config(_flatten_config(cfg))
170
+
171
+ logger.info(f"begin training epoch {epoch}")
172
+
173
+ should_stop = False
174
+ num_updates = trainer.get_num_updates()
175
+ logger.info("Start iterating over samples")
176
+ for i, sample in enumerate(progress):
177
+ with metrics.aggregate("train_inner"):
178
+ log_output = trainer.train_step(sample)
179
+
180
+ if log_output is not None:
181
+ # log mid-epoch stats
182
+ num_updates = trainer.get_num_updates()
183
+ if num_updates % cfg.common.log_interval == 0:
184
+ stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
185
+ progress.log(stats, tag="train_inner", step=num_updates)
186
+
187
+ # reset mid-epoch stats after each log interval
188
+ # the end-of-epoch stats will still be preserved
189
+ metrics.reset_meters("train_inner")
190
+
191
+ valid_losses, should_stop = validate_and_save(cfg, trainer, datasets, epoch)
192
+
193
+ # log end-of-epoch stats
194
+ logger.info(f"end of epoch {epoch} (average epoch stats below)")
195
+ stats = get_training_stats(metrics.get_smoothed_values("train"))
196
+ progress.print(stats, tag="train", step=num_updates)
197
+
198
+ # reset epoch-level meters
199
+ metrics.reset_meters("train")
200
+ return valid_losses, should_stop
201
+
202
+
203
+ def validate_and_save(
204
+ cfg: Config,
205
+ trainer: Trainer,
206
+ datasets,
207
+ epoch: int,
208
+ ) -> Tuple[List[Optional[float]], bool]:
209
+ should_stop = False
210
+ if epoch >= cfg.optimization.max_epoch:
211
+ should_stop = True
212
+ logger.info(
213
+ "Stopping training due to " f"num_epochs: {epoch} >= max_epochs: {cfg.optimization.max_epoch}"
214
+ )
215
+
216
+ do_validate = "valid" in datasets or "test" in datasets
217
+
218
+ # validate
219
+ valid_losses = [None]
220
+ if do_validate:
221
+ valid_losses = validate(cfg, trainer, datasets, epoch)
222
+
223
+ should_stop |= should_stop_early(cfg, valid_losses[0])
224
+
225
+ checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch, valid_losses[0])
226
+ if torch.distributed.is_initialized():
227
+ torch.distributed.barrier()
228
+
229
+ return valid_losses, should_stop
230
+
231
+
232
+ def validate(
233
+ cfg: Config,
234
+ trainer: Trainer,
235
+ datasets,
236
+ epoch: int,
237
+ ):
238
+ """Evaluate the model on the validation set(s) and return the losses."""
239
+
240
+ valid_subsets = datasets.get("valid", [])
241
+ test_subsets = datasets.get("test", [])
242
+
243
+ valid_losses = []
244
+ for subset, dataset in valid_subsets + test_subsets:
245
+ logger.info(f"begin validation on '{subset}' subset")
246
+
247
+ # initialize data iterator
248
+ data_loader, _ = trainer.get_valid_iterator(dataset)
249
+ progress = progress_bar.progress_bar(
250
+ data_loader,
251
+ log_format=cfg.common.log_format,
252
+ log_interval=cfg.common.log_interval,
253
+ log_file=cfg.common.log_file,
254
+ epoch=epoch,
255
+ default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
256
+ wandb_project=(
257
+ cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
258
+ ),
259
+ wandb_entity=(
260
+ cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
261
+ ),
262
+ wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
263
+ )
264
+
265
+ # create a new root metrics aggregator so validation metrics
266
+ # don't pollute other aggregators (e.g., train meters)
267
+ with metrics.aggregate(new_root=True) as agg:
268
+ for i, sample in enumerate(progress):
269
+ trainer.valid_step(sample, subset=subset)
270
+
271
+ stats = agg.get_smoothed_values()
272
+
273
+ if hasattr(trainer.criterion, "post_validate"):
274
+ stats = trainer.criterion.post_validate(
275
+ stats=stats,
276
+ agg=agg,
277
+ )
278
+
279
+ # log validation stats
280
+ stats = get_valid_stats(cfg, trainer, subset, stats)
281
+
282
+ progress.print(stats, tag=subset, step=trainer.get_num_updates())
283
+
284
+ valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
285
+
286
+ return valid_losses
287
+
288
+
289
+ def get_training_stats(stats):
290
+ stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
291
+ return stats
292
+
293
+
294
+ def get_valid_stats(cfg: Config, trainer: Trainer, subset: str, stats: Dict[str, Any]) -> Dict[str, Any]:
295
+ stats["num_updates"] = trainer.get_num_updates()
296
+
297
+ if not hasattr(get_valid_stats, "best"):
298
+ get_valid_stats.best = dict()
299
+
300
+ prev_best = getattr(get_valid_stats, "best").get(subset, stats[cfg.checkpoint.best_checkpoint_metric])
301
+ best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
302
+ get_valid_stats.best[subset] = best_function(stats[cfg.checkpoint.best_checkpoint_metric], prev_best)
303
+
304
+ key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
305
+ stats[key] = get_valid_stats.best[subset]
306
+
307
+ return stats
308
+
309
+
310
+ def _flatten_config(cfg: Config):
311
+ config = OmegaConf.to_container(cfg)
312
+ # remove any legacy Namespaces and replace with a single "args"
313
+ namespace = None
314
+ for k, v in list(config.items()):
315
+ if isinstance(v, argparse.Namespace):
316
+ namespace = v
317
+ del config[k]
318
+ if namespace is not None:
319
+ config["args"] = vars(namespace)
320
+ return config
321
+
322
+
323
+ @hydra.main(config_path=os.path.join("..", "configs"), config_name="config")
324
+ def hydra_main(cfg: Config) -> None:
325
+ add_defaults(cfg)
326
+
327
+ with open_dict(cfg):
328
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
329
+ cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
330
+
331
+ cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
332
+ OmegaConf.set_struct(cfg, True)
333
+
334
+ distributed_utils.call_main(cfg, main)
335
+
336
+
337
+ def cli_main():
338
+ try:
339
+ from hydra._internal.utils import get_args
340
+
341
+ cfg_name = get_args().config_name or "config"
342
+ except Exception:
343
+ logger.warning("Failed to get config name from hydra args")
344
+ cfg_name = "config"
345
+ hydra_init(cfg_name)
346
+ hydra_main()
347
+
348
+
349
+ if __name__ == "__main__":
350
+ cli_main()