python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from torch.utils.data import IterDataPipe
|
|
3
|
+
from typing import TypeVar, Optional, Iterator, List, Sized, Callable
|
|
4
|
+
|
|
5
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BatchIterDataPipe(IterDataPipe[List[T_co]]):
|
|
9
|
+
r""" :class:`BatchIterDataPipe`.
|
|
10
|
+
|
|
11
|
+
Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
|
|
12
|
+
`batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the
|
|
13
|
+
last batch if `drop_last` is set to `False`.
|
|
14
|
+
args:
|
|
15
|
+
datapipe: Iterable DataPipe being batched
|
|
16
|
+
batch_size: The size of each batch
|
|
17
|
+
drop_last: Option to drop the last batch if it's not full
|
|
18
|
+
"""
|
|
19
|
+
datapipe: IterDataPipe[T_co]
|
|
20
|
+
batch_size: int
|
|
21
|
+
drop_last: bool
|
|
22
|
+
length: Optional[int]
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
datapipe: IterDataPipe[T_co],
|
|
26
|
+
*,
|
|
27
|
+
batch_size: int,
|
|
28
|
+
drop_last: bool = False,
|
|
29
|
+
) -> None:
|
|
30
|
+
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.datapipe = datapipe
|
|
33
|
+
self.batch_size = batch_size
|
|
34
|
+
self.drop_last = drop_last
|
|
35
|
+
self.length = None
|
|
36
|
+
|
|
37
|
+
def __iter__(self) -> Iterator[List[T_co]]:
|
|
38
|
+
batch: List[T_co] = []
|
|
39
|
+
for x in self.datapipe:
|
|
40
|
+
batch.append(x)
|
|
41
|
+
if len(batch) == self.batch_size:
|
|
42
|
+
yield batch
|
|
43
|
+
batch.clear()
|
|
44
|
+
if len(batch) > 0:
|
|
45
|
+
if not self.drop_last:
|
|
46
|
+
yield batch
|
|
47
|
+
batch.clear()
|
|
48
|
+
|
|
49
|
+
def __len__(self) -> int:
|
|
50
|
+
if self.length is not None:
|
|
51
|
+
return self.length
|
|
52
|
+
if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
|
|
53
|
+
if self.drop_last:
|
|
54
|
+
self.length = len(self.datapipe) // self.batch_size
|
|
55
|
+
else:
|
|
56
|
+
self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
|
|
57
|
+
return self.length
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
|
|
62
|
+
r""" :class:`BucketBatchIterDataPipe`.
|
|
63
|
+
|
|
64
|
+
Iterable DataPipe to create mini-batches of data from sorted bucket. An outer
|
|
65
|
+
dimension will be added as `batch_size` if `drop_last` is set to `True`,
|
|
66
|
+
or `length % batch_size` for the last batch if `drop_last` is set to `False`.
|
|
67
|
+
args:
|
|
68
|
+
datapipe: Iterable DataPipe being batched
|
|
69
|
+
batch_size: The size of each batch
|
|
70
|
+
drop_last: Option to drop the last batch if it's not full
|
|
71
|
+
bucket_size_mul: The multiplier to specify the size of bucket
|
|
72
|
+
sort_key: Callable to specify the comparison key for sorting within bucket
|
|
73
|
+
"""
|
|
74
|
+
datapipe: IterDataPipe[T_co]
|
|
75
|
+
batch_size: int
|
|
76
|
+
drop_last: bool
|
|
77
|
+
bucket_size_mul: int
|
|
78
|
+
sort_key: Optional[Callable]
|
|
79
|
+
length: Optional[int]
|
|
80
|
+
|
|
81
|
+
def __init__(self,
|
|
82
|
+
datapipe: IterDataPipe[T_co],
|
|
83
|
+
*,
|
|
84
|
+
batch_size: int,
|
|
85
|
+
drop_last: bool = False,
|
|
86
|
+
bucket_size_mul: int = 100,
|
|
87
|
+
sort_key: Optional[Callable] = None,
|
|
88
|
+
) -> None:
|
|
89
|
+
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.datapipe = datapipe
|
|
92
|
+
self.batch_size = batch_size
|
|
93
|
+
self.drop_last = drop_last
|
|
94
|
+
self.bucket_size = batch_size * bucket_size_mul
|
|
95
|
+
self.sort_key = sort_key
|
|
96
|
+
if sort_key is not None and sort_key.__name__ == '<lambda>':
|
|
97
|
+
warnings.warn("Lambda function is not supported for pickle, "
|
|
98
|
+
"please use regular python function instead.")
|
|
99
|
+
self.bucket_ds = BatchIterDataPipe(datapipe, batch_size=self.bucket_size, drop_last=False)
|
|
100
|
+
self.length = None
|
|
101
|
+
|
|
102
|
+
def __iter__(self) -> Iterator[List[T_co]]:
|
|
103
|
+
# Bucket without sorting remains same order, directly returns BatchDataset
|
|
104
|
+
if self.sort_key is None:
|
|
105
|
+
yield from BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last)
|
|
106
|
+
else:
|
|
107
|
+
bucket: List[T_co]
|
|
108
|
+
batch: List[T_co] = []
|
|
109
|
+
for bucket in self.bucket_ds:
|
|
110
|
+
# In-place sort within bucket
|
|
111
|
+
bucket.sort(key=self.sort_key)
|
|
112
|
+
for start in range(0, len(bucket), self.batch_size):
|
|
113
|
+
batch = bucket[start: start + self.batch_size]
|
|
114
|
+
if len(batch) == self.batch_size or not self.drop_last:
|
|
115
|
+
yield batch
|
|
116
|
+
|
|
117
|
+
def __len__(self) -> int:
|
|
118
|
+
if self.length is not None:
|
|
119
|
+
return self.length
|
|
120
|
+
if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
|
|
121
|
+
if self.drop_last:
|
|
122
|
+
self.length = len(self.datapipe) // self.batch_size
|
|
123
|
+
else:
|
|
124
|
+
self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
|
|
125
|
+
return self.length
|
|
126
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from torch.utils.data import IterDataPipe, _utils
|
|
3
|
+
from typing import TypeVar, Callable, Iterator, Sized
|
|
4
|
+
|
|
5
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Default function to return each item directly
|
|
9
|
+
# In order to keep datapipe picklable, eliminates the usage
|
|
10
|
+
# of python lambda function
|
|
11
|
+
def default_fn(data):
|
|
12
|
+
return data
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CallableIterDataPipe(IterDataPipe[T_co]):
|
|
16
|
+
r""" :class:`CallableIterDataPipe`.
|
|
17
|
+
|
|
18
|
+
Iterable DataPipe to run a function over each item from the source DataPipe.
|
|
19
|
+
args:
|
|
20
|
+
datapipe: Source Iterable DataPipe
|
|
21
|
+
fn: Function called over each item
|
|
22
|
+
"""
|
|
23
|
+
datapipe: IterDataPipe
|
|
24
|
+
fn: Callable
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
datapipe: IterDataPipe,
|
|
28
|
+
*args,
|
|
29
|
+
fn: Callable = default_fn,
|
|
30
|
+
**kwargs,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.datapipe = datapipe
|
|
34
|
+
if fn.__name__ == '<lambda>':
|
|
35
|
+
warnings.warn("Lambda function is not supported for pickle, "
|
|
36
|
+
"please use regular python function instead.")
|
|
37
|
+
self.fn = fn # type: ignore
|
|
38
|
+
self.args = args
|
|
39
|
+
self.kwargs = kwargs
|
|
40
|
+
|
|
41
|
+
def __iter__(self) -> Iterator[T_co]:
|
|
42
|
+
for data in self.datapipe:
|
|
43
|
+
yield self.fn(data, *self.args, **self.kwargs)
|
|
44
|
+
|
|
45
|
+
def __len__(self) -> int:
|
|
46
|
+
if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
|
|
47
|
+
return len(self.datapipe)
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CollateIterDataPipe(CallableIterDataPipe):
|
|
52
|
+
r""" :class:`CollateIterDataPipe`.
|
|
53
|
+
|
|
54
|
+
Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
|
|
55
|
+
or customized Data Structure by collate_fn.
|
|
56
|
+
args:
|
|
57
|
+
datapipe: Iterable DataPipe being collated
|
|
58
|
+
collate_fn: Customized collate function to collect and combine data or a batch of data.
|
|
59
|
+
Default function collates to Tensor(s) based on data type.
|
|
60
|
+
|
|
61
|
+
Example: Convert integer data to float Tensor
|
|
62
|
+
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
|
63
|
+
... def __init__(self, start, end):
|
|
64
|
+
... super(MyIterDataPipe).__init__()
|
|
65
|
+
... assert end > start, "this example code only works with end >= start"
|
|
66
|
+
... self.start = start
|
|
67
|
+
... self.end = end
|
|
68
|
+
...
|
|
69
|
+
... def __iter__(self):
|
|
70
|
+
... return iter(range(self.start, self.end))
|
|
71
|
+
...
|
|
72
|
+
... def __len__(self):
|
|
73
|
+
... return self.end - self.start
|
|
74
|
+
...
|
|
75
|
+
>>> ds = MyIterDataPipe(start=3, end=7)
|
|
76
|
+
>>> print(list(ds))
|
|
77
|
+
[3, 4, 5, 6]
|
|
78
|
+
|
|
79
|
+
>>> def collate_fn(batch):
|
|
80
|
+
... return torch.tensor(batch, dtype=torch.float)
|
|
81
|
+
...
|
|
82
|
+
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
|
83
|
+
>>> print(list(collated_ds))
|
|
84
|
+
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
|
85
|
+
"""
|
|
86
|
+
def __init__(self,
|
|
87
|
+
datapipe: IterDataPipe,
|
|
88
|
+
*args,
|
|
89
|
+
collate_fn: Callable = _utils.collate.default_collate,
|
|
90
|
+
**kwargs,
|
|
91
|
+
) -> None:
|
|
92
|
+
super().__init__(datapipe, *args, fn=collate_fn, **kwargs)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from torch.utils.data import IterDataPipe
|
|
2
|
+
from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root
|
|
3
|
+
from typing import List, Union, Iterator
|
|
4
|
+
|
|
5
|
+
class ListDirFilesIterDataPipe(IterDataPipe):
|
|
6
|
+
r""" :class:`ListDirFilesIterDataPipe`
|
|
7
|
+
|
|
8
|
+
Iterable DataPipe to load file pathname(s) (path + filename), yield pathname from given disk root dir.
|
|
9
|
+
args:
|
|
10
|
+
root : root dir
|
|
11
|
+
mask : a unix style filter string or string list for filtering file name(s)
|
|
12
|
+
abspath : whether to return relative pathname or absolute pathname
|
|
13
|
+
length : a nominal length of the datapipe
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
root: str = '.',
|
|
19
|
+
masks: Union[str, List[str]] = '',
|
|
20
|
+
*,
|
|
21
|
+
recursive: bool = False,
|
|
22
|
+
abspath: bool = False,
|
|
23
|
+
length: int = -1):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.root : str = root
|
|
26
|
+
self.masks : Union[str, List[str]] = masks
|
|
27
|
+
self.recursive : bool = recursive
|
|
28
|
+
self.abspath : bool = abspath
|
|
29
|
+
self.length : int = length
|
|
30
|
+
|
|
31
|
+
def __iter__(self) -> Iterator[str] :
|
|
32
|
+
yield from get_file_pathnames_from_root(self.root, self.masks, self.recursive, self.abspath)
|
|
33
|
+
|
|
34
|
+
def __len__(self):
|
|
35
|
+
if self.length == -1:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
return self.length
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from torch.utils.data import IterDataPipe
|
|
2
|
+
from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames
|
|
3
|
+
from typing import Iterable, Iterator, Tuple
|
|
4
|
+
from io import BufferedIOBase
|
|
5
|
+
|
|
6
|
+
class LoadFilesFromDiskIterDataPipe(IterDataPipe):
|
|
7
|
+
r""" :class:`LoadFilesFromDiskIterDataPipe`.
|
|
8
|
+
|
|
9
|
+
Iterable Datapipe to load file binary streams from given pathnames,
|
|
10
|
+
yield pathname and binary stream in a tuple.
|
|
11
|
+
args:
|
|
12
|
+
datapipe: Iterable datapipe that provides pathnames
|
|
13
|
+
length: a nominal length of the datapipe
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
datapipe : Iterable[str],
|
|
19
|
+
length : int = -1):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.datapipe : Iterable = datapipe
|
|
22
|
+
self.length : int = length
|
|
23
|
+
|
|
24
|
+
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]] :
|
|
25
|
+
yield from get_file_binaries_from_pathnames(self.datapipe)
|
|
26
|
+
|
|
27
|
+
def __len__(self):
|
|
28
|
+
if self.length == -1:
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
return self.length
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from torch.utils.data import IterDataPipe
|
|
2
|
+
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
|
|
3
|
+
from typing import Iterable, Iterator, Tuple, Optional, IO, cast
|
|
4
|
+
from io import BufferedIOBase
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import tarfile
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
class ReadFilesFromTarIterDataPipe(IterDataPipe):
|
|
11
|
+
r""" :class:`ReadFilesFromTarIDP`.
|
|
12
|
+
|
|
13
|
+
Iterable datapipe to extract tar binary streams from input iterable which contains tuples of
|
|
14
|
+
pathname and tar binary stream, yields pathname and extracted binary stream in a tuple.
|
|
15
|
+
args:
|
|
16
|
+
datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples
|
|
17
|
+
length: a nominal length of the datapipe
|
|
18
|
+
"""
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
datapipe : Iterable[Tuple[str, BufferedIOBase]],
|
|
22
|
+
length : int = -1):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
|
|
25
|
+
self.length : int = length
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
|
|
29
|
+
if not isinstance(self.datapipe, Iterable):
|
|
30
|
+
raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
|
|
31
|
+
for data in self.datapipe:
|
|
32
|
+
validate_pathname_binary_tuple(data)
|
|
33
|
+
pathname, data_stream = data
|
|
34
|
+
try:
|
|
35
|
+
# typing.cast is used here to silence mypy's type checker
|
|
36
|
+
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode="r:*")
|
|
37
|
+
for tarinfo in tar:
|
|
38
|
+
if not tarinfo.isfile():
|
|
39
|
+
continue
|
|
40
|
+
extracted_fobj = tar.extractfile(tarinfo)
|
|
41
|
+
if extracted_fobj is None:
|
|
42
|
+
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
|
|
43
|
+
raise tarfile.ExtractError
|
|
44
|
+
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
|
|
45
|
+
# Add a reference of the source tarfile into extracted_fobj, so the source
|
|
46
|
+
# tarfile handle won't be released until all the extracted file objs are destroyed.
|
|
47
|
+
# Add `# type: ignore` to silence mypy's type checker
|
|
48
|
+
extracted_fobj.source_tarfile_ref = tar # type: ignore
|
|
49
|
+
# typing.cast is used here to silence mypy's type checker
|
|
50
|
+
yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
|
|
51
|
+
except Exception as e:
|
|
52
|
+
warnings.warn(
|
|
53
|
+
"Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
|
|
54
|
+
raise e
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def __len__(self):
|
|
58
|
+
if self.length == -1:
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
return self.length
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from torch.utils.data import IterDataPipe
|
|
2
|
+
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
|
|
3
|
+
from typing import Iterable, Iterator, Tuple, IO, cast
|
|
4
|
+
from io import BufferedIOBase
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import zipfile
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
class ReadFilesFromZipIterDataPipe(IterDataPipe):
|
|
12
|
+
r""" :class:`ReadFilesFromZipIterDataPipe`.
|
|
13
|
+
|
|
14
|
+
Iterable data pipe to extract zip binary streams from input iterable which contains tuples of
|
|
15
|
+
pathname and zip binary stream, yields pathname and extracted binary stream in a tuple.
|
|
16
|
+
args:
|
|
17
|
+
datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples
|
|
18
|
+
length: a nominal length of the datapipe
|
|
19
|
+
"""
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
datapipe : Iterable[Tuple[str, BufferedIOBase]],
|
|
23
|
+
length : int = -1):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
|
|
26
|
+
self.length : int = length
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
|
|
30
|
+
if not isinstance(self.datapipe, Iterable):
|
|
31
|
+
raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
|
|
32
|
+
for data in self.datapipe:
|
|
33
|
+
validate_pathname_binary_tuple(data)
|
|
34
|
+
pathname, data_stream = data
|
|
35
|
+
try:
|
|
36
|
+
# typing.cast is used here to silence mypy's type checker
|
|
37
|
+
zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
|
|
38
|
+
for zipinfo in zips.infolist():
|
|
39
|
+
# major version should always be 3 here.
|
|
40
|
+
if sys.version_info[1] >= 6:
|
|
41
|
+
if zipinfo.is_dir():
|
|
42
|
+
continue
|
|
43
|
+
elif zipinfo.filename.endswith('/'):
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
extracted_fobj = zips.open(zipinfo)
|
|
47
|
+
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
|
|
48
|
+
# Add a reference of the source zipfile into extracted_fobj, so the source
|
|
49
|
+
# zipfile handle won't be released until all the extracted file objs are destroyed.
|
|
50
|
+
# Add `# type: ignore` to silence mypy's type checker
|
|
51
|
+
extracted_fobj.source_zipfile_ref = zips # type: ignore
|
|
52
|
+
# typing.cast is used here to silence mypy's type checker
|
|
53
|
+
yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
|
|
54
|
+
except Exception as e:
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e))
|
|
57
|
+
raise e
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def __len__(self):
|
|
61
|
+
if self.length == -1:
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
return self.length
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from torch.utils.data import IterDataPipe, Sampler, SequentialSampler
|
|
2
|
+
from typing import TypeVar, Type, Iterator, Sized,Optional
|
|
3
|
+
import itertools
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SamplerIterDataPipe(IterDataPipe[T_co]):
|
|
10
|
+
r""" :class:`SamplerIterDataPipe`.
|
|
11
|
+
|
|
12
|
+
Iterable DataPipe to generate sample elements.
|
|
13
|
+
args:
|
|
14
|
+
datapipe: IterDataPipe sampled from
|
|
15
|
+
sampler: Sampler class to genereate sample elements from input DataPipe.
|
|
16
|
+
Default is :class:`SequentialSampler` for IterDataPipe
|
|
17
|
+
"""
|
|
18
|
+
datapipe: IterDataPipe
|
|
19
|
+
sampler: Sampler
|
|
20
|
+
|
|
21
|
+
def __init__(self,
|
|
22
|
+
datapipe: IterDataPipe,
|
|
23
|
+
*,
|
|
24
|
+
sampler: Type[Sampler] = SequentialSampler,
|
|
25
|
+
**kwargs
|
|
26
|
+
) -> None:
|
|
27
|
+
assert isinstance(datapipe, Sized), \
|
|
28
|
+
"Sampler class requires input datapipe implemented `__len__`"
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.datapipe = datapipe
|
|
31
|
+
# https://github.com/python/mypy/pull/9629 will solve
|
|
32
|
+
self.sampler = sampler(data_source=self.datapipe, **kwargs) # type: ignore
|
|
33
|
+
|
|
34
|
+
def __iter__(self) -> Iterator[T_co]:
|
|
35
|
+
return iter(self.sampler)
|
|
36
|
+
|
|
37
|
+
def __len__(self) -> int:
|
|
38
|
+
# Dataset has been tested as `Sized`
|
|
39
|
+
if isinstance(self.sampler, Sized) and len(self.sampler) >= 0:
|
|
40
|
+
return len(self.sampler)
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
class InfiniteSampler(Sampler):
|
|
44
|
+
"""
|
|
45
|
+
In training, we only care about the "infinite stream" of training data.
|
|
46
|
+
So this sampler produces an infinite stream of indices and
|
|
47
|
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
|
48
|
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
|
49
|
+
where `indices` is an infinite stream of indices consisting of
|
|
50
|
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
|
51
|
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
size: int,
|
|
57
|
+
shuffle: bool = True,
|
|
58
|
+
seed: Optional[int] = 0,
|
|
59
|
+
rank=0,
|
|
60
|
+
world_size=1,
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Args:
|
|
64
|
+
size (int): the total number of data of the underlying dataset to sample from
|
|
65
|
+
shuffle (bool): whether to shuffle the indices or not
|
|
66
|
+
seed (int): the initial seed of the shuffle. Must be the same
|
|
67
|
+
across all workers. If None, will use a random seed shared
|
|
68
|
+
among workers (require synchronization among all workers).
|
|
69
|
+
"""
|
|
70
|
+
self._size = size
|
|
71
|
+
assert size > 0
|
|
72
|
+
self._shuffle = shuffle
|
|
73
|
+
self._seed = int(seed)
|
|
74
|
+
|
|
75
|
+
self._rank = rank
|
|
76
|
+
self._world_size = world_size
|
|
77
|
+
|
|
78
|
+
def __iter__(self):
|
|
79
|
+
start = self._rank
|
|
80
|
+
yield from itertools.islice(
|
|
81
|
+
self._infinite_indices(), start, None, self._world_size
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _infinite_indices(self):
|
|
85
|
+
g = torch.Generator()
|
|
86
|
+
g.manual_seed(self._seed)
|
|
87
|
+
while True:
|
|
88
|
+
if self._shuffle:
|
|
89
|
+
yield from torch.randperm(self._size, generator=g)
|
|
90
|
+
else:
|
|
91
|
+
yield from torch.arange(self._size)
|
|
92
|
+
|
|
93
|
+
def __len__(self):
|
|
94
|
+
return self._size // self._world_size
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import fnmatch
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import List, Union, Iterable
|
|
5
|
+
from io import BufferedIOBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
|
|
9
|
+
# empty mask matches any input name
|
|
10
|
+
if not masks:
|
|
11
|
+
return True
|
|
12
|
+
|
|
13
|
+
if isinstance(masks, str):
|
|
14
|
+
return fnmatch.fnmatch(name, masks)
|
|
15
|
+
|
|
16
|
+
for mask in masks:
|
|
17
|
+
if fnmatch.fnmatch(name, mask):
|
|
18
|
+
return True
|
|
19
|
+
return False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_file_pathnames_from_root(
|
|
23
|
+
root: str,
|
|
24
|
+
masks: Union[str, List[str]],
|
|
25
|
+
recursive: bool = False,
|
|
26
|
+
abspath: bool = False) -> Iterable[str]:
|
|
27
|
+
|
|
28
|
+
# print out an error message and raise the error out
|
|
29
|
+
def onerror(err : OSError):
|
|
30
|
+
warnings.warn(err.filename + " : " + err.strerror)
|
|
31
|
+
raise err
|
|
32
|
+
|
|
33
|
+
for path, dirs, files in os.walk(root, onerror=onerror):
|
|
34
|
+
if abspath:
|
|
35
|
+
path = os.path.abspath(path)
|
|
36
|
+
for f in files:
|
|
37
|
+
if match_masks(f, masks):
|
|
38
|
+
yield os.path.join(path, f)
|
|
39
|
+
if not recursive:
|
|
40
|
+
break
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_file_binaries_from_pathnames(pathnames : Iterable):
|
|
44
|
+
|
|
45
|
+
if not isinstance(pathnames, Iterable):
|
|
46
|
+
warnings.warn("get_file_binaries_from_pathnames needs the input be an Iterable")
|
|
47
|
+
raise TypeError
|
|
48
|
+
|
|
49
|
+
for pathname in pathnames:
|
|
50
|
+
if not isinstance(pathname, str):
|
|
51
|
+
warnings.warn("file pathname must be string type, but got {}".format(type(pathname)))
|
|
52
|
+
raise TypeError
|
|
53
|
+
|
|
54
|
+
yield (pathname, open(pathname, 'rb'))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def validate_pathname_binary_tuple(data):
|
|
58
|
+
if not isinstance(data, tuple):
|
|
59
|
+
raise TypeError("pathname binary data should be tuple type, but got {}".format(type(data)))
|
|
60
|
+
if len(data) != 2:
|
|
61
|
+
raise TypeError("pathname binary tuple length should be 2, but got {}".format(str(len(data))))
|
|
62
|
+
if not isinstance(data[0], str):
|
|
63
|
+
raise TypeError("pathname binary tuple should have string type pathname, but got {}".format(type(data[0])))
|
|
64
|
+
if not isinstance(data[1], BufferedIOBase):
|
|
65
|
+
raise TypeError("pathname binary tuple should have BufferedIOBase based binary type, but got {}".format(type(data[1])))
|