python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional
|
|
3
|
+
import torch
|
|
4
|
+
from . import _utils
|
|
5
|
+
|
|
6
|
+
class _DatasetKind(object):
|
|
7
|
+
Map = 0
|
|
8
|
+
Iterable = 1
|
|
9
|
+
|
|
10
|
+
@staticmethod
|
|
11
|
+
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
|
12
|
+
if kind == _DatasetKind.Map:
|
|
13
|
+
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
|
|
14
|
+
else:
|
|
15
|
+
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
|
|
16
|
+
|
|
17
|
+
class _BaseDataLoaderIter(object):
|
|
18
|
+
def __init__(self, loader) -> None:
|
|
19
|
+
self._dataset = loader.dataset
|
|
20
|
+
self._dataset_kind = loader._dataset_kind
|
|
21
|
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
|
22
|
+
self._auto_collation = loader._auto_collation
|
|
23
|
+
self._drop_last = loader.drop_last
|
|
24
|
+
self._index_sampler = loader._index_sampler
|
|
25
|
+
self._num_workers = loader.num_workers
|
|
26
|
+
self._prefetch_factor = loader.prefetch_factor
|
|
27
|
+
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
|
28
|
+
self._timeout = loader.timeout
|
|
29
|
+
self._collate_fn = loader.collate_fn
|
|
30
|
+
self._sampler_iter = iter(self._index_sampler)
|
|
31
|
+
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
|
|
32
|
+
self._persistent_workers = loader.persistent_workers
|
|
33
|
+
self._num_yielded = 0
|
|
34
|
+
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
|
|
35
|
+
|
|
36
|
+
def __iter__(self) -> '_BaseDataLoaderIter':
|
|
37
|
+
return self
|
|
38
|
+
|
|
39
|
+
def _reset(self, loader, first_iter=False):
|
|
40
|
+
self._sampler_iter = iter(self._index_sampler)
|
|
41
|
+
self._num_yielded = 0
|
|
42
|
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
|
43
|
+
|
|
44
|
+
def _next_index(self):
|
|
45
|
+
res = next(self._sampler_iter) # may raise StopIteration
|
|
46
|
+
return _BaseDataLoaderIter.to_item(res)
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def to_item(data):
|
|
50
|
+
if torch.is_tensor(data):
|
|
51
|
+
return data.item()
|
|
52
|
+
elif isinstance(data,Iterable):
|
|
53
|
+
return [_BaseDataLoaderIter.to_item(x) for x in data]
|
|
54
|
+
else:
|
|
55
|
+
return data
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _next_data(self):
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
|
|
61
|
+
def __next__(self) -> Any:
|
|
62
|
+
with torch.autograd.profiler.record_function(self._profile_name):
|
|
63
|
+
if self._sampler_iter is None:
|
|
64
|
+
self._reset()
|
|
65
|
+
data = self._next_data()
|
|
66
|
+
self._num_yielded += 1
|
|
67
|
+
if self._dataset_kind == _DatasetKind.Iterable and \
|
|
68
|
+
self._IterableDataset_len_called is not None and \
|
|
69
|
+
self._num_yielded > self._IterableDataset_len_called:
|
|
70
|
+
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
|
|
71
|
+
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
|
|
72
|
+
self._num_yielded)
|
|
73
|
+
if self._num_workers > 0:
|
|
74
|
+
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
|
|
75
|
+
"IterableDataset replica at each worker. Please see "
|
|
76
|
+
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
|
|
77
|
+
warnings.warn(warn_msg)
|
|
78
|
+
return data
|
|
79
|
+
|
|
80
|
+
next = __next__ # Python 2 compatibility
|
|
81
|
+
|
|
82
|
+
def __len__(self) -> int:
|
|
83
|
+
return len(self._index_sampler)
|
|
84
|
+
|
|
85
|
+
def __getstate__(self):
|
|
86
|
+
# TODO: add limited pickling support for sharing an iterator
|
|
87
|
+
# across multiple threads for HOGWILD.
|
|
88
|
+
# Probably the best way to do this is by moving the sample pushing
|
|
89
|
+
# to a separate thread and then just sharing the data queue
|
|
90
|
+
# but signalling the end is tricky without a non-blocking API
|
|
91
|
+
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
|
92
|
+
|
|
93
|
+
|
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
|
|
2
|
+
|
|
3
|
+
To support these two classes, in `./_utils` we define many utility methods and
|
|
4
|
+
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
|
5
|
+
in `./_utils/worker.py`.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
'''
|
|
9
|
+
相对于torch.DataLoader的改进点
|
|
10
|
+
1, 原始dataloader的处理流程为主进程采样,工作进程处理,主进程取工作进程的结果数据,在取数据时
|
|
11
|
+
取的顺序与采样的顺序一致,为了保持这种一致会导致工作进程无法满负荷运行,数据处理效率底,同时由于
|
|
12
|
+
采样大多为随机采样,保持这种顺序并不是十分必要,因些这里删除了顺序一致性的要求;
|
|
13
|
+
2,原始的dataloader按顺序给工作进程分派任务,而不考虑工作进程的实际状态,这时修改为按工作进程的排队
|
|
14
|
+
任务状态分配任务;
|
|
15
|
+
3,原始dataloader生成数据的粒度为一个batch, 当工作进程数较多是会导致工作进程已经处理了大量的数据,
|
|
16
|
+
但主进程却无法取到任何一个完整的batch而引起训练阻塞,产生的效果就是工作进程数据多了之后,数据处理反
|
|
17
|
+
而变慢了,这里引入一个参数batch_split_nr,让工作进程以batch_size/batch_split_nr的粒度进行数据数据;
|
|
18
|
+
4, 当batch size较大时,数据从CPU复制到GPU时会消耗不少时间,这里在pin_memory时,直接将数据复制到GPU;
|
|
19
|
+
|
|
20
|
+
参数设置参考
|
|
21
|
+
1, num_workers设置为4~32
|
|
22
|
+
2, pin_memory设置True
|
|
23
|
+
3, batch_split_nr设置为4左右(需要能整除batch_size)
|
|
24
|
+
4, ulimit -n 65535 #设置可以打开的最大文件数
|
|
25
|
+
'''
|
|
26
|
+
|
|
27
|
+
import os
|
|
28
|
+
import threading
|
|
29
|
+
import itertools
|
|
30
|
+
import warnings
|
|
31
|
+
from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional
|
|
32
|
+
import multiprocessing as python_multiprocessing
|
|
33
|
+
import torch
|
|
34
|
+
import torch.multiprocessing as multiprocessing
|
|
35
|
+
from torch._utils import ExceptionWrapper
|
|
36
|
+
import wml.wtorch.utils as wtu
|
|
37
|
+
import time
|
|
38
|
+
from . import IterableDataset, Sampler,InfiniteSequentialSampler, SequentialSampler, RandomSampler, BatchSampler, Dataset
|
|
39
|
+
from . import _utils
|
|
40
|
+
import wml.wml_utils as wmlu
|
|
41
|
+
from . import _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, _DatasetKind
|
|
42
|
+
|
|
43
|
+
string_classes = (str, bytes)
|
|
44
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
45
|
+
T = TypeVar('T')
|
|
46
|
+
_worker_init_fn_t = Callable[[int], None]
|
|
47
|
+
|
|
48
|
+
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
|
49
|
+
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
|
50
|
+
# See https://github.com/python/mypy/issues/3737.
|
|
51
|
+
_collate_fn_t = Callable[[List[T]], Any]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# This function used to be defined in this file. However, it was moved to
|
|
55
|
+
# _utils/collate.py. Although it is rather hard to access this from user land
|
|
56
|
+
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
|
57
|
+
# probably is user code out there using it. This aliasing maintains BC in this
|
|
58
|
+
# aspect.
|
|
59
|
+
default_collate: _collate_fn_t = _utils.collate.default_collate
|
|
60
|
+
|
|
61
|
+
get_worker_info = _utils.worker.get_worker_info
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class _InfiniteConstantSampler(Sampler):
|
|
67
|
+
r"""Analogous to ``itertools.repeat(None, None)``.
|
|
68
|
+
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
data_source (Dataset): dataset to sample from
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self):
|
|
75
|
+
super(_InfiniteConstantSampler, self).__init__(None)
|
|
76
|
+
|
|
77
|
+
def __iter__(self):
|
|
78
|
+
while True:
|
|
79
|
+
yield None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class DataLoader(Generic[T_co]):
|
|
83
|
+
r"""
|
|
84
|
+
Data loader. Combines a dataset and a sampler, and provides an iterable over
|
|
85
|
+
the given dataset.
|
|
86
|
+
|
|
87
|
+
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
|
88
|
+
iterable-style datasets with single- or multi-process loading, customizing
|
|
89
|
+
loading order and optional automatic batching (collation) and memory pinning.
|
|
90
|
+
|
|
91
|
+
See :py:mod:`torch.utils.data` documentation page for more details.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
dataset (Dataset): dataset from which to load the data.
|
|
95
|
+
batch_size (int, optional): how many samples per batch to load
|
|
96
|
+
(default: ``1``).
|
|
97
|
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
|
98
|
+
at every epoch (default: ``False``).
|
|
99
|
+
sampler (Sampler or Iterable, optional): defines the strategy to draw
|
|
100
|
+
samples from the dataset. Can be any ``Iterable`` with ``__len__``
|
|
101
|
+
implemented. If specified, :attr:`shuffle` must not be specified.
|
|
102
|
+
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
|
|
103
|
+
returns a batch of indices at a time. Mutually exclusive with
|
|
104
|
+
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
|
|
105
|
+
and :attr:`drop_last`.
|
|
106
|
+
num_workers (int, optional): how many subprocesses to use for data
|
|
107
|
+
loading. ``0`` means that the data will be loaded in the main process.
|
|
108
|
+
(default: ``0``)
|
|
109
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
110
|
+
mini-batch of Tensor(s). Used when using batched loading from a
|
|
111
|
+
map-style dataset.
|
|
112
|
+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
|
113
|
+
into CUDA pinned memory before returning them. If your data elements
|
|
114
|
+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
|
115
|
+
see the example below.
|
|
116
|
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
|
117
|
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
|
118
|
+
the size of dataset is not divisible by the batch size, then the last batch
|
|
119
|
+
will be smaller. (default: ``False``)
|
|
120
|
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
|
121
|
+
from workers. Should always be non-negative. (default: ``0``)
|
|
122
|
+
worker_init_fn (callable, optional): If not ``None``, this will be called on each
|
|
123
|
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
|
124
|
+
input, after seeding and before data loading. (default: ``None``)
|
|
125
|
+
prefetch_factor (int, optional, keyword-only arg): Number of samples loaded
|
|
126
|
+
in advance by each worker. ``2`` means there will be a total of
|
|
127
|
+
2 * num_workers samples prefetched across all workers. (default: ``2``)
|
|
128
|
+
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
|
|
129
|
+
the worker processes after a dataset has been consumed once. This allows to
|
|
130
|
+
maintain the workers `Dataset` instances alive. (default: ``False``)
|
|
131
|
+
batch_split_nr: Each element in data_queue's sample number reduced from batch_size to
|
|
132
|
+
batch_size/batch_split_nr
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
|
136
|
+
cannot be an unpicklable object, e.g., a lambda function. See
|
|
137
|
+
:ref:`multiprocessing-best-practices` on more details related
|
|
138
|
+
to multiprocessing in PyTorch.
|
|
139
|
+
|
|
140
|
+
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
|
141
|
+
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
|
142
|
+
it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
|
|
143
|
+
rounding depending on :attr:`drop_last`, regardless of multi-process loading
|
|
144
|
+
configurations. This represents the best guess PyTorch can make because PyTorch
|
|
145
|
+
trusts user :attr:`dataset` code in correctly handling multi-process
|
|
146
|
+
loading to avoid duplicate data.
|
|
147
|
+
|
|
148
|
+
However, if sharding results in multiple workers having incomplete last batches,
|
|
149
|
+
this estimate can still be inaccurate, because (1) an otherwise complete batch can
|
|
150
|
+
be broken into multiple ones and (2) more than one batch worth of samples can be
|
|
151
|
+
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
|
|
152
|
+
cases in general.
|
|
153
|
+
|
|
154
|
+
See `Dataset Types`_ for more details on these two types of datasets and how
|
|
155
|
+
:class:`~torch.utils.data.IterableDataset` interacts with
|
|
156
|
+
`Multi-process data loading`_.
|
|
157
|
+
|
|
158
|
+
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
|
|
159
|
+
:ref:`data-loading-randomness` notes for random seed related questions.
|
|
160
|
+
"""
|
|
161
|
+
dataset: Dataset[T_co]
|
|
162
|
+
batch_size: Optional[int]
|
|
163
|
+
num_workers: int
|
|
164
|
+
pin_memory: bool
|
|
165
|
+
drop_last: bool
|
|
166
|
+
timeout: float
|
|
167
|
+
sampler: Sampler
|
|
168
|
+
prefetch_factor: int
|
|
169
|
+
_iterator : Optional['_BaseDataLoaderIter']
|
|
170
|
+
__initialized = False
|
|
171
|
+
|
|
172
|
+
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
|
|
173
|
+
shuffle: bool = True, sampler: Optional[Sampler[int]] = None,
|
|
174
|
+
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
|
|
175
|
+
num_workers: int = 0, collate_fn: _collate_fn_t = None,
|
|
176
|
+
pin_memory: bool = True, drop_last: bool = False,
|
|
177
|
+
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
|
|
178
|
+
multiprocessing_context=None, generator=None,
|
|
179
|
+
*, prefetch_factor: int = 2,
|
|
180
|
+
persistent_workers: bool = True,
|
|
181
|
+
batch_split_nr:int =2):
|
|
182
|
+
torch._C._log_api_usage_once("python.data_loader") # type: ignore
|
|
183
|
+
|
|
184
|
+
self.batch_split_nr = max(1,batch_split_nr)
|
|
185
|
+
|
|
186
|
+
if num_workers < 0:
|
|
187
|
+
raise ValueError('num_workers option should be non-negative; '
|
|
188
|
+
'use num_workers=0 to disable multiprocessing.')
|
|
189
|
+
|
|
190
|
+
if timeout < 0:
|
|
191
|
+
raise ValueError('timeout option should be non-negative')
|
|
192
|
+
|
|
193
|
+
if num_workers == 0 and prefetch_factor != 2:
|
|
194
|
+
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
|
|
195
|
+
'let num_workers > 0 to enable multiprocessing.')
|
|
196
|
+
assert prefetch_factor > 0
|
|
197
|
+
|
|
198
|
+
if persistent_workers and num_workers == 0:
|
|
199
|
+
raise ValueError('persistent_workers option needs num_workers > 0')
|
|
200
|
+
|
|
201
|
+
self.dataset = dataset
|
|
202
|
+
self.num_workers = num_workers
|
|
203
|
+
self.prefetch_factor = prefetch_factor
|
|
204
|
+
self.pin_memory = pin_memory
|
|
205
|
+
self.timeout = timeout
|
|
206
|
+
self.worker_init_fn = worker_init_fn
|
|
207
|
+
self.multiprocessing_context = multiprocessing_context
|
|
208
|
+
|
|
209
|
+
# Arg-check dataset related before checking samplers because we want to
|
|
210
|
+
# tell users that iterable-style datasets are incompatible with custom
|
|
211
|
+
# samplers first, so that they don't learn that this combo doesn't work
|
|
212
|
+
# after spending time fixing the custom sampler errors.
|
|
213
|
+
if isinstance(dataset, IterableDataset):
|
|
214
|
+
self._dataset_kind = _DatasetKind.Iterable
|
|
215
|
+
# NOTE [ Custom Samplers and IterableDataset ]
|
|
216
|
+
#
|
|
217
|
+
# `IterableDataset` does not support custom `batch_sampler` or
|
|
218
|
+
# `sampler` since the key is irrelevant (unless we support
|
|
219
|
+
# generator-style dataset one day...).
|
|
220
|
+
#
|
|
221
|
+
# For `sampler`, we always create a dummy sampler. This is an
|
|
222
|
+
# infinite sampler even when the dataset may have an implemented
|
|
223
|
+
# finite `__len__` because in multi-process data loading, naive
|
|
224
|
+
# settings will return duplicated data (which may be desired), and
|
|
225
|
+
# thus using a sampler with length matching that of dataset will
|
|
226
|
+
# cause data lost (you may have duplicates of the first couple
|
|
227
|
+
# batches, but never see anything afterwards). Therefore,
|
|
228
|
+
# `Iterabledataset` always uses an infinite sampler, an instance of
|
|
229
|
+
# `_InfiniteConstantSampler` defined above.
|
|
230
|
+
#
|
|
231
|
+
# A custom `batch_sampler` essentially only controls the batch size.
|
|
232
|
+
# However, it is unclear how useful it would be since an iterable-style
|
|
233
|
+
# dataset can handle that within itself. Moreover, it is pointless
|
|
234
|
+
# in multi-process data loading as the assignment order of batches
|
|
235
|
+
# to workers is an implementation detail so users can not control
|
|
236
|
+
# how to batchify each worker's iterable. Thus, we disable this
|
|
237
|
+
# option. If this turns out to be useful in future, we can re-enable
|
|
238
|
+
# this, and support custom samplers that specify the assignments to
|
|
239
|
+
# specific workers.
|
|
240
|
+
if shuffle is not False:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"DataLoader with IterableDataset: expected unspecified "
|
|
243
|
+
"shuffle option, but got shuffle={}".format(shuffle))
|
|
244
|
+
elif sampler is not None:
|
|
245
|
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"DataLoader with IterableDataset: expected unspecified "
|
|
248
|
+
"sampler option, but got sampler={}".format(sampler))
|
|
249
|
+
elif batch_sampler is not None:
|
|
250
|
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"DataLoader with IterableDataset: expected unspecified "
|
|
253
|
+
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
|
|
254
|
+
else:
|
|
255
|
+
self._dataset_kind = _DatasetKind.Map
|
|
256
|
+
|
|
257
|
+
if sampler is not None and shuffle:
|
|
258
|
+
print('WARNING: sampler option is mutually exclusive with '
|
|
259
|
+
'shuffle')
|
|
260
|
+
|
|
261
|
+
if batch_sampler is not None:
|
|
262
|
+
# auto_collation with custom batch_sampler
|
|
263
|
+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
|
264
|
+
print('WARNING: batch_sampler option is mutually exclusive '
|
|
265
|
+
'with batch_size, shuffle, sampler, and '
|
|
266
|
+
'drop_last')
|
|
267
|
+
batch_size = None
|
|
268
|
+
drop_last = False
|
|
269
|
+
print(f"Use batch sampler")
|
|
270
|
+
elif batch_size is None:
|
|
271
|
+
# no auto_collation
|
|
272
|
+
if drop_last:
|
|
273
|
+
raise ValueError('batch_size=None option disables auto-batching '
|
|
274
|
+
'and is mutually exclusive with drop_last')
|
|
275
|
+
print(f"batch_size={batch_size}")
|
|
276
|
+
|
|
277
|
+
if sampler is None: # give default samplers
|
|
278
|
+
if self._dataset_kind == _DatasetKind.Iterable:
|
|
279
|
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
|
280
|
+
sampler = _InfiniteConstantSampler()
|
|
281
|
+
else: # map-style
|
|
282
|
+
if shuffle or self.num_workers>0:
|
|
283
|
+
# Cannot statically verify that dataset is Sized
|
|
284
|
+
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
285
|
+
sampler = RandomSampler(dataset, generator=generator) # type: ignore
|
|
286
|
+
else:
|
|
287
|
+
sampler = InfiniteSequentialSampler(dataset)
|
|
288
|
+
|
|
289
|
+
if batch_size is not None and batch_sampler is None:
|
|
290
|
+
# auto_collation without custom batch_sampler
|
|
291
|
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
|
292
|
+
|
|
293
|
+
self.batch_size = batch_size
|
|
294
|
+
self.drop_last = drop_last
|
|
295
|
+
self.sampler = sampler
|
|
296
|
+
self.batch_sampler = batch_sampler
|
|
297
|
+
self.generator = generator
|
|
298
|
+
|
|
299
|
+
if collate_fn is None:
|
|
300
|
+
if self._auto_collation:
|
|
301
|
+
collate_fn = _utils.collate.default_collate
|
|
302
|
+
else:
|
|
303
|
+
collate_fn = _utils.collate.default_convert
|
|
304
|
+
|
|
305
|
+
self.collate_fn = collate_fn
|
|
306
|
+
self.persistent_workers = persistent_workers
|
|
307
|
+
|
|
308
|
+
self.__initialized = True
|
|
309
|
+
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
|
|
310
|
+
|
|
311
|
+
self._iterator = None
|
|
312
|
+
|
|
313
|
+
self.check_worker_number_rationality()
|
|
314
|
+
|
|
315
|
+
def _get_iterator(self) -> '_BaseDataLoaderIter':
|
|
316
|
+
if self.num_workers == 0:
|
|
317
|
+
return _SingleProcessDataLoaderIter(self)
|
|
318
|
+
else:
|
|
319
|
+
self.check_worker_number_rationality()
|
|
320
|
+
return _MultiProcessingDataLoaderIter(self,batch_split_nr=self.batch_split_nr)
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def multiprocessing_context(self):
|
|
324
|
+
return self.__multiprocessing_context
|
|
325
|
+
|
|
326
|
+
@multiprocessing_context.setter
|
|
327
|
+
def multiprocessing_context(self, multiprocessing_context):
|
|
328
|
+
if multiprocessing_context is not None:
|
|
329
|
+
if self.num_workers > 0:
|
|
330
|
+
if isinstance(multiprocessing_context, string_classes):
|
|
331
|
+
valid_start_methods = multiprocessing.get_all_start_methods()
|
|
332
|
+
if multiprocessing_context not in valid_start_methods:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
('multiprocessing_context option '
|
|
335
|
+
'should specify a valid start method in {!r}, but got '
|
|
336
|
+
'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
|
|
337
|
+
# error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
|
|
338
|
+
multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore
|
|
339
|
+
|
|
340
|
+
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
|
|
341
|
+
raise TypeError(('multiprocessing_context option should be a valid context '
|
|
342
|
+
'object or a string specifying the start method, but got '
|
|
343
|
+
'multiprocessing_context={}').format(multiprocessing_context))
|
|
344
|
+
else:
|
|
345
|
+
raise ValueError(('multiprocessing_context can only be used with '
|
|
346
|
+
'multi-process loading (num_workers > 0), but got '
|
|
347
|
+
'num_workers={}').format(self.num_workers))
|
|
348
|
+
|
|
349
|
+
self.__multiprocessing_context = multiprocessing_context
|
|
350
|
+
|
|
351
|
+
def __setattr__(self, attr, val):
|
|
352
|
+
if self.__initialized and attr in (
|
|
353
|
+
'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'):
|
|
354
|
+
raise ValueError('{} attribute should not be set after {} is '
|
|
355
|
+
'initialized'.format(attr, self.__class__.__name__))
|
|
356
|
+
|
|
357
|
+
super(DataLoader, self).__setattr__(attr, val)
|
|
358
|
+
|
|
359
|
+
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
|
360
|
+
# since '_BaseDataLoaderIter' references 'DataLoader'.
|
|
361
|
+
def __iter__(self) -> '_BaseDataLoaderIter':
|
|
362
|
+
# When using a single worker the returned iterator should be
|
|
363
|
+
# created everytime to avoid reseting its state
|
|
364
|
+
# However, in the case of a multiple workers iterator
|
|
365
|
+
# the iterator is only created once in the lifetime of the
|
|
366
|
+
# DataLoader object so that workers can be reused
|
|
367
|
+
if self.persistent_workers and self.num_workers > 0:
|
|
368
|
+
if self._iterator is None:
|
|
369
|
+
self._iterator = self._get_iterator()
|
|
370
|
+
else:
|
|
371
|
+
self._iterator._reset(self)
|
|
372
|
+
return self._iterator
|
|
373
|
+
else:
|
|
374
|
+
if self._iterator is not None:
|
|
375
|
+
del self._iterator
|
|
376
|
+
self._iterator = None
|
|
377
|
+
return self._get_iterator()
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def _auto_collation(self):
|
|
381
|
+
return self.batch_sampler is not None
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def _index_sampler(self):
|
|
385
|
+
# The actual sampler used for generating indices for `_DatasetFetcher`
|
|
386
|
+
# (see _utils/fetch.py) to read data at each time. This would be
|
|
387
|
+
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
|
388
|
+
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
|
389
|
+
# reasons.
|
|
390
|
+
if self._auto_collation:
|
|
391
|
+
return self.batch_sampler
|
|
392
|
+
else:
|
|
393
|
+
return self.sampler
|
|
394
|
+
|
|
395
|
+
def __len__(self) -> int:
|
|
396
|
+
if self._dataset_kind == _DatasetKind.Iterable:
|
|
397
|
+
# NOTE [ IterableDataset and __len__ ]
|
|
398
|
+
#
|
|
399
|
+
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
|
400
|
+
# does multi-processing data loading, since the samples will be duplicated.
|
|
401
|
+
# However, no real use case should be actually using that behavior, so
|
|
402
|
+
# it should count as a user error. We should generally trust user
|
|
403
|
+
# code to do the proper thing (e.g., configure each replica differently
|
|
404
|
+
# in `__iter__`), and give us the correct `__len__` if they choose to
|
|
405
|
+
# implement it (this will still throw if the dataset does not implement
|
|
406
|
+
# a `__len__`).
|
|
407
|
+
#
|
|
408
|
+
# To provide a further warning, we track if `__len__` was called on the
|
|
409
|
+
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
|
410
|
+
# if the iterator ends up yielding more than this number of samples.
|
|
411
|
+
|
|
412
|
+
# Cannot statically verify that dataset is Sized
|
|
413
|
+
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
|
|
414
|
+
if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
|
|
415
|
+
from math import ceil
|
|
416
|
+
if self.drop_last:
|
|
417
|
+
length = length // self.batch_size
|
|
418
|
+
else:
|
|
419
|
+
length = ceil(length / self.batch_size)
|
|
420
|
+
return length
|
|
421
|
+
else:
|
|
422
|
+
return len(self._index_sampler)
|
|
423
|
+
|
|
424
|
+
def check_worker_number_rationality(self):
|
|
425
|
+
# This function check whether the dataloader's worker number is rational based on
|
|
426
|
+
# current system's resource. Current rule is that if the number of workers this
|
|
427
|
+
# Dataloader will create is bigger than the number of logical cpus that is allowed to
|
|
428
|
+
# use, than we will pop up a warning to let user pay attention.
|
|
429
|
+
#
|
|
430
|
+
# eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
|
|
431
|
+
# threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
|
|
432
|
+
# DataLoader process can use half of them which is 32, then the rational max number of
|
|
433
|
+
# worker that initiated from this process is 32.
|
|
434
|
+
# Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
|
|
435
|
+
# So the warning message is triggered to notify the user to lower the worker number if
|
|
436
|
+
# necessary.
|
|
437
|
+
#
|
|
438
|
+
#
|
|
439
|
+
# [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
|
|
440
|
+
# available (available in most of Linux system, but not OSX and Windows).
|
|
441
|
+
# When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
|
|
442
|
+
# it doesn't repect cpuset.
|
|
443
|
+
# We don't take threading into account since each worker process is single threaded
|
|
444
|
+
# at this time.
|
|
445
|
+
#
|
|
446
|
+
# We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
|
|
447
|
+
# other than `torch.set_num_threads` to 1 in the worker process, if the passing
|
|
448
|
+
# in functions use 3rd party modules that rely on those threading flags to determine
|
|
449
|
+
# how many thread to create (eg. numpy, etc), then it is caller's responsibility to
|
|
450
|
+
# set those flags correctly.
|
|
451
|
+
def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
|
|
452
|
+
|
|
453
|
+
suggested_max_worker_msg = ((
|
|
454
|
+
"Our suggested max number of worker in current system is {}{}, which is smaller "
|
|
455
|
+
"than what this DataLoader is going to create.").format(
|
|
456
|
+
num_worker_suggest,
|
|
457
|
+
("" if cpuset_checked else " (`cpuset` is not taken into account)"))
|
|
458
|
+
) if num_worker_suggest is not None else (
|
|
459
|
+
"DataLoader is not able to compute a suggested max number of worker in current system.")
|
|
460
|
+
|
|
461
|
+
warn_msg = (
|
|
462
|
+
"This DataLoader will create {} worker processes in total. {} "
|
|
463
|
+
"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
|
|
464
|
+
"lower the worker number to avoid potential slowness/freeze if necessary.").format(
|
|
465
|
+
num_worker_created,
|
|
466
|
+
suggested_max_worker_msg)
|
|
467
|
+
return warn_msg
|
|
468
|
+
|
|
469
|
+
if not self.num_workers or self.num_workers == 0:
|
|
470
|
+
return
|
|
471
|
+
|
|
472
|
+
# try to compute a suggested max number of worker based on system's resource
|
|
473
|
+
max_num_worker_suggest = None
|
|
474
|
+
cpuset_checked = False
|
|
475
|
+
if hasattr(os, 'sched_getaffinity'):
|
|
476
|
+
try:
|
|
477
|
+
max_num_worker_suggest = len(os.sched_getaffinity(0))
|
|
478
|
+
cpuset_checked = True
|
|
479
|
+
except Exception:
|
|
480
|
+
pass
|
|
481
|
+
if max_num_worker_suggest is None:
|
|
482
|
+
# os.cpu_count() could return Optional[int]
|
|
483
|
+
# get cpu count first and check None in order to satify mypy check
|
|
484
|
+
cpu_count = os.cpu_count()
|
|
485
|
+
if cpu_count is not None:
|
|
486
|
+
max_num_worker_suggest = cpu_count
|
|
487
|
+
|
|
488
|
+
if max_num_worker_suggest is None:
|
|
489
|
+
warnings.warn(_create_warning_msg(
|
|
490
|
+
max_num_worker_suggest,
|
|
491
|
+
self.num_workers,
|
|
492
|
+
cpuset_checked))
|
|
493
|
+
return
|
|
494
|
+
|
|
495
|
+
if self.num_workers > max_num_worker_suggest:
|
|
496
|
+
warnings.warn(_create_warning_msg(
|
|
497
|
+
max_num_worker_suggest,
|
|
498
|
+
self.num_workers,
|
|
499
|
+
cpuset_checked))
|
|
500
|
+
|
|
501
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
import torch.utils.data.datapipes.iter
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from torch.utils.data.datapipes.iter.listdirfiles import ListDirFilesIterDataPipe as ListDirFiles
|
|
2
|
+
from torch.utils.data.datapipes.iter.loadfilesfromdisk import LoadFilesFromDiskIterDataPipe as LoadFilesFromDisk
|
|
3
|
+
from torch.utils.data.datapipes.iter.readfilesfromtar import ReadFilesFromTarIterDataPipe as ReadFilesFromTar
|
|
4
|
+
from torch.utils.data.datapipes.iter.readfilesfromzip import ReadFilesFromZipIterDataPipe as ReadFilesFromZip
|
|
5
|
+
|
|
6
|
+
# Functional DataPipe
|
|
7
|
+
from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch
|
|
8
|
+
from torch.utils.data.datapipes.iter.callable import CallableIterDataPipe as Callable, CollateIterDataPipe as Collate
|
|
9
|
+
from torch.utils.data.datapipes.iter.sampler import SamplerIterDataPipe as Sampler
|
|
10
|
+
|
|
11
|
+
__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip'
|
|
12
|
+
'Batch', 'BucketBatch', 'Callable', 'Collate', 'Sampler']
|