python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
|
|
2
|
+
collate samples fetched from dataset into Tensor(s).
|
|
3
|
+
|
|
4
|
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
5
|
+
static methods.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import re
|
|
10
|
+
if torch.__version__ < "1.9.0":
|
|
11
|
+
from torch._six import container_abcs, string_classes, int_classes
|
|
12
|
+
else:
|
|
13
|
+
import collections as container_abcs
|
|
14
|
+
string_classes = (str, bytes)
|
|
15
|
+
int_classes = int
|
|
16
|
+
|
|
17
|
+
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def default_convert(data):
|
|
21
|
+
r"""Converts each NumPy array data field into a tensor"""
|
|
22
|
+
elem_type = type(data)
|
|
23
|
+
if isinstance(data, torch.Tensor):
|
|
24
|
+
return data
|
|
25
|
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
26
|
+
and elem_type.__name__ != 'string_':
|
|
27
|
+
# array of string classes and object
|
|
28
|
+
if elem_type.__name__ == 'ndarray' \
|
|
29
|
+
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
|
|
30
|
+
return data
|
|
31
|
+
return torch.as_tensor(data)
|
|
32
|
+
elif isinstance(data, container_abcs.Mapping):
|
|
33
|
+
return {key: default_convert(data[key]) for key in data}
|
|
34
|
+
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
35
|
+
return elem_type(*(default_convert(d) for d in data))
|
|
36
|
+
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
|
|
37
|
+
return [default_convert(d) for d in data]
|
|
38
|
+
else:
|
|
39
|
+
return data
|
|
40
|
+
|
|
41
|
+
def null_convert(data):
|
|
42
|
+
return data
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
default_collate_err_msg_format = (
|
|
46
|
+
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
|
47
|
+
"dicts or lists; found {}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def default_collate(batch):
|
|
51
|
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
52
|
+
|
|
53
|
+
elem = batch[0]
|
|
54
|
+
elem_type = type(elem)
|
|
55
|
+
if isinstance(elem, torch.Tensor):
|
|
56
|
+
out = None
|
|
57
|
+
if torch.utils.data.get_worker_info() is not None:
|
|
58
|
+
# If we're in a background process, concatenate directly into a
|
|
59
|
+
# shared memory tensor to avoid an extra copy
|
|
60
|
+
numel = sum([x.numel() for x in batch])
|
|
61
|
+
storage = elem.storage()._new_shared(numel)
|
|
62
|
+
out = elem.new(storage)
|
|
63
|
+
return torch.stack(batch, 0, out=out)
|
|
64
|
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
65
|
+
and elem_type.__name__ != 'string_':
|
|
66
|
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
|
67
|
+
# array of string classes and object
|
|
68
|
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
69
|
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
70
|
+
|
|
71
|
+
return default_collate([torch.as_tensor(b) for b in batch])
|
|
72
|
+
elif elem.shape == (): # scalars
|
|
73
|
+
return torch.as_tensor(batch)
|
|
74
|
+
elif isinstance(elem, float):
|
|
75
|
+
return torch.tensor(batch, dtype=torch.float64)
|
|
76
|
+
elif isinstance(elem, int_classes):
|
|
77
|
+
return torch.tensor(batch)
|
|
78
|
+
elif isinstance(elem, string_classes):
|
|
79
|
+
return batch
|
|
80
|
+
elif isinstance(elem, container_abcs.Mapping):
|
|
81
|
+
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
|
82
|
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
83
|
+
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
|
84
|
+
elif isinstance(elem, container_abcs.Sequence):
|
|
85
|
+
# check to make sure that the elements in batch have consistent size
|
|
86
|
+
it = iter(batch)
|
|
87
|
+
elem_size = len(next(it))
|
|
88
|
+
if not all(len(elem) == elem_size for elem in it):
|
|
89
|
+
raise RuntimeError('each element in list of batch should be of equal size')
|
|
90
|
+
transposed = zip(*batch)
|
|
91
|
+
return [default_collate(samples) for samples in transposed]
|
|
92
|
+
|
|
93
|
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
|
94
|
+
|
|
95
|
+
def detection_default_collate_cat(batch):
|
|
96
|
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
97
|
+
|
|
98
|
+
elem = batch[0]
|
|
99
|
+
elem_type = type(elem)
|
|
100
|
+
if isinstance(elem, torch.Tensor):
|
|
101
|
+
out = None
|
|
102
|
+
if torch.utils.data.get_worker_info() is not None:
|
|
103
|
+
# If we're in a background process, concatenate directly into a
|
|
104
|
+
# shared memory tensor to avoid an extra copy
|
|
105
|
+
numel = sum([x.numel() for x in batch])
|
|
106
|
+
storage = elem.storage()._new_shared(numel)
|
|
107
|
+
out = elem.new(storage)
|
|
108
|
+
return torch.cat(batch, dim=0, out=out)
|
|
109
|
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
110
|
+
and elem_type.__name__ != 'string_':
|
|
111
|
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
|
112
|
+
# array of string classes and object
|
|
113
|
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
114
|
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
115
|
+
|
|
116
|
+
return detection_default_collate_cat([torch.as_tensor(b) for b in batch])
|
|
117
|
+
elif elem.shape == (): # scalars
|
|
118
|
+
return torch.as_tensor(batch)
|
|
119
|
+
elif isinstance(elem, float):
|
|
120
|
+
return torch.tensor(batch, dtype=torch.float64)
|
|
121
|
+
elif isinstance(elem, int_classes):
|
|
122
|
+
return torch.tensor(batch)
|
|
123
|
+
elif isinstance(elem, string_classes):
|
|
124
|
+
return batch
|
|
125
|
+
elif isinstance(elem, container_abcs.Mapping):
|
|
126
|
+
return {key: detection_default_collate_cat([d[key] for d in batch]) for key in elem}
|
|
127
|
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
128
|
+
return elem_type(*(detection_default_collate_cat(samples) for samples in zip(*batch)))
|
|
129
|
+
elif isinstance(elem, container_abcs.Sequence):
|
|
130
|
+
# check to make sure that the elements in batch have consistent size
|
|
131
|
+
it = iter(batch)
|
|
132
|
+
elem_size = len(next(it))
|
|
133
|
+
if not all(len(elem) == elem_size for elem in it):
|
|
134
|
+
raise RuntimeError('each element in list of batch should be of equal size')
|
|
135
|
+
transposed = zip(*batch)
|
|
136
|
+
return [detection_default_collate_cat(samples) for samples in transposed]
|
|
137
|
+
|
|
138
|
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
|
139
|
+
|
|
140
|
+
def detection_default_collate(batch):
|
|
141
|
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
142
|
+
|
|
143
|
+
elem = batch[0]
|
|
144
|
+
elem_type = type(elem)
|
|
145
|
+
if isinstance(elem, torch.Tensor):
|
|
146
|
+
out = None
|
|
147
|
+
if torch.utils.data.get_worker_info() is not None:
|
|
148
|
+
# If we're in a background process, concatenate directly into a
|
|
149
|
+
# shared memory tensor to avoid an extra copy
|
|
150
|
+
numel = sum([x.numel() for x in batch])
|
|
151
|
+
storage = elem.storage()._new_shared(numel)
|
|
152
|
+
out = elem.new(storage)
|
|
153
|
+
return torch.stack(batch, 0, out=out)
|
|
154
|
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
155
|
+
and elem_type.__name__ != 'string_':
|
|
156
|
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
|
157
|
+
# array of string classes and object
|
|
158
|
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
159
|
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
160
|
+
|
|
161
|
+
return detection_default_collate([torch.as_tensor(b) for b in batch])
|
|
162
|
+
elif elem.shape == (): # scalars
|
|
163
|
+
return torch.as_tensor(batch)
|
|
164
|
+
elif isinstance(elem, float):
|
|
165
|
+
return torch.tensor(batch, dtype=torch.float64)
|
|
166
|
+
elif isinstance(elem, int_classes):
|
|
167
|
+
return torch.tensor(batch)
|
|
168
|
+
elif isinstance(elem, string_classes):
|
|
169
|
+
return batch
|
|
170
|
+
elif isinstance(elem, container_abcs.Mapping):
|
|
171
|
+
return {key: detection_default_collate([d[key] for d in batch]) for key in elem}
|
|
172
|
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
173
|
+
return elem_type(*(detection_default_collate(samples) for samples in zip(*batch)))
|
|
174
|
+
elif isinstance(elem, container_abcs.Sequence):
|
|
175
|
+
# check to make sure that the elements in batch have consistent size
|
|
176
|
+
it = iter(batch)
|
|
177
|
+
elem_size = len(next(it))
|
|
178
|
+
if not all(len(elem) == elem_size for elem in it):
|
|
179
|
+
raise RuntimeError('each element in list of batch should be of equal size')
|
|
180
|
+
transposed = zip(*batch)
|
|
181
|
+
return [detection_default_collate(samples) for samples in transposed]
|
|
182
|
+
|
|
183
|
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
|
|
2
|
+
data from an iterable-style or map-style dataset. This logic is shared in both
|
|
3
|
+
single- and multi-processing data loading.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _BaseDatasetFetcher(object):
|
|
8
|
+
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
9
|
+
self.dataset = dataset
|
|
10
|
+
self.auto_collation = auto_collation
|
|
11
|
+
self.collate_fn = collate_fn
|
|
12
|
+
self.drop_last = drop_last
|
|
13
|
+
|
|
14
|
+
def fetch(self, possibly_batched_index):
|
|
15
|
+
raise NotImplementedError()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
|
19
|
+
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
20
|
+
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
|
|
21
|
+
self.dataset_iter = iter(dataset)
|
|
22
|
+
|
|
23
|
+
def fetch(self, possibly_batched_index):
|
|
24
|
+
if self.auto_collation:
|
|
25
|
+
data = []
|
|
26
|
+
for _ in possibly_batched_index:
|
|
27
|
+
try:
|
|
28
|
+
data.append(next(self.dataset_iter))
|
|
29
|
+
except StopIteration:
|
|
30
|
+
break
|
|
31
|
+
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
|
|
32
|
+
raise StopIteration
|
|
33
|
+
else:
|
|
34
|
+
data = next(self.dataset_iter)
|
|
35
|
+
return self.collate_fn(data)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|
39
|
+
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
40
|
+
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
|
|
41
|
+
|
|
42
|
+
def fetch(self, possibly_batched_index):
|
|
43
|
+
if self.auto_collation:
|
|
44
|
+
data = [self.dataset[idx] for idx in possibly_batched_index]
|
|
45
|
+
else:
|
|
46
|
+
data = self.dataset[possibly_batched_index]
|
|
47
|
+
return self.collate_fn(data)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
|
|
2
|
+
fetched tensors into pinned memory.
|
|
3
|
+
|
|
4
|
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
5
|
+
static methods.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
import torch
|
|
10
|
+
if torch.__version__ < "1.9.0":
|
|
11
|
+
from torch._six import queue, container_abcs, string_classes
|
|
12
|
+
else:
|
|
13
|
+
import queue
|
|
14
|
+
import collections as container_abcs
|
|
15
|
+
string_classes = (str, bytes)
|
|
16
|
+
|
|
17
|
+
from . import MP_STATUS_CHECK_INTERVAL
|
|
18
|
+
import os
|
|
19
|
+
from torch._utils import ExceptionWrapper
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
|
|
23
|
+
# This setting is thread local, and prevents the copy in pin_memory from
|
|
24
|
+
# consuming all CPU cores.
|
|
25
|
+
torch.set_num_threads(1)
|
|
26
|
+
|
|
27
|
+
torch.cuda.set_device(device_id)
|
|
28
|
+
|
|
29
|
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
|
30
|
+
# logic of this function.
|
|
31
|
+
while not done_event.is_set():
|
|
32
|
+
try:
|
|
33
|
+
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
|
34
|
+
except queue.Empty:
|
|
35
|
+
continue
|
|
36
|
+
idx, data = r
|
|
37
|
+
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
|
|
38
|
+
try:
|
|
39
|
+
data = pin_memory(data)
|
|
40
|
+
except Exception:
|
|
41
|
+
data = ExceptionWrapper(
|
|
42
|
+
where="in pin memory thread for device {}".format(device_id))
|
|
43
|
+
r = (idx, data)
|
|
44
|
+
while not done_event.is_set():
|
|
45
|
+
try:
|
|
46
|
+
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
|
|
47
|
+
time.sleep(1e-4)
|
|
48
|
+
break
|
|
49
|
+
except queue.Full:
|
|
50
|
+
time.sleep(1)
|
|
51
|
+
continue
|
|
52
|
+
del r # save memory
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def pin_memory(data):
|
|
56
|
+
if isinstance(data, torch.Tensor):
|
|
57
|
+
return data.pin_memory().cuda()
|
|
58
|
+
elif isinstance(data, string_classes):
|
|
59
|
+
return data
|
|
60
|
+
elif isinstance(data, container_abcs.Mapping):
|
|
61
|
+
return {k: pin_memory(sample) for k, sample in data.items()}
|
|
62
|
+
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
63
|
+
return type(data)(*(pin_memory(sample) for sample in data))
|
|
64
|
+
elif isinstance(data, container_abcs.Sequence):
|
|
65
|
+
return [pin_memory(sample) for sample in data]
|
|
66
|
+
elif hasattr(data, "pin_memory"):
|
|
67
|
+
return data.pin_memory()
|
|
68
|
+
else:
|
|
69
|
+
return data
|
|
70
|
+
|
|
71
|
+
def _pin_memory_loop_stream(in_queue, out_queue, device_id, done_event,stream):
|
|
72
|
+
# This setting is thread local, and prevents the copy in pin_memory from
|
|
73
|
+
# consuming all CPU cores.
|
|
74
|
+
torch.set_num_threads(1)
|
|
75
|
+
|
|
76
|
+
torch.cuda.set_device(device_id)
|
|
77
|
+
|
|
78
|
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
|
79
|
+
# logic of this function.
|
|
80
|
+
while not done_event.is_set():
|
|
81
|
+
try:
|
|
82
|
+
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
|
83
|
+
except queue.Empty:
|
|
84
|
+
continue
|
|
85
|
+
idx, data = r
|
|
86
|
+
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
|
|
87
|
+
try:
|
|
88
|
+
with torch.cuda.stream(stream):
|
|
89
|
+
data = pin_memory_stream(data)
|
|
90
|
+
except Exception:
|
|
91
|
+
data = ExceptionWrapper(
|
|
92
|
+
where="in pin memory thread for device {}".format(device_id))
|
|
93
|
+
r = (idx, data)
|
|
94
|
+
while not done_event.is_set():
|
|
95
|
+
try:
|
|
96
|
+
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
|
|
97
|
+
time.sleep(1e-4)
|
|
98
|
+
break
|
|
99
|
+
except queue.Full:
|
|
100
|
+
time.sleep(1)
|
|
101
|
+
continue
|
|
102
|
+
del r # save memory
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def pin_memory_stream(data):
|
|
106
|
+
if isinstance(data, torch.Tensor):
|
|
107
|
+
if data.dtype==torch.int16: #hack: 不处理int16
|
|
108
|
+
return data
|
|
109
|
+
return data.pin_memory().cuda(non_blocking=True)
|
|
110
|
+
elif isinstance(data, string_classes):
|
|
111
|
+
return data
|
|
112
|
+
elif isinstance(data, container_abcs.Mapping):
|
|
113
|
+
return {k: pin_memory_stream(sample) for k, sample in data.items()}
|
|
114
|
+
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
115
|
+
return type(data)(*(pin_memory_stream(sample) for sample in data))
|
|
116
|
+
elif isinstance(data, container_abcs.Sequence):
|
|
117
|
+
return [pin_memory_stream(sample) for sample in data]
|
|
118
|
+
elif hasattr(data, "pin_memory"):
|
|
119
|
+
return data.pin_memory().cuda(non_blocking=True)
|
|
120
|
+
else:
|
|
121
|
+
return data
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
r""""Signal handling for multiprocessing data loading.
|
|
2
|
+
|
|
3
|
+
NOTE [ Signal handling in multiprocessing data loading ]
|
|
4
|
+
|
|
5
|
+
In cases like DataLoader, if a worker process dies due to bus error/segfault
|
|
6
|
+
or just hang, the main process will hang waiting for data. This is difficult
|
|
7
|
+
to avoid on PyTorch side as it can be caused by limited shm, or other
|
|
8
|
+
libraries users call in the workers. In this file and `DataLoader.cpp`, we make
|
|
9
|
+
our best effort to provide some error message to users when such unfortunate
|
|
10
|
+
events happen.
|
|
11
|
+
|
|
12
|
+
When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
|
|
13
|
+
defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
|
|
14
|
+
via `_set_worker_pids`.
|
|
15
|
+
|
|
16
|
+
When an error happens in a worker process, the main process received a SIGCHLD,
|
|
17
|
+
and Python will eventually call the handler registered below
|
|
18
|
+
(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
|
|
19
|
+
call checks all registered worker pids and raise proper error message to
|
|
20
|
+
prevent main process from hanging waiting for data from worker.
|
|
21
|
+
|
|
22
|
+
Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
|
|
23
|
+
`_set_worker_signal_handlers` is called to register critical signal handlers
|
|
24
|
+
(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
|
|
25
|
+
message to stderr before triggering the default handler. So a message will also
|
|
26
|
+
be printed from the worker process when it is killed by such signals.
|
|
27
|
+
|
|
28
|
+
See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
|
|
29
|
+
this signal handling design and other mechanism we implement to make our
|
|
30
|
+
multiprocessing data loading robust to errors.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import signal
|
|
34
|
+
import threading
|
|
35
|
+
from . import IS_WINDOWS
|
|
36
|
+
|
|
37
|
+
# Some of the following imported functions are not used in this file, but are to
|
|
38
|
+
# be used `_utils.signal_handling.XXXXX`.
|
|
39
|
+
from torch._C import _set_worker_pids, _remove_worker_pids # noqa: F401
|
|
40
|
+
from torch._C import _error_if_any_worker_fails, _set_worker_signal_handlers # noqa: F401
|
|
41
|
+
|
|
42
|
+
_SIGCHLD_handler_set = False
|
|
43
|
+
r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
|
|
44
|
+
handler needs to be set for all DataLoaders in a process."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _set_SIGCHLD_handler():
|
|
48
|
+
# Windows doesn't support SIGCHLD handler
|
|
49
|
+
if IS_WINDOWS:
|
|
50
|
+
return
|
|
51
|
+
# can't set signal in child threads
|
|
52
|
+
if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore
|
|
53
|
+
return
|
|
54
|
+
global _SIGCHLD_handler_set
|
|
55
|
+
if _SIGCHLD_handler_set:
|
|
56
|
+
return
|
|
57
|
+
previous_handler = signal.getsignal(signal.SIGCHLD)
|
|
58
|
+
if not callable(previous_handler):
|
|
59
|
+
# This doesn't catch default handler, but SIGCHLD default handler is a
|
|
60
|
+
# no-op.
|
|
61
|
+
previous_handler = None
|
|
62
|
+
|
|
63
|
+
def handler(signum, frame):
|
|
64
|
+
# This following call uses `waitid` with WNOHANG from C side. Therefore,
|
|
65
|
+
# Python can still get and update the process status successfully.
|
|
66
|
+
_error_if_any_worker_fails()
|
|
67
|
+
if previous_handler is not None:
|
|
68
|
+
assert callable(previous_handler)
|
|
69
|
+
previous_handler(signum, frame)
|
|
70
|
+
|
|
71
|
+
signal.signal(signal.SIGCHLD, handler)
|
|
72
|
+
_SIGCHLD_handler_set = True
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
|
2
|
+
|
|
3
|
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
4
|
+
static methods.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import random
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
if torch.__version__ < "1.9.0":
|
|
13
|
+
from torch._six import queue
|
|
14
|
+
else:
|
|
15
|
+
import queue
|
|
16
|
+
from torch._utils import ExceptionWrapper
|
|
17
|
+
from typing import Union
|
|
18
|
+
from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS
|
|
19
|
+
|
|
20
|
+
if IS_WINDOWS:
|
|
21
|
+
import ctypes
|
|
22
|
+
from ctypes.wintypes import DWORD, BOOL, HANDLE
|
|
23
|
+
|
|
24
|
+
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
|
25
|
+
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
|
26
|
+
# of the manager and ask if the process status has changed.
|
|
27
|
+
class ManagerWatchdog(object):
|
|
28
|
+
def __init__(self):
|
|
29
|
+
self.manager_pid = os.getppid()
|
|
30
|
+
|
|
31
|
+
# mypy cannot detect this code is windows only
|
|
32
|
+
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore
|
|
33
|
+
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
|
34
|
+
self.kernel32.OpenProcess.restype = HANDLE
|
|
35
|
+
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
|
36
|
+
self.kernel32.WaitForSingleObject.restype = DWORD
|
|
37
|
+
|
|
38
|
+
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
|
39
|
+
SYNCHRONIZE = 0x00100000
|
|
40
|
+
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
|
|
41
|
+
|
|
42
|
+
if not self.manager_handle:
|
|
43
|
+
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore
|
|
44
|
+
|
|
45
|
+
self.manager_dead = False
|
|
46
|
+
|
|
47
|
+
def is_alive(self):
|
|
48
|
+
if not self.manager_dead:
|
|
49
|
+
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
|
50
|
+
self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
|
51
|
+
return not self.manager_dead
|
|
52
|
+
else:
|
|
53
|
+
class ManagerWatchdog(object): # type: ignore[no-redef]
|
|
54
|
+
def __init__(self):
|
|
55
|
+
self.manager_pid = os.getppid()
|
|
56
|
+
self.manager_dead = False
|
|
57
|
+
|
|
58
|
+
def is_alive(self):
|
|
59
|
+
if not self.manager_dead:
|
|
60
|
+
self.manager_dead = os.getppid() != self.manager_pid
|
|
61
|
+
return not self.manager_dead
|
|
62
|
+
|
|
63
|
+
_worker_info = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class WorkerInfo(object):
|
|
67
|
+
__initialized = False
|
|
68
|
+
|
|
69
|
+
def __init__(self, **kwargs):
|
|
70
|
+
for k, v in kwargs.items():
|
|
71
|
+
setattr(self, k, v)
|
|
72
|
+
self.__keys = tuple(kwargs.keys())
|
|
73
|
+
self.__initialized = True
|
|
74
|
+
|
|
75
|
+
def __setattr__(self, key, val):
|
|
76
|
+
if self.__initialized:
|
|
77
|
+
raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
|
|
78
|
+
return super(WorkerInfo, self).__setattr__(key, val)
|
|
79
|
+
|
|
80
|
+
def __repr__(self):
|
|
81
|
+
items = []
|
|
82
|
+
for k in self.__keys:
|
|
83
|
+
items.append('{}={}'.format(k, getattr(self, k)))
|
|
84
|
+
return '{}({})'.format(self.__class__.__name__, ', '.join(items))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_worker_info():
|
|
88
|
+
r"""Returns the information about the current
|
|
89
|
+
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
|
90
|
+
|
|
91
|
+
When called in a worker, this returns an object guaranteed to have the
|
|
92
|
+
following attributes:
|
|
93
|
+
|
|
94
|
+
* :attr:`id`: the current worker id.
|
|
95
|
+
* :attr:`num_workers`: the total number of workers.
|
|
96
|
+
* :attr:`seed`: the random seed set for the current worker. This value is
|
|
97
|
+
determined by main process RNG and the worker id. See
|
|
98
|
+
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
|
99
|
+
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
|
100
|
+
that this will be a different object in a different process than the one
|
|
101
|
+
in the main process.
|
|
102
|
+
|
|
103
|
+
When called in the main process, this returns ``None``.
|
|
104
|
+
|
|
105
|
+
.. note::
|
|
106
|
+
When used in a :attr:`worker_init_fn` passed over to
|
|
107
|
+
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
|
108
|
+
set up each worker process differently, for instance, using ``worker_id``
|
|
109
|
+
to configure the ``dataset`` object to only read a specific fraction of a
|
|
110
|
+
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
|
111
|
+
code (e.g., NumPy).
|
|
112
|
+
"""
|
|
113
|
+
return _worker_info
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
r"""Dummy class used to signal the end of an IterableDataset"""
|
|
117
|
+
@dataclass(frozen=True)
|
|
118
|
+
class _IterableDatasetStopIteration(object):
|
|
119
|
+
worker_id: int
|
|
120
|
+
|
|
121
|
+
r"""Dummy class used to resume the fetching when worker reuse is enabled"""
|
|
122
|
+
@dataclass(frozen=True)
|
|
123
|
+
class _ResumeIteration(object):
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
|
|
127
|
+
auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
|
|
128
|
+
num_workers, persistent_workers):
|
|
129
|
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
|
130
|
+
# logic of this function.
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
|
134
|
+
# module's handlers are executed after Python returns from C low-level
|
|
135
|
+
# handlers, likely when the same fatal signal had already happened
|
|
136
|
+
# again.
|
|
137
|
+
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
|
138
|
+
signal_handling._set_worker_signal_handlers()
|
|
139
|
+
|
|
140
|
+
torch.set_num_threads(1)
|
|
141
|
+
random.seed(seed)
|
|
142
|
+
torch.manual_seed(seed)
|
|
143
|
+
|
|
144
|
+
global _worker_info
|
|
145
|
+
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
|
|
146
|
+
seed=seed, dataset=dataset)
|
|
147
|
+
|
|
148
|
+
from wml.wtorch.data import _DatasetKind
|
|
149
|
+
|
|
150
|
+
init_exception = None
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
if init_fn is not None:
|
|
154
|
+
init_fn(worker_id)
|
|
155
|
+
|
|
156
|
+
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
|
|
157
|
+
except Exception:
|
|
158
|
+
init_exception = ExceptionWrapper(
|
|
159
|
+
where="in DataLoader worker process {}".format(worker_id))
|
|
160
|
+
|
|
161
|
+
# When using Iterable mode, some worker can exit earlier than others due
|
|
162
|
+
# to the IterableDataset behaving differently for different workers.
|
|
163
|
+
# When such things happen, an `_IterableDatasetStopIteration` object is
|
|
164
|
+
# sent over to the main process with the ID of this worker, so that the
|
|
165
|
+
# main process won't send more tasks to this worker, and will send
|
|
166
|
+
# `None` to this worker to properly exit it.
|
|
167
|
+
#
|
|
168
|
+
# Note that we cannot set `done_event` from a worker as it is shared
|
|
169
|
+
# among all processes. Instead, we set the `iteration_end` flag to
|
|
170
|
+
# signify that the iterator is exhausted. When either `done_event` or
|
|
171
|
+
# `iteration_end` is set, we skip all processing step and just wait for
|
|
172
|
+
# `None`.
|
|
173
|
+
iteration_end = False
|
|
174
|
+
|
|
175
|
+
watchdog = ManagerWatchdog()
|
|
176
|
+
|
|
177
|
+
while watchdog.is_alive():
|
|
178
|
+
try:
|
|
179
|
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
|
180
|
+
except queue.Empty:
|
|
181
|
+
continue
|
|
182
|
+
if isinstance(r, _ResumeIteration):
|
|
183
|
+
# Acknowledge the main process
|
|
184
|
+
data_queue.put((r, None))
|
|
185
|
+
iteration_end = False
|
|
186
|
+
# Recreate the fetcher for worker-reuse policy
|
|
187
|
+
fetcher = _DatasetKind.create_fetcher(
|
|
188
|
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
|
|
189
|
+
continue
|
|
190
|
+
elif r is None:
|
|
191
|
+
# Received the final signal
|
|
192
|
+
assert done_event.is_set() or iteration_end
|
|
193
|
+
break
|
|
194
|
+
elif done_event.is_set() or iteration_end:
|
|
195
|
+
# `done_event` is set. But I haven't received the final signal
|
|
196
|
+
# (None) yet. I will keep continuing until get it, and skip the
|
|
197
|
+
# processing steps.
|
|
198
|
+
continue
|
|
199
|
+
idx, index = r
|
|
200
|
+
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
|
|
201
|
+
if init_exception is not None:
|
|
202
|
+
data = init_exception
|
|
203
|
+
init_exception = None
|
|
204
|
+
else:
|
|
205
|
+
try:
|
|
206
|
+
data = fetcher.fetch(index)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
|
|
209
|
+
data = _IterableDatasetStopIteration(worker_id)
|
|
210
|
+
# Set `iteration_end`
|
|
211
|
+
# (1) to save future `next(...)` calls, and
|
|
212
|
+
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
|
213
|
+
iteration_end = True
|
|
214
|
+
else:
|
|
215
|
+
# It is important that we don't store exc_info in a variable.
|
|
216
|
+
# `ExceptionWrapper` does the correct thing.
|
|
217
|
+
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
|
218
|
+
data = ExceptionWrapper(
|
|
219
|
+
where=f"in DataLoader worker process {worker_id}, msg: {e}")
|
|
220
|
+
data_queue.put((idx, data))
|
|
221
|
+
del data, idx, index, r # save memory
|
|
222
|
+
except KeyboardInterrupt:
|
|
223
|
+
# Main process will raise KeyboardInterrupt anyways.
|
|
224
|
+
pass
|
|
225
|
+
if done_event.is_set():
|
|
226
|
+
data_queue.cancel_join_thread()
|
|
227
|
+
data_queue.close()
|