genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/scripts/test.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import pprint
|
|
4
|
+
import sys
|
|
5
|
+
from itertools import chain
|
|
6
|
+
|
|
7
|
+
import polars as pl
|
|
8
|
+
import torch.distributed
|
|
9
|
+
import torch.utils.data
|
|
10
|
+
|
|
11
|
+
logging.basicConfig(
|
|
12
|
+
format="%(asctime)s | %(levelname)s %(name)s %(message)s)))",
|
|
13
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
14
|
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
|
15
|
+
stream=sys.stdout,
|
|
16
|
+
)
|
|
17
|
+
logger = logging.getLogger("genhpf.test")
|
|
18
|
+
|
|
19
|
+
import hydra
|
|
20
|
+
import torch
|
|
21
|
+
from hydra.core.hydra_config import HydraConfig
|
|
22
|
+
from omegaconf import OmegaConf, open_dict
|
|
23
|
+
|
|
24
|
+
from genhpf import criterions, models
|
|
25
|
+
from genhpf.configs import Config
|
|
26
|
+
from genhpf.configs.initialize import add_defaults, hydra_init
|
|
27
|
+
from genhpf.datasets import load_dataset
|
|
28
|
+
from genhpf.loggings import metrics, progress_bar
|
|
29
|
+
from genhpf.utils import distributed_utils, utils
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def main(cfg: Config) -> None:
|
|
33
|
+
assert (
|
|
34
|
+
cfg.checkpoint.load_checkpoint is not None
|
|
35
|
+
), "Please specify the checkpoint to load with `checkpoint.load_checkpoint`"
|
|
36
|
+
|
|
37
|
+
assert cfg.dataset.batch_size is not None, "batch_size must be specified"
|
|
38
|
+
metrics.reset()
|
|
39
|
+
|
|
40
|
+
use_cuda = torch.cuda.is_available()
|
|
41
|
+
|
|
42
|
+
if use_cuda:
|
|
43
|
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
|
44
|
+
|
|
45
|
+
if cfg.distributed_training.distributed_world_size > 1:
|
|
46
|
+
data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
|
|
47
|
+
else:
|
|
48
|
+
data_parallel_world_size = 1
|
|
49
|
+
|
|
50
|
+
# print args
|
|
51
|
+
logger.info(pprint.pformat(dict(cfg)))
|
|
52
|
+
|
|
53
|
+
# load model
|
|
54
|
+
model = models.build_model(cfg.model)
|
|
55
|
+
logger.info(f"loading model from {cfg.checkpoint.load_checkpoint}")
|
|
56
|
+
model_state_dict = torch.load(cfg.checkpoint.load_checkpoint, map_location="cpu")["model"]
|
|
57
|
+
model.load_state_dict(model_state_dict, strict=True)
|
|
58
|
+
logger.info(f"loaded model from {cfg.checkpoint.load_checkpoint}")
|
|
59
|
+
|
|
60
|
+
logger.info(model)
|
|
61
|
+
logger.info(f"model: {model.__class__.__name__}")
|
|
62
|
+
logger.info(
|
|
63
|
+
"num. shared model params: {:,} (num. trained: {:,})".format(
|
|
64
|
+
sum(p.numel() for p in model.parameters()),
|
|
65
|
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Move model to GPU
|
|
70
|
+
model.eval()
|
|
71
|
+
if use_cuda:
|
|
72
|
+
model.cuda()
|
|
73
|
+
|
|
74
|
+
# build criterion
|
|
75
|
+
criterion = criterions.build_criterion(cfg.criterion)
|
|
76
|
+
criterion.eval()
|
|
77
|
+
|
|
78
|
+
def _fp_convert_sample(sample):
|
|
79
|
+
def apply_float(t):
|
|
80
|
+
if t.dtype in [torch.float64, torch.float32, torch.int16]:
|
|
81
|
+
return t.to(dtype=torch.float)
|
|
82
|
+
return t
|
|
83
|
+
|
|
84
|
+
sample = utils.apply_to_sample(apply_float, sample)
|
|
85
|
+
|
|
86
|
+
return sample
|
|
87
|
+
|
|
88
|
+
assert cfg.dataset.test_subset is not None, "Please specify the test subset with `dataset.test_subset`"
|
|
89
|
+
test_subsets = cfg.dataset.test_subset.split(",")
|
|
90
|
+
|
|
91
|
+
if cfg.dataset.combine_test_subsets:
|
|
92
|
+
datasets = [("combined-test", load_dataset(cfg.dataset.data, test_subsets, cfg))]
|
|
93
|
+
else:
|
|
94
|
+
datasets = [
|
|
95
|
+
(subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in test_subsets
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
output_meds_predictions = False
|
|
99
|
+
if cfg.dataset.data_format == "meds" and cfg.meds.output_predictions:
|
|
100
|
+
if len(test_subsets) > 1:
|
|
101
|
+
raise NotImplementedError(
|
|
102
|
+
"MEDS dataset does not currently support multiple test subsets when `output_predictions` "
|
|
103
|
+
"is enabled. Please specify only one test subset."
|
|
104
|
+
)
|
|
105
|
+
if cfg.dataset.combine_test_subsets:
|
|
106
|
+
raise NotImplementedError(
|
|
107
|
+
"MEDS dataset does not currently support `combine_test_subsets` when `output_predictions` "
|
|
108
|
+
"is enabled. Please set `dataset.combine_test_subsets` to False."
|
|
109
|
+
)
|
|
110
|
+
if len(cfg.criterion.task_names) > 1:
|
|
111
|
+
raise NotImplementedError(
|
|
112
|
+
"MEDS dataset does not currently support multiple tasks when `output_predictions` "
|
|
113
|
+
"is enabled. Please specify only one task."
|
|
114
|
+
)
|
|
115
|
+
if cfg.criterion.num_labels[0] > 1:
|
|
116
|
+
raise NotImplementedError(
|
|
117
|
+
"MEDS dataset currently only supports binary classification when `output_predictions` "
|
|
118
|
+
"is enabled. Please specify only one label by setting `criterion.num_labels` to 1."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
assert (
|
|
122
|
+
cfg.meds.labels_dir is not None and cfg.meds.output_dir is not None
|
|
123
|
+
), "Please specify labels_dir and output_dir in the MEDS config to output predictions."
|
|
124
|
+
assert data_parallel_world_size == 1, (
|
|
125
|
+
"MEDS dataset does not currently support distributed testing when `output_predictions` "
|
|
126
|
+
"is enabled. Please set `distributed_training.distributed_world_size` to 1."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
output_meds_predictions = True
|
|
130
|
+
|
|
131
|
+
labels = pl.read_parquet(os.path.join(cfg.meds.labels_dir, f"{test_subsets[0]}/*.parquet"))
|
|
132
|
+
labels = labels.sort(by=["subject_id", "prediction_time"])
|
|
133
|
+
labels = labels.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
|
|
134
|
+
labels = labels.with_columns(
|
|
135
|
+
pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str).alias("subject_id")
|
|
136
|
+
)
|
|
137
|
+
labels = labels.drop("suffix")
|
|
138
|
+
labels = labels.select(["subject_id", "prediction_time", "boolean_value"])
|
|
139
|
+
|
|
140
|
+
meds_pred_output = {
|
|
141
|
+
"subject_id": [],
|
|
142
|
+
"predicted_boolean_value": [],
|
|
143
|
+
"predicted_boolean_probability": [],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
for subset, dataset in datasets:
|
|
147
|
+
logger.info(f"begin validation on '{subset}' subset")
|
|
148
|
+
|
|
149
|
+
# initialize data iterator
|
|
150
|
+
batch_sampler = (
|
|
151
|
+
torch.utils.data.DistributedSampler(dataset, shuffle=False)
|
|
152
|
+
if torch.distributed.is_initialized()
|
|
153
|
+
else None
|
|
154
|
+
)
|
|
155
|
+
batch_iterator = torch.utils.data.DataLoader(
|
|
156
|
+
dataset,
|
|
157
|
+
batch_size=cfg.dataset.batch_size,
|
|
158
|
+
shuffle=False,
|
|
159
|
+
num_workers=cfg.dataset.num_workers,
|
|
160
|
+
collate_fn=dataset.collator,
|
|
161
|
+
sampler=batch_sampler,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
progress = progress_bar.progress_bar(
|
|
165
|
+
batch_iterator,
|
|
166
|
+
log_format=cfg.common.log_format,
|
|
167
|
+
log_interval=cfg.common.log_interval,
|
|
168
|
+
log_file=cfg.common.log_file,
|
|
169
|
+
epoch=0,
|
|
170
|
+
default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
|
|
171
|
+
wandb_project=(
|
|
172
|
+
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
|
|
173
|
+
),
|
|
174
|
+
wandb_entity=(
|
|
175
|
+
cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
|
|
176
|
+
),
|
|
177
|
+
wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
log_outputs = []
|
|
181
|
+
for i, sample in enumerate(progress):
|
|
182
|
+
with torch.no_grad():
|
|
183
|
+
sample = utils.prepare_sample(sample)
|
|
184
|
+
sample = _fp_convert_sample(sample)
|
|
185
|
+
_loss, _sample_size, log_output, net_output = criterion(model, sample, return_net_output=True)
|
|
186
|
+
log_outputs.append(log_output)
|
|
187
|
+
if output_meds_predictions:
|
|
188
|
+
meds_pred_output["subject_id"].extend(sample["id"])
|
|
189
|
+
logits = model.get_logits(sample, net_output)
|
|
190
|
+
probs = torch.sigmoid(logits).view(-1).cpu()
|
|
191
|
+
meds_pred_output["predicted_boolean_probability"].extend(probs.tolist())
|
|
192
|
+
meds_pred_output["predicted_boolean_value"].extend(
|
|
193
|
+
(probs >= cfg.criterion.threshold).int().tolist()
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
if output_meds_predictions:
|
|
197
|
+
meds_pred_output = pl.DataFrame(meds_pred_output)
|
|
198
|
+
meds_pred_output = meds_pred_output.join(labels, on="subject_id", how="left")
|
|
199
|
+
meds_pred_output = meds_pred_output.select(
|
|
200
|
+
[
|
|
201
|
+
pl.col("subject_id"),
|
|
202
|
+
pl.col("prediction_time"),
|
|
203
|
+
pl.col("boolean_value"),
|
|
204
|
+
pl.col("predicted_boolean_value"),
|
|
205
|
+
pl.col("predicted_boolean_probability"),
|
|
206
|
+
]
|
|
207
|
+
)
|
|
208
|
+
meds_pred_output = meds_pred_output.with_columns(
|
|
209
|
+
pl.col("subject_id").map_elements(lambda x: x.split("_")[0], return_dtype=pl.String).cast(int)
|
|
210
|
+
)
|
|
211
|
+
meds_pred_output = (
|
|
212
|
+
meds_pred_output.with_columns(pl.col("predicted_boolean_value").cast(bool))
|
|
213
|
+
)
|
|
214
|
+
if not os.path.exists(cfg.meds.output_dir):
|
|
215
|
+
os.makedirs(cfg.meds.output_dir)
|
|
216
|
+
meds_pred_output.write_parquet(os.path.join(cfg.meds.output_dir, f"{subset}.parquet"))
|
|
217
|
+
|
|
218
|
+
if data_parallel_world_size > 1:
|
|
219
|
+
log_outputs = distributed_utils.all_gather_list(
|
|
220
|
+
log_outputs,
|
|
221
|
+
max_size=cfg.common.all_gather_list_size,
|
|
222
|
+
group=distributed_utils.get_data_parallel_group(),
|
|
223
|
+
)
|
|
224
|
+
log_outputs = list(chain.from_iterable(log_outputs))
|
|
225
|
+
|
|
226
|
+
with metrics.aggregate(new_root=True) as agg:
|
|
227
|
+
criterion.__class__.reduce_metrics(log_outputs)
|
|
228
|
+
del log_outputs
|
|
229
|
+
log_outputs = agg.get_smoothed_values()
|
|
230
|
+
|
|
231
|
+
if hasattr(criterion, "post_validate"):
|
|
232
|
+
stats = criterion.post_validate(stats=log_outputs, agg=agg)
|
|
233
|
+
|
|
234
|
+
progress.print(stats, tag=subset, step=None)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@hydra.main(config_path=os.path.join("..", "configs"), config_name="config")
|
|
238
|
+
def hydra_main(cfg: Config) -> None:
|
|
239
|
+
add_defaults(cfg)
|
|
240
|
+
|
|
241
|
+
with open_dict(cfg):
|
|
242
|
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
|
243
|
+
cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
|
|
244
|
+
|
|
245
|
+
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
|
|
246
|
+
OmegaConf.set_struct(cfg, True)
|
|
247
|
+
|
|
248
|
+
distributed_utils.call_main(cfg, main)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def cli_main():
|
|
252
|
+
try:
|
|
253
|
+
from hydra._internal.utils import get_args
|
|
254
|
+
|
|
255
|
+
cfg_name = get_args().config_name or "config"
|
|
256
|
+
except Exception:
|
|
257
|
+
logger.warning("Failed to get config name from hydra args")
|
|
258
|
+
cfg_name = "config"
|
|
259
|
+
hydra_init(cfg_name)
|
|
260
|
+
hydra_main()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
if __name__ == "__main__":
|
|
264
|
+
cli_main()
|
genhpf/scripts/train.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import pprint
|
|
5
|
+
import random
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import torch.distributed
|
|
10
|
+
|
|
11
|
+
logging.basicConfig(
|
|
12
|
+
format="%(asctime)s | %(levelname)s %(name)s %(message)s)))",
|
|
13
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
14
|
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
|
15
|
+
stream=sys.stdout,
|
|
16
|
+
)
|
|
17
|
+
logger = logging.getLogger("genhpf.train")
|
|
18
|
+
|
|
19
|
+
import hydra
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
from hydra.core.hydra_config import HydraConfig
|
|
23
|
+
from omegaconf import OmegaConf, open_dict
|
|
24
|
+
|
|
25
|
+
from genhpf import criterions, models
|
|
26
|
+
from genhpf.configs import Config
|
|
27
|
+
from genhpf.configs.initialize import add_defaults, hydra_init
|
|
28
|
+
from genhpf.datasets import load_dataset
|
|
29
|
+
from genhpf.loggings import meters, metrics, progress_bar
|
|
30
|
+
from genhpf.trainer import Trainer
|
|
31
|
+
from genhpf.utils import checkpoint_utils, distributed_utils, utils
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def main(cfg: Config) -> None:
|
|
35
|
+
if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
|
|
36
|
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
|
37
|
+
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
|
38
|
+
|
|
39
|
+
if cfg.common.debug:
|
|
40
|
+
os.environ["OMP_NUM_THREADS"] = "2"
|
|
41
|
+
os.environ["MKL_NUM_THREADS"] = "2"
|
|
42
|
+
torch.set_num_threads(2)
|
|
43
|
+
torch.set_num_interop_threads(2)
|
|
44
|
+
cfg.optimization.max_epoch = 1
|
|
45
|
+
|
|
46
|
+
assert cfg.dataset.batch_size is not None, "batch_size must be specified"
|
|
47
|
+
metrics.reset()
|
|
48
|
+
|
|
49
|
+
np.random.seed(cfg.common.seed)
|
|
50
|
+
random.seed(cfg.common.seed)
|
|
51
|
+
utils.set_torch_seed(cfg.common.seed)
|
|
52
|
+
|
|
53
|
+
if distributed_utils.is_master(cfg.distributed_training):
|
|
54
|
+
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
|
|
55
|
+
|
|
56
|
+
# print args
|
|
57
|
+
logger.info(pprint.pformat(dict(cfg)))
|
|
58
|
+
|
|
59
|
+
model = models.build_model(cfg.model)
|
|
60
|
+
if cfg.checkpoint.load_checkpoint is not None:
|
|
61
|
+
state_dict = torch.load(cfg.checkpoint.load_checkpoint, map_location="cpu")["model"]
|
|
62
|
+
model.load_state_dict(state_dict, strict=True)
|
|
63
|
+
logger.info(f"loaded model from {cfg.checkpoint.load_checkpoint}")
|
|
64
|
+
criterion = criterions.build_criterion(cfg.criterion)
|
|
65
|
+
|
|
66
|
+
logger.info(model)
|
|
67
|
+
logger.info(f"model: {model.__class__.__name__}")
|
|
68
|
+
logger.info(f"criterion: {criterion.__class__.__name__}")
|
|
69
|
+
logger.info(
|
|
70
|
+
"num. shared model params: {:,} (num. trained: {:,})".format(
|
|
71
|
+
sum(p.numel() for p in model.parameters()),
|
|
72
|
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
datasets = {}
|
|
77
|
+
train_subsets = cfg.dataset.train_subset.split(",")
|
|
78
|
+
if len(train_subsets) > 1:
|
|
79
|
+
assert (
|
|
80
|
+
cfg.dataset.combine_train_subsets
|
|
81
|
+
), "train_subset contains multiple datasets, but combine_train_subsets is not set"
|
|
82
|
+
datasets["train"] = [("combined-train", load_dataset(cfg.dataset.data, train_subsets, cfg))]
|
|
83
|
+
else:
|
|
84
|
+
datasets["train"] = [(train_subsets[0].strip(), load_dataset(cfg.dataset.data, train_subsets, cfg))]
|
|
85
|
+
|
|
86
|
+
if not cfg.dataset.disable_validation and cfg.dataset.valid_subset is not None:
|
|
87
|
+
valid_subsets = cfg.dataset.valid_subset.split(",")
|
|
88
|
+
if cfg.dataset.combine_valid_subsets:
|
|
89
|
+
datasets["valid"] = [("combined-valid", load_dataset(cfg.dataset.data, valid_subsets, cfg))]
|
|
90
|
+
else:
|
|
91
|
+
datasets["valid"] = [
|
|
92
|
+
(subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in valid_subsets
|
|
93
|
+
]
|
|
94
|
+
if cfg.dataset.test_subset is not None:
|
|
95
|
+
test_subsets = cfg.dataset.test_subset.split(",")
|
|
96
|
+
if cfg.dataset.combine_test_subsets:
|
|
97
|
+
datasets["test"] = [("combined-test", load_dataset(cfg.dataset.data, test_subsets, cfg))]
|
|
98
|
+
else:
|
|
99
|
+
datasets["test"] = [
|
|
100
|
+
(subset.strip(), load_dataset(cfg.dataset.data, [subset], cfg)) for subset in test_subsets
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
trainer = Trainer(cfg, model, criterion)
|
|
104
|
+
|
|
105
|
+
logger.info(f"training on {cfg.distributed_training.distributed_world_size} devices (GPUs)")
|
|
106
|
+
logger.info(f"batch size per device = {cfg.dataset.batch_size}")
|
|
107
|
+
|
|
108
|
+
max_epoch = cfg.optimization.max_epoch
|
|
109
|
+
|
|
110
|
+
train_meter = meters.StopwatchMeter()
|
|
111
|
+
train_meter.start()
|
|
112
|
+
for i in range(1, max_epoch + 1):
|
|
113
|
+
# train for one epoch
|
|
114
|
+
valid_losses, should_stop = train(cfg, trainer, datasets, i)
|
|
115
|
+
if should_stop:
|
|
116
|
+
break
|
|
117
|
+
train_meter.stop()
|
|
118
|
+
logger.info(f"done training in {train_meter.sum:.1f} seconds")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def should_stop_early(cfg: Config, valid_loss: float) -> bool:
|
|
122
|
+
# skip check if no validation was done in the current epoch
|
|
123
|
+
if valid_loss is None:
|
|
124
|
+
return False
|
|
125
|
+
if cfg.checkpoint.patience <= 0:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
def is_better(a, b):
|
|
129
|
+
return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
|
|
130
|
+
|
|
131
|
+
prev_best = getattr(should_stop_early, "best", None)
|
|
132
|
+
if prev_best is None or is_better(valid_loss, prev_best):
|
|
133
|
+
should_stop_early.best = valid_loss
|
|
134
|
+
should_stop_early.num_runs = 0
|
|
135
|
+
return False
|
|
136
|
+
else:
|
|
137
|
+
should_stop_early.num_runs += 1
|
|
138
|
+
if should_stop_early.num_runs >= cfg.checkpoint.patience:
|
|
139
|
+
logger.info(
|
|
140
|
+
f"early stop since valid performance hasn't improved for " f"{cfg.checkpoint.patience} runs"
|
|
141
|
+
)
|
|
142
|
+
return True
|
|
143
|
+
else:
|
|
144
|
+
return False
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@metrics.aggregate("train")
|
|
148
|
+
def train(
|
|
149
|
+
cfg: Config,
|
|
150
|
+
trainer: Trainer,
|
|
151
|
+
datasets,
|
|
152
|
+
epoch: int,
|
|
153
|
+
) -> Tuple[List[Optional[float]], bool]:
|
|
154
|
+
"""Train the model for one epoch and return validation losses."""
|
|
155
|
+
# initialize data iterator
|
|
156
|
+
data_loader, batch_sampler = trainer.get_train_iterator(datasets["train"][0][1])
|
|
157
|
+
if batch_sampler is not None:
|
|
158
|
+
batch_sampler.set_epoch(epoch)
|
|
159
|
+
|
|
160
|
+
itr = iter(data_loader)
|
|
161
|
+
progress = progress_bar.progress_bar(
|
|
162
|
+
itr,
|
|
163
|
+
log_format=cfg.common.log_format,
|
|
164
|
+
log_file=cfg.common.log_file,
|
|
165
|
+
log_interval=cfg.common.log_interval,
|
|
166
|
+
epoch=epoch,
|
|
167
|
+
default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
|
|
168
|
+
wandb_project=(
|
|
169
|
+
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
|
|
170
|
+
),
|
|
171
|
+
wandb_entity=(
|
|
172
|
+
cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
|
|
173
|
+
),
|
|
174
|
+
wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
|
|
175
|
+
)
|
|
176
|
+
progress.update_config(_flatten_config(cfg))
|
|
177
|
+
|
|
178
|
+
logger.info(f"begin training epoch {epoch}")
|
|
179
|
+
|
|
180
|
+
should_stop = False
|
|
181
|
+
num_updates = trainer.get_num_updates()
|
|
182
|
+
logger.info("Start iterating over samples")
|
|
183
|
+
for i, sample in enumerate(progress):
|
|
184
|
+
with metrics.aggregate("train_inner"):
|
|
185
|
+
log_output = trainer.train_step(sample)
|
|
186
|
+
|
|
187
|
+
if log_output is not None:
|
|
188
|
+
# log mid-epoch stats
|
|
189
|
+
num_updates = trainer.get_num_updates()
|
|
190
|
+
if num_updates % cfg.common.log_interval == 0:
|
|
191
|
+
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
|
192
|
+
progress.log(stats, tag="train_inner", step=num_updates)
|
|
193
|
+
|
|
194
|
+
# reset mid-epoch stats after each log interval
|
|
195
|
+
# the end-of-epoch stats will still be preserved
|
|
196
|
+
metrics.reset_meters("train_inner")
|
|
197
|
+
|
|
198
|
+
valid_losses, should_stop = validate_and_save(cfg, trainer, datasets, epoch)
|
|
199
|
+
|
|
200
|
+
# log end-of-epoch stats
|
|
201
|
+
logger.info(f"end of epoch {epoch} (average epoch stats below)")
|
|
202
|
+
stats = get_training_stats(metrics.get_smoothed_values("train"))
|
|
203
|
+
progress.print(stats, tag="train", step=num_updates)
|
|
204
|
+
|
|
205
|
+
# reset epoch-level meters
|
|
206
|
+
metrics.reset_meters("train")
|
|
207
|
+
return valid_losses, should_stop
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def validate_and_save(
|
|
211
|
+
cfg: Config,
|
|
212
|
+
trainer: Trainer,
|
|
213
|
+
datasets,
|
|
214
|
+
epoch: int,
|
|
215
|
+
) -> Tuple[List[Optional[float]], bool]:
|
|
216
|
+
should_stop = False
|
|
217
|
+
if epoch >= cfg.optimization.max_epoch:
|
|
218
|
+
should_stop = True
|
|
219
|
+
logger.info(
|
|
220
|
+
"Stopping training due to " f"num_epochs: {epoch} >= max_epochs: {cfg.optimization.max_epoch}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
do_validate = "valid" in datasets or "test" in datasets
|
|
224
|
+
|
|
225
|
+
# validate
|
|
226
|
+
valid_losses = [None]
|
|
227
|
+
if do_validate:
|
|
228
|
+
valid_losses = validate(cfg, trainer, datasets, epoch)
|
|
229
|
+
|
|
230
|
+
should_stop |= should_stop_early(cfg, valid_losses[0])
|
|
231
|
+
|
|
232
|
+
checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch, valid_losses[0])
|
|
233
|
+
if torch.distributed.is_initialized():
|
|
234
|
+
torch.distributed.barrier()
|
|
235
|
+
|
|
236
|
+
return valid_losses, should_stop
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def validate(
|
|
240
|
+
cfg: Config,
|
|
241
|
+
trainer: Trainer,
|
|
242
|
+
datasets,
|
|
243
|
+
epoch: int,
|
|
244
|
+
):
|
|
245
|
+
"""Evaluate the model on the validation set(s) and return the losses."""
|
|
246
|
+
|
|
247
|
+
valid_subsets = datasets.get("valid", [])
|
|
248
|
+
test_subsets = datasets.get("test", [])
|
|
249
|
+
|
|
250
|
+
valid_losses = []
|
|
251
|
+
for subset, dataset in valid_subsets + test_subsets:
|
|
252
|
+
logger.info(f"begin validation on '{subset}' subset")
|
|
253
|
+
|
|
254
|
+
# initialize data iterator
|
|
255
|
+
data_loader, _ = trainer.get_valid_iterator(dataset)
|
|
256
|
+
progress = progress_bar.progress_bar(
|
|
257
|
+
data_loader,
|
|
258
|
+
log_format=cfg.common.log_format,
|
|
259
|
+
log_interval=cfg.common.log_interval,
|
|
260
|
+
log_file=cfg.common.log_file,
|
|
261
|
+
epoch=epoch,
|
|
262
|
+
default_log_format=("tqdm" if cfg.common.no_progress_bar else "simple"),
|
|
263
|
+
wandb_project=(
|
|
264
|
+
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
|
|
265
|
+
),
|
|
266
|
+
wandb_entity=(
|
|
267
|
+
cfg.common.wandb_entity if distributed_utils.is_master(cfg.distributed_training) else None
|
|
268
|
+
),
|
|
269
|
+
wandb_run_name=os.environ.get("WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# create a new root metrics aggregator so validation metrics
|
|
273
|
+
# don't pollute other aggregators (e.g., train meters)
|
|
274
|
+
with metrics.aggregate(new_root=True) as agg:
|
|
275
|
+
for i, sample in enumerate(progress):
|
|
276
|
+
trainer.valid_step(sample, subset=subset)
|
|
277
|
+
|
|
278
|
+
stats = agg.get_smoothed_values()
|
|
279
|
+
|
|
280
|
+
if hasattr(trainer.criterion, "post_validate"):
|
|
281
|
+
stats = trainer.criterion.post_validate(
|
|
282
|
+
stats=stats,
|
|
283
|
+
agg=agg,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# log validation stats
|
|
287
|
+
stats = get_valid_stats(cfg, trainer, subset, stats)
|
|
288
|
+
|
|
289
|
+
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
|
290
|
+
|
|
291
|
+
if np.isnan(stats[cfg.checkpoint.best_checkpoint_metric]):
|
|
292
|
+
logger.info(
|
|
293
|
+
f"validation value for {cfg.checkpoint.best_checkpoint_metric} is NaN. "
|
|
294
|
+
"Changed the best checkpoint metric to loss."
|
|
295
|
+
)
|
|
296
|
+
cfg.checkpoint.best_checkpoint_metric = "loss"
|
|
297
|
+
cfg.checkpoint.maximize_best_checkpoint_metric = False
|
|
298
|
+
|
|
299
|
+
valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
|
|
300
|
+
|
|
301
|
+
return valid_losses
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def get_training_stats(stats):
|
|
305
|
+
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
|
306
|
+
return stats
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def get_valid_stats(cfg: Config, trainer: Trainer, subset: str, stats: Dict[str, Any]) -> Dict[str, Any]:
|
|
310
|
+
stats["num_updates"] = trainer.get_num_updates()
|
|
311
|
+
|
|
312
|
+
if not hasattr(get_valid_stats, "best"):
|
|
313
|
+
get_valid_stats.best = dict()
|
|
314
|
+
|
|
315
|
+
prev_best = getattr(get_valid_stats, "best").get(subset, stats[cfg.checkpoint.best_checkpoint_metric])
|
|
316
|
+
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
|
|
317
|
+
get_valid_stats.best[subset] = best_function(stats[cfg.checkpoint.best_checkpoint_metric], prev_best)
|
|
318
|
+
|
|
319
|
+
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
|
|
320
|
+
stats[key] = get_valid_stats.best[subset]
|
|
321
|
+
|
|
322
|
+
return stats
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _flatten_config(cfg: Config):
|
|
326
|
+
config = OmegaConf.to_container(cfg)
|
|
327
|
+
# remove any legacy Namespaces and replace with a single "args"
|
|
328
|
+
namespace = None
|
|
329
|
+
for k, v in list(config.items()):
|
|
330
|
+
if isinstance(v, argparse.Namespace):
|
|
331
|
+
namespace = v
|
|
332
|
+
del config[k]
|
|
333
|
+
if namespace is not None:
|
|
334
|
+
config["args"] = vars(namespace)
|
|
335
|
+
return config
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@hydra.main(config_path=os.path.join("..", "configs"), config_name="config")
|
|
339
|
+
def hydra_main(cfg: Config) -> None:
|
|
340
|
+
add_defaults(cfg)
|
|
341
|
+
|
|
342
|
+
with open_dict(cfg):
|
|
343
|
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
|
344
|
+
cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)
|
|
345
|
+
|
|
346
|
+
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
|
|
347
|
+
OmegaConf.set_struct(cfg, True)
|
|
348
|
+
|
|
349
|
+
distributed_utils.call_main(cfg, main)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def cli_main():
|
|
353
|
+
try:
|
|
354
|
+
from hydra._internal.utils import get_args
|
|
355
|
+
|
|
356
|
+
cfg_name = get_args().config_name or "config"
|
|
357
|
+
except Exception:
|
|
358
|
+
logger.warning("Failed to get config name from hydra args")
|
|
359
|
+
cfg_name = "config"
|
|
360
|
+
hydra_init(cfg_name)
|
|
361
|
+
hydra_main()
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
if __name__ == "__main__":
|
|
365
|
+
cli_main()
|