python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,93 @@
1
+ import warnings
2
+ from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional
3
+ import torch
4
+ from . import _utils
5
+
6
+ class _DatasetKind(object):
7
+ Map = 0
8
+ Iterable = 1
9
+
10
+ @staticmethod
11
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
12
+ if kind == _DatasetKind.Map:
13
+ return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
14
+ else:
15
+ return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
16
+
17
+ class _BaseDataLoaderIter(object):
18
+ def __init__(self, loader) -> None:
19
+ self._dataset = loader.dataset
20
+ self._dataset_kind = loader._dataset_kind
21
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
22
+ self._auto_collation = loader._auto_collation
23
+ self._drop_last = loader.drop_last
24
+ self._index_sampler = loader._index_sampler
25
+ self._num_workers = loader.num_workers
26
+ self._prefetch_factor = loader.prefetch_factor
27
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
28
+ self._timeout = loader.timeout
29
+ self._collate_fn = loader.collate_fn
30
+ self._sampler_iter = iter(self._index_sampler)
31
+ self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
32
+ self._persistent_workers = loader.persistent_workers
33
+ self._num_yielded = 0
34
+ self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
35
+
36
+ def __iter__(self) -> '_BaseDataLoaderIter':
37
+ return self
38
+
39
+ def _reset(self, loader, first_iter=False):
40
+ self._sampler_iter = iter(self._index_sampler)
41
+ self._num_yielded = 0
42
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
43
+
44
+ def _next_index(self):
45
+ res = next(self._sampler_iter) # may raise StopIteration
46
+ return _BaseDataLoaderIter.to_item(res)
47
+
48
+ @staticmethod
49
+ def to_item(data):
50
+ if torch.is_tensor(data):
51
+ return data.item()
52
+ elif isinstance(data,Iterable):
53
+ return [_BaseDataLoaderIter.to_item(x) for x in data]
54
+ else:
55
+ return data
56
+
57
+
58
+ def _next_data(self):
59
+ raise NotImplementedError
60
+
61
+ def __next__(self) -> Any:
62
+ with torch.autograd.profiler.record_function(self._profile_name):
63
+ if self._sampler_iter is None:
64
+ self._reset()
65
+ data = self._next_data()
66
+ self._num_yielded += 1
67
+ if self._dataset_kind == _DatasetKind.Iterable and \
68
+ self._IterableDataset_len_called is not None and \
69
+ self._num_yielded > self._IterableDataset_len_called:
70
+ warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
71
+ "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
72
+ self._num_yielded)
73
+ if self._num_workers > 0:
74
+ warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
75
+ "IterableDataset replica at each worker. Please see "
76
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
77
+ warnings.warn(warn_msg)
78
+ return data
79
+
80
+ next = __next__ # Python 2 compatibility
81
+
82
+ def __len__(self) -> int:
83
+ return len(self._index_sampler)
84
+
85
+ def __getstate__(self):
86
+ # TODO: add limited pickling support for sharing an iterator
87
+ # across multiple threads for HOGWILD.
88
+ # Probably the best way to do this is by moving the sample pushing
89
+ # to a separate thread and then just sharing the data queue
90
+ # but signalling the end is tricky without a non-blocking API
91
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
92
+
93
+
@@ -0,0 +1,501 @@
1
+ r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
2
+
3
+ To support these two classes, in `./_utils` we define many utility methods and
4
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
5
+ in `./_utils/worker.py`.
6
+ """
7
+
8
+ '''
9
+ 相对于torch.DataLoader的改进点
10
+ 1, 原始dataloader的处理流程为主进程采样,工作进程处理,主进程取工作进程的结果数据,在取数据时
11
+ 取的顺序与采样的顺序一致,为了保持这种一致会导致工作进程无法满负荷运行,数据处理效率底,同时由于
12
+ 采样大多为随机采样,保持这种顺序并不是十分必要,因些这里删除了顺序一致性的要求;
13
+ 2,原始的dataloader按顺序给工作进程分派任务,而不考虑工作进程的实际状态,这时修改为按工作进程的排队
14
+ 任务状态分配任务;
15
+ 3,原始dataloader生成数据的粒度为一个batch, 当工作进程数较多是会导致工作进程已经处理了大量的数据,
16
+ 但主进程却无法取到任何一个完整的batch而引起训练阻塞,产生的效果就是工作进程数据多了之后,数据处理反
17
+ 而变慢了,这里引入一个参数batch_split_nr,让工作进程以batch_size/batch_split_nr的粒度进行数据数据;
18
+ 4, 当batch size较大时,数据从CPU复制到GPU时会消耗不少时间,这里在pin_memory时,直接将数据复制到GPU;
19
+
20
+ 参数设置参考
21
+ 1, num_workers设置为4~32
22
+ 2, pin_memory设置True
23
+ 3, batch_split_nr设置为4左右(需要能整除batch_size)
24
+ 4, ulimit -n 65535 #设置可以打开的最大文件数
25
+ '''
26
+
27
+ import os
28
+ import threading
29
+ import itertools
30
+ import warnings
31
+ from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional
32
+ import multiprocessing as python_multiprocessing
33
+ import torch
34
+ import torch.multiprocessing as multiprocessing
35
+ from torch._utils import ExceptionWrapper
36
+ import wml.wtorch.utils as wtu
37
+ import time
38
+ from . import IterableDataset, Sampler,InfiniteSequentialSampler, SequentialSampler, RandomSampler, BatchSampler, Dataset
39
+ from . import _utils
40
+ import wml.wml_utils as wmlu
41
+ from . import _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, _DatasetKind
42
+
43
+ string_classes = (str, bytes)
44
+ T_co = TypeVar('T_co', covariant=True)
45
+ T = TypeVar('T')
46
+ _worker_init_fn_t = Callable[[int], None]
47
+
48
+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
49
+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
50
+ # See https://github.com/python/mypy/issues/3737.
51
+ _collate_fn_t = Callable[[List[T]], Any]
52
+
53
+
54
+ # This function used to be defined in this file. However, it was moved to
55
+ # _utils/collate.py. Although it is rather hard to access this from user land
56
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
57
+ # probably is user code out there using it. This aliasing maintains BC in this
58
+ # aspect.
59
+ default_collate: _collate_fn_t = _utils.collate.default_collate
60
+
61
+ get_worker_info = _utils.worker.get_worker_info
62
+
63
+
64
+
65
+
66
+ class _InfiniteConstantSampler(Sampler):
67
+ r"""Analogous to ``itertools.repeat(None, None)``.
68
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
69
+
70
+ Args:
71
+ data_source (Dataset): dataset to sample from
72
+ """
73
+
74
+ def __init__(self):
75
+ super(_InfiniteConstantSampler, self).__init__(None)
76
+
77
+ def __iter__(self):
78
+ while True:
79
+ yield None
80
+
81
+
82
+ class DataLoader(Generic[T_co]):
83
+ r"""
84
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
85
+ the given dataset.
86
+
87
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
88
+ iterable-style datasets with single- or multi-process loading, customizing
89
+ loading order and optional automatic batching (collation) and memory pinning.
90
+
91
+ See :py:mod:`torch.utils.data` documentation page for more details.
92
+
93
+ Args:
94
+ dataset (Dataset): dataset from which to load the data.
95
+ batch_size (int, optional): how many samples per batch to load
96
+ (default: ``1``).
97
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
98
+ at every epoch (default: ``False``).
99
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
100
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
101
+ implemented. If specified, :attr:`shuffle` must not be specified.
102
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
103
+ returns a batch of indices at a time. Mutually exclusive with
104
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
105
+ and :attr:`drop_last`.
106
+ num_workers (int, optional): how many subprocesses to use for data
107
+ loading. ``0`` means that the data will be loaded in the main process.
108
+ (default: ``0``)
109
+ collate_fn (callable, optional): merges a list of samples to form a
110
+ mini-batch of Tensor(s). Used when using batched loading from a
111
+ map-style dataset.
112
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
113
+ into CUDA pinned memory before returning them. If your data elements
114
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
115
+ see the example below.
116
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
117
+ if the dataset size is not divisible by the batch size. If ``False`` and
118
+ the size of dataset is not divisible by the batch size, then the last batch
119
+ will be smaller. (default: ``False``)
120
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
121
+ from workers. Should always be non-negative. (default: ``0``)
122
+ worker_init_fn (callable, optional): If not ``None``, this will be called on each
123
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
124
+ input, after seeding and before data loading. (default: ``None``)
125
+ prefetch_factor (int, optional, keyword-only arg): Number of samples loaded
126
+ in advance by each worker. ``2`` means there will be a total of
127
+ 2 * num_workers samples prefetched across all workers. (default: ``2``)
128
+ persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
129
+ the worker processes after a dataset has been consumed once. This allows to
130
+ maintain the workers `Dataset` instances alive. (default: ``False``)
131
+ batch_split_nr: Each element in data_queue's sample number reduced from batch_size to
132
+ batch_size/batch_split_nr
133
+
134
+
135
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
136
+ cannot be an unpicklable object, e.g., a lambda function. See
137
+ :ref:`multiprocessing-best-practices` on more details related
138
+ to multiprocessing in PyTorch.
139
+
140
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
141
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
142
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
143
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
144
+ configurations. This represents the best guess PyTorch can make because PyTorch
145
+ trusts user :attr:`dataset` code in correctly handling multi-process
146
+ loading to avoid duplicate data.
147
+
148
+ However, if sharding results in multiple workers having incomplete last batches,
149
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
150
+ be broken into multiple ones and (2) more than one batch worth of samples can be
151
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
152
+ cases in general.
153
+
154
+ See `Dataset Types`_ for more details on these two types of datasets and how
155
+ :class:`~torch.utils.data.IterableDataset` interacts with
156
+ `Multi-process data loading`_.
157
+
158
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
159
+ :ref:`data-loading-randomness` notes for random seed related questions.
160
+ """
161
+ dataset: Dataset[T_co]
162
+ batch_size: Optional[int]
163
+ num_workers: int
164
+ pin_memory: bool
165
+ drop_last: bool
166
+ timeout: float
167
+ sampler: Sampler
168
+ prefetch_factor: int
169
+ _iterator : Optional['_BaseDataLoaderIter']
170
+ __initialized = False
171
+
172
+ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
173
+ shuffle: bool = True, sampler: Optional[Sampler[int]] = None,
174
+ batch_sampler: Optional[Sampler[Sequence[int]]] = None,
175
+ num_workers: int = 0, collate_fn: _collate_fn_t = None,
176
+ pin_memory: bool = True, drop_last: bool = False,
177
+ timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
178
+ multiprocessing_context=None, generator=None,
179
+ *, prefetch_factor: int = 2,
180
+ persistent_workers: bool = True,
181
+ batch_split_nr:int =2):
182
+ torch._C._log_api_usage_once("python.data_loader") # type: ignore
183
+
184
+ self.batch_split_nr = max(1,batch_split_nr)
185
+
186
+ if num_workers < 0:
187
+ raise ValueError('num_workers option should be non-negative; '
188
+ 'use num_workers=0 to disable multiprocessing.')
189
+
190
+ if timeout < 0:
191
+ raise ValueError('timeout option should be non-negative')
192
+
193
+ if num_workers == 0 and prefetch_factor != 2:
194
+ raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
195
+ 'let num_workers > 0 to enable multiprocessing.')
196
+ assert prefetch_factor > 0
197
+
198
+ if persistent_workers and num_workers == 0:
199
+ raise ValueError('persistent_workers option needs num_workers > 0')
200
+
201
+ self.dataset = dataset
202
+ self.num_workers = num_workers
203
+ self.prefetch_factor = prefetch_factor
204
+ self.pin_memory = pin_memory
205
+ self.timeout = timeout
206
+ self.worker_init_fn = worker_init_fn
207
+ self.multiprocessing_context = multiprocessing_context
208
+
209
+ # Arg-check dataset related before checking samplers because we want to
210
+ # tell users that iterable-style datasets are incompatible with custom
211
+ # samplers first, so that they don't learn that this combo doesn't work
212
+ # after spending time fixing the custom sampler errors.
213
+ if isinstance(dataset, IterableDataset):
214
+ self._dataset_kind = _DatasetKind.Iterable
215
+ # NOTE [ Custom Samplers and IterableDataset ]
216
+ #
217
+ # `IterableDataset` does not support custom `batch_sampler` or
218
+ # `sampler` since the key is irrelevant (unless we support
219
+ # generator-style dataset one day...).
220
+ #
221
+ # For `sampler`, we always create a dummy sampler. This is an
222
+ # infinite sampler even when the dataset may have an implemented
223
+ # finite `__len__` because in multi-process data loading, naive
224
+ # settings will return duplicated data (which may be desired), and
225
+ # thus using a sampler with length matching that of dataset will
226
+ # cause data lost (you may have duplicates of the first couple
227
+ # batches, but never see anything afterwards). Therefore,
228
+ # `Iterabledataset` always uses an infinite sampler, an instance of
229
+ # `_InfiniteConstantSampler` defined above.
230
+ #
231
+ # A custom `batch_sampler` essentially only controls the batch size.
232
+ # However, it is unclear how useful it would be since an iterable-style
233
+ # dataset can handle that within itself. Moreover, it is pointless
234
+ # in multi-process data loading as the assignment order of batches
235
+ # to workers is an implementation detail so users can not control
236
+ # how to batchify each worker's iterable. Thus, we disable this
237
+ # option. If this turns out to be useful in future, we can re-enable
238
+ # this, and support custom samplers that specify the assignments to
239
+ # specific workers.
240
+ if shuffle is not False:
241
+ raise ValueError(
242
+ "DataLoader with IterableDataset: expected unspecified "
243
+ "shuffle option, but got shuffle={}".format(shuffle))
244
+ elif sampler is not None:
245
+ # See NOTE [ Custom Samplers and IterableDataset ]
246
+ raise ValueError(
247
+ "DataLoader with IterableDataset: expected unspecified "
248
+ "sampler option, but got sampler={}".format(sampler))
249
+ elif batch_sampler is not None:
250
+ # See NOTE [ Custom Samplers and IterableDataset ]
251
+ raise ValueError(
252
+ "DataLoader with IterableDataset: expected unspecified "
253
+ "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
254
+ else:
255
+ self._dataset_kind = _DatasetKind.Map
256
+
257
+ if sampler is not None and shuffle:
258
+ print('WARNING: sampler option is mutually exclusive with '
259
+ 'shuffle')
260
+
261
+ if batch_sampler is not None:
262
+ # auto_collation with custom batch_sampler
263
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
264
+ print('WARNING: batch_sampler option is mutually exclusive '
265
+ 'with batch_size, shuffle, sampler, and '
266
+ 'drop_last')
267
+ batch_size = None
268
+ drop_last = False
269
+ print(f"Use batch sampler")
270
+ elif batch_size is None:
271
+ # no auto_collation
272
+ if drop_last:
273
+ raise ValueError('batch_size=None option disables auto-batching '
274
+ 'and is mutually exclusive with drop_last')
275
+ print(f"batch_size={batch_size}")
276
+
277
+ if sampler is None: # give default samplers
278
+ if self._dataset_kind == _DatasetKind.Iterable:
279
+ # See NOTE [ Custom Samplers and IterableDataset ]
280
+ sampler = _InfiniteConstantSampler()
281
+ else: # map-style
282
+ if shuffle or self.num_workers>0:
283
+ # Cannot statically verify that dataset is Sized
284
+ # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
285
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore
286
+ else:
287
+ sampler = InfiniteSequentialSampler(dataset)
288
+
289
+ if batch_size is not None and batch_sampler is None:
290
+ # auto_collation without custom batch_sampler
291
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
292
+
293
+ self.batch_size = batch_size
294
+ self.drop_last = drop_last
295
+ self.sampler = sampler
296
+ self.batch_sampler = batch_sampler
297
+ self.generator = generator
298
+
299
+ if collate_fn is None:
300
+ if self._auto_collation:
301
+ collate_fn = _utils.collate.default_collate
302
+ else:
303
+ collate_fn = _utils.collate.default_convert
304
+
305
+ self.collate_fn = collate_fn
306
+ self.persistent_workers = persistent_workers
307
+
308
+ self.__initialized = True
309
+ self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
310
+
311
+ self._iterator = None
312
+
313
+ self.check_worker_number_rationality()
314
+
315
+ def _get_iterator(self) -> '_BaseDataLoaderIter':
316
+ if self.num_workers == 0:
317
+ return _SingleProcessDataLoaderIter(self)
318
+ else:
319
+ self.check_worker_number_rationality()
320
+ return _MultiProcessingDataLoaderIter(self,batch_split_nr=self.batch_split_nr)
321
+
322
+ @property
323
+ def multiprocessing_context(self):
324
+ return self.__multiprocessing_context
325
+
326
+ @multiprocessing_context.setter
327
+ def multiprocessing_context(self, multiprocessing_context):
328
+ if multiprocessing_context is not None:
329
+ if self.num_workers > 0:
330
+ if isinstance(multiprocessing_context, string_classes):
331
+ valid_start_methods = multiprocessing.get_all_start_methods()
332
+ if multiprocessing_context not in valid_start_methods:
333
+ raise ValueError(
334
+ ('multiprocessing_context option '
335
+ 'should specify a valid start method in {!r}, but got '
336
+ 'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
337
+ # error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
338
+ multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore
339
+
340
+ if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
341
+ raise TypeError(('multiprocessing_context option should be a valid context '
342
+ 'object or a string specifying the start method, but got '
343
+ 'multiprocessing_context={}').format(multiprocessing_context))
344
+ else:
345
+ raise ValueError(('multiprocessing_context can only be used with '
346
+ 'multi-process loading (num_workers > 0), but got '
347
+ 'num_workers={}').format(self.num_workers))
348
+
349
+ self.__multiprocessing_context = multiprocessing_context
350
+
351
+ def __setattr__(self, attr, val):
352
+ if self.__initialized and attr in (
353
+ 'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'):
354
+ raise ValueError('{} attribute should not be set after {} is '
355
+ 'initialized'.format(attr, self.__class__.__name__))
356
+
357
+ super(DataLoader, self).__setattr__(attr, val)
358
+
359
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
360
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
361
+ def __iter__(self) -> '_BaseDataLoaderIter':
362
+ # When using a single worker the returned iterator should be
363
+ # created everytime to avoid reseting its state
364
+ # However, in the case of a multiple workers iterator
365
+ # the iterator is only created once in the lifetime of the
366
+ # DataLoader object so that workers can be reused
367
+ if self.persistent_workers and self.num_workers > 0:
368
+ if self._iterator is None:
369
+ self._iterator = self._get_iterator()
370
+ else:
371
+ self._iterator._reset(self)
372
+ return self._iterator
373
+ else:
374
+ if self._iterator is not None:
375
+ del self._iterator
376
+ self._iterator = None
377
+ return self._get_iterator()
378
+
379
+ @property
380
+ def _auto_collation(self):
381
+ return self.batch_sampler is not None
382
+
383
+ @property
384
+ def _index_sampler(self):
385
+ # The actual sampler used for generating indices for `_DatasetFetcher`
386
+ # (see _utils/fetch.py) to read data at each time. This would be
387
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
388
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
389
+ # reasons.
390
+ if self._auto_collation:
391
+ return self.batch_sampler
392
+ else:
393
+ return self.sampler
394
+
395
+ def __len__(self) -> int:
396
+ if self._dataset_kind == _DatasetKind.Iterable:
397
+ # NOTE [ IterableDataset and __len__ ]
398
+ #
399
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
400
+ # does multi-processing data loading, since the samples will be duplicated.
401
+ # However, no real use case should be actually using that behavior, so
402
+ # it should count as a user error. We should generally trust user
403
+ # code to do the proper thing (e.g., configure each replica differently
404
+ # in `__iter__`), and give us the correct `__len__` if they choose to
405
+ # implement it (this will still throw if the dataset does not implement
406
+ # a `__len__`).
407
+ #
408
+ # To provide a further warning, we track if `__len__` was called on the
409
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
410
+ # if the iterator ends up yielding more than this number of samples.
411
+
412
+ # Cannot statically verify that dataset is Sized
413
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
414
+ if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
415
+ from math import ceil
416
+ if self.drop_last:
417
+ length = length // self.batch_size
418
+ else:
419
+ length = ceil(length / self.batch_size)
420
+ return length
421
+ else:
422
+ return len(self._index_sampler)
423
+
424
+ def check_worker_number_rationality(self):
425
+ # This function check whether the dataloader's worker number is rational based on
426
+ # current system's resource. Current rule is that if the number of workers this
427
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
428
+ # use, than we will pop up a warning to let user pay attention.
429
+ #
430
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
431
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
432
+ # DataLoader process can use half of them which is 32, then the rational max number of
433
+ # worker that initiated from this process is 32.
434
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
435
+ # So the warning message is triggered to notify the user to lower the worker number if
436
+ # necessary.
437
+ #
438
+ #
439
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
440
+ # available (available in most of Linux system, but not OSX and Windows).
441
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
442
+ # it doesn't repect cpuset.
443
+ # We don't take threading into account since each worker process is single threaded
444
+ # at this time.
445
+ #
446
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
447
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
448
+ # in functions use 3rd party modules that rely on those threading flags to determine
449
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
450
+ # set those flags correctly.
451
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
452
+
453
+ suggested_max_worker_msg = ((
454
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
455
+ "than what this DataLoader is going to create.").format(
456
+ num_worker_suggest,
457
+ ("" if cpuset_checked else " (`cpuset` is not taken into account)"))
458
+ ) if num_worker_suggest is not None else (
459
+ "DataLoader is not able to compute a suggested max number of worker in current system.")
460
+
461
+ warn_msg = (
462
+ "This DataLoader will create {} worker processes in total. {} "
463
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
464
+ "lower the worker number to avoid potential slowness/freeze if necessary.").format(
465
+ num_worker_created,
466
+ suggested_max_worker_msg)
467
+ return warn_msg
468
+
469
+ if not self.num_workers or self.num_workers == 0:
470
+ return
471
+
472
+ # try to compute a suggested max number of worker based on system's resource
473
+ max_num_worker_suggest = None
474
+ cpuset_checked = False
475
+ if hasattr(os, 'sched_getaffinity'):
476
+ try:
477
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
478
+ cpuset_checked = True
479
+ except Exception:
480
+ pass
481
+ if max_num_worker_suggest is None:
482
+ # os.cpu_count() could return Optional[int]
483
+ # get cpu count first and check None in order to satify mypy check
484
+ cpu_count = os.cpu_count()
485
+ if cpu_count is not None:
486
+ max_num_worker_suggest = cpu_count
487
+
488
+ if max_num_worker_suggest is None:
489
+ warnings.warn(_create_warning_msg(
490
+ max_num_worker_suggest,
491
+ self.num_workers,
492
+ cpuset_checked))
493
+ return
494
+
495
+ if self.num_workers > max_num_worker_suggest:
496
+ warnings.warn(_create_warning_msg(
497
+ max_num_worker_suggest,
498
+ self.num_workers,
499
+ cpuset_checked))
500
+
501
+
@@ -0,0 +1 @@
1
+ import torch.utils.data.datapipes.iter
@@ -0,0 +1,12 @@
1
+ from torch.utils.data.datapipes.iter.listdirfiles import ListDirFilesIterDataPipe as ListDirFiles
2
+ from torch.utils.data.datapipes.iter.loadfilesfromdisk import LoadFilesFromDiskIterDataPipe as LoadFilesFromDisk
3
+ from torch.utils.data.datapipes.iter.readfilesfromtar import ReadFilesFromTarIterDataPipe as ReadFilesFromTar
4
+ from torch.utils.data.datapipes.iter.readfilesfromzip import ReadFilesFromZipIterDataPipe as ReadFilesFromZip
5
+
6
+ # Functional DataPipe
7
+ from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch
8
+ from torch.utils.data.datapipes.iter.callable import CallableIterDataPipe as Callable, CollateIterDataPipe as Collate
9
+ from torch.utils.data.datapipes.iter.sampler import SamplerIterDataPipe as Sampler
10
+
11
+ __all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip'
12
+ 'Batch', 'BucketBatch', 'Callable', 'Collate', 'Sampler']