python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,126 @@
1
+ import warnings
2
+ from torch.utils.data import IterDataPipe
3
+ from typing import TypeVar, Optional, Iterator, List, Sized, Callable
4
+
5
+ T_co = TypeVar('T_co', covariant=True)
6
+
7
+
8
+ class BatchIterDataPipe(IterDataPipe[List[T_co]]):
9
+ r""" :class:`BatchIterDataPipe`.
10
+
11
+ Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
12
+ `batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the
13
+ last batch if `drop_last` is set to `False`.
14
+ args:
15
+ datapipe: Iterable DataPipe being batched
16
+ batch_size: The size of each batch
17
+ drop_last: Option to drop the last batch if it's not full
18
+ """
19
+ datapipe: IterDataPipe[T_co]
20
+ batch_size: int
21
+ drop_last: bool
22
+ length: Optional[int]
23
+
24
+ def __init__(self,
25
+ datapipe: IterDataPipe[T_co],
26
+ *,
27
+ batch_size: int,
28
+ drop_last: bool = False,
29
+ ) -> None:
30
+ assert batch_size > 0, "Batch size is required to be larger than 0!"
31
+ super().__init__()
32
+ self.datapipe = datapipe
33
+ self.batch_size = batch_size
34
+ self.drop_last = drop_last
35
+ self.length = None
36
+
37
+ def __iter__(self) -> Iterator[List[T_co]]:
38
+ batch: List[T_co] = []
39
+ for x in self.datapipe:
40
+ batch.append(x)
41
+ if len(batch) == self.batch_size:
42
+ yield batch
43
+ batch.clear()
44
+ if len(batch) > 0:
45
+ if not self.drop_last:
46
+ yield batch
47
+ batch.clear()
48
+
49
+ def __len__(self) -> int:
50
+ if self.length is not None:
51
+ return self.length
52
+ if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
53
+ if self.drop_last:
54
+ self.length = len(self.datapipe) // self.batch_size
55
+ else:
56
+ self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
57
+ return self.length
58
+ raise NotImplementedError
59
+
60
+
61
+ class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
62
+ r""" :class:`BucketBatchIterDataPipe`.
63
+
64
+ Iterable DataPipe to create mini-batches of data from sorted bucket. An outer
65
+ dimension will be added as `batch_size` if `drop_last` is set to `True`,
66
+ or `length % batch_size` for the last batch if `drop_last` is set to `False`.
67
+ args:
68
+ datapipe: Iterable DataPipe being batched
69
+ batch_size: The size of each batch
70
+ drop_last: Option to drop the last batch if it's not full
71
+ bucket_size_mul: The multiplier to specify the size of bucket
72
+ sort_key: Callable to specify the comparison key for sorting within bucket
73
+ """
74
+ datapipe: IterDataPipe[T_co]
75
+ batch_size: int
76
+ drop_last: bool
77
+ bucket_size_mul: int
78
+ sort_key: Optional[Callable]
79
+ length: Optional[int]
80
+
81
+ def __init__(self,
82
+ datapipe: IterDataPipe[T_co],
83
+ *,
84
+ batch_size: int,
85
+ drop_last: bool = False,
86
+ bucket_size_mul: int = 100,
87
+ sort_key: Optional[Callable] = None,
88
+ ) -> None:
89
+ assert batch_size > 0, "Batch size is required to be larger than 0!"
90
+ super().__init__()
91
+ self.datapipe = datapipe
92
+ self.batch_size = batch_size
93
+ self.drop_last = drop_last
94
+ self.bucket_size = batch_size * bucket_size_mul
95
+ self.sort_key = sort_key
96
+ if sort_key is not None and sort_key.__name__ == '<lambda>':
97
+ warnings.warn("Lambda function is not supported for pickle, "
98
+ "please use regular python function instead.")
99
+ self.bucket_ds = BatchIterDataPipe(datapipe, batch_size=self.bucket_size, drop_last=False)
100
+ self.length = None
101
+
102
+ def __iter__(self) -> Iterator[List[T_co]]:
103
+ # Bucket without sorting remains same order, directly returns BatchDataset
104
+ if self.sort_key is None:
105
+ yield from BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last)
106
+ else:
107
+ bucket: List[T_co]
108
+ batch: List[T_co] = []
109
+ for bucket in self.bucket_ds:
110
+ # In-place sort within bucket
111
+ bucket.sort(key=self.sort_key)
112
+ for start in range(0, len(bucket), self.batch_size):
113
+ batch = bucket[start: start + self.batch_size]
114
+ if len(batch) == self.batch_size or not self.drop_last:
115
+ yield batch
116
+
117
+ def __len__(self) -> int:
118
+ if self.length is not None:
119
+ return self.length
120
+ if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
121
+ if self.drop_last:
122
+ self.length = len(self.datapipe) // self.batch_size
123
+ else:
124
+ self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
125
+ return self.length
126
+ raise NotImplementedError
@@ -0,0 +1,92 @@
1
+ import warnings
2
+ from torch.utils.data import IterDataPipe, _utils
3
+ from typing import TypeVar, Callable, Iterator, Sized
4
+
5
+ T_co = TypeVar('T_co', covariant=True)
6
+
7
+
8
+ # Default function to return each item directly
9
+ # In order to keep datapipe picklable, eliminates the usage
10
+ # of python lambda function
11
+ def default_fn(data):
12
+ return data
13
+
14
+
15
+ class CallableIterDataPipe(IterDataPipe[T_co]):
16
+ r""" :class:`CallableIterDataPipe`.
17
+
18
+ Iterable DataPipe to run a function over each item from the source DataPipe.
19
+ args:
20
+ datapipe: Source Iterable DataPipe
21
+ fn: Function called over each item
22
+ """
23
+ datapipe: IterDataPipe
24
+ fn: Callable
25
+
26
+ def __init__(self,
27
+ datapipe: IterDataPipe,
28
+ *args,
29
+ fn: Callable = default_fn,
30
+ **kwargs,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.datapipe = datapipe
34
+ if fn.__name__ == '<lambda>':
35
+ warnings.warn("Lambda function is not supported for pickle, "
36
+ "please use regular python function instead.")
37
+ self.fn = fn # type: ignore
38
+ self.args = args
39
+ self.kwargs = kwargs
40
+
41
+ def __iter__(self) -> Iterator[T_co]:
42
+ for data in self.datapipe:
43
+ yield self.fn(data, *self.args, **self.kwargs)
44
+
45
+ def __len__(self) -> int:
46
+ if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
47
+ return len(self.datapipe)
48
+ raise NotImplementedError
49
+
50
+
51
+ class CollateIterDataPipe(CallableIterDataPipe):
52
+ r""" :class:`CollateIterDataPipe`.
53
+
54
+ Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
55
+ or customized Data Structure by collate_fn.
56
+ args:
57
+ datapipe: Iterable DataPipe being collated
58
+ collate_fn: Customized collate function to collect and combine data or a batch of data.
59
+ Default function collates to Tensor(s) based on data type.
60
+
61
+ Example: Convert integer data to float Tensor
62
+ >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
63
+ ... def __init__(self, start, end):
64
+ ... super(MyIterDataPipe).__init__()
65
+ ... assert end > start, "this example code only works with end >= start"
66
+ ... self.start = start
67
+ ... self.end = end
68
+ ...
69
+ ... def __iter__(self):
70
+ ... return iter(range(self.start, self.end))
71
+ ...
72
+ ... def __len__(self):
73
+ ... return self.end - self.start
74
+ ...
75
+ >>> ds = MyIterDataPipe(start=3, end=7)
76
+ >>> print(list(ds))
77
+ [3, 4, 5, 6]
78
+
79
+ >>> def collate_fn(batch):
80
+ ... return torch.tensor(batch, dtype=torch.float)
81
+ ...
82
+ >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
83
+ >>> print(list(collated_ds))
84
+ [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
85
+ """
86
+ def __init__(self,
87
+ datapipe: IterDataPipe,
88
+ *args,
89
+ collate_fn: Callable = _utils.collate.default_collate,
90
+ **kwargs,
91
+ ) -> None:
92
+ super().__init__(datapipe, *args, fn=collate_fn, **kwargs)
@@ -0,0 +1,37 @@
1
+ from torch.utils.data import IterDataPipe
2
+ from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root
3
+ from typing import List, Union, Iterator
4
+
5
+ class ListDirFilesIterDataPipe(IterDataPipe):
6
+ r""" :class:`ListDirFilesIterDataPipe`
7
+
8
+ Iterable DataPipe to load file pathname(s) (path + filename), yield pathname from given disk root dir.
9
+ args:
10
+ root : root dir
11
+ mask : a unix style filter string or string list for filtering file name(s)
12
+ abspath : whether to return relative pathname or absolute pathname
13
+ length : a nominal length of the datapipe
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ root: str = '.',
19
+ masks: Union[str, List[str]] = '',
20
+ *,
21
+ recursive: bool = False,
22
+ abspath: bool = False,
23
+ length: int = -1):
24
+ super().__init__()
25
+ self.root : str = root
26
+ self.masks : Union[str, List[str]] = masks
27
+ self.recursive : bool = recursive
28
+ self.abspath : bool = abspath
29
+ self.length : int = length
30
+
31
+ def __iter__(self) -> Iterator[str] :
32
+ yield from get_file_pathnames_from_root(self.root, self.masks, self.recursive, self.abspath)
33
+
34
+ def __len__(self):
35
+ if self.length == -1:
36
+ raise NotImplementedError
37
+ return self.length
@@ -0,0 +1,30 @@
1
+ from torch.utils.data import IterDataPipe
2
+ from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames
3
+ from typing import Iterable, Iterator, Tuple
4
+ from io import BufferedIOBase
5
+
6
+ class LoadFilesFromDiskIterDataPipe(IterDataPipe):
7
+ r""" :class:`LoadFilesFromDiskIterDataPipe`.
8
+
9
+ Iterable Datapipe to load file binary streams from given pathnames,
10
+ yield pathname and binary stream in a tuple.
11
+ args:
12
+ datapipe: Iterable datapipe that provides pathnames
13
+ length: a nominal length of the datapipe
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ datapipe : Iterable[str],
19
+ length : int = -1):
20
+ super().__init__()
21
+ self.datapipe : Iterable = datapipe
22
+ self.length : int = length
23
+
24
+ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]] :
25
+ yield from get_file_binaries_from_pathnames(self.datapipe)
26
+
27
+ def __len__(self):
28
+ if self.length == -1:
29
+ raise NotImplementedError
30
+ return self.length
@@ -0,0 +1,60 @@
1
+ from torch.utils.data import IterDataPipe
2
+ from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
3
+ from typing import Iterable, Iterator, Tuple, Optional, IO, cast
4
+ from io import BufferedIOBase
5
+
6
+ import os
7
+ import tarfile
8
+ import warnings
9
+
10
+ class ReadFilesFromTarIterDataPipe(IterDataPipe):
11
+ r""" :class:`ReadFilesFromTarIDP`.
12
+
13
+ Iterable datapipe to extract tar binary streams from input iterable which contains tuples of
14
+ pathname and tar binary stream, yields pathname and extracted binary stream in a tuple.
15
+ args:
16
+ datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples
17
+ length: a nominal length of the datapipe
18
+ """
19
+ def __init__(
20
+ self,
21
+ datapipe : Iterable[Tuple[str, BufferedIOBase]],
22
+ length : int = -1):
23
+ super().__init__()
24
+ self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
25
+ self.length : int = length
26
+
27
+
28
+ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
29
+ if not isinstance(self.datapipe, Iterable):
30
+ raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
31
+ for data in self.datapipe:
32
+ validate_pathname_binary_tuple(data)
33
+ pathname, data_stream = data
34
+ try:
35
+ # typing.cast is used here to silence mypy's type checker
36
+ tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode="r:*")
37
+ for tarinfo in tar:
38
+ if not tarinfo.isfile():
39
+ continue
40
+ extracted_fobj = tar.extractfile(tarinfo)
41
+ if extracted_fobj is None:
42
+ warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
43
+ raise tarfile.ExtractError
44
+ inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
45
+ # Add a reference of the source tarfile into extracted_fobj, so the source
46
+ # tarfile handle won't be released until all the extracted file objs are destroyed.
47
+ # Add `# type: ignore` to silence mypy's type checker
48
+ extracted_fobj.source_tarfile_ref = tar # type: ignore
49
+ # typing.cast is used here to silence mypy's type checker
50
+ yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
51
+ except Exception as e:
52
+ warnings.warn(
53
+ "Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
54
+ raise e
55
+
56
+
57
+ def __len__(self):
58
+ if self.length == -1:
59
+ raise NotImplementedError
60
+ return self.length
@@ -0,0 +1,63 @@
1
+ from torch.utils.data import IterDataPipe
2
+ from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
3
+ from typing import Iterable, Iterator, Tuple, IO, cast
4
+ from io import BufferedIOBase
5
+
6
+ import os
7
+ import sys
8
+ import zipfile
9
+ import warnings
10
+
11
+ class ReadFilesFromZipIterDataPipe(IterDataPipe):
12
+ r""" :class:`ReadFilesFromZipIterDataPipe`.
13
+
14
+ Iterable data pipe to extract zip binary streams from input iterable which contains tuples of
15
+ pathname and zip binary stream, yields pathname and extracted binary stream in a tuple.
16
+ args:
17
+ datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples
18
+ length: a nominal length of the datapipe
19
+ """
20
+ def __init__(
21
+ self,
22
+ datapipe : Iterable[Tuple[str, BufferedIOBase]],
23
+ length : int = -1):
24
+ super().__init__()
25
+ self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
26
+ self.length : int = length
27
+
28
+
29
+ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
30
+ if not isinstance(self.datapipe, Iterable):
31
+ raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
32
+ for data in self.datapipe:
33
+ validate_pathname_binary_tuple(data)
34
+ pathname, data_stream = data
35
+ try:
36
+ # typing.cast is used here to silence mypy's type checker
37
+ zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
38
+ for zipinfo in zips.infolist():
39
+ # major version should always be 3 here.
40
+ if sys.version_info[1] >= 6:
41
+ if zipinfo.is_dir():
42
+ continue
43
+ elif zipinfo.filename.endswith('/'):
44
+ continue
45
+
46
+ extracted_fobj = zips.open(zipinfo)
47
+ inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
48
+ # Add a reference of the source zipfile into extracted_fobj, so the source
49
+ # zipfile handle won't be released until all the extracted file objs are destroyed.
50
+ # Add `# type: ignore` to silence mypy's type checker
51
+ extracted_fobj.source_zipfile_ref = zips # type: ignore
52
+ # typing.cast is used here to silence mypy's type checker
53
+ yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
54
+ except Exception as e:
55
+ warnings.warn(
56
+ "Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e))
57
+ raise e
58
+
59
+
60
+ def __len__(self):
61
+ if self.length == -1:
62
+ raise NotImplementedError
63
+ return self.length
@@ -0,0 +1,94 @@
1
+ from torch.utils.data import IterDataPipe, Sampler, SequentialSampler
2
+ from typing import TypeVar, Type, Iterator, Sized,Optional
3
+ import itertools
4
+ import torch
5
+
6
+ T_co = TypeVar('T_co', covariant=True)
7
+
8
+
9
+ class SamplerIterDataPipe(IterDataPipe[T_co]):
10
+ r""" :class:`SamplerIterDataPipe`.
11
+
12
+ Iterable DataPipe to generate sample elements.
13
+ args:
14
+ datapipe: IterDataPipe sampled from
15
+ sampler: Sampler class to genereate sample elements from input DataPipe.
16
+ Default is :class:`SequentialSampler` for IterDataPipe
17
+ """
18
+ datapipe: IterDataPipe
19
+ sampler: Sampler
20
+
21
+ def __init__(self,
22
+ datapipe: IterDataPipe,
23
+ *,
24
+ sampler: Type[Sampler] = SequentialSampler,
25
+ **kwargs
26
+ ) -> None:
27
+ assert isinstance(datapipe, Sized), \
28
+ "Sampler class requires input datapipe implemented `__len__`"
29
+ super().__init__()
30
+ self.datapipe = datapipe
31
+ # https://github.com/python/mypy/pull/9629 will solve
32
+ self.sampler = sampler(data_source=self.datapipe, **kwargs) # type: ignore
33
+
34
+ def __iter__(self) -> Iterator[T_co]:
35
+ return iter(self.sampler)
36
+
37
+ def __len__(self) -> int:
38
+ # Dataset has been tested as `Sized`
39
+ if isinstance(self.sampler, Sized) and len(self.sampler) >= 0:
40
+ return len(self.sampler)
41
+ raise NotImplementedError
42
+
43
+ class InfiniteSampler(Sampler):
44
+ """
45
+ In training, we only care about the "infinite stream" of training data.
46
+ So this sampler produces an infinite stream of indices and
47
+ all workers cooperate to correctly shuffle the indices and sample different indices.
48
+ The samplers in each worker effectively produces `indices[worker_id::num_workers]`
49
+ where `indices` is an infinite stream of indices consisting of
50
+ `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
51
+ or `range(size) + range(size) + ...` (if shuffle is False)
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ size: int,
57
+ shuffle: bool = True,
58
+ seed: Optional[int] = 0,
59
+ rank=0,
60
+ world_size=1,
61
+ ):
62
+ """
63
+ Args:
64
+ size (int): the total number of data of the underlying dataset to sample from
65
+ shuffle (bool): whether to shuffle the indices or not
66
+ seed (int): the initial seed of the shuffle. Must be the same
67
+ across all workers. If None, will use a random seed shared
68
+ among workers (require synchronization among all workers).
69
+ """
70
+ self._size = size
71
+ assert size > 0
72
+ self._shuffle = shuffle
73
+ self._seed = int(seed)
74
+
75
+ self._rank = rank
76
+ self._world_size = world_size
77
+
78
+ def __iter__(self):
79
+ start = self._rank
80
+ yield from itertools.islice(
81
+ self._infinite_indices(), start, None, self._world_size
82
+ )
83
+
84
+ def _infinite_indices(self):
85
+ g = torch.Generator()
86
+ g.manual_seed(self._seed)
87
+ while True:
88
+ if self._shuffle:
89
+ yield from torch.randperm(self._size, generator=g)
90
+ else:
91
+ yield from torch.arange(self._size)
92
+
93
+ def __len__(self):
94
+ return self._size // self._world_size
File without changes
@@ -0,0 +1,65 @@
1
+ import os
2
+ import fnmatch
3
+ import warnings
4
+ from typing import List, Union, Iterable
5
+ from io import BufferedIOBase
6
+
7
+
8
+ def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
9
+ # empty mask matches any input name
10
+ if not masks:
11
+ return True
12
+
13
+ if isinstance(masks, str):
14
+ return fnmatch.fnmatch(name, masks)
15
+
16
+ for mask in masks:
17
+ if fnmatch.fnmatch(name, mask):
18
+ return True
19
+ return False
20
+
21
+
22
+ def get_file_pathnames_from_root(
23
+ root: str,
24
+ masks: Union[str, List[str]],
25
+ recursive: bool = False,
26
+ abspath: bool = False) -> Iterable[str]:
27
+
28
+ # print out an error message and raise the error out
29
+ def onerror(err : OSError):
30
+ warnings.warn(err.filename + " : " + err.strerror)
31
+ raise err
32
+
33
+ for path, dirs, files in os.walk(root, onerror=onerror):
34
+ if abspath:
35
+ path = os.path.abspath(path)
36
+ for f in files:
37
+ if match_masks(f, masks):
38
+ yield os.path.join(path, f)
39
+ if not recursive:
40
+ break
41
+
42
+
43
+ def get_file_binaries_from_pathnames(pathnames : Iterable):
44
+
45
+ if not isinstance(pathnames, Iterable):
46
+ warnings.warn("get_file_binaries_from_pathnames needs the input be an Iterable")
47
+ raise TypeError
48
+
49
+ for pathname in pathnames:
50
+ if not isinstance(pathname, str):
51
+ warnings.warn("file pathname must be string type, but got {}".format(type(pathname)))
52
+ raise TypeError
53
+
54
+ yield (pathname, open(pathname, 'rb'))
55
+
56
+
57
+ def validate_pathname_binary_tuple(data):
58
+ if not isinstance(data, tuple):
59
+ raise TypeError("pathname binary data should be tuple type, but got {}".format(type(data)))
60
+ if len(data) != 2:
61
+ raise TypeError("pathname binary tuple length should be 2, but got {}".format(str(len(data))))
62
+ if not isinstance(data[0], str):
63
+ raise TypeError("pathname binary tuple should have string type pathname, but got {}".format(type(data[0])))
64
+ if not isinstance(data[1], BufferedIOBase):
65
+ raise TypeError("pathname binary tuple should have BufferedIOBase based binary type, but got {}".format(type(data[1])))