genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,497 @@
|
|
|
1
|
+
import io
|
|
2
|
+
from typing import Optional, List, Any
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import socket
|
|
7
|
+
import struct
|
|
8
|
+
import pickle
|
|
9
|
+
import warnings
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.distributed as dist
|
|
15
|
+
|
|
16
|
+
from torch.nn.parallel import DistributedDataParallel
|
|
17
|
+
|
|
18
|
+
from genhpf.configs import Config, DistributedTrainingConfig
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
class ModuleProxyWrapper(nn.Module):
|
|
23
|
+
"""
|
|
24
|
+
Wrap a DistributedDataParallel module and forward requests for missing
|
|
25
|
+
attributes to the module wrapped by DDP (the twice-wrapped module).
|
|
26
|
+
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
|
|
27
|
+
|
|
28
|
+
Usage::
|
|
29
|
+
|
|
30
|
+
module.xyz = "hello world"
|
|
31
|
+
wrapped_module = DistributedDataParallel(module, **ddp_args)
|
|
32
|
+
wrapped_module = ModuleProxyWrapper(wrapped_module)
|
|
33
|
+
assert wrapped_module.xyz == "hello world"
|
|
34
|
+
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
module (nn.Module): module to wrap
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, module: nn.Module):
|
|
41
|
+
super().__init__()
|
|
42
|
+
assert hasattr(module, "module"), \
|
|
43
|
+
"ModuleProxyWrapper expects input to wrap another module"
|
|
44
|
+
self.module = module
|
|
45
|
+
|
|
46
|
+
def __getattr__(self, name):
|
|
47
|
+
"""Forward missing attributes to twice-wrapped module."""
|
|
48
|
+
try:
|
|
49
|
+
# defer to nn.Module's logic
|
|
50
|
+
return super().__getattr__(name)
|
|
51
|
+
except AttributeError:
|
|
52
|
+
try:
|
|
53
|
+
# forward to the once-wrapped module
|
|
54
|
+
return getattr(self.module, name)
|
|
55
|
+
except AttributeError:
|
|
56
|
+
# forward to the twice-wrapped module
|
|
57
|
+
return getattr(self.module.module, name)
|
|
58
|
+
|
|
59
|
+
def state_dict(self, *args, **kwargs):
|
|
60
|
+
"""Forward to the twice-wrapped module."""
|
|
61
|
+
return self.module.module.state_dict(*args, **kwargs)
|
|
62
|
+
|
|
63
|
+
def load_state_dict(self, *args, **kwargs):
|
|
64
|
+
"""Forward to the twice-wrapped module."""
|
|
65
|
+
return self.module.module.load_state_dict(*args, **kwargs)
|
|
66
|
+
|
|
67
|
+
def forward(self, *args, **kwargs):
|
|
68
|
+
return self.module(*args, **kwargs)
|
|
69
|
+
|
|
70
|
+
def DistributedModel(
|
|
71
|
+
args,
|
|
72
|
+
model,
|
|
73
|
+
process_group,
|
|
74
|
+
device
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Wrap a *model* to support distributed data parallel training.
|
|
78
|
+
|
|
79
|
+
This is similar to the built-in DistributedDataParallel, but allows
|
|
80
|
+
additional configuration of the DistributedDataParallel class to
|
|
81
|
+
use, and also provides easier access to the wrapped model by
|
|
82
|
+
forwarding requests for missing attributes to the wrapped model.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
args (argparse.Namespace): genhpf model args
|
|
86
|
+
model (BaseModel): model to wrap
|
|
87
|
+
process_group: the c10d process group to be used for distributed data
|
|
88
|
+
parallel all-reduction
|
|
89
|
+
device: device to move model to
|
|
90
|
+
"""
|
|
91
|
+
assert isinstance(model, nn.Module)
|
|
92
|
+
|
|
93
|
+
wrapped_model = DistributedDataParallel(
|
|
94
|
+
module = model.to(device),
|
|
95
|
+
device_ids = [args.device_id],
|
|
96
|
+
output_device = args.device_id,
|
|
97
|
+
broadcast_buffers = args.broadcast_buffers,
|
|
98
|
+
bucket_cap_mb = args.bucket_cap_mb,
|
|
99
|
+
process_group = process_group,
|
|
100
|
+
find_unused_parameters = args.find_unused_parameters
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if args.ddp_comm_hook == "fp16":
|
|
104
|
+
logger.info("enable fp16 communication hook in DDP")
|
|
105
|
+
try:
|
|
106
|
+
from torch.distributed.algorithms.ddp_comm_hooks import (
|
|
107
|
+
register_ddp_comm_hook,
|
|
108
|
+
DDPCommHookType,
|
|
109
|
+
)
|
|
110
|
+
except:
|
|
111
|
+
logger.error(
|
|
112
|
+
"Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version"
|
|
113
|
+
)
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model)
|
|
117
|
+
|
|
118
|
+
# forward missing getattr and state_dict/load_state_dict to orig model
|
|
119
|
+
wrapped_model = ModuleProxyWrapper(wrapped_model)
|
|
120
|
+
|
|
121
|
+
return wrapped_model
|
|
122
|
+
|
|
123
|
+
def is_master(cfg: DistributedTrainingConfig):
|
|
124
|
+
return cfg.distributed_rank == 0
|
|
125
|
+
|
|
126
|
+
def infer_init_method(cfg: DistributedTrainingConfig):
|
|
127
|
+
assert (
|
|
128
|
+
cfg.distributed_world_size <= torch.cuda.device_count()
|
|
129
|
+
), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
|
|
130
|
+
port = random.randint(10000, 20000)
|
|
131
|
+
cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
|
|
132
|
+
|
|
133
|
+
def distributed_init(cfg: Config):
|
|
134
|
+
if dist.is_available() and dist.is_initialized():
|
|
135
|
+
warnings.warn(
|
|
136
|
+
"Distributed is already initialized, cannot initialize twice!"
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
logger.info(
|
|
140
|
+
"distributed init (rank {}): {}".format(
|
|
141
|
+
cfg.distributed_training.distributed_rank,
|
|
142
|
+
cfg.distributed_training.distributed_init_method
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
dist.init_process_group(
|
|
146
|
+
backend="nccl",
|
|
147
|
+
init_method=cfg.distributed_training.distributed_init_method,
|
|
148
|
+
world_size=cfg.distributed_training.distributed_world_size,
|
|
149
|
+
rank=cfg.distributed_training.distributed_rank
|
|
150
|
+
)
|
|
151
|
+
logger.info(
|
|
152
|
+
"initialized host {} as rank {}".format(
|
|
153
|
+
socket.gethostname(),
|
|
154
|
+
cfg.distributed_training.distributed_rank
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
#perform a dummy all-reduce to initialize the NCCL communicator
|
|
159
|
+
if torch.cuda.is_available():
|
|
160
|
+
dist.all_reduce(torch.zeros(1).cuda())
|
|
161
|
+
|
|
162
|
+
cfg.distributed_training.distributed_rank = dist.get_rank()
|
|
163
|
+
|
|
164
|
+
if is_master(cfg.distributed_training):
|
|
165
|
+
logging.getLogger().setLevel(logging.INFO)
|
|
166
|
+
else:
|
|
167
|
+
logging.getLogger().setLevel(logging.WARNING)
|
|
168
|
+
|
|
169
|
+
return cfg.distributed_training.distributed_rank
|
|
170
|
+
|
|
171
|
+
def distributed_main(i, main, cfg: Config, kwargs):
|
|
172
|
+
cfg.distributed_training.device_id = i
|
|
173
|
+
if torch.cuda.is_available():
|
|
174
|
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
|
175
|
+
if cfg.distributed_training.distributed_rank is None:
|
|
176
|
+
cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
|
|
177
|
+
|
|
178
|
+
cfg.distributed_training.distributed_rank = distributed_init(cfg)
|
|
179
|
+
|
|
180
|
+
main(cfg, **kwargs)
|
|
181
|
+
|
|
182
|
+
if dist.is_initialized():
|
|
183
|
+
dist.barrier(get_global_group())
|
|
184
|
+
|
|
185
|
+
def call_main(cfg: Config, main, **kwargs):
|
|
186
|
+
if (
|
|
187
|
+
cfg.distributed_training.distributed_world_size > 1
|
|
188
|
+
and cfg.distributed_training.distributed_init_method is None
|
|
189
|
+
):
|
|
190
|
+
infer_init_method(cfg.distributed_training)
|
|
191
|
+
|
|
192
|
+
if cfg.distributed_training.distributed_init_method is not None:
|
|
193
|
+
start_rank = cfg.distributed_training.distributed_rank
|
|
194
|
+
cfg.distributed_training.distributed_rank = None # assign automatically
|
|
195
|
+
kwargs["start_rank"] = start_rank
|
|
196
|
+
torch.multiprocessing.spawn(
|
|
197
|
+
fn=distributed_main,
|
|
198
|
+
args=(main, cfg, kwargs),
|
|
199
|
+
nprocs=min(torch.cuda.device_count(), cfg.distributed_training.distributed_world_size),
|
|
200
|
+
join=True
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
main(cfg, **kwargs)
|
|
204
|
+
|
|
205
|
+
def get_rank(group):
|
|
206
|
+
return dist.get_rank(group=group)
|
|
207
|
+
|
|
208
|
+
def get_world_size(group):
|
|
209
|
+
if dist.is_initialized():
|
|
210
|
+
return dist.get_world_size(group=group)
|
|
211
|
+
else:
|
|
212
|
+
return 1
|
|
213
|
+
|
|
214
|
+
def get_global_group():
|
|
215
|
+
if dist.is_initialized():
|
|
216
|
+
if not hasattr(get_global_group, "_global_group"):
|
|
217
|
+
# ideally we could use torch.distributed.group.WORLD, but it seems
|
|
218
|
+
# to cause random NCCL hangs in some cases
|
|
219
|
+
get_global_group._global_group = dist.new_group()
|
|
220
|
+
return get_global_group._global_group
|
|
221
|
+
else:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
def get_global_rank():
|
|
225
|
+
if dist.is_initialized():
|
|
226
|
+
return dist.get_rank()
|
|
227
|
+
else:
|
|
228
|
+
return 0
|
|
229
|
+
|
|
230
|
+
def get_global_world_size():
|
|
231
|
+
if dist.is_initialized():
|
|
232
|
+
return dist.get_world_size()
|
|
233
|
+
else:
|
|
234
|
+
return 1
|
|
235
|
+
|
|
236
|
+
def get_data_parallel_group():
|
|
237
|
+
"""Get the data parallel group the caller rank belongs to."""
|
|
238
|
+
return get_global_group()
|
|
239
|
+
|
|
240
|
+
def get_data_parallel_rank():
|
|
241
|
+
"""Return my rank for the data parallel group."""
|
|
242
|
+
return get_rank(get_data_parallel_group())
|
|
243
|
+
|
|
244
|
+
def get_data_parallel_world_size():
|
|
245
|
+
"""Return world size for the data parallel group."""
|
|
246
|
+
return get_world_size(get_data_parallel_group())
|
|
247
|
+
|
|
248
|
+
def all_reduce(tensor, group, op="sum"):
|
|
249
|
+
if op == "sum":
|
|
250
|
+
op = dist.ReduceOp.SUM
|
|
251
|
+
elif op == "max":
|
|
252
|
+
op = dist.ReduceOp.MAX
|
|
253
|
+
else:
|
|
254
|
+
raise NotImplementedError
|
|
255
|
+
|
|
256
|
+
dist.all_reduce(tensor, op = op, group = group)
|
|
257
|
+
|
|
258
|
+
return tensor
|
|
259
|
+
|
|
260
|
+
def all_gather_list(data, group=None, max_size=32768):
|
|
261
|
+
"""Gathers arbitrary data from all nodes into a list.
|
|
262
|
+
|
|
263
|
+
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
|
|
264
|
+
data. Note that *data* must be picklable and any CUDA tensors will be moved
|
|
265
|
+
to CPU and returned on CPU as well.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
data (Any): data from the local worker to be gathered on other workers
|
|
269
|
+
group: group of the collective
|
|
270
|
+
max_size (int, optional): maximum size of the data to be gathered
|
|
271
|
+
across workers
|
|
272
|
+
"""
|
|
273
|
+
import utils.utils as utils
|
|
274
|
+
|
|
275
|
+
if group is None:
|
|
276
|
+
group = get_global_group()
|
|
277
|
+
rank = get_rank(group = group)
|
|
278
|
+
world_size = get_world_size(group = group)
|
|
279
|
+
|
|
280
|
+
buffer_size = max_size * world_size
|
|
281
|
+
if (
|
|
282
|
+
not hasattr(all_gather_list, "_buffer")
|
|
283
|
+
or all_gather_list._buffer.numel() < buffer_size
|
|
284
|
+
):
|
|
285
|
+
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
|
|
286
|
+
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
|
|
287
|
+
buffer = all_gather_list._buffer
|
|
288
|
+
buffer.zero_()
|
|
289
|
+
cpu_buffer = all_gather_list._cpu_buffer
|
|
290
|
+
|
|
291
|
+
data = utils.move_to_cpu(data)
|
|
292
|
+
enc = pickle.dumps(data)
|
|
293
|
+
enc_size = len(enc)
|
|
294
|
+
header_size = 4 # size of header that contains the length of the encoded data
|
|
295
|
+
size = header_size + enc_size
|
|
296
|
+
if size > max_size:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
"encoded data size ({}) exceeds max_size ({})".format(size, max_size)
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
header = struct.pack(">I", enc_size)
|
|
302
|
+
cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
|
|
303
|
+
start = rank * max_size
|
|
304
|
+
buffer[start : start + size].copy_(cpu_buffer[:size])
|
|
305
|
+
|
|
306
|
+
all_reduce(buffer, group = group)
|
|
307
|
+
|
|
308
|
+
buffer = buffer.cpu()
|
|
309
|
+
try:
|
|
310
|
+
result = []
|
|
311
|
+
for i in range(world_size):
|
|
312
|
+
out_buffer = buffer[i * max_size : (i + 1) * max_size]
|
|
313
|
+
(enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
|
|
314
|
+
if enc_size > 0:
|
|
315
|
+
result.append(
|
|
316
|
+
pickle.loads(
|
|
317
|
+
bytes(out_buffer[header_size : header_size + enc_size].tolist())
|
|
318
|
+
)
|
|
319
|
+
)
|
|
320
|
+
return result
|
|
321
|
+
except pickle.UnpicklingError:
|
|
322
|
+
raise Exception(
|
|
323
|
+
"Unable to unpickle data from other workers. all_gather_list requires all "
|
|
324
|
+
"workers to enter the function together, so this error usually indicates "
|
|
325
|
+
"that the workers have fallen out of sync somehow. Workers can fall out of "
|
|
326
|
+
"sync if one of them runs out of memory, or if there are other conditions "
|
|
327
|
+
"in your training script that can cause one worker to finish an epoch "
|
|
328
|
+
"while other workers are still iterating over their portions of the data. "
|
|
329
|
+
# "Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
def batch_all_gather(tensor, group, return_tensor=False):
|
|
333
|
+
"""Perform an all-gather operation considering tensors with different batch size"""
|
|
334
|
+
world_size = get_world_size(group=group)
|
|
335
|
+
rank = get_rank(group=group)
|
|
336
|
+
|
|
337
|
+
size_list = [
|
|
338
|
+
tensor.new_zeros(tensor.dim(), dtype=torch.int64) for _ in range(world_size)
|
|
339
|
+
]
|
|
340
|
+
local_size = tensor.new_tensor(tensor.shape, dtype=torch.int64)
|
|
341
|
+
dist.all_gather(size_list, local_size, group=group)
|
|
342
|
+
|
|
343
|
+
max_size = torch.stack(size_list).max(dim=0)[0][0]
|
|
344
|
+
size_offsets = [max_size - size[0] for size in size_list]
|
|
345
|
+
|
|
346
|
+
if local_size[0] != max_size:
|
|
347
|
+
offset = torch.cat(
|
|
348
|
+
(
|
|
349
|
+
tensor.new_tensor([max_size - local_size[0]]),
|
|
350
|
+
local_size[1:]
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
padding = tensor.new_zeros(tuple(int(dim) for dim in offset), dtype=torch.uint8)
|
|
354
|
+
tensor = torch.cat((tensor, padding), dim=0)
|
|
355
|
+
|
|
356
|
+
tensor_list = [
|
|
357
|
+
tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
|
|
358
|
+
]
|
|
359
|
+
dist.all_gather(tensor_list, tensor, group=group)
|
|
360
|
+
tensor_list = [
|
|
361
|
+
tensor[:max_size-size_offsets[i]] for i, tensor in enumerate(tensor_list)
|
|
362
|
+
]
|
|
363
|
+
if return_tensor:
|
|
364
|
+
return torch.stack(tensor_list, dim=0)
|
|
365
|
+
else:
|
|
366
|
+
return tensor_list
|
|
367
|
+
|
|
368
|
+
def all_gather(tensor, group, return_tensor=False):
|
|
369
|
+
"""Perform an all-gather operation."""
|
|
370
|
+
world_size = get_world_size(group = group)
|
|
371
|
+
rank = get_rank(group = group)
|
|
372
|
+
tensor_list = [
|
|
373
|
+
tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
|
|
374
|
+
]
|
|
375
|
+
dist.all_gather(tensor_list, tensor, group = group)
|
|
376
|
+
if return_tensor:
|
|
377
|
+
return torch.stack(tensor_list, dim = 0)
|
|
378
|
+
else:
|
|
379
|
+
return tensor_list
|
|
380
|
+
|
|
381
|
+
def broadcast(tensor, src, group):
|
|
382
|
+
dist.broadcast(tensor, src=src, group=group)
|
|
383
|
+
|
|
384
|
+
def broadcast_tensors(
|
|
385
|
+
tensors: Optional[List[torch.Tensor]],
|
|
386
|
+
src_rank: int,
|
|
387
|
+
group: object,
|
|
388
|
+
dist_device: Optional[torch.device] = None
|
|
389
|
+
) -> List[torch.Tensor]:
|
|
390
|
+
"""
|
|
391
|
+
Broadcast a list of tensors without other (non-src) ranks needing to know
|
|
392
|
+
the dtypes/shapes of the tensors.
|
|
393
|
+
"""
|
|
394
|
+
if dist_device is None:
|
|
395
|
+
dist_device = torch.device("cuda")
|
|
396
|
+
|
|
397
|
+
#share metadata first to simplify transfer
|
|
398
|
+
is_src_rank = (get_rank(group) == src_rank)
|
|
399
|
+
if is_src_rank:
|
|
400
|
+
metadata = [
|
|
401
|
+
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
|
|
402
|
+
]
|
|
403
|
+
metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device)
|
|
404
|
+
else:
|
|
405
|
+
metadata = _broadcast_object_slow(None, src_rank, group, dist_device)
|
|
406
|
+
|
|
407
|
+
out_tensors = []
|
|
408
|
+
for i, meta in enumerate(metadata):
|
|
409
|
+
if is_src_rank:
|
|
410
|
+
tensor = tensors[i]
|
|
411
|
+
broadcast(tensors[i].to(dist_device), src=src_rank, group=group)
|
|
412
|
+
else:
|
|
413
|
+
tensor = torch.zeros(
|
|
414
|
+
[meta["size"].numel()], dtype=meta["dtype"], device=dist_device
|
|
415
|
+
)
|
|
416
|
+
broadcast(tensor, src=src_rank, group=group)
|
|
417
|
+
tensor = tensor.view(meta["size"]).to(meta["device"])
|
|
418
|
+
out_tensors.append(tensor)
|
|
419
|
+
return out_tensors
|
|
420
|
+
|
|
421
|
+
def broadcast_object(
|
|
422
|
+
obj : Any,
|
|
423
|
+
src_rank : int,
|
|
424
|
+
group : object,
|
|
425
|
+
dist_device: Optional[torch.device] = None
|
|
426
|
+
) -> Any:
|
|
427
|
+
"""Broadcast an arbitrary Python object to other workers."""
|
|
428
|
+
if dist_device is None:
|
|
429
|
+
dist_device = torch.device("cuda")
|
|
430
|
+
|
|
431
|
+
if get_rank(group) == src_rank:
|
|
432
|
+
# split the tensors from the non-tensors so we can broadcast them
|
|
433
|
+
# directly, avoiding unnecessary serialization/deserialization
|
|
434
|
+
tensors = []
|
|
435
|
+
obj = _split_tensors_from_obj(obj, tensors)
|
|
436
|
+
obj = _broadcast_object_slow(obj, src_rank, group, dist_device)
|
|
437
|
+
tensors = broadcast_tensors(tensors, src_rank, group, dist_device)
|
|
438
|
+
else:
|
|
439
|
+
obj = _broadcast_object_slow(None, src_rank, group, dist_device)
|
|
440
|
+
tensors = broadcast_tensors(None, src_rank, group, dist_device)
|
|
441
|
+
return _put_tensors_in_obj(obj, tensors)
|
|
442
|
+
|
|
443
|
+
def _broadcast_object_slow(
|
|
444
|
+
obj: Any, src_rank: int, group: object, dist_device: torch.device,
|
|
445
|
+
) -> Any:
|
|
446
|
+
rank = get_rank(group)
|
|
447
|
+
if get_rank(group) == src_rank:
|
|
448
|
+
# Emit data
|
|
449
|
+
buffer = io.BytesIO()
|
|
450
|
+
torch.save(obj, buffer)
|
|
451
|
+
buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device)
|
|
452
|
+
length = torch.LongTensor([len(buffer)]).to(dist_device)
|
|
453
|
+
broadcast(length, src=src_rank, group=group)
|
|
454
|
+
broadcast(buffer, src=src_rank, group=group)
|
|
455
|
+
else:
|
|
456
|
+
# Fetch from the source
|
|
457
|
+
length = torch.LongTensor([0]).to(dist_device)
|
|
458
|
+
broadcast(length, src=src_rank, group=group)
|
|
459
|
+
buffer = torch.ByteTensor(int(length.item())).to(dist_device)
|
|
460
|
+
broadcast(buffer, src=src_rank, group=group)
|
|
461
|
+
buffer = io.BytesIO(buffer.cpu().numpy())
|
|
462
|
+
obj = torch.load(buffer, map_location="cpu")
|
|
463
|
+
return obj
|
|
464
|
+
|
|
465
|
+
@dataclass(frozen=True)
|
|
466
|
+
class _TensorPlaceholder:
|
|
467
|
+
index: int
|
|
468
|
+
|
|
469
|
+
def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
|
|
470
|
+
if torch.is_tensor(obj):
|
|
471
|
+
placeholder = _TensorPlaceholder(index = len(tensors))
|
|
472
|
+
tensors.append(obj)
|
|
473
|
+
return placeholder
|
|
474
|
+
elif isinstance(obj, dict):
|
|
475
|
+
return {k: _split_tensors_from_obj(v, tensors) for k,v in obj.items()}
|
|
476
|
+
elif isinstance(obj, list):
|
|
477
|
+
return [_split_tensors_from_obj(v, tensors) for v in obj]
|
|
478
|
+
elif isinstance(obj, tuple):
|
|
479
|
+
return tuple(_split_tensors_from_obj(v, tensors) for v in obj)
|
|
480
|
+
elif isinstance(obj, set):
|
|
481
|
+
return {_split_tensors_from_obj(v, tensors) for v in obj}
|
|
482
|
+
else:
|
|
483
|
+
return obj
|
|
484
|
+
|
|
485
|
+
def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
|
|
486
|
+
if isinstance(obj, _TensorPlaceholder):
|
|
487
|
+
return tensors[obj.index]
|
|
488
|
+
elif isinstance(obj, dict):
|
|
489
|
+
return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()}
|
|
490
|
+
elif isinstance(obj, list):
|
|
491
|
+
return [_put_tensors_in_obj(v, tensors) for v in obj]
|
|
492
|
+
elif isinstance(obj, tuple):
|
|
493
|
+
return tuple(_put_tensors_in_obj(v, tensors) for v in obj)
|
|
494
|
+
elif isinstance(obj, set):
|
|
495
|
+
return {_put_tensors_in_obj(v, tensors) for v in obj}
|
|
496
|
+
else:
|
|
497
|
+
return obj
|
genhpf/utils/file_io.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import shutil
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__file__)
|
|
9
|
+
|
|
10
|
+
IOPathManager = None
|
|
11
|
+
|
|
12
|
+
class PathManager:
|
|
13
|
+
"""
|
|
14
|
+
Wrapper for insulating OSS I/O (using Python builtin operations) from
|
|
15
|
+
iopath's PathManager abstraction (for transparently handling various
|
|
16
|
+
internal backends).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def open(
|
|
21
|
+
path: str,
|
|
22
|
+
mode: str = "r",
|
|
23
|
+
buffering: int = -1,
|
|
24
|
+
encoding: Optional[str] = None,
|
|
25
|
+
errors: Optional[str] = None,
|
|
26
|
+
newline: Optional[str] = None,
|
|
27
|
+
):
|
|
28
|
+
if IOPathManager:
|
|
29
|
+
return IOPathManager.open(
|
|
30
|
+
path=path,
|
|
31
|
+
mode=mode,
|
|
32
|
+
buffering=buffering,
|
|
33
|
+
encoding=encoding,
|
|
34
|
+
errors=errors,
|
|
35
|
+
newline=newline,
|
|
36
|
+
)
|
|
37
|
+
return open(
|
|
38
|
+
path,
|
|
39
|
+
mode=mode,
|
|
40
|
+
buffering=buffering,
|
|
41
|
+
encoding=encoding,
|
|
42
|
+
errors=errors,
|
|
43
|
+
newline=newline,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool:
|
|
48
|
+
if IOPathManager:
|
|
49
|
+
return IOPathManager.copy(
|
|
50
|
+
src_path=src_path, dst_path=dst_path, overwrite=overwrite
|
|
51
|
+
)
|
|
52
|
+
return shutil.copyfile(src_path, dst_path)
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def get_local_path(path: str, **kwargs) -> str:
|
|
56
|
+
if IOPathManager:
|
|
57
|
+
return IOPathManager.get_local_path(path, **kwargs)
|
|
58
|
+
return path
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def exists(path: str) -> bool:
|
|
62
|
+
if IOPathManager:
|
|
63
|
+
return IOPathManager.exists(path)
|
|
64
|
+
return os.path.exists(path)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def isfile(path: str) -> bool:
|
|
68
|
+
if IOPathManager:
|
|
69
|
+
return IOPathManager.isfile(path)
|
|
70
|
+
return os.path.isfile(path)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def ls(path: str) -> List[str]:
|
|
74
|
+
if IOPathManager:
|
|
75
|
+
return IOPathManager.ls(path)
|
|
76
|
+
return os.listdir(path)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def mkdirs(path: str) -> None:
|
|
80
|
+
if IOPathManager:
|
|
81
|
+
return IOPathManager.mkdirs(path)
|
|
82
|
+
os.makedirs(path, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def rm(path: str) -> None:
|
|
86
|
+
if IOPathManager:
|
|
87
|
+
return IOPathManager.rm(path)
|
|
88
|
+
os.remove(path)
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def chmod(path: str, mode: int) -> None:
|
|
92
|
+
if not PathManager.path_requires_pathmanager(path):
|
|
93
|
+
os.chmod(path, mode)
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def register_handler(handler) -> None:
|
|
97
|
+
if IOPathManager:
|
|
98
|
+
return IOPathManager.register_handler(handler=handler)
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def copy_from_local(
|
|
102
|
+
local_path: str, dst_path: str, overwrite: bool = False, **kwargs
|
|
103
|
+
) -> None:
|
|
104
|
+
if IOPathManager:
|
|
105
|
+
return IOPathManager.copy_from_local(
|
|
106
|
+
local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs
|
|
107
|
+
)
|
|
108
|
+
return shutil.copyfile(local_path, dst_path)
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def path_requires_pathmanager(path: str) -> bool:
|
|
112
|
+
"""Do we require PathManager to access given path?"""
|
|
113
|
+
if IOPathManager:
|
|
114
|
+
for p in IOPathManager._path_handlers.keys():
|
|
115
|
+
if path.startswith(p):
|
|
116
|
+
return True
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def supports_rename(path: str) -> bool:
|
|
121
|
+
# PathManager doesn't yet support renames
|
|
122
|
+
return not PathManager.path_requires_pathmanager(path)
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def rename(src: str, dst: str):
|
|
126
|
+
os.rename(src, dst)
|
|
127
|
+
|
|
128
|
+
"""
|
|
129
|
+
ioPath async PathManager methods:
|
|
130
|
+
"""
|
|
131
|
+
@staticmethod
|
|
132
|
+
def opena(
|
|
133
|
+
path: str,
|
|
134
|
+
mode: str = "r",
|
|
135
|
+
buffering: int = -1,
|
|
136
|
+
encoding: Optional[str] = None,
|
|
137
|
+
errors: Optional[str] = None,
|
|
138
|
+
newline: Optional[str] = None,
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Return file descriptor with asynchronous write operations.
|
|
142
|
+
"""
|
|
143
|
+
global IOPathManager
|
|
144
|
+
if not IOPathManager:
|
|
145
|
+
logging.info("ioPath is initializing PathManager.")
|
|
146
|
+
try:
|
|
147
|
+
from iopath.common.file_io import PathManager
|
|
148
|
+
IOPathManager = PathManager()
|
|
149
|
+
except Exception:
|
|
150
|
+
logging.exception("Failed to initialize ioPath PathManager object.")
|
|
151
|
+
return IOPathManager.opena(
|
|
152
|
+
path=path,
|
|
153
|
+
mode=mode,
|
|
154
|
+
buffering=buffering,
|
|
155
|
+
encoding=encoding,
|
|
156
|
+
errors=errors,
|
|
157
|
+
newline=newline,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def async_close() -> bool:
|
|
162
|
+
"""
|
|
163
|
+
Wait for files to be written and clean up asynchronous PathManager.
|
|
164
|
+
NOTE: `PathManager.async_close()` must be called at the end of any
|
|
165
|
+
script that uses `PathManager.opena(...)`.
|
|
166
|
+
"""
|
|
167
|
+
global IOPathManager
|
|
168
|
+
if IOPathManager:
|
|
169
|
+
return IOPathManager.async_close()
|
|
170
|
+
return False
|
genhpf/utils/pdb.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
import os
|
|
3
|
+
import pdb
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
__all__ = ["set_trace"]
|
|
7
|
+
|
|
8
|
+
_stdin = [None]
|
|
9
|
+
_stdin_lock = multiprocessing.Lock()
|
|
10
|
+
try:
|
|
11
|
+
_stdin_fd = sys.stdin.fileno()
|
|
12
|
+
except Exception:
|
|
13
|
+
_stdin_fd = None
|
|
14
|
+
|
|
15
|
+
class MultiprocessingPdb(pdb.Pdb):
|
|
16
|
+
"""A pdb wrapper that works in a multiprocessing environment.
|
|
17
|
+
|
|
18
|
+
Usage: `from fairseq_signals.utils import pdb; pdb.set_trace()`
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
pdb.Pdb.__init__(self, nosigint = True)
|
|
23
|
+
|
|
24
|
+
def _cmdloop(self):
|
|
25
|
+
stdin_bak = sys.stdin
|
|
26
|
+
with _stdin_lock:
|
|
27
|
+
try:
|
|
28
|
+
if _stdin_fd is not None:
|
|
29
|
+
if not _stdin[0]:
|
|
30
|
+
_stdin[0] = os.fdopen(_stdin_fd)
|
|
31
|
+
sys.stdin = _stdin[0]
|
|
32
|
+
self.cmdloop()
|
|
33
|
+
finally:
|
|
34
|
+
sys.stdin = stdin_bak
|
|
35
|
+
|
|
36
|
+
def set_trace():
|
|
37
|
+
pdb = MultiprocessingPdb()
|
|
38
|
+
pdb.set_trace(sys._getframe().f_back)
|