python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,267 @@
1
+ import torch
2
+ #from torch._six import int_classes as _int_classes
3
+ from torch import Tensor
4
+
5
+ from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
6
+ import time
7
+ import itertools
8
+
9
+ T_co = TypeVar('T_co', covariant=True)
10
+ _int_classes = (int,bool)
11
+
12
+ class Sampler(Generic[T_co]):
13
+ r"""Base class for all Samplers.
14
+
15
+ Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
16
+ way to iterate over indices of dataset elements, and a :meth:`__len__` method
17
+ that returns the length of the returned iterators.
18
+
19
+ .. note:: The :meth:`__len__` method isn't strictly required by
20
+ :class:`~torch.utils.data.DataLoader`, but is expected in any
21
+ calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
22
+ """
23
+
24
+ def __init__(self, data_source: Optional[Sized]) -> None:
25
+ pass
26
+
27
+ def __iter__(self) -> Iterator[T_co]:
28
+ raise NotImplementedError
29
+
30
+ # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
31
+ #
32
+ # Many times we have an abstract class representing a collection/iterable of
33
+ # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
34
+ # implementing a `__len__` method. In such cases, we must make sure to not
35
+ # provide a default implementation, because both straightforward default
36
+ # implementations have their issues:
37
+ #
38
+ # + `return NotImplemented`:
39
+ # Calling `len(subclass_instance)` raises:
40
+ # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
41
+ #
42
+ # + `raise NotImplementedError()`:
43
+ # This prevents triggering some fallback behavior. E.g., the built-in
44
+ # `list(X)` tries to call `len(X)` first, and executes a different code
45
+ # path if the method is not found or `NotImplemented` is returned, while
46
+ # raising an `NotImplementedError` will propagate and and make the call
47
+ # fail where it could have use `__iter__` to complete the call.
48
+ #
49
+ # Thus, the only two sensible things to do are
50
+ #
51
+ # + **not** provide a default `__len__`.
52
+ #
53
+ # + raise a `TypeError` instead, which is what Python uses when users call
54
+ # a method that is not defined on an object.
55
+ # (@ssnl verifies that this works on at least Python 3.7.)
56
+
57
+
58
+ class SequentialSampler(Sampler[int]):
59
+ r"""Samples elements sequentially, always in the same order.
60
+
61
+ Args:
62
+ data_source (Dataset): dataset to sample from
63
+ """
64
+ data_source: Sized
65
+
66
+ def __init__(self, data_source):
67
+ self.data_source = data_source
68
+
69
+ def __iter__(self):
70
+ return iter(range(len(self.data_source)))
71
+
72
+ def __len__(self) -> int:
73
+ return len(self.data_source)
74
+
75
+ class InfiniteSequentialSampler(Sampler[int]):
76
+ r"""Samples elements sequentially, always in the same order.
77
+
78
+ Args:
79
+ data_source (Dataset): dataset to sample from
80
+ """
81
+ data_source: Sized
82
+
83
+ def __init__(self, data_source):
84
+ self.data_source = data_source
85
+ self.num_samples = len(data_source)
86
+
87
+ def __iter__(self):
88
+ yield from self._infinite_indices()
89
+
90
+ def _infinite_indices(self):
91
+ g = torch.Generator()
92
+ g.manual_seed(int(time.time()))
93
+ while True:
94
+ yield from torch.arange(self.num_samples)
95
+
96
+ def __len__(self):
97
+ return self.num_samples
98
+
99
+
100
+
101
+ class RandomSampler(Sampler[int]):
102
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
103
+ If with replacement, then user can specify :attr:`num_samples` to draw.
104
+
105
+ Args:
106
+ data_source (Dataset): dataset to sample from
107
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
108
+ num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
109
+ is supposed to be specified only when `replacement` is ``True``.
110
+ generator (Generator): Generator used in sampling.
111
+ """
112
+ data_source: Sized
113
+ replacement: bool
114
+
115
+ def __init__(self, data_source: Sized, replacement: bool = False,
116
+ num_samples: Optional[int] = None, generator=None) -> None:
117
+ self.data_source = data_source
118
+ self.replacement = replacement
119
+ self._num_samples = num_samples
120
+ self.generator = generator
121
+
122
+ if not isinstance(self.replacement, bool):
123
+ raise TypeError("replacement should be a boolean value, but got "
124
+ "replacement={}".format(self.replacement))
125
+
126
+ if self._num_samples is not None and not replacement:
127
+ raise ValueError("With replacement=False, num_samples should not be specified, "
128
+ "since a random permute will be performed.")
129
+
130
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
131
+ raise ValueError("num_samples should be a positive integer "
132
+ "value, but got num_samples={}".format(self.num_samples))
133
+
134
+ @property
135
+ def num_samples(self) -> int:
136
+ # dataset size might change at runtime
137
+ if self._num_samples is None:
138
+ return len(self.data_source)
139
+ return self._num_samples
140
+
141
+ def __iter__(self):
142
+ yield from self._infinite_indices()
143
+
144
+ def _infinite_indices(self):
145
+ g = torch.Generator()
146
+ g.manual_seed(int(time.time()))
147
+ #yield from torch.arange(self.num_samples())
148
+ while True:
149
+ yield from torch.randperm(self.num_samples, generator=g)
150
+
151
+ def __len__(self):
152
+ return self.num_samples
153
+
154
+
155
+ class SubsetRandomSampler(Sampler[int]):
156
+ r"""Samples elements randomly from a given list of indices, without replacement.
157
+
158
+ Args:
159
+ indices (sequence): a sequence of indices
160
+ generator (Generator): Generator used in sampling.
161
+ """
162
+ indices: Sequence[int]
163
+
164
+ def __init__(self, indices: Sequence[int], generator=None) -> None:
165
+ self.indices = indices
166
+ self.generator = generator
167
+
168
+ def __iter__(self):
169
+ return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
170
+
171
+ def __len__(self):
172
+ return len(self.indices)
173
+
174
+
175
+ class WeightedRandomSampler(Sampler[int]):
176
+ r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
177
+
178
+ Args:
179
+ weights (sequence) : a sequence of weights, not necessary summing up to one
180
+ num_samples (int): number of samples to draw
181
+ replacement (bool): if ``True``, samples are drawn with replacement.
182
+ If not, they are drawn without replacement, which means that when a
183
+ sample index is drawn for a row, it cannot be drawn again for that row.
184
+ generator (Generator): Generator used in sampling.
185
+
186
+ Example:
187
+ >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
188
+ [4, 4, 1, 4, 5]
189
+ >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
190
+ [0, 1, 4, 3, 2]
191
+ """
192
+ weights: Tensor
193
+ num_samples: int
194
+ replacement: bool
195
+
196
+ def __init__(self, weights: Sequence[float], num_samples: int,
197
+ replacement: bool = True, generator=None) -> None:
198
+ if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
199
+ num_samples <= 0:
200
+ raise ValueError("num_samples should be a positive integer "
201
+ "value, but got num_samples={}".format(num_samples))
202
+ if not isinstance(replacement, bool):
203
+ raise ValueError("replacement should be a boolean value, but got "
204
+ "replacement={}".format(replacement))
205
+ self.weights = torch.as_tensor(weights, dtype=torch.double)
206
+ self.num_samples = num_samples
207
+ self.replacement = replacement
208
+ self.generator = generator
209
+
210
+ def __iter__(self):
211
+ rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
212
+ return iter(rand_tensor.tolist())
213
+
214
+ def __len__(self):
215
+ return self.num_samples
216
+
217
+
218
+ class BatchSampler(Sampler[List[int]]):
219
+ r"""Wraps another sampler to yield a mini-batch of indices.
220
+
221
+ Args:
222
+ sampler (Sampler or Iterable): Base sampler. Can be any iterable object
223
+ batch_size (int): Size of mini-batch.
224
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
225
+ its size would be less than ``batch_size``
226
+
227
+ Example:
228
+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
229
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
230
+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
231
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
232
+ """
233
+
234
+ def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
235
+ # Since collections.abc.Iterable does not check for `__getitem__`, which
236
+ # is one way for an object to be an iterable, we don't do an `isinstance`
237
+ # check here.
238
+ if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
239
+ batch_size <= 0:
240
+ raise ValueError("batch_size should be a positive integer value, "
241
+ "but got batch_size={}".format(batch_size))
242
+ if not isinstance(drop_last, bool):
243
+ raise ValueError("drop_last should be a boolean value, but got "
244
+ "drop_last={}".format(drop_last))
245
+ self.sampler = sampler
246
+ self.batch_size = batch_size
247
+ self.drop_last = drop_last
248
+
249
+ def __iter__(self):
250
+ batch = []
251
+ for idx in self.sampler:
252
+ batch.append(idx)
253
+ if len(batch) == self.batch_size:
254
+ yield batch
255
+ batch = []
256
+ if len(batch) > 0 and not self.drop_last:
257
+ yield batch
258
+
259
+ def __len__(self):
260
+ # Can only be called if self.sampler has __len__ implemented
261
+ # We cannot enforce this condition, so we turn off typechecking for the
262
+ # implementation below.
263
+ # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
264
+ if self.drop_last:
265
+ return len(self.sampler) // self.batch_size # type: ignore
266
+ else:
267
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore
@@ -0,0 +1,24 @@
1
+ from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional
2
+ import torch
3
+ import time
4
+ from . import _utils
5
+ import wml.wml_utils as wmlu
6
+ from . import _BaseDataLoaderIter,_DatasetKind
7
+
8
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
9
+ def __init__(self, loader):
10
+ super(_SingleProcessDataLoaderIter, self).__init__(loader)
11
+ assert self._timeout == 0
12
+ assert self._num_workers == 0
13
+
14
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
15
+ self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
16
+
17
+ def _next_data(self):
18
+ index = self._next_index() # may raise StopIteration
19
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
20
+ if self._pin_memory:
21
+ data = _utils.pin_memory.pin_memory(data)
22
+ return data
23
+
24
+
@@ -0,0 +1,26 @@
1
+ from wml.wtorch.data.dataloader import DataLoader
2
+ import numpy as np
3
+ from wml.wtorch.data import SequentialSampler
4
+
5
+ class ExampleDataset(object):
6
+ def __init__(self):
7
+ self.data = np.array(list(range(100)))
8
+ self.data = np.reshape(self.data,[50,2])
9
+
10
+ def __len__(self):
11
+ return len(self.data)
12
+
13
+ def __getitem__(self,item):
14
+ return self.data[item]
15
+
16
+
17
+ dataset = ExampleDataset()
18
+ dataloader = DataLoader(ExampleDataset(),4,shuffle=True,num_workers=4,pin_memory=True,batch_split_nr=2)
19
+ #sampler=SequentialSampler(dataset))
20
+
21
+ idx = 0
22
+ for i in range(3):
23
+ for x in iter(dataloader):
24
+ print(idx,x)
25
+ idx += 1
26
+
@@ -0,0 +1,67 @@
1
+ import wml.wml_utils as wmlu
2
+ import random
3
+
4
+ class DataUnit:
5
+ MAX_IDXS_LEN = 100
6
+ def __init__(self,data):
7
+ if not isinstance(data,(list,tuple)):
8
+ raise RuntimeError("Error data type")
9
+ self.data = data
10
+ self._idxs = []
11
+
12
+ def make_idxs(self):
13
+ nr = max(1,int(DataUnit.MAX_IDXS_LEN/len(self.data)))
14
+ self._idxs = []
15
+ for i in range(nr):
16
+ idxs = self.make_one_idxs()
17
+ self._idxs.extend(idxs)
18
+
19
+ def make_one_idxs(self):
20
+ idxs = list(range(len(self.data)))
21
+ random.shuffle(idxs)
22
+ return idxs
23
+
24
+
25
+ def __len__(self):
26
+ return len(self.data)
27
+
28
+ def __getitem__(self, item):
29
+ return self.data[item]
30
+
31
+ def sample(self):
32
+ if len(self._idxs) == 0:
33
+ self.make_idxs()
34
+ idx = self._idxs[-1]
35
+ self._idxs = self._idxs[:-1]
36
+ return self.data[idx]
37
+
38
+ def __repr__(self):
39
+ return type(self).__name__+f",{self.data}"
40
+
41
+
42
+ def make_data_unit(datas,total_nr=None,nr_per_unit=None):
43
+ assert total_nr is None or nr_per_unit is None, "Error arguments"
44
+ if total_nr is not None:
45
+ if total_nr>=len(datas):
46
+ return datas
47
+ datas = wmlu.list_to_2dlistv2(datas,total_nr)
48
+ else:
49
+ if nr_per_unit<=1:
50
+ return datas
51
+ datas = wmlu.list_to_2dlist(datas,nr_per_unit)
52
+
53
+ datas = [DataUnit(x) for x in datas]
54
+ return datas
55
+
56
+ class DataList:
57
+ def __init__(self,data) -> None:
58
+ self.data = data
59
+
60
+ def __len__(self):
61
+ return len(self.data)
62
+
63
+ def __getitem__(self, item):
64
+ data = self.data[item]
65
+ if isinstance(data,DataUnit):
66
+ data = data.sample()
67
+ return data
@@ -0,0 +1,98 @@
1
+ from typing import Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .conv_module import ConvModule
7
+
8
+
9
+ class DepthwiseSeparableConvModule(nn.Module):
10
+ """Depthwise separable convolution module.
11
+
12
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
13
+
14
+ This module can replace a ConvModule with the conv block replaced by two
15
+ conv block: depthwise conv block and pointwise conv block. The depthwise
16
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
17
+ conv block contains pointwise-conv/norm/activation layers. It should be
18
+ noted that there will be norm/activation layer in the depthwise conv block
19
+ if `norm_cfg` and `act_cfg` are specified.
20
+
21
+ Args:
22
+ in_channels (int): Number of channels in the input feature map.
23
+ Same as that in ``nn._ConvNd``.
24
+ out_channels (int): Number of channels produced by the convolution.
25
+ Same as that in ``nn._ConvNd``.
26
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
27
+ Same as that in ``nn._ConvNd``.
28
+ stride (int | tuple[int]): Stride of the convolution.
29
+ Same as that in ``nn._ConvNd``. Default: 1.
30
+ padding (int | tuple[int]): Zero-padding added to both sides of
31
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
32
+ dilation (int | tuple[int]): Spacing between kernel elements.
33
+ Same as that in ``nn._ConvNd``. Default: 1.
34
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
35
+ pointwise ConvModule. Default: None.
36
+ act_cfg (dict): Default activation config for both depthwise ConvModule
37
+ and pointwise ConvModule. Default: dict(type='ReLU').
38
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
39
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
40
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
41
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
42
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
43
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
44
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
45
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
46
+ kwargs (optional): Other shared arguments for depthwise and pointwise
47
+ ConvModule. See ConvModule for ref.
48
+ """
49
+
50
+ def __init__(self,
51
+ in_channels: int,
52
+ out_channels: int,
53
+ kernel_size: Union[int, Tuple[int, int]],
54
+ stride: Union[int, Tuple[int, int]] = 1,
55
+ padding: Union[int, Tuple[int, int]] = 0,
56
+ dilation: Union[int, Tuple[int, int]] = 1,
57
+ norm_cfg: Optional[Dict] = None,
58
+ act_cfg: Dict = dict(type='ReLU'),
59
+ dw_norm_cfg: Union[Dict, str] = 'default',
60
+ dw_act_cfg: Union[Dict, str] = 'default',
61
+ pw_norm_cfg: Union[Dict, str] = 'default',
62
+ pw_act_cfg: Union[Dict, str] = 'default',
63
+ **kwargs):
64
+ super().__init__()
65
+ assert 'groups' not in kwargs, 'groups should not be specified'
66
+
67
+ # if norm/activation config of depthwise/pointwise ConvModule is not
68
+ # specified, use default config.
69
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
70
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
71
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
72
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
73
+
74
+ # depthwise convolution
75
+ self.depthwise_conv = ConvModule(
76
+ in_channels,
77
+ in_channels,
78
+ kernel_size,
79
+ stride=stride,
80
+ padding=padding,
81
+ dilation=dilation,
82
+ groups=in_channels,
83
+ norm_cfg=dw_norm_cfg, # type: ignore
84
+ act_cfg=dw_act_cfg, # type: ignore
85
+ **kwargs)
86
+
87
+ self.pointwise_conv = ConvModule(
88
+ in_channels,
89
+ out_channels,
90
+ 1,
91
+ norm_cfg=pw_norm_cfg, # type: ignore
92
+ act_cfg=pw_act_cfg, # type: ignore
93
+ **kwargs)
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ x = self.depthwise_conv(x)
97
+ x = self.pointwise_conv(x)
98
+ return x