genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/trainer.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from itertools import chain
|
|
4
|
+
from typing import Any, Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
import torch.utils
|
|
9
|
+
import torch.utils.data
|
|
10
|
+
from omegaconf import OmegaConf
|
|
11
|
+
from torch.optim import Adam
|
|
12
|
+
|
|
13
|
+
from genhpf.configs import Config
|
|
14
|
+
from genhpf.datasets import BaseDataset
|
|
15
|
+
from genhpf.loggings import metrics
|
|
16
|
+
from genhpf.utils import checkpoint_utils, distributed_utils, utils
|
|
17
|
+
from genhpf.utils.file_io import PathManager
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Trainer(object):
|
|
23
|
+
def __init__(self, cfg: Config, model, criterion):
|
|
24
|
+
self.cfg = cfg
|
|
25
|
+
self.cuda = torch.cuda.is_available()
|
|
26
|
+
if self.cuda:
|
|
27
|
+
self.device = torch.device("cuda")
|
|
28
|
+
else:
|
|
29
|
+
self.device = torch.device("cpu")
|
|
30
|
+
|
|
31
|
+
self._criterion = criterion
|
|
32
|
+
self._model = model
|
|
33
|
+
|
|
34
|
+
self._criterion = self._criterion.to(device=self.device)
|
|
35
|
+
self._model = self._model.to(device=self.device)
|
|
36
|
+
|
|
37
|
+
self._num_updates = 0
|
|
38
|
+
|
|
39
|
+
self._optimizer = None
|
|
40
|
+
self._wrapped_criterion = None
|
|
41
|
+
self._wrapped_model = None
|
|
42
|
+
|
|
43
|
+
if self.cuda:
|
|
44
|
+
self.cuda_env = utils.CudaEnvironment()
|
|
45
|
+
if self.data_parallel_world_size > 1:
|
|
46
|
+
self.cuda_env_arr = distributed_utils.all_gather_list(
|
|
47
|
+
self.cuda_env, group=distributed_utils.get_global_group()
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
self.cuda_env_arr = [self.cuda_env]
|
|
51
|
+
if self.data_parallel_rank == 0:
|
|
52
|
+
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
|
|
53
|
+
else:
|
|
54
|
+
self.cuda_env = None
|
|
55
|
+
self.cuda_env_arr = None
|
|
56
|
+
|
|
57
|
+
metrics.log_start_time("wall", priority=790, round=0)
|
|
58
|
+
|
|
59
|
+
self._start_time = time.time()
|
|
60
|
+
self._previous_training_time = 0
|
|
61
|
+
self._cumulative_training_time = None
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def data_parallel_world_size(self):
|
|
65
|
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
|
66
|
+
return 1
|
|
67
|
+
return distributed_utils.get_data_parallel_world_size()
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def data_parallel_process_group(self):
|
|
71
|
+
return distributed_utils.get_data_parallel_group()
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def data_parallel_rank(self):
|
|
75
|
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
|
76
|
+
return 0
|
|
77
|
+
return distributed_utils.get_data_parallel_rank()
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def is_data_parallel_master(self):
|
|
81
|
+
return self.data_parallel_rank == 0
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def use_distributed_wrapper(self) -> bool:
|
|
85
|
+
return self.data_parallel_world_size > 1
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def criterion(self):
|
|
89
|
+
if self._wrapped_criterion is None:
|
|
90
|
+
if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
|
|
91
|
+
self._wrapped_criterion = distributed_utils.DistributedModel(
|
|
92
|
+
self.cfg.distributed_training,
|
|
93
|
+
self._criterion,
|
|
94
|
+
process_group=self.data_parallel_process_group,
|
|
95
|
+
device=self.device,
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
self._wrapped_criterion = self._criterion
|
|
99
|
+
return self._wrapped_criterion
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def model(self):
|
|
103
|
+
if self._wrapped_model is None:
|
|
104
|
+
if self.use_distributed_wrapper:
|
|
105
|
+
self._wrapped_model = distributed_utils.DistributedModel(
|
|
106
|
+
self.cfg.distributed_training,
|
|
107
|
+
self._model,
|
|
108
|
+
process_group=self.data_parallel_process_group,
|
|
109
|
+
device=self.device,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
self._wrapped_model = self._model
|
|
113
|
+
return self._wrapped_model
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def optimizer(self):
|
|
117
|
+
if self._optimizer is None:
|
|
118
|
+
self._build_optimizer()
|
|
119
|
+
return self._optimizer
|
|
120
|
+
|
|
121
|
+
def _build_optimizer(self):
|
|
122
|
+
params = list(
|
|
123
|
+
filter(lambda p: p.requires_grad, chain(self.model.parameters(), self.criterion.parameters()))
|
|
124
|
+
)
|
|
125
|
+
self._optimizer = Adam(
|
|
126
|
+
params,
|
|
127
|
+
lr=self.cfg.optimization.lr,
|
|
128
|
+
betas=self.cfg.optimization.adam_betas,
|
|
129
|
+
eps=self.cfg.optimization.adam_eps,
|
|
130
|
+
weight_decay=self.cfg.optimization.weight_decay,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def state_dict(self):
|
|
134
|
+
pretraining_parameter_names = []
|
|
135
|
+
if hasattr(self.model, "get_pretraining_parameter_names"):
|
|
136
|
+
pretraining_parameter_names = self.model.get_pretraining_parameter_names()
|
|
137
|
+
model_state_dict = {
|
|
138
|
+
k: v for k, v in self.model.state_dict().items() if k not in pretraining_parameter_names
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
state_dict = {
|
|
142
|
+
"cfg": (
|
|
143
|
+
OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
|
|
144
|
+
if OmegaConf.is_config(self.cfg)
|
|
145
|
+
else self.cfg
|
|
146
|
+
),
|
|
147
|
+
"model": model_state_dict,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
return state_dict
|
|
151
|
+
|
|
152
|
+
def save_checkpoint(self, filename):
|
|
153
|
+
"""Save all training state in a checkpoint file."""
|
|
154
|
+
logger.info(f"Saving checkpoint to {filename}")
|
|
155
|
+
|
|
156
|
+
state_dict = utils.move_to_cpu(self.state_dict())
|
|
157
|
+
if self.is_data_parallel_master:
|
|
158
|
+
checkpoint_utils.torch_persistent_save(state_dict, filename, async_write=False)
|
|
159
|
+
logger.info(f"Finished saving checkpoint to {filename}")
|
|
160
|
+
|
|
161
|
+
def load_checkpoint(self, filename) -> None:
|
|
162
|
+
"""
|
|
163
|
+
Load all training state from a checkpoint file.
|
|
164
|
+
rank = 0 will load the checkpoint and broadcast it to other ranks.
|
|
165
|
+
"""
|
|
166
|
+
logger.info(f"Loading checkpoint from {filename}")
|
|
167
|
+
is_distributed = self.data_parallel_world_size > 1
|
|
168
|
+
bexists = PathManager.isfile(filename)
|
|
169
|
+
if bexists:
|
|
170
|
+
if self.data_parallel_rank == 0:
|
|
171
|
+
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
|
|
172
|
+
else:
|
|
173
|
+
state = None
|
|
174
|
+
|
|
175
|
+
if is_distributed:
|
|
176
|
+
state = distributed_utils.broadcast_object(
|
|
177
|
+
state, src_rank=0, group=self.data_parallel_process_group, dist_device=self.device
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# load model parameters
|
|
181
|
+
try:
|
|
182
|
+
self.model.load_state_dict(state["model"], strict=True)
|
|
183
|
+
# save memory for later steps
|
|
184
|
+
del state["model"]
|
|
185
|
+
except Exception:
|
|
186
|
+
raise Exception(
|
|
187
|
+
f"Cannot load model parameters from checkpoint {filename}; "
|
|
188
|
+
"please ensure that the architectures match."
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
logger.info(f"No existing checkpoint found {filename}")
|
|
192
|
+
|
|
193
|
+
def get_train_iterator(
|
|
194
|
+
self,
|
|
195
|
+
dataset: BaseDataset,
|
|
196
|
+
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Sampler]:
|
|
197
|
+
"""Return an DataLoader instance for training."""
|
|
198
|
+
batch_sampler = (
|
|
199
|
+
torch.utils.data.DistributedSampler(dataset, shuffle=True) if dist.is_initialized() else None
|
|
200
|
+
)
|
|
201
|
+
batch_iterator = torch.utils.data.DataLoader(
|
|
202
|
+
dataset,
|
|
203
|
+
batch_size=self.cfg.dataset.batch_size,
|
|
204
|
+
shuffle=True if not dist.is_initialized() else False,
|
|
205
|
+
num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
|
|
206
|
+
collate_fn=dataset.collator,
|
|
207
|
+
sampler=batch_sampler,
|
|
208
|
+
)
|
|
209
|
+
return batch_iterator, batch_sampler
|
|
210
|
+
|
|
211
|
+
def get_valid_iterator(
|
|
212
|
+
self,
|
|
213
|
+
dataset: BaseDataset,
|
|
214
|
+
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Sampler]:
|
|
215
|
+
"""Return an DataLoader instance for validation."""
|
|
216
|
+
batch_sampler = (
|
|
217
|
+
torch.utils.data.DistributedSampler(dataset, shuffle=False) if dist.is_initialized() else None
|
|
218
|
+
)
|
|
219
|
+
batch_iterator = torch.utils.data.DataLoader(
|
|
220
|
+
dataset,
|
|
221
|
+
batch_size=self.cfg.dataset.batch_size,
|
|
222
|
+
shuffle=False,
|
|
223
|
+
num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
|
|
224
|
+
collate_fn=dataset.collator,
|
|
225
|
+
sampler=batch_sampler,
|
|
226
|
+
)
|
|
227
|
+
return batch_iterator, batch_sampler
|
|
228
|
+
|
|
229
|
+
@metrics.aggregate("train")
|
|
230
|
+
def train_step(self, sample):
|
|
231
|
+
"""Do forward, backward and optimization steps."""
|
|
232
|
+
|
|
233
|
+
self._set_seed()
|
|
234
|
+
self.model.train()
|
|
235
|
+
self.criterion.train()
|
|
236
|
+
self.zero_grad()
|
|
237
|
+
|
|
238
|
+
metrics.log_start_time("train_wall", priority=800, round=0)
|
|
239
|
+
|
|
240
|
+
logging_outputs = []
|
|
241
|
+
sample = utils.prepare_sample(sample)
|
|
242
|
+
|
|
243
|
+
loss, sample_size, logging_output = self.criterion(self.model, sample)
|
|
244
|
+
if loss.item() > 0:
|
|
245
|
+
loss.backward() # backward
|
|
246
|
+
self.optimizer.step()
|
|
247
|
+
|
|
248
|
+
logging_outputs.append(logging_output)
|
|
249
|
+
|
|
250
|
+
# emptying the CUDA cache after the first step can
|
|
251
|
+
# reduce the chance of OOM
|
|
252
|
+
if self.cuda and self.get_num_updates() == 0:
|
|
253
|
+
torch.cuda.empty_cache()
|
|
254
|
+
|
|
255
|
+
sample_size = float(sample_size)
|
|
256
|
+
|
|
257
|
+
if self._sync_stats():
|
|
258
|
+
train_time = self._local_cumulative_training_time()
|
|
259
|
+
|
|
260
|
+
logging_outputs, (sample_size, total_train_time) = self._aggregate_logging_outputs(
|
|
261
|
+
logging_outputs, sample_size, train_time
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
self._cumulative_training_time = total_train_time / self.data_parallel_world_size
|
|
265
|
+
|
|
266
|
+
logging_output = None
|
|
267
|
+
self.set_num_updates(self.get_num_updates() + 1)
|
|
268
|
+
|
|
269
|
+
if self.cuda and self.cuda_env is not None:
|
|
270
|
+
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
|
|
271
|
+
torch.cuda.reset_peak_memory_stats()
|
|
272
|
+
gb_free = self.cuda_env.total_memory_in_GB - gb_used
|
|
273
|
+
metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0)
|
|
274
|
+
|
|
275
|
+
# extract private logs (usually only for valid steps) before logging
|
|
276
|
+
logging_outputs = list(
|
|
277
|
+
map(lambda x: {key: x[key] for key in x if not key.startswith("_")}, logging_outputs)
|
|
278
|
+
)
|
|
279
|
+
# log stats
|
|
280
|
+
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
|
|
281
|
+
|
|
282
|
+
metrics.log_stop_time("train_wall")
|
|
283
|
+
return logging_output
|
|
284
|
+
|
|
285
|
+
@metrics.aggregate("valid")
|
|
286
|
+
def valid_step(self, sample, subset=None):
|
|
287
|
+
"""Do forward pass in evaluation mode."""
|
|
288
|
+
with torch.no_grad():
|
|
289
|
+
self.model.eval()
|
|
290
|
+
self.criterion.eval()
|
|
291
|
+
|
|
292
|
+
sample = utils.prepare_sample(sample)
|
|
293
|
+
_loss, sample_size, logging_output = self.criterion(self.model, sample)
|
|
294
|
+
logging_outputs = [logging_output]
|
|
295
|
+
|
|
296
|
+
# gather logging outputs from all replicas
|
|
297
|
+
if self.data_parallel_world_size > 1:
|
|
298
|
+
logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
|
|
299
|
+
logging_outputs,
|
|
300
|
+
sample_size,
|
|
301
|
+
ignore=False,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# log validation stats
|
|
305
|
+
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
|
|
306
|
+
|
|
307
|
+
return _loss, sample_size, logging_outputs
|
|
308
|
+
|
|
309
|
+
def zero_grad(self):
|
|
310
|
+
self.optimizer.zero_grad()
|
|
311
|
+
|
|
312
|
+
def get_num_updates(self):
|
|
313
|
+
return self._num_updates
|
|
314
|
+
|
|
315
|
+
def set_num_updates(self, num_updates):
|
|
316
|
+
self._num_updates = num_updates
|
|
317
|
+
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
|
|
318
|
+
|
|
319
|
+
def cumulative_training_time(self):
|
|
320
|
+
if self._cumulative_training_time is None:
|
|
321
|
+
return self._local_cumulative_training_time()
|
|
322
|
+
else:
|
|
323
|
+
return self._cumulative_training_time
|
|
324
|
+
|
|
325
|
+
def _local_cumulative_training_time(self):
|
|
326
|
+
return time.time() - self._start_time + self._previous_training_time
|
|
327
|
+
|
|
328
|
+
def _set_seed(self):
|
|
329
|
+
seed = self.cfg.common.seed + self.get_num_updates()
|
|
330
|
+
utils.set_torch_seed(seed)
|
|
331
|
+
|
|
332
|
+
def _sync_stats(self):
|
|
333
|
+
if self.data_parallel_world_size == 1:
|
|
334
|
+
return False
|
|
335
|
+
else:
|
|
336
|
+
return True
|
|
337
|
+
|
|
338
|
+
def _aggregate_logging_outputs(
|
|
339
|
+
self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False
|
|
340
|
+
):
|
|
341
|
+
return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum, ignore=ignore)
|
|
342
|
+
|
|
343
|
+
def _all_gather_list_sync(self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False):
|
|
344
|
+
if ignore:
|
|
345
|
+
logging_outputs = []
|
|
346
|
+
results = list(
|
|
347
|
+
zip(
|
|
348
|
+
*distributed_utils.all_gather_list(
|
|
349
|
+
[logging_outputs] + list(extra_stats_to_sum),
|
|
350
|
+
max_size=getattr(self.cfg.common, "all_gather_list_size", 32768),
|
|
351
|
+
group=self.data_parallel_process_group,
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
logging_outputs, extra_stats_to_sum = results[0], results[1:]
|
|
356
|
+
logging_outputs = list(chain.from_iterable(logging_outputs))
|
|
357
|
+
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
|
|
358
|
+
return logging_outputs, extra_stats_to_sum
|
|
359
|
+
|
|
360
|
+
def _reduce_and_log_stats(self, logging_outputs, sample_size):
|
|
361
|
+
metrics.log_speed("ups", 1.0, priority=100, round=2)
|
|
362
|
+
|
|
363
|
+
with metrics.aggregate() as agg:
|
|
364
|
+
if logging_outputs is not None:
|
|
365
|
+
self.criterion.__class__.reduce_metrics(logging_outputs)
|
|
366
|
+
del logging_outputs
|
|
367
|
+
|
|
368
|
+
logging_output = agg.get_smoothed_values()
|
|
369
|
+
logging_output["sample_size"] = sample_size
|
|
370
|
+
return logging_output
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import traceback
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from genhpf.configs import CheckpointConfig
|
|
10
|
+
from genhpf.utils.file_io import PathManager
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch: int, val_loss: float):
|
|
16
|
+
from genhpf.loggings import meters
|
|
17
|
+
|
|
18
|
+
# only one worker should attempt to create the required dir
|
|
19
|
+
if trainer.data_parallel_rank == 0:
|
|
20
|
+
os.makedirs(cfg.save_dir, exist_ok=True)
|
|
21
|
+
|
|
22
|
+
prev_best = getattr(save_checkpoint, "best", val_loss)
|
|
23
|
+
if val_loss is not None:
|
|
24
|
+
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
|
25
|
+
save_checkpoint.best = best_function(val_loss, prev_best)
|
|
26
|
+
|
|
27
|
+
if cfg.no_save:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
if not trainer.is_data_parallel_master:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
write_timer = meters.StopwatchMeter()
|
|
34
|
+
write_timer.start()
|
|
35
|
+
|
|
36
|
+
updates = trainer.get_num_updates()
|
|
37
|
+
|
|
38
|
+
logger.info(f"preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
|
39
|
+
|
|
40
|
+
def is_better(a, b):
|
|
41
|
+
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
|
42
|
+
|
|
43
|
+
checkpoint_conds = collections.OrderedDict()
|
|
44
|
+
checkpoint_conds["checkpoint_{}.pt".format(epoch)] = epoch % cfg.save_interval == 0
|
|
45
|
+
checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
|
|
46
|
+
not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
checkpoint_conds["checkpoint_last.pt"] = not cfg.no_last_checkpoints
|
|
50
|
+
|
|
51
|
+
checkpoints = [os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
|
|
52
|
+
if len(checkpoints) > 0:
|
|
53
|
+
trainer.save_checkpoint(checkpoints[0])
|
|
54
|
+
for cp in checkpoints[1:]:
|
|
55
|
+
assert PathManager.copy(
|
|
56
|
+
checkpoints[0], cp, overwrite=True
|
|
57
|
+
), f"Failed to copy {checkpoints[0]} to {cp}"
|
|
58
|
+
|
|
59
|
+
write_timer.stop()
|
|
60
|
+
logger.info(
|
|
61
|
+
"Save checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
|
62
|
+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if cfg.keep_last_epochs > 0:
|
|
67
|
+
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
|
68
|
+
checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint_(\d+)\.pt")
|
|
69
|
+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
|
70
|
+
if os.path.lexists(old_chk):
|
|
71
|
+
os.remove(old_chk)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def torch_persistent_save(obj, filename, async_write: bool = False):
|
|
75
|
+
if async_write:
|
|
76
|
+
with PathManager.opena(filename, "wb") as f:
|
|
77
|
+
_torch_persistent_save(obj, f)
|
|
78
|
+
else:
|
|
79
|
+
if PathManager.supports_rename(filename):
|
|
80
|
+
# do atmoic save
|
|
81
|
+
with PathManager.open(filename + ".tmp", "wb") as f:
|
|
82
|
+
_torch_persistent_save(obj, f)
|
|
83
|
+
PathManager.rename(filename + ".tmp", filename)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _torch_persistent_save(obj, f):
|
|
87
|
+
if isinstance(f, str):
|
|
88
|
+
with PathManager.open(f, "wb") as h:
|
|
89
|
+
_torch_persistent_save(obj, h)
|
|
90
|
+
return
|
|
91
|
+
for i in range(3):
|
|
92
|
+
try:
|
|
93
|
+
return torch.save(obj, f)
|
|
94
|
+
except Exception:
|
|
95
|
+
if i == 2:
|
|
96
|
+
logger.error(traceback.format_exc())
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
|
100
|
+
"""Retrives all checkpoints found in `path` directory.
|
|
101
|
+
|
|
102
|
+
Checkpoints are identified by matching filename to the specified pattern. If
|
|
103
|
+
the pattern contains groups, the result will be sorted by the first group in
|
|
104
|
+
descending order
|
|
105
|
+
"""
|
|
106
|
+
pt_regexp = re.compile(pattern)
|
|
107
|
+
files = PathManager.ls(path)
|
|
108
|
+
|
|
109
|
+
entries = []
|
|
110
|
+
for i, f in enumerate(files):
|
|
111
|
+
m = pt_regexp.fullmatch(f)
|
|
112
|
+
if m is not None:
|
|
113
|
+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
|
114
|
+
entries.append((idx, m.group(0)))
|
|
115
|
+
if keep_match:
|
|
116
|
+
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
|
117
|
+
else:
|
|
118
|
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def load_checkpoint_to_cpu(path, load_on_all_ranks=False):
|
|
122
|
+
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
|
123
|
+
|
|
124
|
+
If doing single-GPU training or if the checkpoint is only being loaded by at
|
|
125
|
+
most one process on each node (current default behavior is for only rank 0
|
|
126
|
+
to read the checkpoint from disk), load_on_all_ranks should be False to
|
|
127
|
+
avoid errors from torch.distributed not having been initialized or
|
|
128
|
+
torch.distributed.barrier() hanging.
|
|
129
|
+
|
|
130
|
+
If all processes on each node may be loading the checkpoint
|
|
131
|
+
simultaneously, load_no_all_ranks should be set to True to avoid I/O
|
|
132
|
+
conflicts.
|
|
133
|
+
|
|
134
|
+
There's currently no support for > 1 but < all process loading the
|
|
135
|
+
checkpoint on each node.
|
|
136
|
+
"""
|
|
137
|
+
local_path = PathManager.get_local_path(path)
|
|
138
|
+
# The locally cached file returned by get_local_path() may be stable for
|
|
139
|
+
# remote files that are periodically updated/overwritten (ex:
|
|
140
|
+
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
|
141
|
+
# (if needed), and then download a fresh copy.
|
|
142
|
+
if local_path != path and PathManager.path_requires_pathmanager(path):
|
|
143
|
+
try:
|
|
144
|
+
os.remove(local_path)
|
|
145
|
+
except FileNotFoundError:
|
|
146
|
+
# With potentially multiple processes removing the same file, the
|
|
147
|
+
# file being missing is benign (missing_ok isn't available until
|
|
148
|
+
# Python 3.8).
|
|
149
|
+
pass
|
|
150
|
+
if load_on_all_ranks:
|
|
151
|
+
torch.distributed.barrier()
|
|
152
|
+
local_path = PathManager.get_local_path(path)
|
|
153
|
+
|
|
154
|
+
with open(local_path, "rb") as f:
|
|
155
|
+
state = torch.load(f, map_location=torch.device("cpu"))
|
|
156
|
+
|
|
157
|
+
return state
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def verify_checkpoint_directory(save_dir: str) -> None:
|
|
161
|
+
if not os.path.exists(save_dir):
|
|
162
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
163
|
+
temp_file_path = os.path.join(save_dir, "dummy")
|
|
164
|
+
try:
|
|
165
|
+
with open(temp_file_path, "w"):
|
|
166
|
+
pass
|
|
167
|
+
except OSError as e:
|
|
168
|
+
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
|
|
169
|
+
raise e
|
|
170
|
+
else:
|
|
171
|
+
os.remove(temp_file_path)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
def compute_mask_indices(
|
|
7
|
+
shape: Tuple[int, int],
|
|
8
|
+
padding_mask: Optional[torch.Tensor],
|
|
9
|
+
mask_prob: float,
|
|
10
|
+
mask_length: int,
|
|
11
|
+
mask_type: str = "static",
|
|
12
|
+
mask_other: float = 0.0,
|
|
13
|
+
min_masks: int = 0,
|
|
14
|
+
no_overlap: bool = False,
|
|
15
|
+
min_space: int = 0,
|
|
16
|
+
) -> np.ndarray:
|
|
17
|
+
"""
|
|
18
|
+
Computes random mask spans for a given shape
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
shape: the the shape for which to compute masks.
|
|
22
|
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
|
23
|
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
|
24
|
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
|
25
|
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
|
26
|
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
|
27
|
+
mask_type: how to compute mask lengths
|
|
28
|
+
static = fixed size
|
|
29
|
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
|
30
|
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
|
31
|
+
poisson = sample from possion distribution with lambda = mask length
|
|
32
|
+
min_masks: minimum number of masked spans
|
|
33
|
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
|
34
|
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
bsz, all_sz = shape
|
|
38
|
+
mask = np.full((bsz, all_sz), False)
|
|
39
|
+
|
|
40
|
+
all_num_mask = int(
|
|
41
|
+
# add a random number for probabilistic rounding
|
|
42
|
+
mask_prob * all_sz / float(mask_length)
|
|
43
|
+
+ np.random.rand()
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
all_num_mask = max(min_masks, all_num_mask)
|
|
47
|
+
|
|
48
|
+
mask_idcs = []
|
|
49
|
+
for i in range(bsz):
|
|
50
|
+
if padding_mask is not None:
|
|
51
|
+
sz = all_sz - padding_mask[i].long().sum().item()
|
|
52
|
+
num_mask = int(
|
|
53
|
+
# add a random number for probabilistic rounding
|
|
54
|
+
mask_prob * sz / float(mask_length)
|
|
55
|
+
+ np.random.rand()
|
|
56
|
+
)
|
|
57
|
+
num_mask = max(min_masks, num_mask)
|
|
58
|
+
else:
|
|
59
|
+
sz = all_sz
|
|
60
|
+
num_mask = all_num_mask
|
|
61
|
+
|
|
62
|
+
if mask_type == "static":
|
|
63
|
+
lengths = np.full(num_mask, mask_length)
|
|
64
|
+
elif mask_type == "uniform":
|
|
65
|
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
|
66
|
+
elif mask_type == "normal":
|
|
67
|
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
|
68
|
+
lengths = [max(1, int(round(x))) for x in lengths]
|
|
69
|
+
elif mask_type == "poisson":
|
|
70
|
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
|
71
|
+
lengths = [int(round(x)) for x in lengths]
|
|
72
|
+
else:
|
|
73
|
+
raise Exception("unknown mask selection " + mask_type)
|
|
74
|
+
|
|
75
|
+
if sum(lengths) == 0:
|
|
76
|
+
lengths[0] = min(mask_length, sz - 1)
|
|
77
|
+
|
|
78
|
+
if no_overlap:
|
|
79
|
+
mask_idc = []
|
|
80
|
+
|
|
81
|
+
def arrange(s, e, length, keep_length):
|
|
82
|
+
span_start = np.random.randint(s, e - length)
|
|
83
|
+
mask_idc.extend(span_start + i for i in range(length))
|
|
84
|
+
|
|
85
|
+
new_parts = []
|
|
86
|
+
if span_start - s - min_space >= keep_length:
|
|
87
|
+
new_parts.append((s, span_start - min_space + 1))
|
|
88
|
+
if e - span_start - keep_length - min_space > keep_length:
|
|
89
|
+
new_parts.append((span_start + length + min_space, e))
|
|
90
|
+
return new_parts
|
|
91
|
+
|
|
92
|
+
parts = [(0, sz)]
|
|
93
|
+
min_length = min(lengths)
|
|
94
|
+
for length in sorted(lengths, reverse=True):
|
|
95
|
+
lens = np.fromiter(
|
|
96
|
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
|
97
|
+
np.int,
|
|
98
|
+
)
|
|
99
|
+
l_sum = np.sum(lens)
|
|
100
|
+
if l_sum == 0:
|
|
101
|
+
break
|
|
102
|
+
probs = lens / np.sum(lens)
|
|
103
|
+
c = np.random.choice(len(parts), p=probs)
|
|
104
|
+
s, e = parts.pop(c)
|
|
105
|
+
parts.extend(arrange(s, e, length, min_length))
|
|
106
|
+
mask_idc = np.asarray(mask_idc)
|
|
107
|
+
else:
|
|
108
|
+
min_len = min(lengths)
|
|
109
|
+
if sz - min_len <= num_mask:
|
|
110
|
+
min_len = sz - num_mask - 1
|
|
111
|
+
|
|
112
|
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
|
113
|
+
|
|
114
|
+
mask_idc = np.asarray(
|
|
115
|
+
[
|
|
116
|
+
mask_idc[j] + offset
|
|
117
|
+
for j in range(len(mask_idc))
|
|
118
|
+
for offset in range(lengths[j])
|
|
119
|
+
]
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
|
123
|
+
|
|
124
|
+
min_len = min([len(m) for m in mask_idcs])
|
|
125
|
+
for i, mask_idc in enumerate(mask_idcs):
|
|
126
|
+
if len(mask_idc) > min_len:
|
|
127
|
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
|
128
|
+
mask[i, mask_idc] = True
|
|
129
|
+
|
|
130
|
+
return mask
|