python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,183 @@
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
2
+ collate samples fetched from dataset into Tensor(s).
3
+
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import torch
9
+ import re
10
+ if torch.__version__ < "1.9.0":
11
+ from torch._six import container_abcs, string_classes, int_classes
12
+ else:
13
+ import collections as container_abcs
14
+ string_classes = (str, bytes)
15
+ int_classes = int
16
+
17
+ np_str_obj_array_pattern = re.compile(r'[SaUO]')
18
+
19
+
20
+ def default_convert(data):
21
+ r"""Converts each NumPy array data field into a tensor"""
22
+ elem_type = type(data)
23
+ if isinstance(data, torch.Tensor):
24
+ return data
25
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
26
+ and elem_type.__name__ != 'string_':
27
+ # array of string classes and object
28
+ if elem_type.__name__ == 'ndarray' \
29
+ and np_str_obj_array_pattern.search(data.dtype.str) is not None:
30
+ return data
31
+ return torch.as_tensor(data)
32
+ elif isinstance(data, container_abcs.Mapping):
33
+ return {key: default_convert(data[key]) for key in data}
34
+ elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
35
+ return elem_type(*(default_convert(d) for d in data))
36
+ elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
37
+ return [default_convert(d) for d in data]
38
+ else:
39
+ return data
40
+
41
+ def null_convert(data):
42
+ return data
43
+
44
+
45
+ default_collate_err_msg_format = (
46
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
47
+ "dicts or lists; found {}")
48
+
49
+
50
+ def default_collate(batch):
51
+ r"""Puts each data field into a tensor with outer dimension batch size"""
52
+
53
+ elem = batch[0]
54
+ elem_type = type(elem)
55
+ if isinstance(elem, torch.Tensor):
56
+ out = None
57
+ if torch.utils.data.get_worker_info() is not None:
58
+ # If we're in a background process, concatenate directly into a
59
+ # shared memory tensor to avoid an extra copy
60
+ numel = sum([x.numel() for x in batch])
61
+ storage = elem.storage()._new_shared(numel)
62
+ out = elem.new(storage)
63
+ return torch.stack(batch, 0, out=out)
64
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
65
+ and elem_type.__name__ != 'string_':
66
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
67
+ # array of string classes and object
68
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
69
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
70
+
71
+ return default_collate([torch.as_tensor(b) for b in batch])
72
+ elif elem.shape == (): # scalars
73
+ return torch.as_tensor(batch)
74
+ elif isinstance(elem, float):
75
+ return torch.tensor(batch, dtype=torch.float64)
76
+ elif isinstance(elem, int_classes):
77
+ return torch.tensor(batch)
78
+ elif isinstance(elem, string_classes):
79
+ return batch
80
+ elif isinstance(elem, container_abcs.Mapping):
81
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
82
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
83
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
84
+ elif isinstance(elem, container_abcs.Sequence):
85
+ # check to make sure that the elements in batch have consistent size
86
+ it = iter(batch)
87
+ elem_size = len(next(it))
88
+ if not all(len(elem) == elem_size for elem in it):
89
+ raise RuntimeError('each element in list of batch should be of equal size')
90
+ transposed = zip(*batch)
91
+ return [default_collate(samples) for samples in transposed]
92
+
93
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
94
+
95
+ def detection_default_collate_cat(batch):
96
+ r"""Puts each data field into a tensor with outer dimension batch size"""
97
+
98
+ elem = batch[0]
99
+ elem_type = type(elem)
100
+ if isinstance(elem, torch.Tensor):
101
+ out = None
102
+ if torch.utils.data.get_worker_info() is not None:
103
+ # If we're in a background process, concatenate directly into a
104
+ # shared memory tensor to avoid an extra copy
105
+ numel = sum([x.numel() for x in batch])
106
+ storage = elem.storage()._new_shared(numel)
107
+ out = elem.new(storage)
108
+ return torch.cat(batch, dim=0, out=out)
109
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
110
+ and elem_type.__name__ != 'string_':
111
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
112
+ # array of string classes and object
113
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
114
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
115
+
116
+ return detection_default_collate_cat([torch.as_tensor(b) for b in batch])
117
+ elif elem.shape == (): # scalars
118
+ return torch.as_tensor(batch)
119
+ elif isinstance(elem, float):
120
+ return torch.tensor(batch, dtype=torch.float64)
121
+ elif isinstance(elem, int_classes):
122
+ return torch.tensor(batch)
123
+ elif isinstance(elem, string_classes):
124
+ return batch
125
+ elif isinstance(elem, container_abcs.Mapping):
126
+ return {key: detection_default_collate_cat([d[key] for d in batch]) for key in elem}
127
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
128
+ return elem_type(*(detection_default_collate_cat(samples) for samples in zip(*batch)))
129
+ elif isinstance(elem, container_abcs.Sequence):
130
+ # check to make sure that the elements in batch have consistent size
131
+ it = iter(batch)
132
+ elem_size = len(next(it))
133
+ if not all(len(elem) == elem_size for elem in it):
134
+ raise RuntimeError('each element in list of batch should be of equal size')
135
+ transposed = zip(*batch)
136
+ return [detection_default_collate_cat(samples) for samples in transposed]
137
+
138
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
139
+
140
+ def detection_default_collate(batch):
141
+ r"""Puts each data field into a tensor with outer dimension batch size"""
142
+
143
+ elem = batch[0]
144
+ elem_type = type(elem)
145
+ if isinstance(elem, torch.Tensor):
146
+ out = None
147
+ if torch.utils.data.get_worker_info() is not None:
148
+ # If we're in a background process, concatenate directly into a
149
+ # shared memory tensor to avoid an extra copy
150
+ numel = sum([x.numel() for x in batch])
151
+ storage = elem.storage()._new_shared(numel)
152
+ out = elem.new(storage)
153
+ return torch.stack(batch, 0, out=out)
154
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
155
+ and elem_type.__name__ != 'string_':
156
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
157
+ # array of string classes and object
158
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
159
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
160
+
161
+ return detection_default_collate([torch.as_tensor(b) for b in batch])
162
+ elif elem.shape == (): # scalars
163
+ return torch.as_tensor(batch)
164
+ elif isinstance(elem, float):
165
+ return torch.tensor(batch, dtype=torch.float64)
166
+ elif isinstance(elem, int_classes):
167
+ return torch.tensor(batch)
168
+ elif isinstance(elem, string_classes):
169
+ return batch
170
+ elif isinstance(elem, container_abcs.Mapping):
171
+ return {key: detection_default_collate([d[key] for d in batch]) for key in elem}
172
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
173
+ return elem_type(*(detection_default_collate(samples) for samples in zip(*batch)))
174
+ elif isinstance(elem, container_abcs.Sequence):
175
+ # check to make sure that the elements in batch have consistent size
176
+ it = iter(batch)
177
+ elem_size = len(next(it))
178
+ if not all(len(elem) == elem_size for elem in it):
179
+ raise RuntimeError('each element in list of batch should be of equal size')
180
+ transposed = zip(*batch)
181
+ return [detection_default_collate(samples) for samples in transposed]
182
+
183
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
@@ -0,0 +1,47 @@
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
2
+ data from an iterable-style or map-style dataset. This logic is shared in both
3
+ single- and multi-processing data loading.
4
+ """
5
+
6
+
7
+ class _BaseDatasetFetcher(object):
8
+ def __init__(self, dataset, auto_collation, collate_fn, drop_last):
9
+ self.dataset = dataset
10
+ self.auto_collation = auto_collation
11
+ self.collate_fn = collate_fn
12
+ self.drop_last = drop_last
13
+
14
+ def fetch(self, possibly_batched_index):
15
+ raise NotImplementedError()
16
+
17
+
18
+ class _IterableDatasetFetcher(_BaseDatasetFetcher):
19
+ def __init__(self, dataset, auto_collation, collate_fn, drop_last):
20
+ super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
21
+ self.dataset_iter = iter(dataset)
22
+
23
+ def fetch(self, possibly_batched_index):
24
+ if self.auto_collation:
25
+ data = []
26
+ for _ in possibly_batched_index:
27
+ try:
28
+ data.append(next(self.dataset_iter))
29
+ except StopIteration:
30
+ break
31
+ if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
32
+ raise StopIteration
33
+ else:
34
+ data = next(self.dataset_iter)
35
+ return self.collate_fn(data)
36
+
37
+
38
+ class _MapDatasetFetcher(_BaseDatasetFetcher):
39
+ def __init__(self, dataset, auto_collation, collate_fn, drop_last):
40
+ super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
41
+
42
+ def fetch(self, possibly_batched_index):
43
+ if self.auto_collation:
44
+ data = [self.dataset[idx] for idx in possibly_batched_index]
45
+ else:
46
+ data = self.dataset[possibly_batched_index]
47
+ return self.collate_fn(data)
@@ -0,0 +1,121 @@
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
2
+ fetched tensors into pinned memory.
3
+
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import time
9
+ import torch
10
+ if torch.__version__ < "1.9.0":
11
+ from torch._six import queue, container_abcs, string_classes
12
+ else:
13
+ import queue
14
+ import collections as container_abcs
15
+ string_classes = (str, bytes)
16
+
17
+ from . import MP_STATUS_CHECK_INTERVAL
18
+ import os
19
+ from torch._utils import ExceptionWrapper
20
+
21
+
22
+ def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
23
+ # This setting is thread local, and prevents the copy in pin_memory from
24
+ # consuming all CPU cores.
25
+ torch.set_num_threads(1)
26
+
27
+ torch.cuda.set_device(device_id)
28
+
29
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
30
+ # logic of this function.
31
+ while not done_event.is_set():
32
+ try:
33
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
34
+ except queue.Empty:
35
+ continue
36
+ idx, data = r
37
+ if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
38
+ try:
39
+ data = pin_memory(data)
40
+ except Exception:
41
+ data = ExceptionWrapper(
42
+ where="in pin memory thread for device {}".format(device_id))
43
+ r = (idx, data)
44
+ while not done_event.is_set():
45
+ try:
46
+ out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
47
+ time.sleep(1e-4)
48
+ break
49
+ except queue.Full:
50
+ time.sleep(1)
51
+ continue
52
+ del r # save memory
53
+
54
+
55
+ def pin_memory(data):
56
+ if isinstance(data, torch.Tensor):
57
+ return data.pin_memory().cuda()
58
+ elif isinstance(data, string_classes):
59
+ return data
60
+ elif isinstance(data, container_abcs.Mapping):
61
+ return {k: pin_memory(sample) for k, sample in data.items()}
62
+ elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
63
+ return type(data)(*(pin_memory(sample) for sample in data))
64
+ elif isinstance(data, container_abcs.Sequence):
65
+ return [pin_memory(sample) for sample in data]
66
+ elif hasattr(data, "pin_memory"):
67
+ return data.pin_memory()
68
+ else:
69
+ return data
70
+
71
+ def _pin_memory_loop_stream(in_queue, out_queue, device_id, done_event,stream):
72
+ # This setting is thread local, and prevents the copy in pin_memory from
73
+ # consuming all CPU cores.
74
+ torch.set_num_threads(1)
75
+
76
+ torch.cuda.set_device(device_id)
77
+
78
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
79
+ # logic of this function.
80
+ while not done_event.is_set():
81
+ try:
82
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
83
+ except queue.Empty:
84
+ continue
85
+ idx, data = r
86
+ if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
87
+ try:
88
+ with torch.cuda.stream(stream):
89
+ data = pin_memory_stream(data)
90
+ except Exception:
91
+ data = ExceptionWrapper(
92
+ where="in pin memory thread for device {}".format(device_id))
93
+ r = (idx, data)
94
+ while not done_event.is_set():
95
+ try:
96
+ out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
97
+ time.sleep(1e-4)
98
+ break
99
+ except queue.Full:
100
+ time.sleep(1)
101
+ continue
102
+ del r # save memory
103
+
104
+
105
+ def pin_memory_stream(data):
106
+ if isinstance(data, torch.Tensor):
107
+ if data.dtype==torch.int16: #hack: 不处理int16
108
+ return data
109
+ return data.pin_memory().cuda(non_blocking=True)
110
+ elif isinstance(data, string_classes):
111
+ return data
112
+ elif isinstance(data, container_abcs.Mapping):
113
+ return {k: pin_memory_stream(sample) for k, sample in data.items()}
114
+ elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
115
+ return type(data)(*(pin_memory_stream(sample) for sample in data))
116
+ elif isinstance(data, container_abcs.Sequence):
117
+ return [pin_memory_stream(sample) for sample in data]
118
+ elif hasattr(data, "pin_memory"):
119
+ return data.pin_memory().cuda(non_blocking=True)
120
+ else:
121
+ return data
@@ -0,0 +1,72 @@
1
+ r""""Signal handling for multiprocessing data loading.
2
+
3
+ NOTE [ Signal handling in multiprocessing data loading ]
4
+
5
+ In cases like DataLoader, if a worker process dies due to bus error/segfault
6
+ or just hang, the main process will hang waiting for data. This is difficult
7
+ to avoid on PyTorch side as it can be caused by limited shm, or other
8
+ libraries users call in the workers. In this file and `DataLoader.cpp`, we make
9
+ our best effort to provide some error message to users when such unfortunate
10
+ events happen.
11
+
12
+ When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
13
+ defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
14
+ via `_set_worker_pids`.
15
+
16
+ When an error happens in a worker process, the main process received a SIGCHLD,
17
+ and Python will eventually call the handler registered below
18
+ (in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
19
+ call checks all registered worker pids and raise proper error message to
20
+ prevent main process from hanging waiting for data from worker.
21
+
22
+ Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
23
+ `_set_worker_signal_handlers` is called to register critical signal handlers
24
+ (e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
25
+ message to stderr before triggering the default handler. So a message will also
26
+ be printed from the worker process when it is killed by such signals.
27
+
28
+ See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
29
+ this signal handling design and other mechanism we implement to make our
30
+ multiprocessing data loading robust to errors.
31
+ """
32
+
33
+ import signal
34
+ import threading
35
+ from . import IS_WINDOWS
36
+
37
+ # Some of the following imported functions are not used in this file, but are to
38
+ # be used `_utils.signal_handling.XXXXX`.
39
+ from torch._C import _set_worker_pids, _remove_worker_pids # noqa: F401
40
+ from torch._C import _error_if_any_worker_fails, _set_worker_signal_handlers # noqa: F401
41
+
42
+ _SIGCHLD_handler_set = False
43
+ r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
44
+ handler needs to be set for all DataLoaders in a process."""
45
+
46
+
47
+ def _set_SIGCHLD_handler():
48
+ # Windows doesn't support SIGCHLD handler
49
+ if IS_WINDOWS:
50
+ return
51
+ # can't set signal in child threads
52
+ if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore
53
+ return
54
+ global _SIGCHLD_handler_set
55
+ if _SIGCHLD_handler_set:
56
+ return
57
+ previous_handler = signal.getsignal(signal.SIGCHLD)
58
+ if not callable(previous_handler):
59
+ # This doesn't catch default handler, but SIGCHLD default handler is a
60
+ # no-op.
61
+ previous_handler = None
62
+
63
+ def handler(signum, frame):
64
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
65
+ # Python can still get and update the process status successfully.
66
+ _error_if_any_worker_fails()
67
+ if previous_handler is not None:
68
+ assert callable(previous_handler)
69
+ previous_handler(signum, frame)
70
+
71
+ signal.signal(signal.SIGCHLD, handler)
72
+ _SIGCHLD_handler_set = True
@@ -0,0 +1,227 @@
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
2
+
3
+ These **needs** to be in global scope since Py2 doesn't support serializing
4
+ static methods.
5
+ """
6
+
7
+ import torch
8
+ import random
9
+ import os
10
+ import time
11
+ from dataclasses import dataclass
12
+ if torch.__version__ < "1.9.0":
13
+ from torch._six import queue
14
+ else:
15
+ import queue
16
+ from torch._utils import ExceptionWrapper
17
+ from typing import Union
18
+ from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS
19
+
20
+ if IS_WINDOWS:
21
+ import ctypes
22
+ from ctypes.wintypes import DWORD, BOOL, HANDLE
23
+
24
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
25
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
26
+ # of the manager and ask if the process status has changed.
27
+ class ManagerWatchdog(object):
28
+ def __init__(self):
29
+ self.manager_pid = os.getppid()
30
+
31
+ # mypy cannot detect this code is windows only
32
+ self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore
33
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
34
+ self.kernel32.OpenProcess.restype = HANDLE
35
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
36
+ self.kernel32.WaitForSingleObject.restype = DWORD
37
+
38
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
39
+ SYNCHRONIZE = 0x00100000
40
+ self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
41
+
42
+ if not self.manager_handle:
43
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore
44
+
45
+ self.manager_dead = False
46
+
47
+ def is_alive(self):
48
+ if not self.manager_dead:
49
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
50
+ self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
51
+ return not self.manager_dead
52
+ else:
53
+ class ManagerWatchdog(object): # type: ignore[no-redef]
54
+ def __init__(self):
55
+ self.manager_pid = os.getppid()
56
+ self.manager_dead = False
57
+
58
+ def is_alive(self):
59
+ if not self.manager_dead:
60
+ self.manager_dead = os.getppid() != self.manager_pid
61
+ return not self.manager_dead
62
+
63
+ _worker_info = None
64
+
65
+
66
+ class WorkerInfo(object):
67
+ __initialized = False
68
+
69
+ def __init__(self, **kwargs):
70
+ for k, v in kwargs.items():
71
+ setattr(self, k, v)
72
+ self.__keys = tuple(kwargs.keys())
73
+ self.__initialized = True
74
+
75
+ def __setattr__(self, key, val):
76
+ if self.__initialized:
77
+ raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
78
+ return super(WorkerInfo, self).__setattr__(key, val)
79
+
80
+ def __repr__(self):
81
+ items = []
82
+ for k in self.__keys:
83
+ items.append('{}={}'.format(k, getattr(self, k)))
84
+ return '{}({})'.format(self.__class__.__name__, ', '.join(items))
85
+
86
+
87
+ def get_worker_info():
88
+ r"""Returns the information about the current
89
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
90
+
91
+ When called in a worker, this returns an object guaranteed to have the
92
+ following attributes:
93
+
94
+ * :attr:`id`: the current worker id.
95
+ * :attr:`num_workers`: the total number of workers.
96
+ * :attr:`seed`: the random seed set for the current worker. This value is
97
+ determined by main process RNG and the worker id. See
98
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
99
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
100
+ that this will be a different object in a different process than the one
101
+ in the main process.
102
+
103
+ When called in the main process, this returns ``None``.
104
+
105
+ .. note::
106
+ When used in a :attr:`worker_init_fn` passed over to
107
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
108
+ set up each worker process differently, for instance, using ``worker_id``
109
+ to configure the ``dataset`` object to only read a specific fraction of a
110
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
111
+ code (e.g., NumPy).
112
+ """
113
+ return _worker_info
114
+
115
+
116
+ r"""Dummy class used to signal the end of an IterableDataset"""
117
+ @dataclass(frozen=True)
118
+ class _IterableDatasetStopIteration(object):
119
+ worker_id: int
120
+
121
+ r"""Dummy class used to resume the fetching when worker reuse is enabled"""
122
+ @dataclass(frozen=True)
123
+ class _ResumeIteration(object):
124
+ pass
125
+
126
+ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
127
+ auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
128
+ num_workers, persistent_workers):
129
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
130
+ # logic of this function.
131
+
132
+ try:
133
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
134
+ # module's handlers are executed after Python returns from C low-level
135
+ # handlers, likely when the same fatal signal had already happened
136
+ # again.
137
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
138
+ signal_handling._set_worker_signal_handlers()
139
+
140
+ torch.set_num_threads(1)
141
+ random.seed(seed)
142
+ torch.manual_seed(seed)
143
+
144
+ global _worker_info
145
+ _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
146
+ seed=seed, dataset=dataset)
147
+
148
+ from wml.wtorch.data import _DatasetKind
149
+
150
+ init_exception = None
151
+
152
+ try:
153
+ if init_fn is not None:
154
+ init_fn(worker_id)
155
+
156
+ fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
157
+ except Exception:
158
+ init_exception = ExceptionWrapper(
159
+ where="in DataLoader worker process {}".format(worker_id))
160
+
161
+ # When using Iterable mode, some worker can exit earlier than others due
162
+ # to the IterableDataset behaving differently for different workers.
163
+ # When such things happen, an `_IterableDatasetStopIteration` object is
164
+ # sent over to the main process with the ID of this worker, so that the
165
+ # main process won't send more tasks to this worker, and will send
166
+ # `None` to this worker to properly exit it.
167
+ #
168
+ # Note that we cannot set `done_event` from a worker as it is shared
169
+ # among all processes. Instead, we set the `iteration_end` flag to
170
+ # signify that the iterator is exhausted. When either `done_event` or
171
+ # `iteration_end` is set, we skip all processing step and just wait for
172
+ # `None`.
173
+ iteration_end = False
174
+
175
+ watchdog = ManagerWatchdog()
176
+
177
+ while watchdog.is_alive():
178
+ try:
179
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
180
+ except queue.Empty:
181
+ continue
182
+ if isinstance(r, _ResumeIteration):
183
+ # Acknowledge the main process
184
+ data_queue.put((r, None))
185
+ iteration_end = False
186
+ # Recreate the fetcher for worker-reuse policy
187
+ fetcher = _DatasetKind.create_fetcher(
188
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last)
189
+ continue
190
+ elif r is None:
191
+ # Received the final signal
192
+ assert done_event.is_set() or iteration_end
193
+ break
194
+ elif done_event.is_set() or iteration_end:
195
+ # `done_event` is set. But I haven't received the final signal
196
+ # (None) yet. I will keep continuing until get it, and skip the
197
+ # processing steps.
198
+ continue
199
+ idx, index = r
200
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
201
+ if init_exception is not None:
202
+ data = init_exception
203
+ init_exception = None
204
+ else:
205
+ try:
206
+ data = fetcher.fetch(index)
207
+ except Exception as e:
208
+ if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
209
+ data = _IterableDatasetStopIteration(worker_id)
210
+ # Set `iteration_end`
211
+ # (1) to save future `next(...)` calls, and
212
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
213
+ iteration_end = True
214
+ else:
215
+ # It is important that we don't store exc_info in a variable.
216
+ # `ExceptionWrapper` does the correct thing.
217
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
218
+ data = ExceptionWrapper(
219
+ where=f"in DataLoader worker process {worker_id}, msg: {e}")
220
+ data_queue.put((idx, data))
221
+ del data, idx, index, r # save memory
222
+ except KeyboardInterrupt:
223
+ # Main process will raise KeyboardInterrupt anyways.
224
+ pass
225
+ if done_event.is_set():
226
+ data_queue.cancel_join_thread()
227
+ data_queue.close()