genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/scripts/test.py ADDED
@@ -0,0 +1,264 @@
1
+ import logging
2
+ import os
3
+ import pprint
4
+ import sys
5
+ from itertools import chain
6
+
7
+ import polars as pl
8
+ import torch.distributed
9
+ import torch.utils.data
10
+
11
+ logging.basicConfig(
12
+ format="%(asctime)s | %(levelname)s %(name)s %(message)s)))",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
15
+ stream=sys.stdout,
16
+ )
17
+ logger = logging.getLogger("genhpf.test")
18
+
19
+ import hydra
20
+ import torch
21
+ from hydra.core.hydra_config import HydraConfig
22
+ from omegaconf import OmegaConf, open_dict
23
+
24
+ from genhpf import criterions, models
25
+ from genhpf.configs import Config
26
+ from genhpf.configs.initialize import add_defaults, hydra_init
27
+ from genhpf.datasets import load_dataset
28
+ from genhpf.loggings import metrics, progress_bar
29
+ from genhpf.utils import distributed_utils, utils
30
+
31
+
32
+ def main(cfg: Config) -> None:
33
+ assert (
34
+ cfg.checkpoint.load_checkpoint is not None
35
+ ), "Please specify the checkpoint to load with `checkpoint.load_checkpoint`"
36
+
37
+ assert cfg.dataset.batch_size is not None, "batch_size must be specified"
38
+ metrics.reset()
39
+
40
+ use_cuda = torch.cuda.is_available()
41
+
42
+ if use_cuda:
43
+ torch.cuda.set_device(cfg.distributed_training.device_id)
44
+
45
+ if cfg.distributed_training.distributed_world_size > 1:
46
+ data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
47
+ else:
48
+ data_parallel_world_size = 1
49
+
50
+ # print args
51
+ logger.info(pprint.pformat(dict(cfg)))
52
+
53
+ # load model
54
+ model = models.build_model(cfg.model)
55
+ logger.info(f"loading model from {cfg.checkpoint.load_checkpoint}")
56
+ model_state_dict = torch.load(cfg.checkpoint.load_checkpoint, map_location="cpu")["model"]
57
+ model.load_state_dict(model_state_dict, strict=True)
58
+ logger.info(f"loaded model from {cfg.checkpoint.load_checkpoint}")
59
+
60
+ logger.info(model)
61
+ logger.info(f"model: {model.__class__.__name__}")
62
+ logger.info(
63
+ "num. shared model params: {:,} (num. trained: {:,})".format(
64
+ sum(p.numel() for p in model.parameters()),
65
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
66
+ )
67
+ )
68
+
69
+ # Move model to GPU
70
+ model.eval()
71
+ if use_cuda:
72
+ model.cuda()
73
+
74
+ # build criterion
75
+ criterion = criterions.build_criterion(cfg.criterion)
76
+ criterion.eval()
77
+
78
+ def _fp_convert_sample(sample):
79
+ def apply_float(t):
80
+ if t.dtype in [torch.float64, torch.float32, torch.int16]:
81
+ return t.to(dtype=torch.float)
82
+ return t
83
+
84
+ sample = utils.apply_to_sample(apply_float, sample)
85
+
86
+ return sample
87
+
88
+ assert cfg.dataset.test_subset is not None, "Please specify the test subset with `dataset.test_subset`"
89
+ test_subsets = cfg.dataset.test_subset.split(",")
90
+
91
+ if cfg.dataset.combine_test_subsets:
92
+ datasets = [("combined-test", load_dataset(cfg.dataset.data, test_subsets, cfg))]
93
+ else:
94
+ datasets = [
95
+ (subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in test_subsets
96
+ ]
97
+
98
+ output_meds_predictions = False
99
+ if cfg.dataset.data_format == "meds" and cfg.meds.output_predictions:
100
+ if len(test_subsets) > 1:
101
+ raise NotImplementedError(
102
+ "MEDS dataset does not currently support multiple test subsets when `output_predictions` "
103
+ "is enabled. Please specify only one test subset."
104
+ )
105
+ if cfg.dataset.combine_test_subsets:
106
+ raise NotImplementedError(
107
+ "MEDS dataset does not currently support `combine_test_subsets` when `output_predictions` "
108
+ "is enabled. Please set `dataset.combine_test_subsets` to False."
109
+ )
110
+ if len(cfg.criterion.task_names) > 1:
111
+ raise NotImplementedError(
112
+ "MEDS dataset does not currently support multiple tasks when `output_predictions` "
113
+ "is enabled. Please specify only one task."
114
+ )
115
+ if cfg.criterion.num_labels[0] > 1:
116
+ raise NotImplementedError(
117
+ "MEDS dataset currently only supports binary classification when `output_predictions` "
118
+ "is enabled. Please specify only one label by setting `criterion.num_labels` to 1."
119
+ )
120
+
121
+ assert (
122
+ cfg.meds.labels_dir is not None and cfg.meds.output_dir is not None
123
+ ), "Please specify labels_dir and output_dir in the MEDS config to output predictions."
124
+ assert data_parallel_world_size == 1, (
125
+ "MEDS dataset does not currently support distributed testing when `output_predictions` "
126
+ "is enabled. Please set `distributed_training.distributed_world_size` to 1."
127
+ )
128
+
129
+ output_meds_predictions = True
130
+
131
+ labels = pl.read_parquet(os.path.join(cfg.meds.labels_dir, f"{test_subsets[0]}/*.parquet"))
132
+ labels = labels.sort(by=["subject_id", "prediction_time"])
133
+ labels = labels.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
134
+ labels = labels.with_columns(
135
+ pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str).alias("subject_id")
136
+ )
137
+ labels = labels.drop("suffix")
138
+ labels = labels.select(["subject_id", "prediction_time", "boolean_value"])
139
+
140
+ meds_pred_output = {
141
+ "subject_id": [],
142
+ "predicted_boolean_value": [],
143
+ "predicted_boolean_probability": [],
144
+ }
145
+
146
+ for subset, dataset in datasets:
147
+ logger.info(f"begin validation on '{subset}' subset")
148
+
149
+ # initialize data iterator
150
+ batch_sampler = (
151
+ torch.utils.data.DistributedSampler(dataset, shuffle=False)
152
+ if torch.distributed.is_initialized()
153
+ else None
154
+ )
155
+ batch_iterator = torch.utils.data.DataLoader(
156
+ dataset,
157
+ batch_size=cfg.dataset.batch_size,
158
+ shuffle=False,
159
+ num_workers=cfg.dataset.num_workers,
160
+ collate_fn=dataset.collator,
161
+ sampler=batch_sampler,
162
+ )
163
+
164
+ progress = progress_bar.progress_bar(
165
+ batch_iterator,
166
+ log_format=cfg.common.log_format,
167
+ log_interval=cfg.common.log_interval,
168
+ log_file=cfg.common.log_file,
169
+ epoch=0,
170
+ default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
171
+ wandb_project=(
172
+ cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
173
+ ),
174
+ wandb_entity=(
175
+ cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
176
+ ),
177
+ wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
178
+ )
179
+
180
+ log_outputs = []
181
+ for i, sample in enumerate(progress):
182
+ with torch.no_grad():
183
+ sample = utils.prepare_sample(sample)
184
+ sample = _fp_convert_sample(sample)
185
+ _loss, _sample_size, log_output, net_output = criterion(model, sample, return_net_output=True)
186
+ log_outputs.append(log_output)
187
+ if output_meds_predictions:
188
+ meds_pred_output["subject_id"].extend(sample["id"])
189
+ logits = model.get_logits(sample, net_output)
190
+ probs = torch.sigmoid(logits).view(-1).cpu()
191
+ meds_pred_output["predicted_boolean_probability"].extend(probs.tolist())
192
+ meds_pred_output["predicted_boolean_value"].extend(
193
+ (probs >= cfg.criterion.threshold).int().tolist()
194
+ )
195
+
196
+ if output_meds_predictions:
197
+ meds_pred_output = pl.DataFrame(meds_pred_output)
198
+ meds_pred_output = meds_pred_output.join(labels, on="subject_id", how="left")
199
+ meds_pred_output = meds_pred_output.select(
200
+ [
201
+ pl.col("subject_id"),
202
+ pl.col("prediction_time"),
203
+ pl.col("boolean_value"),
204
+ pl.col("predicted_boolean_value"),
205
+ pl.col("predicted_boolean_probability"),
206
+ ]
207
+ )
208
+ meds_pred_output = meds_pred_output.with_columns(
209
+ pl.col("subject_id").map_elements(lambda x: x.split("_")[0], return_dtype=pl.String).cast(int)
210
+ )
211
+ meds_pred_output = (
212
+ meds_pred_output.with_columns(pl.col("predicted_boolean_value").cast(bool))
213
+ )
214
+ if not os.path.exists(cfg.meds.output_dir):
215
+ os.makedirs(cfg.meds.output_dir)
216
+ meds_pred_output.write_parquet(os.path.join(cfg.meds.output_dir, f"{subset}.parquet"))
217
+
218
+ if data_parallel_world_size > 1:
219
+ log_outputs = distributed_utils.all_gather_list(
220
+ log_outputs,
221
+ max_size=cfg.common.all_gather_list_size,
222
+ group=distributed_utils.get_data_parallel_group(),
223
+ )
224
+ log_outputs = list(chain.from_iterable(log_outputs))
225
+
226
+ with metrics.aggregate(new_root=True) as agg:
227
+ criterion.__class__.reduce_metrics(log_outputs)
228
+ del log_outputs
229
+ log_outputs = agg.get_smoothed_values()
230
+
231
+ if hasattr(criterion, "post_validate"):
232
+ stats = criterion.post_validate(stats=log_outputs, agg=agg)
233
+
234
+ progress.print(stats, tag=subset, step=None)
235
+
236
+
237
+ @hydra.main(config_path=os.path.join("..", "configs"), config_name="config")
238
+ def hydra_main(cfg: Config) -> None:
239
+ add_defaults(cfg)
240
+
241
+ with open_dict(cfg):
242
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
243
+ cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
244
+
245
+ cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
246
+ OmegaConf.set_struct(cfg, True)
247
+
248
+ distributed_utils.call_main(cfg, main)
249
+
250
+
251
+ def cli_main():
252
+ try:
253
+ from hydra._internal.utils import get_args
254
+
255
+ cfg_name = get_args().config_name or "config"
256
+ except Exception:
257
+ logger.warning("Failed to get config name from hydra args")
258
+ cfg_name = "config"
259
+ hydra_init(cfg_name)
260
+ hydra_main()
261
+
262
+
263
+ if __name__ == "__main__":
264
+ cli_main()
@@ -0,0 +1,365 @@
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import pprint
5
+ import random
6
+ import sys
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import torch.distributed
10
+
11
+ logging.basicConfig(
12
+ format="%(asctime)s | %(levelname)s %(name)s %(message)s)))",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
15
+ stream=sys.stdout,
16
+ )
17
+ logger = logging.getLogger("genhpf.train")
18
+
19
+ import hydra
20
+ import numpy as np
21
+ import torch
22
+ from hydra.core.hydra_config import HydraConfig
23
+ from omegaconf import OmegaConf, open_dict
24
+
25
+ from genhpf import criterions, models
26
+ from genhpf.configs import Config
27
+ from genhpf.configs.initialize import add_defaults, hydra_init
28
+ from genhpf.datasets import load_dataset
29
+ from genhpf.loggings import meters, metrics, progress_bar
30
+ from genhpf.trainer import Trainer
31
+ from genhpf.utils import checkpoint_utils, distributed_utils, utils
32
+
33
+
34
+ def main(cfg: Config) -> None:
35
+ if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
36
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
37
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
38
+
39
+ if cfg.common.debug:
40
+ os.environ["OMP_NUM_THREADS"] = "2"
41
+ os.environ["MKL_NUM_THREADS"] = "2"
42
+ torch.set_num_threads(2)
43
+ torch.set_num_interop_threads(2)
44
+ cfg.optimization.max_epoch = 1
45
+
46
+ assert cfg.dataset.batch_size is not None, "batch_size must be specified"
47
+ metrics.reset()
48
+
49
+ np.random.seed(cfg.common.seed)
50
+ random.seed(cfg.common.seed)
51
+ utils.set_torch_seed(cfg.common.seed)
52
+
53
+ if distributed_utils.is_master(cfg.distributed_training):
54
+ checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
55
+
56
+ # print args
57
+ logger.info(pprint.pformat(dict(cfg)))
58
+
59
+ model = models.build_model(cfg.model)
60
+ if cfg.checkpoint.load_checkpoint is not None:
61
+ state_dict = torch.load(cfg.checkpoint.load_checkpoint, map_location="cpu")["model"]
62
+ model.load_state_dict(state_dict, strict=True)
63
+ logger.info(f"loaded model from {cfg.checkpoint.load_checkpoint}")
64
+ criterion = criterions.build_criterion(cfg.criterion)
65
+
66
+ logger.info(model)
67
+ logger.info(f"model: {model.__class__.__name__}")
68
+ logger.info(f"criterion: {criterion.__class__.__name__}")
69
+ logger.info(
70
+ "num. shared model params: {:,} (num. trained: {:,})".format(
71
+ sum(p.numel() for p in model.parameters()),
72
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
73
+ )
74
+ )
75
+
76
+ datasets = {}
77
+ train_subsets = cfg.dataset.train_subset.split(",")
78
+ if len(train_subsets) > 1:
79
+ assert (
80
+ cfg.dataset.combine_train_subsets
81
+ ), "train_subset contains multiple datasets, but combine_train_subsets is not set"
82
+ datasets["train"] = [("combined-train", load_dataset(cfg.dataset.data, train_subsets, cfg))]
83
+ else:
84
+ datasets["train"] = [(train_subsets[0].strip(), load_dataset(cfg.dataset.data, train_subsets, cfg))]
85
+
86
+ if not cfg.dataset.disable_validation and cfg.dataset.valid_subset is not None:
87
+ valid_subsets = cfg.dataset.valid_subset.split(",")
88
+ if cfg.dataset.combine_valid_subsets:
89
+ datasets["valid"] = [("combined-valid", load_dataset(cfg.dataset.data, valid_subsets, cfg))]
90
+ else:
91
+ datasets["valid"] = [
92
+ (subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in valid_subsets
93
+ ]
94
+ if cfg.dataset.test_subset is not None:
95
+ test_subsets = cfg.dataset.test_subset.split(",")
96
+ if cfg.dataset.combine_test_subsets:
97
+ datasets["test"] = [("combined-test", load_dataset(cfg.dataset.data, test_subsets, cfg))]
98
+ else:
99
+ datasets["test"] = [
100
+ (subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in test_subsets
101
+ ]
102
+
103
+ trainer = Trainer(cfg, model, criterion)
104
+
105
+ logger.info(f"training on {cfg.distributed_training.distributed_world_size} devices (GPUs)")
106
+ logger.info(f"batch size per device = {cfg.dataset.batch_size}")
107
+
108
+ max_epoch = cfg.optimization.max_epoch
109
+
110
+ train_meter = meters.StopwatchMeter()
111
+ train_meter.start()
112
+ for i in range(1, max_epoch + 1):
113
+ # train for one epoch
114
+ valid_losses, should_stop = train(cfg, trainer, datasets, i)
115
+ if should_stop:
116
+ break
117
+ train_meter.stop()
118
+ logger.info(f"done training in {train_meter.sum:.1f} seconds")
119
+
120
+
121
+ def should_stop_early(cfg: Config, valid_loss: float) -> bool:
122
+ # skip check if no validation was done in the current epoch
123
+ if valid_loss is None:
124
+ return False
125
+ if cfg.checkpoint.patience <= 0:
126
+ return False
127
+
128
+ def is_better(a, b):
129
+ return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
130
+
131
+ prev_best = getattr(should_stop_early, "best", None)
132
+ if prev_best is None or is_better(valid_loss, prev_best):
133
+ should_stop_early.best = valid_loss
134
+ should_stop_early.num_runs = 0
135
+ return False
136
+ else:
137
+ should_stop_early.num_runs += 1
138
+ if should_stop_early.num_runs >= cfg.checkpoint.patience:
139
+ logger.info(
140
+ f"early stop since valid performance hasn't improved for " f"{cfg.checkpoint.patience} runs"
141
+ )
142
+ return True
143
+ else:
144
+ return False
145
+
146
+
147
+ @metrics.aggregate("train")
148
+ def train(
149
+ cfg: Config,
150
+ trainer: Trainer,
151
+ datasets,
152
+ epoch: int,
153
+ ) -> Tuple[List[Optional[float]], bool]:
154
+ """Train the model for one epoch and return validation losses."""
155
+ # initialize data iterator
156
+ data_loader, batch_sampler = trainer.get_train_iterator(datasets["train"][0][1])
157
+ if batch_sampler is not None:
158
+ batch_sampler.set_epoch(epoch)
159
+
160
+ itr = iter(data_loader)
161
+ progress = progress_bar.progress_bar(
162
+ itr,
163
+ log_format=cfg.common.log_format,
164
+ log_file=cfg.common.log_file,
165
+ log_interval=cfg.common.log_interval,
166
+ epoch=epoch,
167
+ default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
168
+ wandb_project=(
169
+ cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
170
+ ),
171
+ wandb_entity=(
172
+ cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
173
+ ),
174
+ wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
175
+ )
176
+ progress.update_config(_flatten_config(cfg))
177
+
178
+ logger.info(f"begin training epoch {epoch}")
179
+
180
+ should_stop = False
181
+ num_updates = trainer.get_num_updates()
182
+ logger.info("Start iterating over samples")
183
+ for i, sample in enumerate(progress):
184
+ with metrics.aggregate("train_inner"):
185
+ log_output = trainer.train_step(sample)
186
+
187
+ if log_output is not None:
188
+ # log mid-epoch stats
189
+ num_updates = trainer.get_num_updates()
190
+ if num_updates % cfg.common.log_interval == 0:
191
+ stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
192
+ progress.log(stats, tag="train_inner", step=num_updates)
193
+
194
+ # reset mid-epoch stats after each log interval
195
+ # the end-of-epoch stats will still be preserved
196
+ metrics.reset_meters("train_inner")
197
+
198
+ valid_losses, should_stop = validate_and_save(cfg, trainer, datasets, epoch)
199
+
200
+ # log end-of-epoch stats
201
+ logger.info(f"end of epoch {epoch} (average epoch stats below)")
202
+ stats = get_training_stats(metrics.get_smoothed_values("train"))
203
+ progress.print(stats, tag="train", step=num_updates)
204
+
205
+ # reset epoch-level meters
206
+ metrics.reset_meters("train")
207
+ return valid_losses, should_stop
208
+
209
+
210
+ def validate_and_save(
211
+ cfg: Config,
212
+ trainer: Trainer,
213
+ datasets,
214
+ epoch: int,
215
+ ) -> Tuple[List[Optional[float]], bool]:
216
+ should_stop = False
217
+ if epoch >= cfg.optimization.max_epoch:
218
+ should_stop = True
219
+ logger.info(
220
+ "Stopping training due to " f"num_epochs: {epoch} >= max_epochs: {cfg.optimization.max_epoch}"
221
+ )
222
+
223
+ do_validate = "valid" in datasets or "test" in datasets
224
+
225
+ # validate
226
+ valid_losses = [None]
227
+ if do_validate:
228
+ valid_losses = validate(cfg, trainer, datasets, epoch)
229
+
230
+ should_stop |= should_stop_early(cfg, valid_losses[0])
231
+
232
+ checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch, valid_losses[0])
233
+ if torch.distributed.is_initialized():
234
+ torch.distributed.barrier()
235
+
236
+ return valid_losses, should_stop
237
+
238
+
239
+ def validate(
240
+ cfg: Config,
241
+ trainer: Trainer,
242
+ datasets,
243
+ epoch: int,
244
+ ):
245
+ """Evaluate the model on the validation set(s) and return the losses."""
246
+
247
+ valid_subsets = datasets.get("valid", [])
248
+ test_subsets = datasets.get("test", [])
249
+
250
+ valid_losses = []
251
+ for subset, dataset in valid_subsets + test_subsets:
252
+ logger.info(f"begin validation on '{subset}' subset")
253
+
254
+ # initialize data iterator
255
+ data_loader, _ = trainer.get_valid_iterator(dataset)
256
+ progress = progress_bar.progress_bar(
257
+ data_loader,
258
+ log_format=cfg.common.log_format,
259
+ log_interval=cfg.common.log_interval,
260
+ log_file=cfg.common.log_file,
261
+ epoch=epoch,
262
+ default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
263
+ wandb_project=(
264
+ cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
265
+ ),
266
+ wandb_entity=(
267
+ cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
268
+ ),
269
+ wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
270
+ )
271
+
272
+ # create a new root metrics aggregator so validation metrics
273
+ # don't pollute other aggregators (e.g., train meters)
274
+ with metrics.aggregate(new_root=True) as agg:
275
+ for i, sample in enumerate(progress):
276
+ trainer.valid_step(sample, subset=subset)
277
+
278
+ stats = agg.get_smoothed_values()
279
+
280
+ if hasattr(trainer.criterion, "post_validate"):
281
+ stats = trainer.criterion.post_validate(
282
+ stats=stats,
283
+ agg=agg,
284
+ )
285
+
286
+ # log validation stats
287
+ stats = get_valid_stats(cfg, trainer, subset, stats)
288
+
289
+ progress.print(stats, tag=subset, step=trainer.get_num_updates())
290
+
291
+ if np.isnan(stats[cfg.checkpoint.best_checkpoint_metric]):
292
+ logger.info(
293
+ f"validation value for {cfg.checkpoint.best_checkpoint_metric} is NaN. "
294
+ "Changed the best checkpoint metric to loss."
295
+ )
296
+ cfg.checkpoint.best_checkpoint_metric = "loss"
297
+ cfg.checkpoint.maximize_best_checkpoint_metric = False
298
+
299
+ valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
300
+
301
+ return valid_losses
302
+
303
+
304
+ def get_training_stats(stats):
305
+ stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
306
+ return stats
307
+
308
+
309
+ def get_valid_stats(cfg: Config, trainer: Trainer, subset: str, stats: Dict[str, Any]) -> Dict[str, Any]:
310
+ stats["num_updates"] = trainer.get_num_updates()
311
+
312
+ if not hasattr(get_valid_stats, "best"):
313
+ get_valid_stats.best = dict()
314
+
315
+ prev_best = getattr(get_valid_stats, "best").get(subset, stats[cfg.checkpoint.best_checkpoint_metric])
316
+ best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
317
+ get_valid_stats.best[subset] = best_function(stats[cfg.checkpoint.best_checkpoint_metric], prev_best)
318
+
319
+ key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
320
+ stats[key] = get_valid_stats.best[subset]
321
+
322
+ return stats
323
+
324
+
325
+ def _flatten_config(cfg: Config):
326
+ config = OmegaConf.to_container(cfg)
327
+ # remove any legacy Namespaces and replace with a single "args"
328
+ namespace = None
329
+ for k, v in list(config.items()):
330
+ if isinstance(v, argparse.Namespace):
331
+ namespace = v
332
+ del config[k]
333
+ if namespace is not None:
334
+ config["args"] = vars(namespace)
335
+ return config
336
+
337
+
338
+ @hydra.main(config_path=os.path.join("..", "configs"), config_name="config")
339
+ def hydra_main(cfg: Config) -> None:
340
+ add_defaults(cfg)
341
+
342
+ with open_dict(cfg):
343
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
344
+ cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
345
+
346
+ cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
347
+ OmegaConf.set_struct(cfg, True)
348
+
349
+ distributed_utils.call_main(cfg, main)
350
+
351
+
352
+ def cli_main():
353
+ try:
354
+ from hydra._internal.utils import get_args
355
+
356
+ cfg_name = get_args().config_name or "config"
357
+ except Exception:
358
+ logger.warning("Failed to get config name from hydra args")
359
+ cfg_name = "config"
360
+ hydra_init(cfg_name)
361
+ hydra_main()
362
+
363
+
364
+ if __name__ == "__main__":
365
+ cli_main()