bplusplus 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

Files changed (95) hide show
  1. bplusplus/__init__.py +5 -3
  2. bplusplus/{collect_images.py → collect.py} +3 -3
  3. bplusplus/prepare.py +573 -0
  4. bplusplus/train_validate.py +8 -64
  5. bplusplus/yolov5detect/__init__.py +1 -0
  6. bplusplus/yolov5detect/detect.py +444 -0
  7. bplusplus/yolov5detect/export.py +1530 -0
  8. bplusplus/yolov5detect/insect.yaml +8 -0
  9. bplusplus/yolov5detect/models/__init__.py +0 -0
  10. bplusplus/yolov5detect/models/common.py +1109 -0
  11. bplusplus/yolov5detect/models/experimental.py +130 -0
  12. bplusplus/yolov5detect/models/hub/anchors.yaml +56 -0
  13. bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +52 -0
  14. bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +42 -0
  15. bplusplus/yolov5detect/models/hub/yolov3.yaml +52 -0
  16. bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +49 -0
  17. bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +43 -0
  18. bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +55 -0
  19. bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +42 -0
  20. bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +57 -0
  21. bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +68 -0
  22. bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +49 -0
  23. bplusplus/yolov5detect/models/hub/yolov5l6.yaml +61 -0
  24. bplusplus/yolov5detect/models/hub/yolov5m6.yaml +61 -0
  25. bplusplus/yolov5detect/models/hub/yolov5n6.yaml +61 -0
  26. bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +50 -0
  27. bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +49 -0
  28. bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +49 -0
  29. bplusplus/yolov5detect/models/hub/yolov5s6.yaml +61 -0
  30. bplusplus/yolov5detect/models/hub/yolov5x6.yaml +61 -0
  31. bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +49 -0
  32. bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +49 -0
  33. bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +49 -0
  34. bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +49 -0
  35. bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +49 -0
  36. bplusplus/yolov5detect/models/tf.py +797 -0
  37. bplusplus/yolov5detect/models/yolo.py +495 -0
  38. bplusplus/yolov5detect/models/yolov5l.yaml +49 -0
  39. bplusplus/yolov5detect/models/yolov5m.yaml +49 -0
  40. bplusplus/yolov5detect/models/yolov5n.yaml +49 -0
  41. bplusplus/yolov5detect/models/yolov5s.yaml +49 -0
  42. bplusplus/yolov5detect/models/yolov5x.yaml +49 -0
  43. bplusplus/yolov5detect/utils/__init__.py +97 -0
  44. bplusplus/yolov5detect/utils/activations.py +134 -0
  45. bplusplus/yolov5detect/utils/augmentations.py +448 -0
  46. bplusplus/yolov5detect/utils/autoanchor.py +175 -0
  47. bplusplus/yolov5detect/utils/autobatch.py +70 -0
  48. bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
  49. bplusplus/yolov5detect/utils/aws/mime.sh +26 -0
  50. bplusplus/yolov5detect/utils/aws/resume.py +41 -0
  51. bplusplus/yolov5detect/utils/aws/userdata.sh +27 -0
  52. bplusplus/yolov5detect/utils/callbacks.py +72 -0
  53. bplusplus/yolov5detect/utils/dataloaders.py +1385 -0
  54. bplusplus/yolov5detect/utils/docker/Dockerfile +73 -0
  55. bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +40 -0
  56. bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +42 -0
  57. bplusplus/yolov5detect/utils/downloads.py +136 -0
  58. bplusplus/yolov5detect/utils/flask_rest_api/README.md +70 -0
  59. bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +17 -0
  60. bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +49 -0
  61. bplusplus/yolov5detect/utils/general.py +1294 -0
  62. bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +25 -0
  63. bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +6 -0
  64. bplusplus/yolov5detect/utils/google_app_engine/app.yaml +16 -0
  65. bplusplus/yolov5detect/utils/loggers/__init__.py +476 -0
  66. bplusplus/yolov5detect/utils/loggers/clearml/README.md +222 -0
  67. bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
  68. bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +230 -0
  69. bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +90 -0
  70. bplusplus/yolov5detect/utils/loggers/comet/README.md +250 -0
  71. bplusplus/yolov5detect/utils/loggers/comet/__init__.py +551 -0
  72. bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +151 -0
  73. bplusplus/yolov5detect/utils/loggers/comet/hpo.py +126 -0
  74. bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +135 -0
  75. bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
  76. bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +210 -0
  77. bplusplus/yolov5detect/utils/loss.py +259 -0
  78. bplusplus/yolov5detect/utils/metrics.py +381 -0
  79. bplusplus/yolov5detect/utils/plots.py +517 -0
  80. bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
  81. bplusplus/yolov5detect/utils/segment/augmentations.py +100 -0
  82. bplusplus/yolov5detect/utils/segment/dataloaders.py +366 -0
  83. bplusplus/yolov5detect/utils/segment/general.py +160 -0
  84. bplusplus/yolov5detect/utils/segment/loss.py +198 -0
  85. bplusplus/yolov5detect/utils/segment/metrics.py +225 -0
  86. bplusplus/yolov5detect/utils/segment/plots.py +152 -0
  87. bplusplus/yolov5detect/utils/torch_utils.py +482 -0
  88. bplusplus/yolov5detect/utils/triton.py +90 -0
  89. bplusplus-1.1.0.dist-info/METADATA +179 -0
  90. bplusplus-1.1.0.dist-info/RECORD +92 -0
  91. bplusplus/build_model.py +0 -38
  92. bplusplus-0.1.0.dist-info/METADATA +0 -91
  93. bplusplus-0.1.0.dist-info/RECORD +0 -8
  94. {bplusplus-0.1.0.dist-info → bplusplus-1.1.0.dist-info}/LICENSE +0 -0
  95. {bplusplus-0.1.0.dist-info → bplusplus-1.1.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,366 @@
1
+ # Ultralytics YOLOv5 🚀, AGPL-3.0 license
2
+ """Dataloaders."""
3
+
4
+ import os
5
+ import random
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+
12
+ from ..augmentations import augment_hsv, copy_paste, letterbox
13
+ from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, SmartDistributedSampler, seed_worker
14
+ from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
15
+ from ..torch_utils import torch_distributed_zero_first
16
+ from .augmentations import mixup, random_perspective
17
+
18
+ RANK = int(os.getenv("RANK", -1))
19
+
20
+
21
+ def create_dataloader(
22
+ path,
23
+ imgsz,
24
+ batch_size,
25
+ stride,
26
+ single_cls=False,
27
+ hyp=None,
28
+ augment=False,
29
+ cache=False,
30
+ pad=0.0,
31
+ rect=False,
32
+ rank=-1,
33
+ workers=8,
34
+ image_weights=False,
35
+ quad=False,
36
+ prefix="",
37
+ shuffle=False,
38
+ mask_downsample_ratio=1,
39
+ overlap_mask=False,
40
+ seed=0,
41
+ ):
42
+ """Creates a dataloader for training, validating, or testing YOLO models with various dataset options."""
43
+ if rect and shuffle:
44
+ LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
45
+ shuffle = False
46
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
47
+ dataset = LoadImagesAndLabelsAndMasks(
48
+ path,
49
+ imgsz,
50
+ batch_size,
51
+ augment=augment, # augmentation
52
+ hyp=hyp, # hyperparameters
53
+ rect=rect, # rectangular batches
54
+ cache_images=cache,
55
+ single_cls=single_cls,
56
+ stride=int(stride),
57
+ pad=pad,
58
+ image_weights=image_weights,
59
+ prefix=prefix,
60
+ downsample_ratio=mask_downsample_ratio,
61
+ overlap=overlap_mask,
62
+ rank=rank,
63
+ )
64
+
65
+ batch_size = min(batch_size, len(dataset))
66
+ nd = torch.cuda.device_count() # number of CUDA devices
67
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
68
+ sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
69
+ loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
70
+ generator = torch.Generator()
71
+ generator.manual_seed(6148914691236517205 + seed + RANK)
72
+ return loader(
73
+ dataset,
74
+ batch_size=batch_size,
75
+ shuffle=shuffle and sampler is None,
76
+ num_workers=nw,
77
+ sampler=sampler,
78
+ pin_memory=True,
79
+ collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
80
+ worker_init_fn=seed_worker,
81
+ generator=generator,
82
+ ), dataset
83
+
84
+
85
+ class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
86
+ """Loads images, labels, and segmentation masks for training and testing YOLO models with augmentation support."""
87
+
88
+ def __init__(
89
+ self,
90
+ path,
91
+ img_size=640,
92
+ batch_size=16,
93
+ augment=False,
94
+ hyp=None,
95
+ rect=False,
96
+ image_weights=False,
97
+ cache_images=False,
98
+ single_cls=False,
99
+ stride=32,
100
+ pad=0,
101
+ min_items=0,
102
+ prefix="",
103
+ downsample_ratio=1,
104
+ overlap=False,
105
+ rank=-1,
106
+ seed=0,
107
+ ):
108
+ """Initializes the dataset with image, label, and mask loading capabilities for training/testing."""
109
+ super().__init__(
110
+ path,
111
+ img_size,
112
+ batch_size,
113
+ augment,
114
+ hyp,
115
+ rect,
116
+ image_weights,
117
+ cache_images,
118
+ single_cls,
119
+ stride,
120
+ pad,
121
+ min_items,
122
+ prefix,
123
+ rank,
124
+ seed,
125
+ )
126
+ self.downsample_ratio = downsample_ratio
127
+ self.overlap = overlap
128
+
129
+ def __getitem__(self, index):
130
+ """Returns a transformed item from the dataset at the specified index, handling indexing and image weighting."""
131
+ index = self.indices[index] # linear, shuffled, or image_weights
132
+
133
+ hyp = self.hyp
134
+ mosaic = self.mosaic and random.random() < hyp["mosaic"]
135
+ masks = []
136
+ if mosaic:
137
+ # Load mosaic
138
+ img, labels, segments = self.load_mosaic(index)
139
+ shapes = None
140
+
141
+ # MixUp augmentation
142
+ if random.random() < hyp["mixup"]:
143
+ img, labels, segments = mixup(img, labels, segments, *self.load_mosaic(random.randint(0, self.n - 1)))
144
+
145
+ else:
146
+ # Load image
147
+ img, (h0, w0), (h, w) = self.load_image(index)
148
+
149
+ # Letterbox
150
+ shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
151
+ img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
152
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
153
+
154
+ labels = self.labels[index].copy()
155
+ # [array, array, ....], array.shape=(num_points, 2), xyxyxyxy
156
+ segments = self.segments[index].copy()
157
+ if len(segments):
158
+ for i_s in range(len(segments)):
159
+ segments[i_s] = xyn2xy(
160
+ segments[i_s],
161
+ ratio[0] * w,
162
+ ratio[1] * h,
163
+ padw=pad[0],
164
+ padh=pad[1],
165
+ )
166
+ if labels.size: # normalized xywh to pixel xyxy format
167
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
168
+
169
+ if self.augment:
170
+ img, labels, segments = random_perspective(
171
+ img,
172
+ labels,
173
+ segments=segments,
174
+ degrees=hyp["degrees"],
175
+ translate=hyp["translate"],
176
+ scale=hyp["scale"],
177
+ shear=hyp["shear"],
178
+ perspective=hyp["perspective"],
179
+ )
180
+
181
+ nl = len(labels) # number of labels
182
+ if nl:
183
+ labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
184
+ if self.overlap:
185
+ masks, sorted_idx = polygons2masks_overlap(
186
+ img.shape[:2], segments, downsample_ratio=self.downsample_ratio
187
+ )
188
+ masks = masks[None] # (640, 640) -> (1, 640, 640)
189
+ labels = labels[sorted_idx]
190
+ else:
191
+ masks = polygons2masks(img.shape[:2], segments, color=1, downsample_ratio=self.downsample_ratio)
192
+
193
+ masks = (
194
+ torch.from_numpy(masks)
195
+ if len(masks)
196
+ else torch.zeros(
197
+ 1 if self.overlap else nl, img.shape[0] // self.downsample_ratio, img.shape[1] // self.downsample_ratio
198
+ )
199
+ )
200
+ # TODO: albumentations support
201
+ if self.augment:
202
+ # Albumentations
203
+ # there are some augmentation that won't change boxes and masks,
204
+ # so just be it for now.
205
+ img, labels = self.albumentations(img, labels)
206
+ nl = len(labels) # update after albumentations
207
+
208
+ # HSV color-space
209
+ augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
210
+
211
+ # Flip up-down
212
+ if random.random() < hyp["flipud"]:
213
+ img = np.flipud(img)
214
+ if nl:
215
+ labels[:, 2] = 1 - labels[:, 2]
216
+ masks = torch.flip(masks, dims=[1])
217
+
218
+ # Flip left-right
219
+ if random.random() < hyp["fliplr"]:
220
+ img = np.fliplr(img)
221
+ if nl:
222
+ labels[:, 1] = 1 - labels[:, 1]
223
+ masks = torch.flip(masks, dims=[2])
224
+
225
+ # Cutouts # labels = cutout(img, labels, p=0.5)
226
+
227
+ labels_out = torch.zeros((nl, 6))
228
+ if nl:
229
+ labels_out[:, 1:] = torch.from_numpy(labels)
230
+
231
+ # Convert
232
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
233
+ img = np.ascontiguousarray(img)
234
+
235
+ return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks)
236
+
237
+ def load_mosaic(self, index):
238
+ """Loads 1 image + 3 random images into a 4-image YOLOv5 mosaic, adjusting labels and segments accordingly."""
239
+ labels4, segments4 = [], []
240
+ s = self.img_size
241
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
242
+
243
+ # 3 additional image indices
244
+ indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
245
+ for i, index in enumerate(indices):
246
+ # Load image
247
+ img, _, (h, w) = self.load_image(index)
248
+
249
+ # place img in img4
250
+ if i == 0: # top left
251
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
252
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
253
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
254
+ elif i == 1: # top right
255
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
256
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
257
+ elif i == 2: # bottom left
258
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
259
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
260
+ elif i == 3: # bottom right
261
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
262
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
263
+
264
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
265
+ padw = x1a - x1b
266
+ padh = y1a - y1b
267
+
268
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
269
+
270
+ if labels.size:
271
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
272
+ segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
273
+ labels4.append(labels)
274
+ segments4.extend(segments)
275
+
276
+ # Concat/clip labels
277
+ labels4 = np.concatenate(labels4, 0)
278
+ for x in (labels4[:, 1:], *segments4):
279
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
280
+ # img4, labels4 = replicate(img4, labels4) # replicate
281
+
282
+ # Augment
283
+ img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp["copy_paste"])
284
+ img4, labels4, segments4 = random_perspective(
285
+ img4,
286
+ labels4,
287
+ segments4,
288
+ degrees=self.hyp["degrees"],
289
+ translate=self.hyp["translate"],
290
+ scale=self.hyp["scale"],
291
+ shear=self.hyp["shear"],
292
+ perspective=self.hyp["perspective"],
293
+ border=self.mosaic_border,
294
+ ) # border to remove
295
+ return img4, labels4, segments4
296
+
297
+ @staticmethod
298
+ def collate_fn(batch):
299
+ """Custom collation function for DataLoader, batches images, labels, paths, shapes, and segmentation masks."""
300
+ img, label, path, shapes, masks = zip(*batch) # transposed
301
+ batched_masks = torch.cat(masks, 0)
302
+ for i, l in enumerate(label):
303
+ l[:, 0] = i # add target image index for build_targets()
304
+ return torch.stack(img, 0), torch.cat(label, 0), path, shapes, batched_masks
305
+
306
+
307
+ def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
308
+ """
309
+ Args:
310
+ img_size (tuple): The image size.
311
+ polygons (np.ndarray): [N, M], N is the number of polygons,
312
+ M is the number of points(Be divided by 2).
313
+ """
314
+ mask = np.zeros(img_size, dtype=np.uint8)
315
+ polygons = np.asarray(polygons)
316
+ polygons = polygons.astype(np.int32)
317
+ shape = polygons.shape
318
+ polygons = polygons.reshape(shape[0], -1, 2)
319
+ cv2.fillPoly(mask, polygons, color=color)
320
+ nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
321
+ # NOTE: fillPoly firstly then resize is trying the keep the same way
322
+ # of loss calculation when mask-ratio=1.
323
+ mask = cv2.resize(mask, (nw, nh))
324
+ return mask
325
+
326
+
327
+ def polygons2masks(img_size, polygons, color, downsample_ratio=1):
328
+ """
329
+ Args:
330
+ img_size (tuple): The image size.
331
+ polygons (list[np.ndarray]): each polygon is [N, M],
332
+ N is the number of polygons,
333
+ M is the number of points(Be divided by 2).
334
+ """
335
+ masks = []
336
+ for si in range(len(polygons)):
337
+ mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
338
+ masks.append(mask)
339
+ return np.array(masks)
340
+
341
+
342
+ def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
343
+ """Return a (640, 640) overlap mask."""
344
+ masks = np.zeros(
345
+ (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
346
+ dtype=np.int32 if len(segments) > 255 else np.uint8,
347
+ )
348
+ areas = []
349
+ ms = []
350
+ for si in range(len(segments)):
351
+ mask = polygon2mask(
352
+ img_size,
353
+ [segments[si].reshape(-1)],
354
+ downsample_ratio=downsample_ratio,
355
+ color=1,
356
+ )
357
+ ms.append(mask)
358
+ areas.append(mask.sum())
359
+ areas = np.asarray(areas)
360
+ index = np.argsort(-areas)
361
+ ms = np.array(ms)[index]
362
+ for i in range(len(segments)):
363
+ mask = ms[i] * (i + 1)
364
+ masks = masks + mask
365
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
366
+ return masks, index
@@ -0,0 +1,160 @@
1
+ # Ultralytics YOLOv5 🚀, AGPL-3.0 license
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def crop_mask(masks, boxes):
10
+ """
11
+ "Crop" predicted masks by zeroing out everything not in the predicted bbox. Vectorized by Chong (thanks Chong).
12
+
13
+ Args:
14
+ - masks should be a size [n, h, w] tensor of masks
15
+ - boxes should be a size [n, 4] tensor of bbox coords in relative point form
16
+ """
17
+ n, h, w = masks.shape
18
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
19
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
20
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
21
+
22
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
23
+
24
+
25
+ def process_mask_upsample(protos, masks_in, bboxes, shape):
26
+ """
27
+ Crop after upsample.
28
+ protos: [mask_dim, mask_h, mask_w]
29
+ masks_in: [n, mask_dim], n is number of masks after nms
30
+ bboxes: [n, 4], n is number of masks after nms
31
+ shape: input_image_size, (h, w).
32
+
33
+ return: h, w, n
34
+ """
35
+ c, mh, mw = protos.shape # CHW
36
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
37
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
38
+ masks = crop_mask(masks, bboxes) # CHW
39
+ return masks.gt_(0.5)
40
+
41
+
42
+ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
43
+ """
44
+ Crop before upsample.
45
+ proto_out: [mask_dim, mask_h, mask_w]
46
+ out_masks: [n, mask_dim], n is number of masks after nms
47
+ bboxes: [n, 4], n is number of masks after nms
48
+ shape:input_image_size, (h, w).
49
+
50
+ return: h, w, n
51
+ """
52
+ c, mh, mw = protos.shape # CHW
53
+ ih, iw = shape
54
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
55
+
56
+ downsampled_bboxes = bboxes.clone()
57
+ downsampled_bboxes[:, 0] *= mw / iw
58
+ downsampled_bboxes[:, 2] *= mw / iw
59
+ downsampled_bboxes[:, 3] *= mh / ih
60
+ downsampled_bboxes[:, 1] *= mh / ih
61
+
62
+ masks = crop_mask(masks, downsampled_bboxes) # CHW
63
+ if upsample:
64
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
65
+ return masks.gt_(0.5)
66
+
67
+
68
+ def process_mask_native(protos, masks_in, bboxes, shape):
69
+ """
70
+ Crop after upsample.
71
+ protos: [mask_dim, mask_h, mask_w]
72
+ masks_in: [n, mask_dim], n is number of masks after nms
73
+ bboxes: [n, 4], n is number of masks after nms
74
+ shape: input_image_size, (h, w).
75
+
76
+ return: h, w, n
77
+ """
78
+ c, mh, mw = protos.shape # CHW
79
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
80
+ gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
81
+ pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
82
+ top, left = int(pad[1]), int(pad[0]) # y, x
83
+ bottom, right = int(mh - pad[1]), int(mw - pad[0])
84
+ masks = masks[:, top:bottom, left:right]
85
+
86
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
87
+ masks = crop_mask(masks, bboxes) # CHW
88
+ return masks.gt_(0.5)
89
+
90
+
91
+ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
92
+ """
93
+ img1_shape: model input shape, [h, w]
94
+ img0_shape: origin pic shape, [h, w, 3]
95
+ masks: [h, w, num].
96
+ """
97
+ # Rescale coordinates (xyxy) from im1_shape to im0_shape
98
+ if ratio_pad is None: # calculate from im0_shape
99
+ gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
100
+ pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
101
+ else:
102
+ pad = ratio_pad[1]
103
+ top, left = int(pad[1]), int(pad[0]) # y, x
104
+ bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
105
+
106
+ if len(masks.shape) < 2:
107
+ raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
108
+ masks = masks[top:bottom, left:right]
109
+ # masks = masks.permute(2, 0, 1).contiguous()
110
+ # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
111
+ # masks = masks.permute(1, 2, 0).contiguous()
112
+ masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
113
+
114
+ if len(masks.shape) == 2:
115
+ masks = masks[:, :, None]
116
+ return masks
117
+
118
+
119
+ def mask_iou(mask1, mask2, eps=1e-7):
120
+ """
121
+ mask1: [N, n] m1 means number of predicted objects
122
+ mask2: [M, n] m2 means number of gt objects
123
+ Note: n means image_w x image_h.
124
+
125
+ return: masks iou, [N, M]
126
+ """
127
+ intersection = torch.matmul(mask1, mask2.t()).clamp(0)
128
+ union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
129
+ return intersection / (union + eps)
130
+
131
+
132
+ def masks_iou(mask1, mask2, eps=1e-7):
133
+ """
134
+ mask1: [N, n] m1 means number of predicted objects
135
+ mask2: [N, n] m2 means number of gt objects
136
+ Note: n means image_w x image_h.
137
+
138
+ return: masks iou, (N, )
139
+ """
140
+ intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
141
+ union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
142
+ return intersection / (union + eps)
143
+
144
+
145
+ def masks2segments(masks, strategy="largest"):
146
+ """Converts binary (n,160,160) masks to polygon segments with options for concatenation or selecting the largest
147
+ segment.
148
+ """
149
+ segments = []
150
+ for x in masks.int().cpu().numpy().astype("uint8"):
151
+ c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
152
+ if c:
153
+ if strategy == "concat": # concatenate all segments
154
+ c = np.concatenate([x.reshape(-1, 2) for x in c])
155
+ elif strategy == "largest": # select largest segment
156
+ c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
157
+ else:
158
+ c = np.zeros((0, 2)) # no segments found
159
+ segments.append(c.astype("float32"))
160
+ return segments