genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/trainer.py ADDED
@@ -0,0 +1,370 @@
1
+ import logging
2
+ import time
3
+ from itertools import chain
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.utils
9
+ import torch.utils.data
10
+ from omegaconf import OmegaConf
11
+ from torch.optim import Adam
12
+
13
+ from genhpf.configs import Config
14
+ from genhpf.datasets import BaseDataset
15
+ from genhpf.loggings import metrics
16
+ from genhpf.utils import checkpoint_utils, distributed_utils, utils
17
+ from genhpf.utils.file_io import PathManager
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class Trainer(object):
23
+ def __init__(self, cfg: Config, model, criterion):
24
+ self.cfg = cfg
25
+ self.cuda = torch.cuda.is_available()
26
+ if self.cuda:
27
+ self.device = torch.device("cuda")
28
+ else:
29
+ self.device = torch.device("cpu")
30
+
31
+ self._criterion = criterion
32
+ self._model = model
33
+
34
+ self._criterion = self._criterion.to(device=self.device)
35
+ self._model = self._model.to(device=self.device)
36
+
37
+ self._num_updates = 0
38
+
39
+ self._optimizer = None
40
+ self._wrapped_criterion = None
41
+ self._wrapped_model = None
42
+
43
+ if self.cuda:
44
+ self.cuda_env = utils.CudaEnvironment()
45
+ if self.data_parallel_world_size > 1:
46
+ self.cuda_env_arr = distributed_utils.all_gather_list(
47
+ self.cuda_env, group=distributed_utils.get_global_group()
48
+ )
49
+ else:
50
+ self.cuda_env_arr = [self.cuda_env]
51
+ if self.data_parallel_rank == 0:
52
+ utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
53
+ else:
54
+ self.cuda_env = None
55
+ self.cuda_env_arr = None
56
+
57
+ metrics.log_start_time("wall", priority=790, round=0)
58
+
59
+ self._start_time = time.time()
60
+ self._previous_training_time = 0
61
+ self._cumulative_training_time = None
62
+
63
+ @property
64
+ def data_parallel_world_size(self):
65
+ if self.cfg.distributed_training.distributed_world_size == 1:
66
+ return 1
67
+ return distributed_utils.get_data_parallel_world_size()
68
+
69
+ @property
70
+ def data_parallel_process_group(self):
71
+ return distributed_utils.get_data_parallel_group()
72
+
73
+ @property
74
+ def data_parallel_rank(self):
75
+ if self.cfg.distributed_training.distributed_world_size == 1:
76
+ return 0
77
+ return distributed_utils.get_data_parallel_rank()
78
+
79
+ @property
80
+ def is_data_parallel_master(self):
81
+ return self.data_parallel_rank == 0
82
+
83
+ @property
84
+ def use_distributed_wrapper(self) -> bool:
85
+ return self.data_parallel_world_size > 1
86
+
87
+ @property
88
+ def criterion(self):
89
+ if self._wrapped_criterion is None:
90
+ if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
91
+ self._wrapped_criterion = distributed_utils.DistributedModel(
92
+ self.cfg.distributed_training,
93
+ self._criterion,
94
+ process_group=self.data_parallel_process_group,
95
+ device=self.device,
96
+ )
97
+ else:
98
+ self._wrapped_criterion = self._criterion
99
+ return self._wrapped_criterion
100
+
101
+ @property
102
+ def model(self):
103
+ if self._wrapped_model is None:
104
+ if self.use_distributed_wrapper:
105
+ self._wrapped_model = distributed_utils.DistributedModel(
106
+ self.cfg.distributed_training,
107
+ self._model,
108
+ process_group=self.data_parallel_process_group,
109
+ device=self.device,
110
+ )
111
+ else:
112
+ self._wrapped_model = self._model
113
+ return self._wrapped_model
114
+
115
+ @property
116
+ def optimizer(self):
117
+ if self._optimizer is None:
118
+ self._build_optimizer()
119
+ return self._optimizer
120
+
121
+ def _build_optimizer(self):
122
+ params = list(
123
+ filter(lambda p: p.requires_grad, chain(self.model.parameters(), self.criterion.parameters()))
124
+ )
125
+ self._optimizer = Adam(
126
+ params,
127
+ lr=self.cfg.optimization.lr,
128
+ betas=self.cfg.optimization.adam_betas,
129
+ eps=self.cfg.optimization.adam_eps,
130
+ weight_decay=self.cfg.optimization.weight_decay,
131
+ )
132
+
133
+ def state_dict(self):
134
+ pretraining_parameter_names = []
135
+ if hasattr(self.model, "get_pretraining_parameter_names"):
136
+ pretraining_parameter_names = self.model.get_pretraining_parameter_names()
137
+ model_state_dict = {
138
+ k: v for k, v in self.model.state_dict().items() if k not in pretraining_parameter_names
139
+ }
140
+
141
+ state_dict = {
142
+ "cfg": (
143
+ OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
144
+ if OmegaConf.is_config(self.cfg)
145
+ else self.cfg
146
+ ),
147
+ "model": model_state_dict,
148
+ }
149
+
150
+ return state_dict
151
+
152
+ def save_checkpoint(self, filename):
153
+ """Save all training state in a checkpoint file."""
154
+ logger.info(f"Saving checkpoint to {filename}")
155
+
156
+ state_dict = utils.move_to_cpu(self.state_dict())
157
+ if self.is_data_parallel_master:
158
+ checkpoint_utils.torch_persistent_save(state_dict, filename, async_write=False)
159
+ logger.info(f"Finished saving checkpoint to {filename}")
160
+
161
+ def load_checkpoint(self, filename) -> None:
162
+ """
163
+ Load all training state from a checkpoint file.
164
+ rank = 0 will load the checkpoint and broadcast it to other ranks.
165
+ """
166
+ logger.info(f"Loading checkpoint from {filename}")
167
+ is_distributed = self.data_parallel_world_size > 1
168
+ bexists = PathManager.isfile(filename)
169
+ if bexists:
170
+ if self.data_parallel_rank == 0:
171
+ state = checkpoint_utils.load_checkpoint_to_cpu(filename)
172
+ else:
173
+ state = None
174
+
175
+ if is_distributed:
176
+ state = distributed_utils.broadcast_object(
177
+ state, src_rank=0, group=self.data_parallel_process_group, dist_device=self.device
178
+ )
179
+
180
+ # load model parameters
181
+ try:
182
+ self.model.load_state_dict(state["model"], strict=True)
183
+ # save memory for later steps
184
+ del state["model"]
185
+ except Exception:
186
+ raise Exception(
187
+ f"Cannot load model parameters from checkpoint {filename}; "
188
+ "please ensure that the architectures match."
189
+ )
190
+ else:
191
+ logger.info(f"No existing checkpoint found {filename}")
192
+
193
+ def get_train_iterator(
194
+ self,
195
+ dataset: BaseDataset,
196
+ ) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Sampler]:
197
+ """Return an DataLoader instance for training."""
198
+ batch_sampler = (
199
+ torch.utils.data.DistributedSampler(dataset, shuffle=True) if dist.is_initialized() else None
200
+ )
201
+ batch_iterator = torch.utils.data.DataLoader(
202
+ dataset,
203
+ batch_size=self.cfg.dataset.batch_size,
204
+ shuffle=True if not dist.is_initialized() else False,
205
+ num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
206
+ collate_fn=dataset.collator,
207
+ sampler=batch_sampler,
208
+ )
209
+ return batch_iterator, batch_sampler
210
+
211
+ def get_valid_iterator(
212
+ self,
213
+ dataset: BaseDataset,
214
+ ) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Sampler]:
215
+ """Return an DataLoader instance for validation."""
216
+ batch_sampler = (
217
+ torch.utils.data.DistributedSampler(dataset, shuffle=False) if dist.is_initialized() else None
218
+ )
219
+ batch_iterator = torch.utils.data.DataLoader(
220
+ dataset,
221
+ batch_size=self.cfg.dataset.batch_size,
222
+ shuffle=False,
223
+ num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
224
+ collate_fn=dataset.collator,
225
+ sampler=batch_sampler,
226
+ )
227
+ return batch_iterator, batch_sampler
228
+
229
+ @metrics.aggregate("train")
230
+ def train_step(self, sample):
231
+ """Do forward, backward and optimization steps."""
232
+
233
+ self._set_seed()
234
+ self.model.train()
235
+ self.criterion.train()
236
+ self.zero_grad()
237
+
238
+ metrics.log_start_time("train_wall", priority=800, round=0)
239
+
240
+ logging_outputs = []
241
+ sample = utils.prepare_sample(sample)
242
+
243
+ loss, sample_size, logging_output = self.criterion(self.model, sample)
244
+ if loss.item() > 0:
245
+ loss.backward() # backward
246
+ self.optimizer.step()
247
+
248
+ logging_outputs.append(logging_output)
249
+
250
+ # emptying the CUDA cache after the first step can
251
+ # reduce the chance of OOM
252
+ if self.cuda and self.get_num_updates() == 0:
253
+ torch.cuda.empty_cache()
254
+
255
+ sample_size = float(sample_size)
256
+
257
+ if self._sync_stats():
258
+ train_time = self._local_cumulative_training_time()
259
+
260
+ logging_outputs, (sample_size, total_train_time) = self._aggregate_logging_outputs(
261
+ logging_outputs, sample_size, train_time
262
+ )
263
+
264
+ self._cumulative_training_time = total_train_time / self.data_parallel_world_size
265
+
266
+ logging_output = None
267
+ self.set_num_updates(self.get_num_updates() + 1)
268
+
269
+ if self.cuda and self.cuda_env is not None:
270
+ gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
271
+ torch.cuda.reset_peak_memory_stats()
272
+ gb_free = self.cuda_env.total_memory_in_GB - gb_used
273
+ metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0)
274
+
275
+ # extract private logs (usually only for valid steps) before logging
276
+ logging_outputs = list(
277
+ map(lambda x: {key: x[key] for key in x if not key.startswith("_")}, logging_outputs)
278
+ )
279
+ # log stats
280
+ logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
281
+
282
+ metrics.log_stop_time("train_wall")
283
+ return logging_output
284
+
285
+ @metrics.aggregate("valid")
286
+ def valid_step(self, sample, subset=None):
287
+ """Do forward pass in evaluation mode."""
288
+ with torch.no_grad():
289
+ self.model.eval()
290
+ self.criterion.eval()
291
+
292
+ sample = utils.prepare_sample(sample)
293
+ _loss, sample_size, logging_output = self.criterion(self.model, sample)
294
+ logging_outputs = [logging_output]
295
+
296
+ # gather logging outputs from all replicas
297
+ if self.data_parallel_world_size > 1:
298
+ logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
299
+ logging_outputs,
300
+ sample_size,
301
+ ignore=False,
302
+ )
303
+
304
+ # log validation stats
305
+ logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
306
+
307
+ return _loss, sample_size, logging_outputs
308
+
309
+ def zero_grad(self):
310
+ self.optimizer.zero_grad()
311
+
312
+ def get_num_updates(self):
313
+ return self._num_updates
314
+
315
+ def set_num_updates(self, num_updates):
316
+ self._num_updates = num_updates
317
+ metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
318
+
319
+ def cumulative_training_time(self):
320
+ if self._cumulative_training_time is None:
321
+ return self._local_cumulative_training_time()
322
+ else:
323
+ return self._cumulative_training_time
324
+
325
+ def _local_cumulative_training_time(self):
326
+ return time.time() - self._start_time + self._previous_training_time
327
+
328
+ def _set_seed(self):
329
+ seed = self.cfg.common.seed + self.get_num_updates()
330
+ utils.set_torch_seed(seed)
331
+
332
+ def _sync_stats(self):
333
+ if self.data_parallel_world_size == 1:
334
+ return False
335
+ else:
336
+ return True
337
+
338
+ def _aggregate_logging_outputs(
339
+ self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False
340
+ ):
341
+ return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum, ignore=ignore)
342
+
343
+ def _all_gather_list_sync(self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False):
344
+ if ignore:
345
+ logging_outputs = []
346
+ results = list(
347
+ zip(
348
+ *distributed_utils.all_gather_list(
349
+ [logging_outputs] + list(extra_stats_to_sum),
350
+ max_size=getattr(self.cfg.common, "all_gather_list_size", 32768),
351
+ group=self.data_parallel_process_group,
352
+ )
353
+ )
354
+ )
355
+ logging_outputs, extra_stats_to_sum = results[0], results[1:]
356
+ logging_outputs = list(chain.from_iterable(logging_outputs))
357
+ extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
358
+ return logging_outputs, extra_stats_to_sum
359
+
360
+ def _reduce_and_log_stats(self, logging_outputs, sample_size):
361
+ metrics.log_speed("ups", 1.0, priority=100, round=2)
362
+
363
+ with metrics.aggregate() as agg:
364
+ if logging_outputs is not None:
365
+ self.criterion.__class__.reduce_metrics(logging_outputs)
366
+ del logging_outputs
367
+
368
+ logging_output = agg.get_smoothed_values()
369
+ logging_output["sample_size"] = sample_size
370
+ return logging_output
@@ -0,0 +1,171 @@
1
+ import collections
2
+ import logging
3
+ import os
4
+ import re
5
+ import traceback
6
+
7
+ import torch
8
+
9
+ from genhpf.configs import CheckpointConfig
10
+ from genhpf.utils.file_io import PathManager
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch: int, val_loss: float):
16
+ from genhpf.loggings import meters
17
+
18
+ # only one worker should attempt to create the required dir
19
+ if trainer.data_parallel_rank == 0:
20
+ os.makedirs(cfg.save_dir, exist_ok=True)
21
+
22
+ prev_best = getattr(save_checkpoint, "best", val_loss)
23
+ if val_loss is not None:
24
+ best_function = max if cfg.maximize_best_checkpoint_metric else min
25
+ save_checkpoint.best = best_function(val_loss, prev_best)
26
+
27
+ if cfg.no_save:
28
+ return
29
+
30
+ if not trainer.is_data_parallel_master:
31
+ return
32
+
33
+ write_timer = meters.StopwatchMeter()
34
+ write_timer.start()
35
+
36
+ updates = trainer.get_num_updates()
37
+
38
+ logger.info(f"preparing to save checkpoint for epoch {epoch} @ {updates} updates")
39
+
40
+ def is_better(a, b):
41
+ return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
42
+
43
+ checkpoint_conds = collections.OrderedDict()
44
+ checkpoint_conds["checkpoint_{}.pt".format(epoch)] = epoch % cfg.save_interval == 0
45
+ checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
46
+ not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)
47
+ )
48
+
49
+ checkpoint_conds["checkpoint_last.pt"] = not cfg.no_last_checkpoints
50
+
51
+ checkpoints = [os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
52
+ if len(checkpoints) > 0:
53
+ trainer.save_checkpoint(checkpoints[0])
54
+ for cp in checkpoints[1:]:
55
+ assert PathManager.copy(
56
+ checkpoints[0], cp, overwrite=True
57
+ ), f"Failed to copy {checkpoints[0]} to {cp}"
58
+
59
+ write_timer.stop()
60
+ logger.info(
61
+ "Save checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
62
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
63
+ )
64
+ )
65
+
66
+ if cfg.keep_last_epochs > 0:
67
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
68
+ checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint_(\d+)\.pt")
69
+ for old_chk in checkpoints[cfg.keep_last_epochs :]:
70
+ if os.path.lexists(old_chk):
71
+ os.remove(old_chk)
72
+
73
+
74
+ def torch_persistent_save(obj, filename, async_write: bool = False):
75
+ if async_write:
76
+ with PathManager.opena(filename, "wb") as f:
77
+ _torch_persistent_save(obj, f)
78
+ else:
79
+ if PathManager.supports_rename(filename):
80
+ # do atmoic save
81
+ with PathManager.open(filename + ".tmp", "wb") as f:
82
+ _torch_persistent_save(obj, f)
83
+ PathManager.rename(filename + ".tmp", filename)
84
+
85
+
86
+ def _torch_persistent_save(obj, f):
87
+ if isinstance(f, str):
88
+ with PathManager.open(f, "wb") as h:
89
+ _torch_persistent_save(obj, h)
90
+ return
91
+ for i in range(3):
92
+ try:
93
+ return torch.save(obj, f)
94
+ except Exception:
95
+ if i == 2:
96
+ logger.error(traceback.format_exc())
97
+
98
+
99
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
100
+ """Retrives all checkpoints found in `path` directory.
101
+
102
+ Checkpoints are identified by matching filename to the specified pattern. If
103
+ the pattern contains groups, the result will be sorted by the first group in
104
+ descending order
105
+ """
106
+ pt_regexp = re.compile(pattern)
107
+ files = PathManager.ls(path)
108
+
109
+ entries = []
110
+ for i, f in enumerate(files):
111
+ m = pt_regexp.fullmatch(f)
112
+ if m is not None:
113
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
114
+ entries.append((idx, m.group(0)))
115
+ if keep_match:
116
+ return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
117
+ else:
118
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
119
+
120
+
121
+ def load_checkpoint_to_cpu(path, load_on_all_ranks=False):
122
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility).
123
+
124
+ If doing single-GPU training or if the checkpoint is only being loaded by at
125
+ most one process on each node (current default behavior is for only rank 0
126
+ to read the checkpoint from disk), load_on_all_ranks should be False to
127
+ avoid errors from torch.distributed not having been initialized or
128
+ torch.distributed.barrier() hanging.
129
+
130
+ If all processes on each node may be loading the checkpoint
131
+ simultaneously, load_no_all_ranks should be set to True to avoid I/O
132
+ conflicts.
133
+
134
+ There's currently no support for > 1 but < all process loading the
135
+ checkpoint on each node.
136
+ """
137
+ local_path = PathManager.get_local_path(path)
138
+ # The locally cached file returned by get_local_path() may be stable for
139
+ # remote files that are periodically updated/overwritten (ex:
140
+ # checkpoint_last.pt) - so we remove the local copy, sync across processes
141
+ # (if needed), and then download a fresh copy.
142
+ if local_path != path and PathManager.path_requires_pathmanager(path):
143
+ try:
144
+ os.remove(local_path)
145
+ except FileNotFoundError:
146
+ # With potentially multiple processes removing the same file, the
147
+ # file being missing is benign (missing_ok isn't available until
148
+ # Python 3.8).
149
+ pass
150
+ if load_on_all_ranks:
151
+ torch.distributed.barrier()
152
+ local_path = PathManager.get_local_path(path)
153
+
154
+ with open(local_path, "rb") as f:
155
+ state = torch.load(f, map_location=torch.device("cpu"))
156
+
157
+ return state
158
+
159
+
160
+ def verify_checkpoint_directory(save_dir: str) -> None:
161
+ if not os.path.exists(save_dir):
162
+ os.makedirs(save_dir, exist_ok=True)
163
+ temp_file_path = os.path.join(save_dir, "dummy")
164
+ try:
165
+ with open(temp_file_path, "w"):
166
+ pass
167
+ except OSError as e:
168
+ logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
169
+ raise e
170
+ else:
171
+ os.remove(temp_file_path)
@@ -0,0 +1,130 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ all_num_mask = int(
41
+ # add a random number for probabilistic rounding
42
+ mask_prob * all_sz / float(mask_length)
43
+ + np.random.rand()
44
+ )
45
+
46
+ all_num_mask = max(min_masks, all_num_mask)
47
+
48
+ mask_idcs = []
49
+ for i in range(bsz):
50
+ if padding_mask is not None:
51
+ sz = all_sz - padding_mask[i].long().sum().item()
52
+ num_mask = int(
53
+ # add a random number for probabilistic rounding
54
+ mask_prob * sz / float(mask_length)
55
+ + np.random.rand()
56
+ )
57
+ num_mask = max(min_masks, num_mask)
58
+ else:
59
+ sz = all_sz
60
+ num_mask = all_num_mask
61
+
62
+ if mask_type == "static":
63
+ lengths = np.full(num_mask, mask_length)
64
+ elif mask_type == "uniform":
65
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
66
+ elif mask_type == "normal":
67
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
68
+ lengths = [max(1, int(round(x))) for x in lengths]
69
+ elif mask_type == "poisson":
70
+ lengths = np.random.poisson(mask_length, size=num_mask)
71
+ lengths = [int(round(x)) for x in lengths]
72
+ else:
73
+ raise Exception("unknown mask selection " + mask_type)
74
+
75
+ if sum(lengths) == 0:
76
+ lengths[0] = min(mask_length, sz - 1)
77
+
78
+ if no_overlap:
79
+ mask_idc = []
80
+
81
+ def arrange(s, e, length, keep_length):
82
+ span_start = np.random.randint(s, e - length)
83
+ mask_idc.extend(span_start + i for i in range(length))
84
+
85
+ new_parts = []
86
+ if span_start - s - min_space >= keep_length:
87
+ new_parts.append((s, span_start - min_space + 1))
88
+ if e - span_start - keep_length - min_space > keep_length:
89
+ new_parts.append((span_start + length + min_space, e))
90
+ return new_parts
91
+
92
+ parts = [(0, sz)]
93
+ min_length = min(lengths)
94
+ for length in sorted(lengths, reverse=True):
95
+ lens = np.fromiter(
96
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
97
+ np.int,
98
+ )
99
+ l_sum = np.sum(lens)
100
+ if l_sum == 0:
101
+ break
102
+ probs = lens / np.sum(lens)
103
+ c = np.random.choice(len(parts), p=probs)
104
+ s, e = parts.pop(c)
105
+ parts.extend(arrange(s, e, length, min_length))
106
+ mask_idc = np.asarray(mask_idc)
107
+ else:
108
+ min_len = min(lengths)
109
+ if sz - min_len <= num_mask:
110
+ min_len = sz - num_mask - 1
111
+
112
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
113
+
114
+ mask_idc = np.asarray(
115
+ [
116
+ mask_idc[j] + offset
117
+ for j in range(len(mask_idc))
118
+ for offset in range(lengths[j])
119
+ ]
120
+ )
121
+
122
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
123
+
124
+ min_len = min([len(m) for m in mask_idcs])
125
+ for i, mask_idc in enumerate(mask_idcs):
126
+ if len(mask_idc) > min_len:
127
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
128
+ mask[i, mask_idc] = True
129
+
130
+ return mask