python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/dist.py ADDED
@@ -0,0 +1,591 @@
1
+ import cv2
2
+ import torch.distributed as dist
3
+ import os
4
+ import functools
5
+ import wml.wml_utils as wmlu
6
+ import torch.nn as nn
7
+ import torch
8
+ from collections import OrderedDict
9
+ import subprocess
10
+ import pickle
11
+ from typing import Callable, Optional, Tuple, Union
12
+ from torch.distributed import ProcessGroup
13
+ from torch import distributed as torch_dist
14
+ from collections.abc import Iterable, Mapping
15
+ from torch import Tensor
16
+
17
+
18
+
19
+ ASYNC_NORM = (
20
+ nn.BatchNorm1d,
21
+ nn.BatchNorm2d,
22
+ nn.BatchNorm3d,
23
+ nn.InstanceNorm1d,
24
+ nn.InstanceNorm2d,
25
+ nn.InstanceNorm3d,
26
+ )
27
+
28
+ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
29
+ """Return the number of the given process group.
30
+
31
+ Note:
32
+ Calling ``get_world_size`` in non-distributed environment will return
33
+ 1.
34
+
35
+ Args:
36
+ group (ProcessGroup, optional): The process group to work on. If None,
37
+ the default process group will be used. Defaults to None.
38
+
39
+ Returns:
40
+ int: Return the number of processes of the given process group if in
41
+ distributed environment, otherwise 1.
42
+ """
43
+ if is_distributed():
44
+ # handle low versions of torch like 1.5.0 which does not support
45
+ # passing in None for group argument
46
+ if group is None:
47
+ group = get_default_group()
48
+ return torch_dist.get_world_size(group)
49
+ else:
50
+ return 1
51
+
52
+ def get_rank() -> int:
53
+ if not dist.is_available():
54
+ return 0
55
+ if not dist.is_initialized():
56
+ return 0
57
+ return dist.get_rank()
58
+
59
+ def is_main_process() -> bool:
60
+ return get_rank() == 0
61
+
62
+ def setup_dist_group(rank,world_size,port="12355",host="localhost",backend='nccl'):
63
+ os.environ['MASTER_ADDR'] = host
64
+ os.environ['MASTER_PORT'] = port
65
+ #backend: gloo, nccl
66
+ dist.init_process_group(backend,rank=rank,world_size=world_size)
67
+
68
+ def cleanup_dist_train():
69
+ dist.destroy_process_group()
70
+
71
+ @functools.lru_cache()
72
+ def _get_global_gloo_group():
73
+ """
74
+ Return a process group based on gloo backend, containing all the ranks
75
+ The result is cached.
76
+ """
77
+ if dist.get_backend() == "nccl":
78
+ return dist.new_group(backend="gloo")
79
+ else:
80
+ return dist.group.WORLD
81
+
82
+ def pyobj2tensor(pyobj, device="cuda"):
83
+ """serialize picklable python object to tensor"""
84
+ storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
85
+ return torch.ByteTensor(storage).to(device=device)
86
+
87
+ def tensor2pyobj(tensor):
88
+ """deserialize tensor to picklable python object"""
89
+ return pickle.loads(tensor.cpu().numpy().tobytes())
90
+
91
+ def _get_reduce_op(op_name):
92
+ return {
93
+ "sum": dist.ReduceOp.SUM,
94
+ "mean": dist.ReduceOp.SUM,
95
+ }[op_name.lower()]
96
+
97
+ def all_reduce(py_dict, op="sum", group=None):
98
+ """
99
+ Apply all reduce function for python dict object.
100
+ NOTE: make sure that every py_dict has the same keys and values are in the same shape.
101
+
102
+ Args:
103
+ py_dict (dict): dict to apply all reduce op.
104
+ op (str): operator, could be "sum" or "mean".
105
+ """
106
+ world_size = get_world_size()
107
+ if world_size == 1:
108
+ return py_dict
109
+ if group is None:
110
+ group = _get_global_gloo_group()
111
+ if dist.get_world_size(group) == 1:
112
+ return py_dict
113
+
114
+ # all reduce logic across different devices.
115
+ py_key = list(py_dict.keys())
116
+ py_key_tensor = pyobj2tensor(py_key)
117
+ dist.broadcast(py_key_tensor, src=0)
118
+ py_key = tensor2pyobj(py_key_tensor)
119
+
120
+ tensor_shapes = [py_dict[k].shape for k in py_key]
121
+ tensor_numels = [py_dict[k].numel() for k in py_key]
122
+
123
+ flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
124
+ dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
125
+ if op == "mean":
126
+ flatten_tensor /= world_size
127
+
128
+ split_tensors = [
129
+ x.reshape(shape)
130
+ for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
131
+ ]
132
+ return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
133
+
134
+ def get_async_norm_states(module):
135
+ async_norm_states = OrderedDict()
136
+ for name, child in module.named_modules():
137
+ if isinstance(child, ASYNC_NORM):
138
+ for k, v in child.state_dict().items():
139
+ async_norm_states[".".join([name, k])] = v
140
+ return async_norm_states
141
+
142
+ def all_reduce_norm(module):
143
+ """
144
+ All reduce norm statistics in different devices.
145
+ """
146
+ states = get_async_norm_states(module)
147
+ print("Reduce keys:")
148
+ wmlu.show_list(list(states.keys()))
149
+ states = all_reduce(states, op="mean")
150
+ module.load_state_dict(states, strict=False)
151
+
152
+ def _find_free_port():
153
+ """
154
+ Find an available port of current machine / node.
155
+ """
156
+ import socket
157
+
158
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
159
+ # Binding to port 0 will cause the OS to find an available port for us
160
+ sock.bind(("", 0))
161
+ port = sock.getsockname()[1]
162
+ sock.close()
163
+ # NOTE: there is still a chance the port could be taken by other processes.
164
+ return port
165
+
166
+ def convert_sync_batchnorm(module, process_group=None):
167
+ r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
168
+ :class:`torch.nn.SyncBatchNorm` layers.
169
+
170
+ Args:
171
+ module (nn.Module): module containing one or more attr:`BatchNorm*D` layers
172
+ process_group (optional): process group to scope synchronization,
173
+ default is the whole world
174
+
175
+ Returns:
176
+ The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
177
+ layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
178
+ a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
179
+ instead.
180
+
181
+ Example::
182
+
183
+ >>> # Network with nn.BatchNorm layer
184
+ >>> module = torch.nn.Sequential(
185
+ >>> torch.nn.Linear(20, 100),
186
+ >>> torch.nn.BatchNorm1d(100),
187
+ >>> ).cuda()
188
+ >>> # creating process group (optional)
189
+ >>> # ranks is a list of int identifying rank ids.
190
+ >>> ranks = list(range(8))
191
+ >>> r1, r2 = ranks[:4], ranks[4:]
192
+ >>> # Note: every rank calls into new_group for every
193
+ >>> # process group created, even if that rank is not
194
+ >>> # part of the group.
195
+ >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
196
+ >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
197
+ >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
198
+
199
+ """
200
+ module_output = module
201
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.training:
202
+ module_output = torch.nn.SyncBatchNorm(module.num_features,
203
+ module.eps, module.momentum,
204
+ module.affine,
205
+ module.track_running_stats,
206
+ process_group)
207
+ if module.affine:
208
+ with torch.no_grad():
209
+ module_output.weight = module.weight
210
+ module_output.bias = module.bias
211
+ module_output.running_mean = module.running_mean
212
+ module_output.running_var = module.running_var
213
+ module_output.num_batches_tracked = module.num_batches_tracked
214
+ if hasattr(module, "qconfig"):
215
+ module_output.qconfig = module.qconfig
216
+ for name, child in module.named_children():
217
+ module_output.add_module(name, convert_sync_batchnorm(child, process_group))
218
+ del module
219
+ return module_output
220
+
221
+ def configure_nccl():
222
+ """Configure multi-machine environment variables of NCCL."""
223
+ os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
224
+ os.environ["NCCL_IB_HCA"] = subprocess.getoutput(
225
+ "pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; "
226
+ "do cat $i/ports/1/gid_attrs/types/* 2>/dev/null "
227
+ "| grep v >/dev/null && echo $i ; done; popd > /dev/null"
228
+ )
229
+ os.environ["NCCL_IB_GID_INDEX"] = "3"
230
+ os.environ["NCCL_IB_TC"] = "106"
231
+
232
+
233
+ def configure_omp(num_threads=1):
234
+ """
235
+ If OMP_NUM_THREADS is not configured and world_size is greater than 1,
236
+ Configure OMP_NUM_THREADS environment variables of NCCL to `num_thread`.
237
+
238
+ Args:
239
+ num_threads (int): value of `OMP_NUM_THREADS` to set.
240
+ """
241
+ # We set OMP_NUM_THREADS=1 by default, which achieves the best speed on our machines
242
+ # feel free to change it for better performance.
243
+ if "OMP_NUM_THREADS" not in os.environ and get_world_size() > 1:
244
+ os.environ["OMP_NUM_THREADS"] = str(num_threads)
245
+ if is_main_process():
246
+ print(
247
+ "\n***************************************************************\n"
248
+ "We set `OMP_NUM_THREADS` for each process to {} to speed up.\n"
249
+ "please further tune the variable for optimal performance.\n"
250
+ "***************************************************************".format(
251
+ os.environ["OMP_NUM_THREADS"]
252
+ )
253
+ )
254
+
255
+
256
+ def configure_module(ulimit_value=8192):
257
+ """
258
+ Configure pytorch module environment. setting of ulimit and cv2 will be set.
259
+
260
+ Args:
261
+ ulimit_value(int): default open file number on linux. Default value: 8192.
262
+ """
263
+ # system setting
264
+ try:
265
+ import resource
266
+
267
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
268
+ resource.setrlimit(resource.RLIMIT_NOFILE, (ulimit_value, rlimit[1]))
269
+ except Exception:
270
+ # Exception might be raised in Windows OS or rlimit reaches max limit number.
271
+ # However, set rlimit value might not be necessary.
272
+ pass
273
+
274
+ # cv2
275
+ # multiprocess might be harmful on performance of torch dataloader
276
+ os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
277
+ try:
278
+ cv2.setNumThreads(0)
279
+ cv2.ocl.setUseOpenCL(False)
280
+ except Exception:
281
+ # cv2 version mismatch might rasie exceptions.
282
+ pass
283
+
284
+ def reduce_mean(tensor):
285
+ """"Obtain the mean of tensor on different GPUs."""
286
+ if not (dist.is_available() and dist.is_initialized()):
287
+ return tensor
288
+ tensor = tensor.clone()
289
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
290
+ return tensor
291
+
292
+ def is_distributed() -> bool:
293
+ """Return True if distributed environment has been initialized."""
294
+ return torch_dist.is_available() and torch_dist.is_initialized()
295
+
296
+ def get_default_group() -> Optional[ProcessGroup]:
297
+ """Return default process group."""
298
+
299
+ return torch_dist.distributed_c10d._get_default_group()
300
+
301
+ def barrier(group: Optional[ProcessGroup] = None) -> None:
302
+ """Synchronize all processes from the given process group.
303
+
304
+ This collective blocks processes until the whole group enters this
305
+ function.
306
+
307
+ Note:
308
+ Calling ``barrier`` in non-distributed environment will do nothing.
309
+
310
+ Args:
311
+ group (ProcessGroup, optional): The process group to work on. If None,
312
+ the default process group will be used. Defaults to None.
313
+ """
314
+ if is_distributed():
315
+ # handle low versions of torch like 1.5.0 which does not support
316
+ # passing in None for group argument
317
+ if group is None:
318
+ group = get_default_group()
319
+ torch_dist.barrier(group)
320
+
321
+ def broadcast(data: Tensor,
322
+ src: int = 0,
323
+ group: Optional[ProcessGroup] = None) -> None:
324
+ """Broadcast the data from ``src`` process to the whole group.
325
+
326
+ ``data`` must have the same number of elements in all processes
327
+ participating in the collective.
328
+
329
+ Note:
330
+ Calling ``broadcast`` in non-distributed environment does nothing.
331
+
332
+ Args:
333
+ data (Tensor): Data to be sent if ``src`` is the rank of current
334
+ process, and data to be used to save received data otherwise.
335
+ src (int): Source rank. Defaults to 0.
336
+ group (ProcessGroup, optional): The process group to work on. If None,
337
+ the default process group will be used. Defaults to None.
338
+
339
+ Examples:
340
+ >>> import torch
341
+ >>> import mmengine.dist as dist
342
+
343
+ >>> # non-distributed environment
344
+ >>> data = torch.arange(2, dtype=torch.int64)
345
+ >>> data
346
+ tensor([0, 1])
347
+ >>> dist.broadcast(data)
348
+ >>> data
349
+ tensor([0, 1])
350
+
351
+ >>> # distributed environment
352
+ >>> # We have 2 process groups, 2 ranks.
353
+ >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
354
+ >>> data
355
+ tensor([1, 2]) # Rank 0
356
+ tensor([3, 4]) # Rank 1
357
+ >>> dist.broadcast(data)
358
+ >>> data
359
+ tensor([1, 2]) # Rank 0
360
+ tensor([1, 2]) # Rank 1
361
+ """
362
+ if get_world_size(group) > 1:
363
+ if group is None:
364
+ group = get_default_group()
365
+
366
+ input_device = get_data_device(data)
367
+ backend_device = get_comm_device(group)
368
+ data_on_device = cast_data_device(data, backend_device)
369
+ # broadcast requires tensor is contiguous
370
+ data_on_device = data_on_device.contiguous() # type: ignore
371
+ torch_dist.broadcast(data_on_device, src, group)
372
+
373
+ if get_rank(group) != src:
374
+ cast_data_device(data_on_device, input_device, data)
375
+
376
+
377
+ def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device:
378
+ """Return the device of ``data``.
379
+
380
+ If ``data`` is a sequence of Tensor, all items in ``data`` should have a
381
+ same device type.
382
+
383
+ If ``data`` is a dict whose values are Tensor, all values should have a
384
+ same device type.
385
+
386
+ Args:
387
+ data (Tensor or Sequence or dict): Inputs to be inferred the device.
388
+
389
+ Returns:
390
+ torch.device: The device of ``data``.
391
+
392
+ Examples:
393
+ >>> import torch
394
+ >>> from mmengine.dist import cast_data_device
395
+ >>> # data is a Tensor
396
+ >>> data = torch.tensor([0, 1])
397
+ >>> get_data_device(data)
398
+ device(type='cpu')
399
+ >>> # data is a list of Tensor
400
+ >>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
401
+ >>> get_data_device(data)
402
+ device(type='cpu')
403
+ >>> # data is a dict
404
+ >>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
405
+ >>> get_data_device(data)
406
+ device(type='cpu')
407
+ """
408
+ if isinstance(data, Tensor):
409
+ return data.device
410
+ elif isinstance(data, Mapping):
411
+ pre = None
412
+ for v in data.values():
413
+ cur = get_data_device(v)
414
+ if pre is None:
415
+ pre = cur
416
+ else:
417
+ if cur != pre:
418
+ raise ValueError(
419
+ 'device type in data should be consistent, but got '
420
+ f'{cur} and {pre}')
421
+ if pre is None:
422
+ raise ValueError('data should not be empty.')
423
+ return pre
424
+ elif isinstance(data, Iterable) and not isinstance(data, str):
425
+ pre = None
426
+ for item in data:
427
+ cur = get_data_device(item)
428
+ if pre is None:
429
+ pre = cur
430
+ else:
431
+ if cur != pre:
432
+ raise ValueError(
433
+ 'device type in data should be consistent, but got '
434
+ f'{cur} and {pre}')
435
+ if pre is None:
436
+ raise ValueError('data should not be empty.')
437
+ return pre
438
+ else:
439
+ raise TypeError('data should be a Tensor, sequence of tensor or dict, '
440
+ f'but got {data}')
441
+
442
+
443
+ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
444
+ """Return the device for communication among groups.
445
+
446
+ Args:
447
+ group (ProcessGroup, optional): The process group to work on.
448
+
449
+ Returns:
450
+ torch.device: The device of backend.
451
+ """
452
+ backend = get_backend(group)
453
+ if backend == 'hccl':
454
+ import torch_npu # noqa: F401
455
+ return torch.device('npu', torch.npu.current_device())
456
+ elif backend == torch_dist.Backend.NCCL:
457
+ return torch.device('cuda', torch.cuda.current_device())
458
+ elif backend == 'cncl':
459
+ import torch_mlu # noqa: F401
460
+ return torch.device('mlu', torch.mlu.current_device())
461
+ elif backend == 'smddp':
462
+ return torch.device('cuda', torch.cuda.current_device())
463
+ else:
464
+ # GLOO and MPI backends use cpu device by default
465
+ return torch.device('cpu')
466
+
467
+
468
+ def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]:
469
+ """Return the backend of the given process group.
470
+
471
+ Note:
472
+ Calling ``get_backend`` in non-distributed environment will return
473
+ None.
474
+
475
+ Args:
476
+ group (ProcessGroup, optional): The process group to work on. The
477
+ default is the general main process group. If another specific
478
+ group is specified, the calling process must be part of
479
+ :attr:`group`. Defaults to None.
480
+
481
+ Returns:
482
+ str or None: Return the backend of the given process group as a lower
483
+ case string if in distributed environment, otherwise None.
484
+ """
485
+ if is_distributed():
486
+ # handle low versions of torch like 1.5.0 which does not support
487
+ # passing in None for group argument
488
+ if group is None:
489
+ group = get_default_group()
490
+ return torch_dist.get_backend(group)
491
+ else:
492
+ return None
493
+
494
+
495
+ def cast_data_device(
496
+ data: Union[Tensor, Mapping, Iterable],
497
+ device: torch.device,
498
+ out: Optional[Union[Tensor, Mapping, Iterable]] = None
499
+ ) -> Union[Tensor, Mapping, Iterable]:
500
+ """Recursively convert Tensor in ``data`` to ``device``.
501
+
502
+ If ``data`` has already on the ``device``, it will not be casted again.
503
+
504
+ Args:
505
+ data (Tensor or list or dict): Inputs to be casted.
506
+ device (torch.device): Destination device type.
507
+ out (Tensor or list or dict, optional): If ``out`` is specified, its
508
+ value will be equal to ``data``. Defaults to None.
509
+
510
+ Returns:
511
+ Tensor or list or dict: ``data`` was casted to ``device``.
512
+ """
513
+ if out is not None:
514
+ if type(data) != type(out):
515
+ raise TypeError(
516
+ 'out should be the same type with data, but got data is '
517
+ f'{type(data)} and out is {type(data)}')
518
+
519
+ if isinstance(out, set):
520
+ raise TypeError('out should not be a set')
521
+
522
+ if isinstance(data, Tensor):
523
+ if get_data_device(data) == device:
524
+ data_on_device = data
525
+ else:
526
+ data_on_device = data.to(device)
527
+
528
+ if out is not None:
529
+ # modify the value of out inplace
530
+ out.copy_(data_on_device) # type: ignore
531
+
532
+ return data_on_device
533
+ elif isinstance(data, Mapping):
534
+ data_on_device = {}
535
+ if out is not None:
536
+ data_len = len(data)
537
+ out_len = len(out) # type: ignore
538
+ if data_len != out_len:
539
+ raise ValueError('length of data and out should be same, '
540
+ f'but got {data_len} and {out_len}')
541
+
542
+ for k, v in data.items():
543
+ data_on_device[k] = cast_data_device(v, device,
544
+ out[k]) # type: ignore
545
+ else:
546
+ for k, v in data.items():
547
+ data_on_device[k] = cast_data_device(v, device)
548
+
549
+ if len(data_on_device) == 0:
550
+ raise ValueError('data should not be empty')
551
+
552
+ # To ensure the type of output as same as input, we use `type(data)`
553
+ # to wrap the output
554
+ return type(data)(data_on_device) # type: ignore
555
+ elif isinstance(data, Iterable) and not isinstance(
556
+ data, str) and not isinstance(data, np.ndarray):
557
+ data_on_device = []
558
+ if out is not None:
559
+ for v1, v2 in zip(data, out):
560
+ data_on_device.append(cast_data_device(v1, device, v2))
561
+ else:
562
+ for v in data:
563
+ data_on_device.append(cast_data_device(v, device))
564
+
565
+ if len(data_on_device) == 0:
566
+ raise ValueError('data should not be empty')
567
+
568
+ return type(data)(data_on_device) # type: ignore
569
+ else:
570
+ raise TypeError('data should be a Tensor, list of tensor or dict, '
571
+ f'but got {data}')
572
+
573
+
574
+ def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]:
575
+ """Get distributed information of the given process group.
576
+
577
+ Note:
578
+ Calling ``get_dist_info`` in non-distributed environment will return
579
+ (0, 1).
580
+
581
+ Args:
582
+ group (ProcessGroup, optional): The process group to work on. If None,
583
+ the default process group will be used. Defaults to None.
584
+
585
+ Returns:
586
+ tuple[int, int]: Return a tuple containing the ``rank`` and
587
+ ``world_size``.
588
+ """
589
+ world_size = get_world_size(group)
590
+ rank = get_rank(group)
591
+ return rank, world_size
@@ -0,0 +1,6 @@
1
+ from .dropblock import DropBlock2D, WDropBlock2D, DropBlock3D
2
+ from .scheduler import LinearScheduler
3
+ from .dropout import WDropout
4
+
5
+
6
+ __all__ = ['DropBlock2D', 'DropBlock3D', 'LinearScheduler','WDropout', 'WDropBlock2D']