genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,497 @@
1
+ import io
2
+ from typing import Optional, List, Any
3
+ import logging
4
+ import os
5
+ import random
6
+ import socket
7
+ import struct
8
+ import pickle
9
+ import warnings
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.distributed as dist
15
+
16
+ from torch.nn.parallel import DistributedDataParallel
17
+
18
+ from genhpf.configs import Config, DistributedTrainingConfig
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class ModuleProxyWrapper(nn.Module):
23
+ """
24
+ Wrap a DistributedDataParallel module and forward requests for missing
25
+ attributes to the module wrapped by DDP (the twice-wrapped module).
26
+ Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
27
+
28
+ Usage::
29
+
30
+ module.xyz = "hello world"
31
+ wrapped_module = DistributedDataParallel(module, **ddp_args)
32
+ wrapped_module = ModuleProxyWrapper(wrapped_module)
33
+ assert wrapped_module.xyz == "hello world"
34
+ assert wrapped_module.state_dict().keys() == module.state_dict().keys()
35
+
36
+ Args:
37
+ module (nn.Module): module to wrap
38
+ """
39
+
40
+ def __init__(self, module: nn.Module):
41
+ super().__init__()
42
+ assert hasattr(module, "module"), \
43
+ "ModuleProxyWrapper expects input to wrap another module"
44
+ self.module = module
45
+
46
+ def __getattr__(self, name):
47
+ """Forward missing attributes to twice-wrapped module."""
48
+ try:
49
+ # defer to nn.Module's logic
50
+ return super().__getattr__(name)
51
+ except AttributeError:
52
+ try:
53
+ # forward to the once-wrapped module
54
+ return getattr(self.module, name)
55
+ except AttributeError:
56
+ # forward to the twice-wrapped module
57
+ return getattr(self.module.module, name)
58
+
59
+ def state_dict(self, *args, **kwargs):
60
+ """Forward to the twice-wrapped module."""
61
+ return self.module.module.state_dict(*args, **kwargs)
62
+
63
+ def load_state_dict(self, *args, **kwargs):
64
+ """Forward to the twice-wrapped module."""
65
+ return self.module.module.load_state_dict(*args, **kwargs)
66
+
67
+ def forward(self, *args, **kwargs):
68
+ return self.module(*args, **kwargs)
69
+
70
+ def DistributedModel(
71
+ args,
72
+ model,
73
+ process_group,
74
+ device
75
+ ):
76
+ """
77
+ Wrap a *model* to support distributed data parallel training.
78
+
79
+ This is similar to the built-in DistributedDataParallel, but allows
80
+ additional configuration of the DistributedDataParallel class to
81
+ use, and also provides easier access to the wrapped model by
82
+ forwarding requests for missing attributes to the wrapped model.
83
+
84
+ Args:
85
+ args (argparse.Namespace): genhpf model args
86
+ model (BaseModel): model to wrap
87
+ process_group: the c10d process group to be used for distributed data
88
+ parallel all-reduction
89
+ device: device to move model to
90
+ """
91
+ assert isinstance(model, nn.Module)
92
+
93
+ wrapped_model = DistributedDataParallel(
94
+ module = model.to(device),
95
+ device_ids = [args.device_id],
96
+ output_device = args.device_id,
97
+ broadcast_buffers = args.broadcast_buffers,
98
+ bucket_cap_mb = args.bucket_cap_mb,
99
+ process_group = process_group,
100
+ find_unused_parameters = args.find_unused_parameters
101
+ )
102
+
103
+ if args.ddp_comm_hook == "fp16":
104
+ logger.info("enable fp16 communication hook in DDP")
105
+ try:
106
+ from torch.distributed.algorithms.ddp_comm_hooks import (
107
+ register_ddp_comm_hook,
108
+ DDPCommHookType,
109
+ )
110
+ except:
111
+ logger.error(
112
+ "Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version"
113
+ )
114
+ raise
115
+
116
+ register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model)
117
+
118
+ # forward missing getattr and state_dict/load_state_dict to orig model
119
+ wrapped_model = ModuleProxyWrapper(wrapped_model)
120
+
121
+ return wrapped_model
122
+
123
+ def is_master(cfg: DistributedTrainingConfig):
124
+ return cfg.distributed_rank == 0
125
+
126
+ def infer_init_method(cfg: DistributedTrainingConfig):
127
+ assert (
128
+ cfg.distributed_world_size <= torch.cuda.device_count()
129
+ ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
130
+ port = random.randint(10000, 20000)
131
+ cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
132
+
133
+ def distributed_init(cfg: Config):
134
+ if dist.is_available() and dist.is_initialized():
135
+ warnings.warn(
136
+ "Distributed is already initialized, cannot initialize twice!"
137
+ )
138
+ else:
139
+ logger.info(
140
+ "distributed init (rank {}): {}".format(
141
+ cfg.distributed_training.distributed_rank,
142
+ cfg.distributed_training.distributed_init_method
143
+ )
144
+ )
145
+ dist.init_process_group(
146
+ backend="nccl",
147
+ init_method=cfg.distributed_training.distributed_init_method,
148
+ world_size=cfg.distributed_training.distributed_world_size,
149
+ rank=cfg.distributed_training.distributed_rank
150
+ )
151
+ logger.info(
152
+ "initialized host {} as rank {}".format(
153
+ socket.gethostname(),
154
+ cfg.distributed_training.distributed_rank
155
+ )
156
+ )
157
+
158
+ #perform a dummy all-reduce to initialize the NCCL communicator
159
+ if torch.cuda.is_available():
160
+ dist.all_reduce(torch.zeros(1).cuda())
161
+
162
+ cfg.distributed_training.distributed_rank = dist.get_rank()
163
+
164
+ if is_master(cfg.distributed_training):
165
+ logging.getLogger().setLevel(logging.INFO)
166
+ else:
167
+ logging.getLogger().setLevel(logging.WARNING)
168
+
169
+ return cfg.distributed_training.distributed_rank
170
+
171
+ def distributed_main(i, main, cfg: Config, kwargs):
172
+ cfg.distributed_training.device_id = i
173
+ if torch.cuda.is_available():
174
+ torch.cuda.set_device(cfg.distributed_training.device_id)
175
+ if cfg.distributed_training.distributed_rank is None:
176
+ cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
177
+
178
+ cfg.distributed_training.distributed_rank = distributed_init(cfg)
179
+
180
+ main(cfg, **kwargs)
181
+
182
+ if dist.is_initialized():
183
+ dist.barrier(get_global_group())
184
+
185
+ def call_main(cfg: Config, main, **kwargs):
186
+ if (
187
+ cfg.distributed_training.distributed_world_size > 1
188
+ and cfg.distributed_training.distributed_init_method is None
189
+ ):
190
+ infer_init_method(cfg.distributed_training)
191
+
192
+ if cfg.distributed_training.distributed_init_method is not None:
193
+ start_rank = cfg.distributed_training.distributed_rank
194
+ cfg.distributed_training.distributed_rank = None # assign automatically
195
+ kwargs["start_rank"] = start_rank
196
+ torch.multiprocessing.spawn(
197
+ fn=distributed_main,
198
+ args=(main, cfg, kwargs),
199
+ nprocs=min(torch.cuda.device_count(), cfg.distributed_training.distributed_world_size),
200
+ join=True
201
+ )
202
+ else:
203
+ main(cfg, **kwargs)
204
+
205
+ def get_rank(group):
206
+ return dist.get_rank(group=group)
207
+
208
+ def get_world_size(group):
209
+ if dist.is_initialized():
210
+ return dist.get_world_size(group=group)
211
+ else:
212
+ return 1
213
+
214
+ def get_global_group():
215
+ if dist.is_initialized():
216
+ if not hasattr(get_global_group, "_global_group"):
217
+ # ideally we could use torch.distributed.group.WORLD, but it seems
218
+ # to cause random NCCL hangs in some cases
219
+ get_global_group._global_group = dist.new_group()
220
+ return get_global_group._global_group
221
+ else:
222
+ return None
223
+
224
+ def get_global_rank():
225
+ if dist.is_initialized():
226
+ return dist.get_rank()
227
+ else:
228
+ return 0
229
+
230
+ def get_global_world_size():
231
+ if dist.is_initialized():
232
+ return dist.get_world_size()
233
+ else:
234
+ return 1
235
+
236
+ def get_data_parallel_group():
237
+ """Get the data parallel group the caller rank belongs to."""
238
+ return get_global_group()
239
+
240
+ def get_data_parallel_rank():
241
+ """Return my rank for the data parallel group."""
242
+ return get_rank(get_data_parallel_group())
243
+
244
+ def get_data_parallel_world_size():
245
+ """Return world size for the data parallel group."""
246
+ return get_world_size(get_data_parallel_group())
247
+
248
+ def all_reduce(tensor, group, op="sum"):
249
+ if op == "sum":
250
+ op = dist.ReduceOp.SUM
251
+ elif op == "max":
252
+ op = dist.ReduceOp.MAX
253
+ else:
254
+ raise NotImplementedError
255
+
256
+ dist.all_reduce(tensor, op = op, group = group)
257
+
258
+ return tensor
259
+
260
+ def all_gather_list(data, group=None, max_size=32768):
261
+ """Gathers arbitrary data from all nodes into a list.
262
+
263
+ Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
264
+ data. Note that *data* must be picklable and any CUDA tensors will be moved
265
+ to CPU and returned on CPU as well.
266
+
267
+ Args:
268
+ data (Any): data from the local worker to be gathered on other workers
269
+ group: group of the collective
270
+ max_size (int, optional): maximum size of the data to be gathered
271
+ across workers
272
+ """
273
+ import utils.utils as utils
274
+
275
+ if group is None:
276
+ group = get_global_group()
277
+ rank = get_rank(group = group)
278
+ world_size = get_world_size(group = group)
279
+
280
+ buffer_size = max_size * world_size
281
+ if (
282
+ not hasattr(all_gather_list, "_buffer")
283
+ or all_gather_list._buffer.numel() < buffer_size
284
+ ):
285
+ all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
286
+ all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
287
+ buffer = all_gather_list._buffer
288
+ buffer.zero_()
289
+ cpu_buffer = all_gather_list._cpu_buffer
290
+
291
+ data = utils.move_to_cpu(data)
292
+ enc = pickle.dumps(data)
293
+ enc_size = len(enc)
294
+ header_size = 4 # size of header that contains the length of the encoded data
295
+ size = header_size + enc_size
296
+ if size > max_size:
297
+ raise ValueError(
298
+ "encoded data size ({}) exceeds max_size ({})".format(size, max_size)
299
+ )
300
+
301
+ header = struct.pack(">I", enc_size)
302
+ cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
303
+ start = rank * max_size
304
+ buffer[start : start + size].copy_(cpu_buffer[:size])
305
+
306
+ all_reduce(buffer, group = group)
307
+
308
+ buffer = buffer.cpu()
309
+ try:
310
+ result = []
311
+ for i in range(world_size):
312
+ out_buffer = buffer[i * max_size : (i + 1) * max_size]
313
+ (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
314
+ if enc_size > 0:
315
+ result.append(
316
+ pickle.loads(
317
+ bytes(out_buffer[header_size : header_size + enc_size].tolist())
318
+ )
319
+ )
320
+ return result
321
+ except pickle.UnpicklingError:
322
+ raise Exception(
323
+ "Unable to unpickle data from other workers. all_gather_list requires all "
324
+ "workers to enter the function together, so this error usually indicates "
325
+ "that the workers have fallen out of sync somehow. Workers can fall out of "
326
+ "sync if one of them runs out of memory, or if there are other conditions "
327
+ "in your training script that can cause one worker to finish an epoch "
328
+ "while other workers are still iterating over their portions of the data. "
329
+ # "Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
330
+ )
331
+
332
+ def batch_all_gather(tensor, group, return_tensor=False):
333
+ """Perform an all-gather operation considering tensors with different batch size"""
334
+ world_size = get_world_size(group=group)
335
+ rank = get_rank(group=group)
336
+
337
+ size_list = [
338
+ tensor.new_zeros(tensor.dim(), dtype=torch.int64) for _ in range(world_size)
339
+ ]
340
+ local_size = tensor.new_tensor(tensor.shape, dtype=torch.int64)
341
+ dist.all_gather(size_list, local_size, group=group)
342
+
343
+ max_size = torch.stack(size_list).max(dim=0)[0][0]
344
+ size_offsets = [max_size - size[0] for size in size_list]
345
+
346
+ if local_size[0] != max_size:
347
+ offset = torch.cat(
348
+ (
349
+ tensor.new_tensor([max_size - local_size[0]]),
350
+ local_size[1:]
351
+ )
352
+ )
353
+ padding = tensor.new_zeros(tuple(int(dim) for dim in offset), dtype=torch.uint8)
354
+ tensor = torch.cat((tensor, padding), dim=0)
355
+
356
+ tensor_list = [
357
+ tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
358
+ ]
359
+ dist.all_gather(tensor_list, tensor, group=group)
360
+ tensor_list = [
361
+ tensor[:max_size-size_offsets[i]] for i, tensor in enumerate(tensor_list)
362
+ ]
363
+ if return_tensor:
364
+ return torch.stack(tensor_list, dim=0)
365
+ else:
366
+ return tensor_list
367
+
368
+ def all_gather(tensor, group, return_tensor=False):
369
+ """Perform an all-gather operation."""
370
+ world_size = get_world_size(group = group)
371
+ rank = get_rank(group = group)
372
+ tensor_list = [
373
+ tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
374
+ ]
375
+ dist.all_gather(tensor_list, tensor, group = group)
376
+ if return_tensor:
377
+ return torch.stack(tensor_list, dim = 0)
378
+ else:
379
+ return tensor_list
380
+
381
+ def broadcast(tensor, src, group):
382
+ dist.broadcast(tensor, src=src, group=group)
383
+
384
+ def broadcast_tensors(
385
+ tensors: Optional[List[torch.Tensor]],
386
+ src_rank: int,
387
+ group: object,
388
+ dist_device: Optional[torch.device] = None
389
+ ) -> List[torch.Tensor]:
390
+ """
391
+ Broadcast a list of tensors without other (non-src) ranks needing to know
392
+ the dtypes/shapes of the tensors.
393
+ """
394
+ if dist_device is None:
395
+ dist_device = torch.device("cuda")
396
+
397
+ #share metadata first to simplify transfer
398
+ is_src_rank = (get_rank(group) == src_rank)
399
+ if is_src_rank:
400
+ metadata = [
401
+ {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
402
+ ]
403
+ metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device)
404
+ else:
405
+ metadata = _broadcast_object_slow(None, src_rank, group, dist_device)
406
+
407
+ out_tensors = []
408
+ for i, meta in enumerate(metadata):
409
+ if is_src_rank:
410
+ tensor = tensors[i]
411
+ broadcast(tensors[i].to(dist_device), src=src_rank, group=group)
412
+ else:
413
+ tensor = torch.zeros(
414
+ [meta["size"].numel()], dtype=meta["dtype"], device=dist_device
415
+ )
416
+ broadcast(tensor, src=src_rank, group=group)
417
+ tensor = tensor.view(meta["size"]).to(meta["device"])
418
+ out_tensors.append(tensor)
419
+ return out_tensors
420
+
421
+ def broadcast_object(
422
+ obj : Any,
423
+ src_rank : int,
424
+ group : object,
425
+ dist_device: Optional[torch.device] = None
426
+ ) -> Any:
427
+ """Broadcast an arbitrary Python object to other workers."""
428
+ if dist_device is None:
429
+ dist_device = torch.device("cuda")
430
+
431
+ if get_rank(group) == src_rank:
432
+ # split the tensors from the non-tensors so we can broadcast them
433
+ # directly, avoiding unnecessary serialization/deserialization
434
+ tensors = []
435
+ obj = _split_tensors_from_obj(obj, tensors)
436
+ obj = _broadcast_object_slow(obj, src_rank, group, dist_device)
437
+ tensors = broadcast_tensors(tensors, src_rank, group, dist_device)
438
+ else:
439
+ obj = _broadcast_object_slow(None, src_rank, group, dist_device)
440
+ tensors = broadcast_tensors(None, src_rank, group, dist_device)
441
+ return _put_tensors_in_obj(obj, tensors)
442
+
443
+ def _broadcast_object_slow(
444
+ obj: Any, src_rank: int, group: object, dist_device: torch.device,
445
+ ) -> Any:
446
+ rank = get_rank(group)
447
+ if get_rank(group) == src_rank:
448
+ # Emit data
449
+ buffer = io.BytesIO()
450
+ torch.save(obj, buffer)
451
+ buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device)
452
+ length = torch.LongTensor([len(buffer)]).to(dist_device)
453
+ broadcast(length, src=src_rank, group=group)
454
+ broadcast(buffer, src=src_rank, group=group)
455
+ else:
456
+ # Fetch from the source
457
+ length = torch.LongTensor([0]).to(dist_device)
458
+ broadcast(length, src=src_rank, group=group)
459
+ buffer = torch.ByteTensor(int(length.item())).to(dist_device)
460
+ broadcast(buffer, src=src_rank, group=group)
461
+ buffer = io.BytesIO(buffer.cpu().numpy())
462
+ obj = torch.load(buffer, map_location="cpu")
463
+ return obj
464
+
465
+ @dataclass(frozen=True)
466
+ class _TensorPlaceholder:
467
+ index: int
468
+
469
+ def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
470
+ if torch.is_tensor(obj):
471
+ placeholder = _TensorPlaceholder(index = len(tensors))
472
+ tensors.append(obj)
473
+ return placeholder
474
+ elif isinstance(obj, dict):
475
+ return {k: _split_tensors_from_obj(v, tensors) for k,v in obj.items()}
476
+ elif isinstance(obj, list):
477
+ return [_split_tensors_from_obj(v, tensors) for v in obj]
478
+ elif isinstance(obj, tuple):
479
+ return tuple(_split_tensors_from_obj(v, tensors) for v in obj)
480
+ elif isinstance(obj, set):
481
+ return {_split_tensors_from_obj(v, tensors) for v in obj}
482
+ else:
483
+ return obj
484
+
485
+ def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
486
+ if isinstance(obj, _TensorPlaceholder):
487
+ return tensors[obj.index]
488
+ elif isinstance(obj, dict):
489
+ return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()}
490
+ elif isinstance(obj, list):
491
+ return [_put_tensors_in_obj(v, tensors) for v in obj]
492
+ elif isinstance(obj, tuple):
493
+ return tuple(_put_tensors_in_obj(v, tensors) for v in obj)
494
+ elif isinstance(obj, set):
495
+ return {_put_tensors_in_obj(v, tensors) for v in obj}
496
+ else:
497
+ return obj
@@ -0,0 +1,170 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import logging
4
+ import os
5
+ import shutil
6
+ from typing import List, Optional
7
+
8
+ logger = logging.getLogger(__file__)
9
+
10
+ IOPathManager = None
11
+
12
+ class PathManager:
13
+ """
14
+ Wrapper for insulating OSS I/O (using Python builtin operations) from
15
+ iopath's PathManager abstraction (for transparently handling various
16
+ internal backends).
17
+ """
18
+
19
+ @staticmethod
20
+ def open(
21
+ path: str,
22
+ mode: str = "r",
23
+ buffering: int = -1,
24
+ encoding: Optional[str] = None,
25
+ errors: Optional[str] = None,
26
+ newline: Optional[str] = None,
27
+ ):
28
+ if IOPathManager:
29
+ return IOPathManager.open(
30
+ path=path,
31
+ mode=mode,
32
+ buffering=buffering,
33
+ encoding=encoding,
34
+ errors=errors,
35
+ newline=newline,
36
+ )
37
+ return open(
38
+ path,
39
+ mode=mode,
40
+ buffering=buffering,
41
+ encoding=encoding,
42
+ errors=errors,
43
+ newline=newline,
44
+ )
45
+
46
+ @staticmethod
47
+ def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool:
48
+ if IOPathManager:
49
+ return IOPathManager.copy(
50
+ src_path=src_path, dst_path=dst_path, overwrite=overwrite
51
+ )
52
+ return shutil.copyfile(src_path, dst_path)
53
+
54
+ @staticmethod
55
+ def get_local_path(path: str, **kwargs) -> str:
56
+ if IOPathManager:
57
+ return IOPathManager.get_local_path(path, **kwargs)
58
+ return path
59
+
60
+ @staticmethod
61
+ def exists(path: str) -> bool:
62
+ if IOPathManager:
63
+ return IOPathManager.exists(path)
64
+ return os.path.exists(path)
65
+
66
+ @staticmethod
67
+ def isfile(path: str) -> bool:
68
+ if IOPathManager:
69
+ return IOPathManager.isfile(path)
70
+ return os.path.isfile(path)
71
+
72
+ @staticmethod
73
+ def ls(path: str) -> List[str]:
74
+ if IOPathManager:
75
+ return IOPathManager.ls(path)
76
+ return os.listdir(path)
77
+
78
+ @staticmethod
79
+ def mkdirs(path: str) -> None:
80
+ if IOPathManager:
81
+ return IOPathManager.mkdirs(path)
82
+ os.makedirs(path, exist_ok=True)
83
+
84
+ @staticmethod
85
+ def rm(path: str) -> None:
86
+ if IOPathManager:
87
+ return IOPathManager.rm(path)
88
+ os.remove(path)
89
+
90
+ @staticmethod
91
+ def chmod(path: str, mode: int) -> None:
92
+ if not PathManager.path_requires_pathmanager(path):
93
+ os.chmod(path, mode)
94
+
95
+ @staticmethod
96
+ def register_handler(handler) -> None:
97
+ if IOPathManager:
98
+ return IOPathManager.register_handler(handler=handler)
99
+
100
+ @staticmethod
101
+ def copy_from_local(
102
+ local_path: str, dst_path: str, overwrite: bool = False, **kwargs
103
+ ) -> None:
104
+ if IOPathManager:
105
+ return IOPathManager.copy_from_local(
106
+ local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs
107
+ )
108
+ return shutil.copyfile(local_path, dst_path)
109
+
110
+ @staticmethod
111
+ def path_requires_pathmanager(path: str) -> bool:
112
+ """Do we require PathManager to access given path?"""
113
+ if IOPathManager:
114
+ for p in IOPathManager._path_handlers.keys():
115
+ if path.startswith(p):
116
+ return True
117
+ return False
118
+
119
+ @staticmethod
120
+ def supports_rename(path: str) -> bool:
121
+ # PathManager doesn't yet support renames
122
+ return not PathManager.path_requires_pathmanager(path)
123
+
124
+ @staticmethod
125
+ def rename(src: str, dst: str):
126
+ os.rename(src, dst)
127
+
128
+ """
129
+ ioPath async PathManager methods:
130
+ """
131
+ @staticmethod
132
+ def opena(
133
+ path: str,
134
+ mode: str = "r",
135
+ buffering: int = -1,
136
+ encoding: Optional[str] = None,
137
+ errors: Optional[str] = None,
138
+ newline: Optional[str] = None,
139
+ ):
140
+ """
141
+ Return file descriptor with asynchronous write operations.
142
+ """
143
+ global IOPathManager
144
+ if not IOPathManager:
145
+ logging.info("ioPath is initializing PathManager.")
146
+ try:
147
+ from iopath.common.file_io import PathManager
148
+ IOPathManager = PathManager()
149
+ except Exception:
150
+ logging.exception("Failed to initialize ioPath PathManager object.")
151
+ return IOPathManager.opena(
152
+ path=path,
153
+ mode=mode,
154
+ buffering=buffering,
155
+ encoding=encoding,
156
+ errors=errors,
157
+ newline=newline,
158
+ )
159
+
160
+ @staticmethod
161
+ def async_close() -> bool:
162
+ """
163
+ Wait for files to be written and clean up asynchronous PathManager.
164
+ NOTE: `PathManager.async_close()` must be called at the end of any
165
+ script that uses `PathManager.opena(...)`.
166
+ """
167
+ global IOPathManager
168
+ if IOPathManager:
169
+ return IOPathManager.async_close()
170
+ return False
genhpf/utils/pdb.py ADDED
@@ -0,0 +1,38 @@
1
+ import multiprocessing
2
+ import os
3
+ import pdb
4
+ import sys
5
+
6
+ __all__ = ["set_trace"]
7
+
8
+ _stdin = [None]
9
+ _stdin_lock = multiprocessing.Lock()
10
+ try:
11
+ _stdin_fd = sys.stdin.fileno()
12
+ except Exception:
13
+ _stdin_fd = None
14
+
15
+ class MultiprocessingPdb(pdb.Pdb):
16
+ """A pdb wrapper that works in a multiprocessing environment.
17
+
18
+ Usage: `from fairseq_signals.utils import pdb; pdb.set_trace()`
19
+ """
20
+
21
+ def __init__(self):
22
+ pdb.Pdb.__init__(self, nosigint = True)
23
+
24
+ def _cmdloop(self):
25
+ stdin_bak = sys.stdin
26
+ with _stdin_lock:
27
+ try:
28
+ if _stdin_fd is not None:
29
+ if not _stdin[0]:
30
+ _stdin[0] = os.fdopen(_stdin_fd)
31
+ sys.stdin = _stdin[0]
32
+ self.cmdloop()
33
+ finally:
34
+ sys.stdin = stdin_bak
35
+
36
+ def set_trace():
37
+ pdb = MultiprocessingPdb()
38
+ pdb.set_trace(sys._getframe().f_back)