python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/dist.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
import os
|
|
4
|
+
import functools
|
|
5
|
+
import wml.wml_utils as wmlu
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
import subprocess
|
|
10
|
+
import pickle
|
|
11
|
+
from typing import Callable, Optional, Tuple, Union
|
|
12
|
+
from torch.distributed import ProcessGroup
|
|
13
|
+
from torch import distributed as torch_dist
|
|
14
|
+
from collections.abc import Iterable, Mapping
|
|
15
|
+
from torch import Tensor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
ASYNC_NORM = (
|
|
20
|
+
nn.BatchNorm1d,
|
|
21
|
+
nn.BatchNorm2d,
|
|
22
|
+
nn.BatchNorm3d,
|
|
23
|
+
nn.InstanceNorm1d,
|
|
24
|
+
nn.InstanceNorm2d,
|
|
25
|
+
nn.InstanceNorm3d,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
|
|
29
|
+
"""Return the number of the given process group.
|
|
30
|
+
|
|
31
|
+
Note:
|
|
32
|
+
Calling ``get_world_size`` in non-distributed environment will return
|
|
33
|
+
1.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
group (ProcessGroup, optional): The process group to work on. If None,
|
|
37
|
+
the default process group will be used. Defaults to None.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
int: Return the number of processes of the given process group if in
|
|
41
|
+
distributed environment, otherwise 1.
|
|
42
|
+
"""
|
|
43
|
+
if is_distributed():
|
|
44
|
+
# handle low versions of torch like 1.5.0 which does not support
|
|
45
|
+
# passing in None for group argument
|
|
46
|
+
if group is None:
|
|
47
|
+
group = get_default_group()
|
|
48
|
+
return torch_dist.get_world_size(group)
|
|
49
|
+
else:
|
|
50
|
+
return 1
|
|
51
|
+
|
|
52
|
+
def get_rank() -> int:
|
|
53
|
+
if not dist.is_available():
|
|
54
|
+
return 0
|
|
55
|
+
if not dist.is_initialized():
|
|
56
|
+
return 0
|
|
57
|
+
return dist.get_rank()
|
|
58
|
+
|
|
59
|
+
def is_main_process() -> bool:
|
|
60
|
+
return get_rank() == 0
|
|
61
|
+
|
|
62
|
+
def setup_dist_group(rank,world_size,port="12355",host="localhost",backend='nccl'):
|
|
63
|
+
os.environ['MASTER_ADDR'] = host
|
|
64
|
+
os.environ['MASTER_PORT'] = port
|
|
65
|
+
#backend: gloo, nccl
|
|
66
|
+
dist.init_process_group(backend,rank=rank,world_size=world_size)
|
|
67
|
+
|
|
68
|
+
def cleanup_dist_train():
|
|
69
|
+
dist.destroy_process_group()
|
|
70
|
+
|
|
71
|
+
@functools.lru_cache()
|
|
72
|
+
def _get_global_gloo_group():
|
|
73
|
+
"""
|
|
74
|
+
Return a process group based on gloo backend, containing all the ranks
|
|
75
|
+
The result is cached.
|
|
76
|
+
"""
|
|
77
|
+
if dist.get_backend() == "nccl":
|
|
78
|
+
return dist.new_group(backend="gloo")
|
|
79
|
+
else:
|
|
80
|
+
return dist.group.WORLD
|
|
81
|
+
|
|
82
|
+
def pyobj2tensor(pyobj, device="cuda"):
|
|
83
|
+
"""serialize picklable python object to tensor"""
|
|
84
|
+
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
|
|
85
|
+
return torch.ByteTensor(storage).to(device=device)
|
|
86
|
+
|
|
87
|
+
def tensor2pyobj(tensor):
|
|
88
|
+
"""deserialize tensor to picklable python object"""
|
|
89
|
+
return pickle.loads(tensor.cpu().numpy().tobytes())
|
|
90
|
+
|
|
91
|
+
def _get_reduce_op(op_name):
|
|
92
|
+
return {
|
|
93
|
+
"sum": dist.ReduceOp.SUM,
|
|
94
|
+
"mean": dist.ReduceOp.SUM,
|
|
95
|
+
}[op_name.lower()]
|
|
96
|
+
|
|
97
|
+
def all_reduce(py_dict, op="sum", group=None):
|
|
98
|
+
"""
|
|
99
|
+
Apply all reduce function for python dict object.
|
|
100
|
+
NOTE: make sure that every py_dict has the same keys and values are in the same shape.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
py_dict (dict): dict to apply all reduce op.
|
|
104
|
+
op (str): operator, could be "sum" or "mean".
|
|
105
|
+
"""
|
|
106
|
+
world_size = get_world_size()
|
|
107
|
+
if world_size == 1:
|
|
108
|
+
return py_dict
|
|
109
|
+
if group is None:
|
|
110
|
+
group = _get_global_gloo_group()
|
|
111
|
+
if dist.get_world_size(group) == 1:
|
|
112
|
+
return py_dict
|
|
113
|
+
|
|
114
|
+
# all reduce logic across different devices.
|
|
115
|
+
py_key = list(py_dict.keys())
|
|
116
|
+
py_key_tensor = pyobj2tensor(py_key)
|
|
117
|
+
dist.broadcast(py_key_tensor, src=0)
|
|
118
|
+
py_key = tensor2pyobj(py_key_tensor)
|
|
119
|
+
|
|
120
|
+
tensor_shapes = [py_dict[k].shape for k in py_key]
|
|
121
|
+
tensor_numels = [py_dict[k].numel() for k in py_key]
|
|
122
|
+
|
|
123
|
+
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
|
|
124
|
+
dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
|
|
125
|
+
if op == "mean":
|
|
126
|
+
flatten_tensor /= world_size
|
|
127
|
+
|
|
128
|
+
split_tensors = [
|
|
129
|
+
x.reshape(shape)
|
|
130
|
+
for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
|
|
131
|
+
]
|
|
132
|
+
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
|
|
133
|
+
|
|
134
|
+
def get_async_norm_states(module):
|
|
135
|
+
async_norm_states = OrderedDict()
|
|
136
|
+
for name, child in module.named_modules():
|
|
137
|
+
if isinstance(child, ASYNC_NORM):
|
|
138
|
+
for k, v in child.state_dict().items():
|
|
139
|
+
async_norm_states[".".join([name, k])] = v
|
|
140
|
+
return async_norm_states
|
|
141
|
+
|
|
142
|
+
def all_reduce_norm(module):
|
|
143
|
+
"""
|
|
144
|
+
All reduce norm statistics in different devices.
|
|
145
|
+
"""
|
|
146
|
+
states = get_async_norm_states(module)
|
|
147
|
+
print("Reduce keys:")
|
|
148
|
+
wmlu.show_list(list(states.keys()))
|
|
149
|
+
states = all_reduce(states, op="mean")
|
|
150
|
+
module.load_state_dict(states, strict=False)
|
|
151
|
+
|
|
152
|
+
def _find_free_port():
|
|
153
|
+
"""
|
|
154
|
+
Find an available port of current machine / node.
|
|
155
|
+
"""
|
|
156
|
+
import socket
|
|
157
|
+
|
|
158
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
159
|
+
# Binding to port 0 will cause the OS to find an available port for us
|
|
160
|
+
sock.bind(("", 0))
|
|
161
|
+
port = sock.getsockname()[1]
|
|
162
|
+
sock.close()
|
|
163
|
+
# NOTE: there is still a chance the port could be taken by other processes.
|
|
164
|
+
return port
|
|
165
|
+
|
|
166
|
+
def convert_sync_batchnorm(module, process_group=None):
|
|
167
|
+
r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
|
|
168
|
+
:class:`torch.nn.SyncBatchNorm` layers.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
module (nn.Module): module containing one or more attr:`BatchNorm*D` layers
|
|
172
|
+
process_group (optional): process group to scope synchronization,
|
|
173
|
+
default is the whole world
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
|
|
177
|
+
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
|
|
178
|
+
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
|
|
179
|
+
instead.
|
|
180
|
+
|
|
181
|
+
Example::
|
|
182
|
+
|
|
183
|
+
>>> # Network with nn.BatchNorm layer
|
|
184
|
+
>>> module = torch.nn.Sequential(
|
|
185
|
+
>>> torch.nn.Linear(20, 100),
|
|
186
|
+
>>> torch.nn.BatchNorm1d(100),
|
|
187
|
+
>>> ).cuda()
|
|
188
|
+
>>> # creating process group (optional)
|
|
189
|
+
>>> # ranks is a list of int identifying rank ids.
|
|
190
|
+
>>> ranks = list(range(8))
|
|
191
|
+
>>> r1, r2 = ranks[:4], ranks[4:]
|
|
192
|
+
>>> # Note: every rank calls into new_group for every
|
|
193
|
+
>>> # process group created, even if that rank is not
|
|
194
|
+
>>> # part of the group.
|
|
195
|
+
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
|
196
|
+
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
|
197
|
+
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
|
|
198
|
+
|
|
199
|
+
"""
|
|
200
|
+
module_output = module
|
|
201
|
+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.training:
|
|
202
|
+
module_output = torch.nn.SyncBatchNorm(module.num_features,
|
|
203
|
+
module.eps, module.momentum,
|
|
204
|
+
module.affine,
|
|
205
|
+
module.track_running_stats,
|
|
206
|
+
process_group)
|
|
207
|
+
if module.affine:
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
module_output.weight = module.weight
|
|
210
|
+
module_output.bias = module.bias
|
|
211
|
+
module_output.running_mean = module.running_mean
|
|
212
|
+
module_output.running_var = module.running_var
|
|
213
|
+
module_output.num_batches_tracked = module.num_batches_tracked
|
|
214
|
+
if hasattr(module, "qconfig"):
|
|
215
|
+
module_output.qconfig = module.qconfig
|
|
216
|
+
for name, child in module.named_children():
|
|
217
|
+
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
|
|
218
|
+
del module
|
|
219
|
+
return module_output
|
|
220
|
+
|
|
221
|
+
def configure_nccl():
|
|
222
|
+
"""Configure multi-machine environment variables of NCCL."""
|
|
223
|
+
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
|
|
224
|
+
os.environ["NCCL_IB_HCA"] = subprocess.getoutput(
|
|
225
|
+
"pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; "
|
|
226
|
+
"do cat $i/ports/1/gid_attrs/types/* 2>/dev/null "
|
|
227
|
+
"| grep v >/dev/null && echo $i ; done; popd > /dev/null"
|
|
228
|
+
)
|
|
229
|
+
os.environ["NCCL_IB_GID_INDEX"] = "3"
|
|
230
|
+
os.environ["NCCL_IB_TC"] = "106"
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def configure_omp(num_threads=1):
|
|
234
|
+
"""
|
|
235
|
+
If OMP_NUM_THREADS is not configured and world_size is greater than 1,
|
|
236
|
+
Configure OMP_NUM_THREADS environment variables of NCCL to `num_thread`.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
num_threads (int): value of `OMP_NUM_THREADS` to set.
|
|
240
|
+
"""
|
|
241
|
+
# We set OMP_NUM_THREADS=1 by default, which achieves the best speed on our machines
|
|
242
|
+
# feel free to change it for better performance.
|
|
243
|
+
if "OMP_NUM_THREADS" not in os.environ and get_world_size() > 1:
|
|
244
|
+
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
|
245
|
+
if is_main_process():
|
|
246
|
+
print(
|
|
247
|
+
"\n***************************************************************\n"
|
|
248
|
+
"We set `OMP_NUM_THREADS` for each process to {} to speed up.\n"
|
|
249
|
+
"please further tune the variable for optimal performance.\n"
|
|
250
|
+
"***************************************************************".format(
|
|
251
|
+
os.environ["OMP_NUM_THREADS"]
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def configure_module(ulimit_value=8192):
|
|
257
|
+
"""
|
|
258
|
+
Configure pytorch module environment. setting of ulimit and cv2 will be set.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
ulimit_value(int): default open file number on linux. Default value: 8192.
|
|
262
|
+
"""
|
|
263
|
+
# system setting
|
|
264
|
+
try:
|
|
265
|
+
import resource
|
|
266
|
+
|
|
267
|
+
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
268
|
+
resource.setrlimit(resource.RLIMIT_NOFILE, (ulimit_value, rlimit[1]))
|
|
269
|
+
except Exception:
|
|
270
|
+
# Exception might be raised in Windows OS or rlimit reaches max limit number.
|
|
271
|
+
# However, set rlimit value might not be necessary.
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
# cv2
|
|
275
|
+
# multiprocess might be harmful on performance of torch dataloader
|
|
276
|
+
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
|
|
277
|
+
try:
|
|
278
|
+
cv2.setNumThreads(0)
|
|
279
|
+
cv2.ocl.setUseOpenCL(False)
|
|
280
|
+
except Exception:
|
|
281
|
+
# cv2 version mismatch might rasie exceptions.
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
def reduce_mean(tensor):
|
|
285
|
+
""""Obtain the mean of tensor on different GPUs."""
|
|
286
|
+
if not (dist.is_available() and dist.is_initialized()):
|
|
287
|
+
return tensor
|
|
288
|
+
tensor = tensor.clone()
|
|
289
|
+
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
|
290
|
+
return tensor
|
|
291
|
+
|
|
292
|
+
def is_distributed() -> bool:
|
|
293
|
+
"""Return True if distributed environment has been initialized."""
|
|
294
|
+
return torch_dist.is_available() and torch_dist.is_initialized()
|
|
295
|
+
|
|
296
|
+
def get_default_group() -> Optional[ProcessGroup]:
|
|
297
|
+
"""Return default process group."""
|
|
298
|
+
|
|
299
|
+
return torch_dist.distributed_c10d._get_default_group()
|
|
300
|
+
|
|
301
|
+
def barrier(group: Optional[ProcessGroup] = None) -> None:
|
|
302
|
+
"""Synchronize all processes from the given process group.
|
|
303
|
+
|
|
304
|
+
This collective blocks processes until the whole group enters this
|
|
305
|
+
function.
|
|
306
|
+
|
|
307
|
+
Note:
|
|
308
|
+
Calling ``barrier`` in non-distributed environment will do nothing.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
group (ProcessGroup, optional): The process group to work on. If None,
|
|
312
|
+
the default process group will be used. Defaults to None.
|
|
313
|
+
"""
|
|
314
|
+
if is_distributed():
|
|
315
|
+
# handle low versions of torch like 1.5.0 which does not support
|
|
316
|
+
# passing in None for group argument
|
|
317
|
+
if group is None:
|
|
318
|
+
group = get_default_group()
|
|
319
|
+
torch_dist.barrier(group)
|
|
320
|
+
|
|
321
|
+
def broadcast(data: Tensor,
|
|
322
|
+
src: int = 0,
|
|
323
|
+
group: Optional[ProcessGroup] = None) -> None:
|
|
324
|
+
"""Broadcast the data from ``src`` process to the whole group.
|
|
325
|
+
|
|
326
|
+
``data`` must have the same number of elements in all processes
|
|
327
|
+
participating in the collective.
|
|
328
|
+
|
|
329
|
+
Note:
|
|
330
|
+
Calling ``broadcast`` in non-distributed environment does nothing.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
data (Tensor): Data to be sent if ``src`` is the rank of current
|
|
334
|
+
process, and data to be used to save received data otherwise.
|
|
335
|
+
src (int): Source rank. Defaults to 0.
|
|
336
|
+
group (ProcessGroup, optional): The process group to work on. If None,
|
|
337
|
+
the default process group will be used. Defaults to None.
|
|
338
|
+
|
|
339
|
+
Examples:
|
|
340
|
+
>>> import torch
|
|
341
|
+
>>> import mmengine.dist as dist
|
|
342
|
+
|
|
343
|
+
>>> # non-distributed environment
|
|
344
|
+
>>> data = torch.arange(2, dtype=torch.int64)
|
|
345
|
+
>>> data
|
|
346
|
+
tensor([0, 1])
|
|
347
|
+
>>> dist.broadcast(data)
|
|
348
|
+
>>> data
|
|
349
|
+
tensor([0, 1])
|
|
350
|
+
|
|
351
|
+
>>> # distributed environment
|
|
352
|
+
>>> # We have 2 process groups, 2 ranks.
|
|
353
|
+
>>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
|
|
354
|
+
>>> data
|
|
355
|
+
tensor([1, 2]) # Rank 0
|
|
356
|
+
tensor([3, 4]) # Rank 1
|
|
357
|
+
>>> dist.broadcast(data)
|
|
358
|
+
>>> data
|
|
359
|
+
tensor([1, 2]) # Rank 0
|
|
360
|
+
tensor([1, 2]) # Rank 1
|
|
361
|
+
"""
|
|
362
|
+
if get_world_size(group) > 1:
|
|
363
|
+
if group is None:
|
|
364
|
+
group = get_default_group()
|
|
365
|
+
|
|
366
|
+
input_device = get_data_device(data)
|
|
367
|
+
backend_device = get_comm_device(group)
|
|
368
|
+
data_on_device = cast_data_device(data, backend_device)
|
|
369
|
+
# broadcast requires tensor is contiguous
|
|
370
|
+
data_on_device = data_on_device.contiguous() # type: ignore
|
|
371
|
+
torch_dist.broadcast(data_on_device, src, group)
|
|
372
|
+
|
|
373
|
+
if get_rank(group) != src:
|
|
374
|
+
cast_data_device(data_on_device, input_device, data)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device:
|
|
378
|
+
"""Return the device of ``data``.
|
|
379
|
+
|
|
380
|
+
If ``data`` is a sequence of Tensor, all items in ``data`` should have a
|
|
381
|
+
same device type.
|
|
382
|
+
|
|
383
|
+
If ``data`` is a dict whose values are Tensor, all values should have a
|
|
384
|
+
same device type.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
data (Tensor or Sequence or dict): Inputs to be inferred the device.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
torch.device: The device of ``data``.
|
|
391
|
+
|
|
392
|
+
Examples:
|
|
393
|
+
>>> import torch
|
|
394
|
+
>>> from mmengine.dist import cast_data_device
|
|
395
|
+
>>> # data is a Tensor
|
|
396
|
+
>>> data = torch.tensor([0, 1])
|
|
397
|
+
>>> get_data_device(data)
|
|
398
|
+
device(type='cpu')
|
|
399
|
+
>>> # data is a list of Tensor
|
|
400
|
+
>>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
401
|
+
>>> get_data_device(data)
|
|
402
|
+
device(type='cpu')
|
|
403
|
+
>>> # data is a dict
|
|
404
|
+
>>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
|
|
405
|
+
>>> get_data_device(data)
|
|
406
|
+
device(type='cpu')
|
|
407
|
+
"""
|
|
408
|
+
if isinstance(data, Tensor):
|
|
409
|
+
return data.device
|
|
410
|
+
elif isinstance(data, Mapping):
|
|
411
|
+
pre = None
|
|
412
|
+
for v in data.values():
|
|
413
|
+
cur = get_data_device(v)
|
|
414
|
+
if pre is None:
|
|
415
|
+
pre = cur
|
|
416
|
+
else:
|
|
417
|
+
if cur != pre:
|
|
418
|
+
raise ValueError(
|
|
419
|
+
'device type in data should be consistent, but got '
|
|
420
|
+
f'{cur} and {pre}')
|
|
421
|
+
if pre is None:
|
|
422
|
+
raise ValueError('data should not be empty.')
|
|
423
|
+
return pre
|
|
424
|
+
elif isinstance(data, Iterable) and not isinstance(data, str):
|
|
425
|
+
pre = None
|
|
426
|
+
for item in data:
|
|
427
|
+
cur = get_data_device(item)
|
|
428
|
+
if pre is None:
|
|
429
|
+
pre = cur
|
|
430
|
+
else:
|
|
431
|
+
if cur != pre:
|
|
432
|
+
raise ValueError(
|
|
433
|
+
'device type in data should be consistent, but got '
|
|
434
|
+
f'{cur} and {pre}')
|
|
435
|
+
if pre is None:
|
|
436
|
+
raise ValueError('data should not be empty.')
|
|
437
|
+
return pre
|
|
438
|
+
else:
|
|
439
|
+
raise TypeError('data should be a Tensor, sequence of tensor or dict, '
|
|
440
|
+
f'but got {data}')
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
|
|
444
|
+
"""Return the device for communication among groups.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
group (ProcessGroup, optional): The process group to work on.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
torch.device: The device of backend.
|
|
451
|
+
"""
|
|
452
|
+
backend = get_backend(group)
|
|
453
|
+
if backend == 'hccl':
|
|
454
|
+
import torch_npu # noqa: F401
|
|
455
|
+
return torch.device('npu', torch.npu.current_device())
|
|
456
|
+
elif backend == torch_dist.Backend.NCCL:
|
|
457
|
+
return torch.device('cuda', torch.cuda.current_device())
|
|
458
|
+
elif backend == 'cncl':
|
|
459
|
+
import torch_mlu # noqa: F401
|
|
460
|
+
return torch.device('mlu', torch.mlu.current_device())
|
|
461
|
+
elif backend == 'smddp':
|
|
462
|
+
return torch.device('cuda', torch.cuda.current_device())
|
|
463
|
+
else:
|
|
464
|
+
# GLOO and MPI backends use cpu device by default
|
|
465
|
+
return torch.device('cpu')
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]:
|
|
469
|
+
"""Return the backend of the given process group.
|
|
470
|
+
|
|
471
|
+
Note:
|
|
472
|
+
Calling ``get_backend`` in non-distributed environment will return
|
|
473
|
+
None.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
group (ProcessGroup, optional): The process group to work on. The
|
|
477
|
+
default is the general main process group. If another specific
|
|
478
|
+
group is specified, the calling process must be part of
|
|
479
|
+
:attr:`group`. Defaults to None.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
str or None: Return the backend of the given process group as a lower
|
|
483
|
+
case string if in distributed environment, otherwise None.
|
|
484
|
+
"""
|
|
485
|
+
if is_distributed():
|
|
486
|
+
# handle low versions of torch like 1.5.0 which does not support
|
|
487
|
+
# passing in None for group argument
|
|
488
|
+
if group is None:
|
|
489
|
+
group = get_default_group()
|
|
490
|
+
return torch_dist.get_backend(group)
|
|
491
|
+
else:
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def cast_data_device(
|
|
496
|
+
data: Union[Tensor, Mapping, Iterable],
|
|
497
|
+
device: torch.device,
|
|
498
|
+
out: Optional[Union[Tensor, Mapping, Iterable]] = None
|
|
499
|
+
) -> Union[Tensor, Mapping, Iterable]:
|
|
500
|
+
"""Recursively convert Tensor in ``data`` to ``device``.
|
|
501
|
+
|
|
502
|
+
If ``data`` has already on the ``device``, it will not be casted again.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
data (Tensor or list or dict): Inputs to be casted.
|
|
506
|
+
device (torch.device): Destination device type.
|
|
507
|
+
out (Tensor or list or dict, optional): If ``out`` is specified, its
|
|
508
|
+
value will be equal to ``data``. Defaults to None.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
Tensor or list or dict: ``data`` was casted to ``device``.
|
|
512
|
+
"""
|
|
513
|
+
if out is not None:
|
|
514
|
+
if type(data) != type(out):
|
|
515
|
+
raise TypeError(
|
|
516
|
+
'out should be the same type with data, but got data is '
|
|
517
|
+
f'{type(data)} and out is {type(data)}')
|
|
518
|
+
|
|
519
|
+
if isinstance(out, set):
|
|
520
|
+
raise TypeError('out should not be a set')
|
|
521
|
+
|
|
522
|
+
if isinstance(data, Tensor):
|
|
523
|
+
if get_data_device(data) == device:
|
|
524
|
+
data_on_device = data
|
|
525
|
+
else:
|
|
526
|
+
data_on_device = data.to(device)
|
|
527
|
+
|
|
528
|
+
if out is not None:
|
|
529
|
+
# modify the value of out inplace
|
|
530
|
+
out.copy_(data_on_device) # type: ignore
|
|
531
|
+
|
|
532
|
+
return data_on_device
|
|
533
|
+
elif isinstance(data, Mapping):
|
|
534
|
+
data_on_device = {}
|
|
535
|
+
if out is not None:
|
|
536
|
+
data_len = len(data)
|
|
537
|
+
out_len = len(out) # type: ignore
|
|
538
|
+
if data_len != out_len:
|
|
539
|
+
raise ValueError('length of data and out should be same, '
|
|
540
|
+
f'but got {data_len} and {out_len}')
|
|
541
|
+
|
|
542
|
+
for k, v in data.items():
|
|
543
|
+
data_on_device[k] = cast_data_device(v, device,
|
|
544
|
+
out[k]) # type: ignore
|
|
545
|
+
else:
|
|
546
|
+
for k, v in data.items():
|
|
547
|
+
data_on_device[k] = cast_data_device(v, device)
|
|
548
|
+
|
|
549
|
+
if len(data_on_device) == 0:
|
|
550
|
+
raise ValueError('data should not be empty')
|
|
551
|
+
|
|
552
|
+
# To ensure the type of output as same as input, we use `type(data)`
|
|
553
|
+
# to wrap the output
|
|
554
|
+
return type(data)(data_on_device) # type: ignore
|
|
555
|
+
elif isinstance(data, Iterable) and not isinstance(
|
|
556
|
+
data, str) and not isinstance(data, np.ndarray):
|
|
557
|
+
data_on_device = []
|
|
558
|
+
if out is not None:
|
|
559
|
+
for v1, v2 in zip(data, out):
|
|
560
|
+
data_on_device.append(cast_data_device(v1, device, v2))
|
|
561
|
+
else:
|
|
562
|
+
for v in data:
|
|
563
|
+
data_on_device.append(cast_data_device(v, device))
|
|
564
|
+
|
|
565
|
+
if len(data_on_device) == 0:
|
|
566
|
+
raise ValueError('data should not be empty')
|
|
567
|
+
|
|
568
|
+
return type(data)(data_on_device) # type: ignore
|
|
569
|
+
else:
|
|
570
|
+
raise TypeError('data should be a Tensor, list of tensor or dict, '
|
|
571
|
+
f'but got {data}')
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]:
|
|
575
|
+
"""Get distributed information of the given process group.
|
|
576
|
+
|
|
577
|
+
Note:
|
|
578
|
+
Calling ``get_dist_info`` in non-distributed environment will return
|
|
579
|
+
(0, 1).
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
group (ProcessGroup, optional): The process group to work on. If None,
|
|
583
|
+
the default process group will be used. Defaults to None.
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
tuple[int, int]: Return a tuple containing the ``rank`` and
|
|
587
|
+
``world_size``.
|
|
588
|
+
"""
|
|
589
|
+
world_size = get_world_size(group)
|
|
590
|
+
rank = get_rank(group)
|
|
591
|
+
return rank, world_size
|