python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
import bisect
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from torch._utils import _accumulate
|
|
6
|
+
from torch import randperm
|
|
7
|
+
# No 'default_generator' in torch/__init__.pyi
|
|
8
|
+
from torch import default_generator # type: ignore
|
|
9
|
+
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
|
|
10
|
+
from torch import Tensor, Generator
|
|
11
|
+
|
|
12
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
13
|
+
T = TypeVar('T')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Dataset(Generic[T_co]):
|
|
17
|
+
r"""An abstract class representing a :class:`Dataset`.
|
|
18
|
+
|
|
19
|
+
All datasets that represent a map from keys to data samples should subclass
|
|
20
|
+
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
|
|
21
|
+
data sample for a given key. Subclasses could also optionally overwrite
|
|
22
|
+
:meth:`__len__`, which is expected to return the size of the dataset by many
|
|
23
|
+
:class:`~torch.utils.data.Sampler` implementations and the default options
|
|
24
|
+
of :class:`~torch.utils.data.DataLoader`.
|
|
25
|
+
|
|
26
|
+
.. note::
|
|
27
|
+
:class:`~torch.utils.data.DataLoader` by default constructs a index
|
|
28
|
+
sampler that yields integral indices. To make it work with a map-style
|
|
29
|
+
dataset with non-integral indices/keys, a custom sampler must be provided.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __getitem__(self, index) -> T_co:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
|
|
36
|
+
return ConcatDataset([self, other])
|
|
37
|
+
|
|
38
|
+
# No `def __len__(self)` default?
|
|
39
|
+
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
40
|
+
# in pytorch/torch/utils/data/sampler.py
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class IterableDataset(Dataset[T_co]):
|
|
44
|
+
r"""An iterable Dataset.
|
|
45
|
+
|
|
46
|
+
All datasets that represent an iterable of data samples should subclass it.
|
|
47
|
+
Such form of datasets is particularly useful when data come from a stream.
|
|
48
|
+
|
|
49
|
+
All subclasses should overwrite :meth:`__iter__`, which would return an
|
|
50
|
+
iterator of samples in this dataset.
|
|
51
|
+
|
|
52
|
+
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
|
|
53
|
+
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
|
|
54
|
+
iterator. When :attr:`num_workers > 0`, each worker process will have a
|
|
55
|
+
different copy of the dataset object, so it is often desired to configure
|
|
56
|
+
each copy independently to avoid having duplicate data returned from the
|
|
57
|
+
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
|
|
58
|
+
process, returns information about the worker. It can be used in either the
|
|
59
|
+
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
|
|
60
|
+
:attr:`worker_init_fn` option to modify each copy's behavior.
|
|
61
|
+
|
|
62
|
+
Example 1: splitting workload across all workers in :meth:`__iter__`::
|
|
63
|
+
|
|
64
|
+
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
|
65
|
+
... def __init__(self, start, end):
|
|
66
|
+
... super(MyIterableDataset).__init__()
|
|
67
|
+
... assert end > start, "this example code only works with end >= start"
|
|
68
|
+
... self.start = start
|
|
69
|
+
... self.end = end
|
|
70
|
+
...
|
|
71
|
+
... def __iter__(self):
|
|
72
|
+
... worker_info = torch.utils.data.get_worker_info()
|
|
73
|
+
... if worker_info is None: # single-process data loading, return the full iterator
|
|
74
|
+
... iter_start = self.start
|
|
75
|
+
... iter_end = self.end
|
|
76
|
+
... else: # in a worker process
|
|
77
|
+
... # split workload
|
|
78
|
+
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
|
|
79
|
+
... worker_id = worker_info.id
|
|
80
|
+
... iter_start = self.start + worker_id * per_worker
|
|
81
|
+
... iter_end = min(iter_start + per_worker, self.end)
|
|
82
|
+
... return iter(range(iter_start, iter_end))
|
|
83
|
+
...
|
|
84
|
+
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
|
|
85
|
+
>>> ds = MyIterableDataset(start=3, end=7)
|
|
86
|
+
|
|
87
|
+
>>> # Single-process loading
|
|
88
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
89
|
+
[3, 4, 5, 6]
|
|
90
|
+
|
|
91
|
+
>>> # Mult-process loading with two worker processes
|
|
92
|
+
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
|
|
93
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
|
|
94
|
+
[3, 5, 4, 6]
|
|
95
|
+
|
|
96
|
+
>>> # With even more workers
|
|
97
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
|
|
98
|
+
[3, 4, 5, 6]
|
|
99
|
+
|
|
100
|
+
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
|
|
101
|
+
|
|
102
|
+
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
|
|
103
|
+
... def __init__(self, start, end):
|
|
104
|
+
... super(MyIterableDataset).__init__()
|
|
105
|
+
... assert end > start, "this example code only works with end >= start"
|
|
106
|
+
... self.start = start
|
|
107
|
+
... self.end = end
|
|
108
|
+
...
|
|
109
|
+
... def __iter__(self):
|
|
110
|
+
... return iter(range(self.start, self.end))
|
|
111
|
+
...
|
|
112
|
+
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
|
|
113
|
+
>>> ds = MyIterableDataset(start=3, end=7)
|
|
114
|
+
|
|
115
|
+
>>> # Single-process loading
|
|
116
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
117
|
+
[3, 4, 5, 6]
|
|
118
|
+
>>>
|
|
119
|
+
>>> # Directly doing multi-process loading yields duplicate data
|
|
120
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
|
|
121
|
+
[3, 3, 4, 4, 5, 5, 6, 6]
|
|
122
|
+
|
|
123
|
+
>>> # Define a `worker_init_fn` that configures each dataset copy differently
|
|
124
|
+
>>> def worker_init_fn(worker_id):
|
|
125
|
+
... worker_info = torch.utils.data.get_worker_info()
|
|
126
|
+
... dataset = worker_info.dataset # the dataset copy in this worker process
|
|
127
|
+
... overall_start = dataset.start
|
|
128
|
+
... overall_end = dataset.end
|
|
129
|
+
... # configure the dataset to only process the split workload
|
|
130
|
+
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
|
|
131
|
+
... worker_id = worker_info.id
|
|
132
|
+
... dataset.start = overall_start + worker_id * per_worker
|
|
133
|
+
... dataset.end = min(dataset.start + per_worker, overall_end)
|
|
134
|
+
...
|
|
135
|
+
|
|
136
|
+
>>> # Mult-process loading with the custom `worker_init_fn`
|
|
137
|
+
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
|
|
138
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
|
|
139
|
+
[3, 5, 4, 6]
|
|
140
|
+
|
|
141
|
+
>>> # With even more workers
|
|
142
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
|
|
143
|
+
[3, 4, 5, 6]
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __iter__(self) -> Iterator[T_co]:
|
|
147
|
+
raise NotImplementedError
|
|
148
|
+
|
|
149
|
+
def __add__(self, other: Dataset[T_co]):
|
|
150
|
+
return ChainDataset([self, other])
|
|
151
|
+
|
|
152
|
+
# No `def __len__(self)` default?
|
|
153
|
+
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
|
157
|
+
r"""Dataset wrapping tensors.
|
|
158
|
+
|
|
159
|
+
Each sample will be retrieved by indexing tensors along the first dimension.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
*tensors (Tensor): tensors that have the same size of the first dimension.
|
|
163
|
+
"""
|
|
164
|
+
tensors: Tuple[Tensor, ...]
|
|
165
|
+
|
|
166
|
+
def __init__(self, *tensors: Tensor) -> None:
|
|
167
|
+
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
|
|
168
|
+
self.tensors = tensors
|
|
169
|
+
|
|
170
|
+
def __getitem__(self, index):
|
|
171
|
+
return tuple(tensor[index] for tensor in self.tensors)
|
|
172
|
+
|
|
173
|
+
def __len__(self):
|
|
174
|
+
return self.tensors[0].size(0)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ConcatDataset(Dataset[T_co]):
|
|
178
|
+
r"""Dataset as a concatenation of multiple datasets.
|
|
179
|
+
|
|
180
|
+
This class is useful to assemble different existing datasets.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
datasets (sequence): List of datasets to be concatenated
|
|
184
|
+
"""
|
|
185
|
+
datasets: List[Dataset[T_co]]
|
|
186
|
+
cumulative_sizes: List[int]
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def cumsum(sequence):
|
|
190
|
+
r, s = [], 0
|
|
191
|
+
for e in sequence:
|
|
192
|
+
l = len(e)
|
|
193
|
+
r.append(l + s)
|
|
194
|
+
s += l
|
|
195
|
+
return r
|
|
196
|
+
|
|
197
|
+
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
198
|
+
super(ConcatDataset, self).__init__()
|
|
199
|
+
# Cannot verify that datasets is Sized
|
|
200
|
+
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
|
|
201
|
+
self.datasets = list(datasets)
|
|
202
|
+
for d in self.datasets:
|
|
203
|
+
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
|
|
204
|
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
|
205
|
+
|
|
206
|
+
def __len__(self):
|
|
207
|
+
return self.cumulative_sizes[-1]
|
|
208
|
+
|
|
209
|
+
def __getitem__(self, idx):
|
|
210
|
+
if idx < 0:
|
|
211
|
+
if -idx > len(self):
|
|
212
|
+
raise ValueError("absolute value of index should not exceed dataset length")
|
|
213
|
+
idx = len(self) + idx
|
|
214
|
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
215
|
+
if dataset_idx == 0:
|
|
216
|
+
sample_idx = idx
|
|
217
|
+
else:
|
|
218
|
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
219
|
+
return self.datasets[dataset_idx][sample_idx]
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def cummulative_sizes(self):
|
|
223
|
+
warnings.warn("cummulative_sizes attribute is renamed to "
|
|
224
|
+
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
|
225
|
+
return self.cumulative_sizes
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class ChainDataset(IterableDataset):
|
|
229
|
+
r"""Dataset for chainning multiple :class:`IterableDataset` s.
|
|
230
|
+
|
|
231
|
+
This class is useful to assemble different existing dataset streams. The
|
|
232
|
+
chainning operation is done on-the-fly, so concatenating large-scale
|
|
233
|
+
datasets with this class will be efficient.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
datasets (iterable of IterableDataset): datasets to be chained together
|
|
237
|
+
"""
|
|
238
|
+
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
239
|
+
super(ChainDataset, self).__init__()
|
|
240
|
+
self.datasets = datasets
|
|
241
|
+
|
|
242
|
+
def __iter__(self):
|
|
243
|
+
for d in self.datasets:
|
|
244
|
+
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
|
245
|
+
for x in d:
|
|
246
|
+
yield x
|
|
247
|
+
|
|
248
|
+
def __len__(self):
|
|
249
|
+
total = 0
|
|
250
|
+
for d in self.datasets:
|
|
251
|
+
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
|
|
252
|
+
# Cannot verify that all self.datasets are Sized
|
|
253
|
+
total += len(d) # type: ignore
|
|
254
|
+
return total
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class BufferedShuffleDataset(IterableDataset[T_co]):
|
|
258
|
+
r"""Dataset shuffled from the original dataset.
|
|
259
|
+
|
|
260
|
+
This class is useful to shuffle an existing instance of an IterableDataset.
|
|
261
|
+
The buffer with `buffer_size` is filled with the items from the dataset first. Then,
|
|
262
|
+
each item will be yielded from the buffer by reservoir sampling via iterator.
|
|
263
|
+
|
|
264
|
+
`buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
|
|
265
|
+
dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
|
|
266
|
+
is required to be greater than or equal to the size of dataset.
|
|
267
|
+
|
|
268
|
+
When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
|
|
269
|
+
dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
|
|
270
|
+
And, the method to set up a random seed is different based on :attr:`num_workers`.
|
|
271
|
+
|
|
272
|
+
For single-process mode (:attr:`num_workers == 0`), the random seed is required to
|
|
273
|
+
be set before the :class:`~torch.utils.data.DataLoader` in the main process.
|
|
274
|
+
|
|
275
|
+
>>> ds = BufferedShuffleDataset(dataset)
|
|
276
|
+
>>> random.seed(...)
|
|
277
|
+
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
|
|
278
|
+
|
|
279
|
+
For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
|
|
280
|
+
function in each worker.
|
|
281
|
+
|
|
282
|
+
>>> ds = BufferedShuffleDataset(dataset)
|
|
283
|
+
>>> def init_fn(worker_id):
|
|
284
|
+
... random.seed(...)
|
|
285
|
+
>>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
dataset (IterableDataset): The original IterableDataset.
|
|
289
|
+
buffer_size (int): The buffer size for shuffling.
|
|
290
|
+
"""
|
|
291
|
+
dataset: IterableDataset[T_co]
|
|
292
|
+
buffer_size: int
|
|
293
|
+
|
|
294
|
+
def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None:
|
|
295
|
+
super(BufferedShuffleDataset, self).__init__()
|
|
296
|
+
assert buffer_size > 0, "buffer_size should be larger than 0"
|
|
297
|
+
self.dataset = dataset
|
|
298
|
+
self.buffer_size = buffer_size
|
|
299
|
+
|
|
300
|
+
def __iter__(self) -> Iterator[T_co]:
|
|
301
|
+
buf: List[T_co] = []
|
|
302
|
+
for x in self.dataset:
|
|
303
|
+
if len(buf) == self.buffer_size:
|
|
304
|
+
idx = random.randint(0, self.buffer_size - 1)
|
|
305
|
+
yield buf[idx]
|
|
306
|
+
buf[idx] = x
|
|
307
|
+
else:
|
|
308
|
+
buf.append(x)
|
|
309
|
+
random.shuffle(buf)
|
|
310
|
+
while buf:
|
|
311
|
+
yield buf.pop()
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class Subset(Dataset[T_co]):
|
|
315
|
+
r"""
|
|
316
|
+
Subset of a dataset at specified indices.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
dataset (Dataset): The whole Dataset
|
|
320
|
+
indices (sequence): Indices in the whole set selected for subset
|
|
321
|
+
"""
|
|
322
|
+
dataset: Dataset[T_co]
|
|
323
|
+
indices: Sequence[int]
|
|
324
|
+
|
|
325
|
+
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
|
|
326
|
+
self.dataset = dataset
|
|
327
|
+
self.indices = indices
|
|
328
|
+
|
|
329
|
+
def __getitem__(self, idx):
|
|
330
|
+
return self.dataset[self.indices[idx]]
|
|
331
|
+
|
|
332
|
+
def __len__(self):
|
|
333
|
+
return len(self.indices)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def random_split(dataset: Dataset[T], lengths: Sequence[int],
|
|
337
|
+
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
|
|
338
|
+
r"""
|
|
339
|
+
Randomly split a dataset into non-overlapping new datasets of given lengths.
|
|
340
|
+
Optionally fix the generator for reproducible results, e.g.:
|
|
341
|
+
|
|
342
|
+
>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
dataset (Dataset): Dataset to be split
|
|
346
|
+
lengths (sequence): lengths of splits to be produced
|
|
347
|
+
generator (Generator): Generator used for the random permutation.
|
|
348
|
+
"""
|
|
349
|
+
# Cannot verify that dataset is Sized
|
|
350
|
+
if sum(lengths) != len(dataset): # type: ignore
|
|
351
|
+
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
|
352
|
+
|
|
353
|
+
indices = randperm(sum(lengths), generator=generator).tolist()
|
|
354
|
+
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import fnmatch
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import List, Union, Iterable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
|
|
8
|
+
# empty mask matches any input name
|
|
9
|
+
if not masks:
|
|
10
|
+
return True
|
|
11
|
+
|
|
12
|
+
if isinstance(masks, str):
|
|
13
|
+
return fnmatch.fnmatch(name, masks)
|
|
14
|
+
|
|
15
|
+
for mask in masks:
|
|
16
|
+
if fnmatch.fnmatch(name, mask):
|
|
17
|
+
return True
|
|
18
|
+
return False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_file_pathnames_from_root(
|
|
22
|
+
root: str,
|
|
23
|
+
masks: Union[str, List[str]],
|
|
24
|
+
recursive: bool = False,
|
|
25
|
+
abspath: bool = False) -> Iterable[str]:
|
|
26
|
+
|
|
27
|
+
# print out an error message and raise the error out
|
|
28
|
+
def onerror(err : OSError):
|
|
29
|
+
warnings.warn(err.filename + " : " + err.strerror)
|
|
30
|
+
raise err
|
|
31
|
+
|
|
32
|
+
for path, dirs, files in os.walk(root, onerror=onerror):
|
|
33
|
+
if abspath:
|
|
34
|
+
path = os.path.abspath(path)
|
|
35
|
+
for f in files:
|
|
36
|
+
if match_masks(f, masks):
|
|
37
|
+
yield os.path.join(path, f)
|
|
38
|
+
if not recursive:
|
|
39
|
+
break
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_file_binaries_from_pathnames(pathnames : Iterable):
|
|
43
|
+
|
|
44
|
+
if not isinstance(pathnames, Iterable):
|
|
45
|
+
warnings.warn("get_file_binaries_from_pathnames needs the input be an Iterable")
|
|
46
|
+
raise TypeError
|
|
47
|
+
|
|
48
|
+
for pathname in pathnames:
|
|
49
|
+
if not isinstance(pathname, str):
|
|
50
|
+
warnings.warn("file pathname must be string type, but got {}".format(type(pathname)))
|
|
51
|
+
raise TypeError
|
|
52
|
+
|
|
53
|
+
yield (pathname, open(pathname, 'rb'))
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from torch.utils.data.dataset import IterableDataset
|
|
2
|
+
from torch.utils.data.datasets.common import get_file_pathnames_from_root
|
|
3
|
+
|
|
4
|
+
from typing import List, Union, Iterator
|
|
5
|
+
|
|
6
|
+
class ListDirFilesIterableDataset(IterableDataset):
|
|
7
|
+
r""" :class:`ListDirFilesIterableDataset`
|
|
8
|
+
|
|
9
|
+
IterableDataset to load file pathname(s) (path + filename), yield pathname from given disk root dir.
|
|
10
|
+
args:
|
|
11
|
+
root : root dir
|
|
12
|
+
mask : a unix style filter string or string list for filtering file name(s)
|
|
13
|
+
abspath : whether to return relative pathname or absolute pathname
|
|
14
|
+
length : a nominal length of the dataset
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
root: str = '.',
|
|
20
|
+
masks: Union[str, List[str]] = '*.tar',
|
|
21
|
+
*,
|
|
22
|
+
abspath: bool = False,
|
|
23
|
+
length: int = -1):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.root : str = root
|
|
26
|
+
self.masks : Union[str, List[str]] = masks
|
|
27
|
+
self.abspath : bool = abspath
|
|
28
|
+
self.length : int = length
|
|
29
|
+
|
|
30
|
+
def __iter__(self) -> Iterator[str] :
|
|
31
|
+
yield from get_file_pathnames_from_root(self.root, self.masks, self.abspath)
|
|
32
|
+
|
|
33
|
+
def __len__(self):
|
|
34
|
+
if self.length == -1:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
return self.length
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from torch.utils.data.dataset import IterableDataset
|
|
2
|
+
from torch.utils.data.datasets.common import get_file_binaries_from_pathnames
|
|
3
|
+
|
|
4
|
+
from typing import Iterable, Iterator
|
|
5
|
+
|
|
6
|
+
class LoadFilesFromDiskIterableDataset(IterableDataset):
|
|
7
|
+
r""" :class:`LoadFilesFromDiskIterableDataset`.
|
|
8
|
+
|
|
9
|
+
IterableDataset to load file binary streams from given pathnames,
|
|
10
|
+
yield pathname and binary stream in a tuple.
|
|
11
|
+
args:
|
|
12
|
+
dataset: Iterable dataset that provides pathnames
|
|
13
|
+
length: a nominal length of the dataset
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
dataset : Iterable,
|
|
19
|
+
length : int = -1):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.dataset : Iterable = dataset
|
|
22
|
+
self.length : int = length
|
|
23
|
+
|
|
24
|
+
def __iter__(self) -> Iterator[tuple] :
|
|
25
|
+
yield from get_file_binaries_from_pathnames(self.dataset)
|
|
26
|
+
|
|
27
|
+
def __len__(self):
|
|
28
|
+
if self.length == -1:
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
return self.length
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import TypeVar, Optional, Iterator
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from . import Sampler, Dataset
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DistributedSampler(Sampler[T_co]):
|
|
13
|
+
r"""Sampler that restricts data loading to a subset of the dataset.
|
|
14
|
+
|
|
15
|
+
It is especially useful in conjunction with
|
|
16
|
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
|
|
17
|
+
process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
|
|
18
|
+
:class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
|
|
19
|
+
original dataset that is exclusive to it.
|
|
20
|
+
|
|
21
|
+
.. note::
|
|
22
|
+
Dataset is assumed to be of constant size.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dataset: Dataset used for sampling.
|
|
26
|
+
num_replicas (int, optional): Number of processes participating in
|
|
27
|
+
distributed training. By default, :attr:`world_size` is retrieved from the
|
|
28
|
+
current distributed group.
|
|
29
|
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
|
30
|
+
By default, :attr:`rank` is retrieved from the current distributed
|
|
31
|
+
group.
|
|
32
|
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
|
33
|
+
indices.
|
|
34
|
+
seed (int, optional): random seed used to shuffle the sampler if
|
|
35
|
+
:attr:`shuffle=True`. This number should be identical across all
|
|
36
|
+
processes in the distributed group. Default: ``0``.
|
|
37
|
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
|
38
|
+
tail of the data to make it evenly divisible across the number of
|
|
39
|
+
replicas. If ``False``, the sampler will add extra indices to make
|
|
40
|
+
the data evenly divisible across the replicas. Default: ``False``.
|
|
41
|
+
|
|
42
|
+
.. warning::
|
|
43
|
+
In distributed mode, calling the :meth:`set_epoch` method at
|
|
44
|
+
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
|
|
45
|
+
is necessary to make shuffling work properly across multiple epochs. Otherwise,
|
|
46
|
+
the same ordering will be always used.
|
|
47
|
+
|
|
48
|
+
Example::
|
|
49
|
+
|
|
50
|
+
>>> sampler = DistributedSampler(dataset) if is_distributed else None
|
|
51
|
+
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
|
|
52
|
+
... sampler=sampler)
|
|
53
|
+
>>> for epoch in range(start_epoch, n_epochs):
|
|
54
|
+
... if is_distributed:
|
|
55
|
+
... sampler.set_epoch(epoch)
|
|
56
|
+
... train(loader)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
|
|
60
|
+
rank: Optional[int] = None, shuffle: bool = True,
|
|
61
|
+
seed: int = 0, drop_last: bool = False) -> None:
|
|
62
|
+
if num_replicas is None:
|
|
63
|
+
if not dist.is_available():
|
|
64
|
+
raise RuntimeError("Requires distributed package to be available")
|
|
65
|
+
num_replicas = dist.get_world_size()
|
|
66
|
+
if rank is None:
|
|
67
|
+
if not dist.is_available():
|
|
68
|
+
raise RuntimeError("Requires distributed package to be available")
|
|
69
|
+
rank = dist.get_rank()
|
|
70
|
+
if rank >= num_replicas or rank < 0:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"Invalid rank {}, rank should be in the interval"
|
|
73
|
+
" [0, {}]".format(rank, num_replicas - 1))
|
|
74
|
+
self.dataset = dataset
|
|
75
|
+
self.num_replicas = num_replicas
|
|
76
|
+
self.rank = rank
|
|
77
|
+
self.epoch = 0
|
|
78
|
+
self.drop_last = drop_last
|
|
79
|
+
# If the dataset length is evenly divisible by # of replicas, then there
|
|
80
|
+
# is no need to drop any data, since the dataset will be split equally.
|
|
81
|
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore
|
|
82
|
+
# Split to nearest available length that is evenly divisible.
|
|
83
|
+
# This is to ensure each rank receives the same amount of data when
|
|
84
|
+
# using this Sampler.
|
|
85
|
+
self.num_samples = math.ceil(
|
|
86
|
+
# `type:ignore` is required because Dataset cannot provide a default __len__
|
|
87
|
+
# see NOTE in pytorch/torch/utils/data/sampler.py
|
|
88
|
+
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
|
|
92
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
93
|
+
self.shuffle = shuffle
|
|
94
|
+
self.seed = seed
|
|
95
|
+
|
|
96
|
+
def __iter__(self) -> Iterator[T_co]:
|
|
97
|
+
if self.shuffle:
|
|
98
|
+
# deterministically shuffle based on epoch and seed
|
|
99
|
+
g = torch.Generator()
|
|
100
|
+
g.manual_seed(self.seed + self.epoch)
|
|
101
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
|
|
102
|
+
else:
|
|
103
|
+
indices = list(range(len(self.dataset))) # type: ignore
|
|
104
|
+
|
|
105
|
+
if not self.drop_last:
|
|
106
|
+
# add extra samples to make it evenly divisible
|
|
107
|
+
padding_size = self.total_size - len(indices)
|
|
108
|
+
if padding_size <= len(indices):
|
|
109
|
+
indices += indices[:padding_size]
|
|
110
|
+
else:
|
|
111
|
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
|
112
|
+
else:
|
|
113
|
+
# remove tail of data to make it evenly divisible.
|
|
114
|
+
indices = indices[:self.total_size]
|
|
115
|
+
assert len(indices) == self.total_size
|
|
116
|
+
|
|
117
|
+
# subsample
|
|
118
|
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
|
119
|
+
assert len(indices) == self.num_samples
|
|
120
|
+
|
|
121
|
+
return iter(indices)
|
|
122
|
+
|
|
123
|
+
def __len__(self) -> int:
|
|
124
|
+
return self.num_samples
|
|
125
|
+
|
|
126
|
+
def set_epoch(self, epoch: int) -> None:
|
|
127
|
+
r"""
|
|
128
|
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
|
129
|
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
|
130
|
+
sampler will yield the same ordering.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
epoch (int): Epoch number.
|
|
134
|
+
"""
|
|
135
|
+
self.epoch = epoch
|