python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,354 @@
1
+ import bisect
2
+ import random
3
+ import warnings
4
+
5
+ from torch._utils import _accumulate
6
+ from torch import randperm
7
+ # No 'default_generator' in torch/__init__.pyi
8
+ from torch import default_generator # type: ignore
9
+ from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
10
+ from torch import Tensor, Generator
11
+
12
+ T_co = TypeVar('T_co', covariant=True)
13
+ T = TypeVar('T')
14
+
15
+
16
+ class Dataset(Generic[T_co]):
17
+ r"""An abstract class representing a :class:`Dataset`.
18
+
19
+ All datasets that represent a map from keys to data samples should subclass
20
+ it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
21
+ data sample for a given key. Subclasses could also optionally overwrite
22
+ :meth:`__len__`, which is expected to return the size of the dataset by many
23
+ :class:`~torch.utils.data.Sampler` implementations and the default options
24
+ of :class:`~torch.utils.data.DataLoader`.
25
+
26
+ .. note::
27
+ :class:`~torch.utils.data.DataLoader` by default constructs a index
28
+ sampler that yields integral indices. To make it work with a map-style
29
+ dataset with non-integral indices/keys, a custom sampler must be provided.
30
+ """
31
+
32
+ def __getitem__(self, index) -> T_co:
33
+ raise NotImplementedError
34
+
35
+ def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
36
+ return ConcatDataset([self, other])
37
+
38
+ # No `def __len__(self)` default?
39
+ # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
40
+ # in pytorch/torch/utils/data/sampler.py
41
+
42
+
43
+ class IterableDataset(Dataset[T_co]):
44
+ r"""An iterable Dataset.
45
+
46
+ All datasets that represent an iterable of data samples should subclass it.
47
+ Such form of datasets is particularly useful when data come from a stream.
48
+
49
+ All subclasses should overwrite :meth:`__iter__`, which would return an
50
+ iterator of samples in this dataset.
51
+
52
+ When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
53
+ item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
54
+ iterator. When :attr:`num_workers > 0`, each worker process will have a
55
+ different copy of the dataset object, so it is often desired to configure
56
+ each copy independently to avoid having duplicate data returned from the
57
+ workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
58
+ process, returns information about the worker. It can be used in either the
59
+ dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
60
+ :attr:`worker_init_fn` option to modify each copy's behavior.
61
+
62
+ Example 1: splitting workload across all workers in :meth:`__iter__`::
63
+
64
+ >>> class MyIterableDataset(torch.utils.data.IterableDataset):
65
+ ... def __init__(self, start, end):
66
+ ... super(MyIterableDataset).__init__()
67
+ ... assert end > start, "this example code only works with end >= start"
68
+ ... self.start = start
69
+ ... self.end = end
70
+ ...
71
+ ... def __iter__(self):
72
+ ... worker_info = torch.utils.data.get_worker_info()
73
+ ... if worker_info is None: # single-process data loading, return the full iterator
74
+ ... iter_start = self.start
75
+ ... iter_end = self.end
76
+ ... else: # in a worker process
77
+ ... # split workload
78
+ ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
79
+ ... worker_id = worker_info.id
80
+ ... iter_start = self.start + worker_id * per_worker
81
+ ... iter_end = min(iter_start + per_worker, self.end)
82
+ ... return iter(range(iter_start, iter_end))
83
+ ...
84
+ >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
85
+ >>> ds = MyIterableDataset(start=3, end=7)
86
+
87
+ >>> # Single-process loading
88
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
89
+ [3, 4, 5, 6]
90
+
91
+ >>> # Mult-process loading with two worker processes
92
+ >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
93
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
94
+ [3, 5, 4, 6]
95
+
96
+ >>> # With even more workers
97
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
98
+ [3, 4, 5, 6]
99
+
100
+ Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
101
+
102
+ >>> class MyIterableDataset(torch.utils.data.IterableDataset):
103
+ ... def __init__(self, start, end):
104
+ ... super(MyIterableDataset).__init__()
105
+ ... assert end > start, "this example code only works with end >= start"
106
+ ... self.start = start
107
+ ... self.end = end
108
+ ...
109
+ ... def __iter__(self):
110
+ ... return iter(range(self.start, self.end))
111
+ ...
112
+ >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
113
+ >>> ds = MyIterableDataset(start=3, end=7)
114
+
115
+ >>> # Single-process loading
116
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
117
+ [3, 4, 5, 6]
118
+ >>>
119
+ >>> # Directly doing multi-process loading yields duplicate data
120
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
121
+ [3, 3, 4, 4, 5, 5, 6, 6]
122
+
123
+ >>> # Define a `worker_init_fn` that configures each dataset copy differently
124
+ >>> def worker_init_fn(worker_id):
125
+ ... worker_info = torch.utils.data.get_worker_info()
126
+ ... dataset = worker_info.dataset # the dataset copy in this worker process
127
+ ... overall_start = dataset.start
128
+ ... overall_end = dataset.end
129
+ ... # configure the dataset to only process the split workload
130
+ ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
131
+ ... worker_id = worker_info.id
132
+ ... dataset.start = overall_start + worker_id * per_worker
133
+ ... dataset.end = min(dataset.start + per_worker, overall_end)
134
+ ...
135
+
136
+ >>> # Mult-process loading with the custom `worker_init_fn`
137
+ >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
138
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
139
+ [3, 5, 4, 6]
140
+
141
+ >>> # With even more workers
142
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
143
+ [3, 4, 5, 6]
144
+ """
145
+
146
+ def __iter__(self) -> Iterator[T_co]:
147
+ raise NotImplementedError
148
+
149
+ def __add__(self, other: Dataset[T_co]):
150
+ return ChainDataset([self, other])
151
+
152
+ # No `def __len__(self)` default?
153
+ # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
154
+
155
+
156
+ class TensorDataset(Dataset[Tuple[Tensor, ...]]):
157
+ r"""Dataset wrapping tensors.
158
+
159
+ Each sample will be retrieved by indexing tensors along the first dimension.
160
+
161
+ Args:
162
+ *tensors (Tensor): tensors that have the same size of the first dimension.
163
+ """
164
+ tensors: Tuple[Tensor, ...]
165
+
166
+ def __init__(self, *tensors: Tensor) -> None:
167
+ assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
168
+ self.tensors = tensors
169
+
170
+ def __getitem__(self, index):
171
+ return tuple(tensor[index] for tensor in self.tensors)
172
+
173
+ def __len__(self):
174
+ return self.tensors[0].size(0)
175
+
176
+
177
+ class ConcatDataset(Dataset[T_co]):
178
+ r"""Dataset as a concatenation of multiple datasets.
179
+
180
+ This class is useful to assemble different existing datasets.
181
+
182
+ Args:
183
+ datasets (sequence): List of datasets to be concatenated
184
+ """
185
+ datasets: List[Dataset[T_co]]
186
+ cumulative_sizes: List[int]
187
+
188
+ @staticmethod
189
+ def cumsum(sequence):
190
+ r, s = [], 0
191
+ for e in sequence:
192
+ l = len(e)
193
+ r.append(l + s)
194
+ s += l
195
+ return r
196
+
197
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
198
+ super(ConcatDataset, self).__init__()
199
+ # Cannot verify that datasets is Sized
200
+ assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
201
+ self.datasets = list(datasets)
202
+ for d in self.datasets:
203
+ assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
204
+ self.cumulative_sizes = self.cumsum(self.datasets)
205
+
206
+ def __len__(self):
207
+ return self.cumulative_sizes[-1]
208
+
209
+ def __getitem__(self, idx):
210
+ if idx < 0:
211
+ if -idx > len(self):
212
+ raise ValueError("absolute value of index should not exceed dataset length")
213
+ idx = len(self) + idx
214
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
215
+ if dataset_idx == 0:
216
+ sample_idx = idx
217
+ else:
218
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
219
+ return self.datasets[dataset_idx][sample_idx]
220
+
221
+ @property
222
+ def cummulative_sizes(self):
223
+ warnings.warn("cummulative_sizes attribute is renamed to "
224
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
225
+ return self.cumulative_sizes
226
+
227
+
228
+ class ChainDataset(IterableDataset):
229
+ r"""Dataset for chainning multiple :class:`IterableDataset` s.
230
+
231
+ This class is useful to assemble different existing dataset streams. The
232
+ chainning operation is done on-the-fly, so concatenating large-scale
233
+ datasets with this class will be efficient.
234
+
235
+ Args:
236
+ datasets (iterable of IterableDataset): datasets to be chained together
237
+ """
238
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
239
+ super(ChainDataset, self).__init__()
240
+ self.datasets = datasets
241
+
242
+ def __iter__(self):
243
+ for d in self.datasets:
244
+ assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
245
+ for x in d:
246
+ yield x
247
+
248
+ def __len__(self):
249
+ total = 0
250
+ for d in self.datasets:
251
+ assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
252
+ # Cannot verify that all self.datasets are Sized
253
+ total += len(d) # type: ignore
254
+ return total
255
+
256
+
257
+ class BufferedShuffleDataset(IterableDataset[T_co]):
258
+ r"""Dataset shuffled from the original dataset.
259
+
260
+ This class is useful to shuffle an existing instance of an IterableDataset.
261
+ The buffer with `buffer_size` is filled with the items from the dataset first. Then,
262
+ each item will be yielded from the buffer by reservoir sampling via iterator.
263
+
264
+ `buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
265
+ dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
266
+ is required to be greater than or equal to the size of dataset.
267
+
268
+ When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
269
+ dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
270
+ And, the method to set up a random seed is different based on :attr:`num_workers`.
271
+
272
+ For single-process mode (:attr:`num_workers == 0`), the random seed is required to
273
+ be set before the :class:`~torch.utils.data.DataLoader` in the main process.
274
+
275
+ >>> ds = BufferedShuffleDataset(dataset)
276
+ >>> random.seed(...)
277
+ >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
278
+
279
+ For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
280
+ function in each worker.
281
+
282
+ >>> ds = BufferedShuffleDataset(dataset)
283
+ >>> def init_fn(worker_id):
284
+ ... random.seed(...)
285
+ >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
286
+
287
+ Args:
288
+ dataset (IterableDataset): The original IterableDataset.
289
+ buffer_size (int): The buffer size for shuffling.
290
+ """
291
+ dataset: IterableDataset[T_co]
292
+ buffer_size: int
293
+
294
+ def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None:
295
+ super(BufferedShuffleDataset, self).__init__()
296
+ assert buffer_size > 0, "buffer_size should be larger than 0"
297
+ self.dataset = dataset
298
+ self.buffer_size = buffer_size
299
+
300
+ def __iter__(self) -> Iterator[T_co]:
301
+ buf: List[T_co] = []
302
+ for x in self.dataset:
303
+ if len(buf) == self.buffer_size:
304
+ idx = random.randint(0, self.buffer_size - 1)
305
+ yield buf[idx]
306
+ buf[idx] = x
307
+ else:
308
+ buf.append(x)
309
+ random.shuffle(buf)
310
+ while buf:
311
+ yield buf.pop()
312
+
313
+
314
+ class Subset(Dataset[T_co]):
315
+ r"""
316
+ Subset of a dataset at specified indices.
317
+
318
+ Args:
319
+ dataset (Dataset): The whole Dataset
320
+ indices (sequence): Indices in the whole set selected for subset
321
+ """
322
+ dataset: Dataset[T_co]
323
+ indices: Sequence[int]
324
+
325
+ def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
326
+ self.dataset = dataset
327
+ self.indices = indices
328
+
329
+ def __getitem__(self, idx):
330
+ return self.dataset[self.indices[idx]]
331
+
332
+ def __len__(self):
333
+ return len(self.indices)
334
+
335
+
336
+ def random_split(dataset: Dataset[T], lengths: Sequence[int],
337
+ generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
338
+ r"""
339
+ Randomly split a dataset into non-overlapping new datasets of given lengths.
340
+ Optionally fix the generator for reproducible results, e.g.:
341
+
342
+ >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
343
+
344
+ Args:
345
+ dataset (Dataset): Dataset to be split
346
+ lengths (sequence): lengths of splits to be produced
347
+ generator (Generator): Generator used for the random permutation.
348
+ """
349
+ # Cannot verify that dataset is Sized
350
+ if sum(lengths) != len(dataset): # type: ignore
351
+ raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
352
+
353
+ indices = randperm(sum(lengths), generator=generator).tolist()
354
+ return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
@@ -0,0 +1,4 @@
1
+ from .listdirfilesdataset import ListDirFilesIterableDataset
2
+ from .loadfilesfromdiskdataset import LoadFilesFromDiskIterableDataset
3
+
4
+ __all__ = ['ListDirFilesIterableDataset', 'LoadFilesFromDiskIterableDataset']
@@ -0,0 +1,53 @@
1
+ import os
2
+ import fnmatch
3
+ import warnings
4
+ from typing import List, Union, Iterable
5
+
6
+
7
+ def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
8
+ # empty mask matches any input name
9
+ if not masks:
10
+ return True
11
+
12
+ if isinstance(masks, str):
13
+ return fnmatch.fnmatch(name, masks)
14
+
15
+ for mask in masks:
16
+ if fnmatch.fnmatch(name, mask):
17
+ return True
18
+ return False
19
+
20
+
21
+ def get_file_pathnames_from_root(
22
+ root: str,
23
+ masks: Union[str, List[str]],
24
+ recursive: bool = False,
25
+ abspath: bool = False) -> Iterable[str]:
26
+
27
+ # print out an error message and raise the error out
28
+ def onerror(err : OSError):
29
+ warnings.warn(err.filename + " : " + err.strerror)
30
+ raise err
31
+
32
+ for path, dirs, files in os.walk(root, onerror=onerror):
33
+ if abspath:
34
+ path = os.path.abspath(path)
35
+ for f in files:
36
+ if match_masks(f, masks):
37
+ yield os.path.join(path, f)
38
+ if not recursive:
39
+ break
40
+
41
+
42
+ def get_file_binaries_from_pathnames(pathnames : Iterable):
43
+
44
+ if not isinstance(pathnames, Iterable):
45
+ warnings.warn("get_file_binaries_from_pathnames needs the input be an Iterable")
46
+ raise TypeError
47
+
48
+ for pathname in pathnames:
49
+ if not isinstance(pathname, str):
50
+ warnings.warn("file pathname must be string type, but got {}".format(type(pathname)))
51
+ raise TypeError
52
+
53
+ yield (pathname, open(pathname, 'rb'))
@@ -0,0 +1,36 @@
1
+ from torch.utils.data.dataset import IterableDataset
2
+ from torch.utils.data.datasets.common import get_file_pathnames_from_root
3
+
4
+ from typing import List, Union, Iterator
5
+
6
+ class ListDirFilesIterableDataset(IterableDataset):
7
+ r""" :class:`ListDirFilesIterableDataset`
8
+
9
+ IterableDataset to load file pathname(s) (path + filename), yield pathname from given disk root dir.
10
+ args:
11
+ root : root dir
12
+ mask : a unix style filter string or string list for filtering file name(s)
13
+ abspath : whether to return relative pathname or absolute pathname
14
+ length : a nominal length of the dataset
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ root: str = '.',
20
+ masks: Union[str, List[str]] = '*.tar',
21
+ *,
22
+ abspath: bool = False,
23
+ length: int = -1):
24
+ super().__init__()
25
+ self.root : str = root
26
+ self.masks : Union[str, List[str]] = masks
27
+ self.abspath : bool = abspath
28
+ self.length : int = length
29
+
30
+ def __iter__(self) -> Iterator[str] :
31
+ yield from get_file_pathnames_from_root(self.root, self.masks, self.abspath)
32
+
33
+ def __len__(self):
34
+ if self.length == -1:
35
+ raise NotImplementedError
36
+ return self.length
@@ -0,0 +1,30 @@
1
+ from torch.utils.data.dataset import IterableDataset
2
+ from torch.utils.data.datasets.common import get_file_binaries_from_pathnames
3
+
4
+ from typing import Iterable, Iterator
5
+
6
+ class LoadFilesFromDiskIterableDataset(IterableDataset):
7
+ r""" :class:`LoadFilesFromDiskIterableDataset`.
8
+
9
+ IterableDataset to load file binary streams from given pathnames,
10
+ yield pathname and binary stream in a tuple.
11
+ args:
12
+ dataset: Iterable dataset that provides pathnames
13
+ length: a nominal length of the dataset
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ dataset : Iterable,
19
+ length : int = -1):
20
+ super().__init__()
21
+ self.dataset : Iterable = dataset
22
+ self.length : int = length
23
+
24
+ def __iter__(self) -> Iterator[tuple] :
25
+ yield from get_file_binaries_from_pathnames(self.dataset)
26
+
27
+ def __len__(self):
28
+ if self.length == -1:
29
+ raise NotImplementedError
30
+ return self.length
@@ -0,0 +1,135 @@
1
+ import math
2
+ from typing import TypeVar, Optional, Iterator
3
+
4
+ import torch
5
+ from . import Sampler, Dataset
6
+ import torch.distributed as dist
7
+
8
+
9
+ T_co = TypeVar('T_co', covariant=True)
10
+
11
+
12
+ class DistributedSampler(Sampler[T_co]):
13
+ r"""Sampler that restricts data loading to a subset of the dataset.
14
+
15
+ It is especially useful in conjunction with
16
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
17
+ process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
18
+ :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
19
+ original dataset that is exclusive to it.
20
+
21
+ .. note::
22
+ Dataset is assumed to be of constant size.
23
+
24
+ Args:
25
+ dataset: Dataset used for sampling.
26
+ num_replicas (int, optional): Number of processes participating in
27
+ distributed training. By default, :attr:`world_size` is retrieved from the
28
+ current distributed group.
29
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
30
+ By default, :attr:`rank` is retrieved from the current distributed
31
+ group.
32
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
33
+ indices.
34
+ seed (int, optional): random seed used to shuffle the sampler if
35
+ :attr:`shuffle=True`. This number should be identical across all
36
+ processes in the distributed group. Default: ``0``.
37
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
38
+ tail of the data to make it evenly divisible across the number of
39
+ replicas. If ``False``, the sampler will add extra indices to make
40
+ the data evenly divisible across the replicas. Default: ``False``.
41
+
42
+ .. warning::
43
+ In distributed mode, calling the :meth:`set_epoch` method at
44
+ the beginning of each epoch **before** creating the :class:`DataLoader` iterator
45
+ is necessary to make shuffling work properly across multiple epochs. Otherwise,
46
+ the same ordering will be always used.
47
+
48
+ Example::
49
+
50
+ >>> sampler = DistributedSampler(dataset) if is_distributed else None
51
+ >>> loader = DataLoader(dataset, shuffle=(sampler is None),
52
+ ... sampler=sampler)
53
+ >>> for epoch in range(start_epoch, n_epochs):
54
+ ... if is_distributed:
55
+ ... sampler.set_epoch(epoch)
56
+ ... train(loader)
57
+ """
58
+
59
+ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
60
+ rank: Optional[int] = None, shuffle: bool = True,
61
+ seed: int = 0, drop_last: bool = False) -> None:
62
+ if num_replicas is None:
63
+ if not dist.is_available():
64
+ raise RuntimeError("Requires distributed package to be available")
65
+ num_replicas = dist.get_world_size()
66
+ if rank is None:
67
+ if not dist.is_available():
68
+ raise RuntimeError("Requires distributed package to be available")
69
+ rank = dist.get_rank()
70
+ if rank >= num_replicas or rank < 0:
71
+ raise ValueError(
72
+ "Invalid rank {}, rank should be in the interval"
73
+ " [0, {}]".format(rank, num_replicas - 1))
74
+ self.dataset = dataset
75
+ self.num_replicas = num_replicas
76
+ self.rank = rank
77
+ self.epoch = 0
78
+ self.drop_last = drop_last
79
+ # If the dataset length is evenly divisible by # of replicas, then there
80
+ # is no need to drop any data, since the dataset will be split equally.
81
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore
82
+ # Split to nearest available length that is evenly divisible.
83
+ # This is to ensure each rank receives the same amount of data when
84
+ # using this Sampler.
85
+ self.num_samples = math.ceil(
86
+ # `type:ignore` is required because Dataset cannot provide a default __len__
87
+ # see NOTE in pytorch/torch/utils/data/sampler.py
88
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore
89
+ )
90
+ else:
91
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
92
+ self.total_size = self.num_samples * self.num_replicas
93
+ self.shuffle = shuffle
94
+ self.seed = seed
95
+
96
+ def __iter__(self) -> Iterator[T_co]:
97
+ if self.shuffle:
98
+ # deterministically shuffle based on epoch and seed
99
+ g = torch.Generator()
100
+ g.manual_seed(self.seed + self.epoch)
101
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
102
+ else:
103
+ indices = list(range(len(self.dataset))) # type: ignore
104
+
105
+ if not self.drop_last:
106
+ # add extra samples to make it evenly divisible
107
+ padding_size = self.total_size - len(indices)
108
+ if padding_size <= len(indices):
109
+ indices += indices[:padding_size]
110
+ else:
111
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
112
+ else:
113
+ # remove tail of data to make it evenly divisible.
114
+ indices = indices[:self.total_size]
115
+ assert len(indices) == self.total_size
116
+
117
+ # subsample
118
+ indices = indices[self.rank:self.total_size:self.num_replicas]
119
+ assert len(indices) == self.num_samples
120
+
121
+ return iter(indices)
122
+
123
+ def __len__(self) -> int:
124
+ return self.num_samples
125
+
126
+ def set_epoch(self, epoch: int) -> None:
127
+ r"""
128
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
129
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
130
+ sampler will yield the same ordering.
131
+
132
+ Args:
133
+ epoch (int): Epoch number.
134
+ """
135
+ self.epoch = epoch