magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. magic_pdf/cli/magicpdf.py +18 -7
  2. magic_pdf/libs/config_reader.py +10 -0
  3. magic_pdf/libs/version.py +1 -1
  4. magic_pdf/model/__init__.py +1 -0
  5. magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
  6. magic_pdf/model/model_list.py +1 -0
  7. magic_pdf/model/pdf_extract_kit.py +196 -0
  8. magic_pdf/model/pek_sub_modules/__init__.py +0 -0
  9. magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
  10. magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
  11. magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
  12. magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
  13. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
  14. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
  15. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
  16. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
  17. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
  18. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
  19. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
  20. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
  21. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
  22. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
  23. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
  24. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
  25. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
  26. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
  27. magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
  28. magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
  29. magic_pdf/model/pek_sub_modules/post_process.py +36 -0
  30. magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
  31. magic_pdf/model/pp_structure_v2.py +7 -0
  32. magic_pdf/pipe/AbsPipe.py +8 -14
  33. magic_pdf/pipe/OCRPipe.py +12 -8
  34. magic_pdf/pipe/TXTPipe.py +12 -8
  35. magic_pdf/pipe/UNIPipe.py +9 -7
  36. magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
  37. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
  38. magic_pdf/resources/model_config/model_configs.yaml +9 -0
  39. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
  40. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
  41. magic_pdf/model/360_layout_analysis.py +0 -8
  42. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
  43. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
  44. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
  45. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1236 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import colorsys
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+ from enum import Enum, unique
7
+ import cv2
8
+ import matplotlib as mpl
9
+ import matplotlib.colors as mplc
10
+ import matplotlib.figure as mplfigure
11
+ import pycocotools.mask as mask_util
12
+ import torch
13
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
14
+ from PIL import Image
15
+
16
+ from detectron2.data import MetadataCatalog
17
+ from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
18
+ from detectron2.utils.file_io import PathManager
19
+
20
+ from detectron2.utils.colormap import random_color
21
+
22
+ import pdb
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ __all__ = ["ColorMode", "VisImage", "Visualizer"]
27
+
28
+
29
+ _SMALL_OBJECT_AREA_THRESH = 1000
30
+ _LARGE_MASK_AREA_THRESH = 120000
31
+ _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
32
+ _BLACK = (0, 0, 0)
33
+ _RED = (1.0, 0, 0)
34
+
35
+ _KEYPOINT_THRESHOLD = 0.05
36
+
37
+ #CLASS_NAMES = ["footnote", "footer", "header"]
38
+
39
+ @unique
40
+ class ColorMode(Enum):
41
+ """
42
+ Enum of different color modes to use for instance visualizations.
43
+ """
44
+
45
+ IMAGE = 0
46
+ """
47
+ Picks a random color for every instance and overlay segmentations with low opacity.
48
+ """
49
+ SEGMENTATION = 1
50
+ """
51
+ Let instances of the same category have similar colors
52
+ (from metadata.thing_colors), and overlay them with
53
+ high opacity. This provides more attention on the quality of segmentation.
54
+ """
55
+ IMAGE_BW = 2
56
+ """
57
+ Same as IMAGE, but convert all areas without masks to gray-scale.
58
+ Only available for drawing per-instance mask predictions.
59
+ """
60
+
61
+
62
+ class GenericMask:
63
+ """
64
+ Attribute:
65
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
66
+ Each ndarray has format [x, y, x, y, ...]
67
+ mask (ndarray): a binary mask
68
+ """
69
+
70
+ def __init__(self, mask_or_polygons, height, width):
71
+ self._mask = self._polygons = self._has_holes = None
72
+ self.height = height
73
+ self.width = width
74
+
75
+ m = mask_or_polygons
76
+ if isinstance(m, dict):
77
+ # RLEs
78
+ assert "counts" in m and "size" in m
79
+ if isinstance(m["counts"], list): # uncompressed RLEs
80
+ h, w = m["size"]
81
+ assert h == height and w == width
82
+ m = mask_util.frPyObjects(m, h, w)
83
+ self._mask = mask_util.decode(m)[:, :]
84
+ return
85
+
86
+ if isinstance(m, list): # list[ndarray]
87
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
88
+ return
89
+
90
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
91
+ assert m.shape[1] != 2, m.shape
92
+ assert m.shape == (
93
+ height,
94
+ width,
95
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
96
+ self._mask = m.astype("uint8")
97
+ return
98
+
99
+ raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
100
+
101
+ @property
102
+ def mask(self):
103
+ if self._mask is None:
104
+ self._mask = self.polygons_to_mask(self._polygons)
105
+ return self._mask
106
+
107
+ @property
108
+ def polygons(self):
109
+ if self._polygons is None:
110
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
111
+ return self._polygons
112
+
113
+ @property
114
+ def has_holes(self):
115
+ if self._has_holes is None:
116
+ if self._mask is not None:
117
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
118
+ else:
119
+ self._has_holes = False # if original format is polygon, does not have holes
120
+ return self._has_holes
121
+
122
+ def mask_to_polygons(self, mask):
123
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
124
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
125
+ # Internal contours (holes) are placed in hierarchy-2.
126
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
127
+ mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
128
+ res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
129
+ hierarchy = res[-1]
130
+ if hierarchy is None: # empty mask
131
+ return [], False
132
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
133
+ res = res[-2]
134
+ res = [x.flatten() for x in res]
135
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
136
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
137
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
138
+ res = [x + 0.5 for x in res if len(x) >= 6]
139
+ return res, has_holes
140
+
141
+ def polygons_to_mask(self, polygons):
142
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
143
+ rle = mask_util.merge(rle)
144
+ return mask_util.decode(rle)[:, :]
145
+
146
+ def area(self):
147
+ return self.mask.sum()
148
+
149
+ def bbox(self):
150
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
151
+ p = mask_util.merge(p)
152
+ bbox = mask_util.toBbox(p)
153
+ bbox[2] += bbox[0]
154
+ bbox[3] += bbox[1]
155
+ return bbox
156
+
157
+
158
+ class _PanopticPrediction:
159
+ """
160
+ Unify different panoptic annotation/prediction formats
161
+ """
162
+
163
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
164
+ if segments_info is None:
165
+ assert metadata is not None
166
+ # If "segments_info" is None, we assume "panoptic_img" is a
167
+ # H*W int32 image storing the panoptic_id in the format of
168
+ # category_id * label_divisor + instance_id. We reserve -1 for
169
+ # VOID label.
170
+ label_divisor = metadata.label_divisor
171
+ segments_info = []
172
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
173
+ if panoptic_label == -1:
174
+ # VOID region.
175
+ continue
176
+ pred_class = panoptic_label // label_divisor
177
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
178
+ segments_info.append(
179
+ {
180
+ "id": int(panoptic_label),
181
+ "category_id": int(pred_class),
182
+ "isthing": bool(isthing),
183
+ }
184
+ )
185
+ del metadata
186
+
187
+ self._seg = panoptic_seg
188
+
189
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
190
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
191
+ areas = areas.numpy()
192
+ sorted_idxs = np.argsort(-areas)
193
+ self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
194
+ self._seg_ids = self._seg_ids.tolist()
195
+ for sid, area in zip(self._seg_ids, self._seg_areas):
196
+ if sid in self._sinfo:
197
+ self._sinfo[sid]["area"] = float(area)
198
+
199
+ def non_empty_mask(self):
200
+ """
201
+ Returns:
202
+ (H, W) array, a mask for all pixels that have a prediction
203
+ """
204
+ empty_ids = []
205
+ for id in self._seg_ids:
206
+ if id not in self._sinfo:
207
+ empty_ids.append(id)
208
+ if len(empty_ids) == 0:
209
+ return np.zeros(self._seg.shape, dtype=np.uint8)
210
+ assert (
211
+ len(empty_ids) == 1
212
+ ), ">1 ids corresponds to no labels. This is currently not supported"
213
+ return (self._seg != empty_ids[0]).numpy().astype(np.bool)
214
+
215
+ def semantic_masks(self):
216
+ for sid in self._seg_ids:
217
+ sinfo = self._sinfo.get(sid)
218
+ if sinfo is None or sinfo["isthing"]:
219
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
220
+ continue
221
+ yield (self._seg == sid).numpy().astype(np.bool), sinfo
222
+
223
+ def instance_masks(self):
224
+ for sid in self._seg_ids:
225
+ sinfo = self._sinfo.get(sid)
226
+ if sinfo is None or not sinfo["isthing"]:
227
+ continue
228
+ mask = (self._seg == sid).numpy().astype(np.bool)
229
+ if mask.sum() > 0:
230
+ yield mask, sinfo
231
+
232
+
233
+ def _create_text_labels(classes, scores, class_names, is_crowd=None):
234
+ """
235
+ Args:
236
+ classes (list[int] or None):
237
+ scores (list[float] or None):
238
+ class_names (list[str] or None):
239
+ is_crowd (list[bool] or None):
240
+
241
+ Returns:
242
+ list[str] or None
243
+ """
244
+ #class_names = CLASS_NAMES
245
+ labels = None
246
+ if classes is not None:
247
+ if class_names is not None and len(class_names) > 0:
248
+ labels = [class_names[i] for i in classes]
249
+ else:
250
+ labels = [str(i) for i in classes]
251
+
252
+ if scores is not None:
253
+ if labels is None:
254
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
255
+ else:
256
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
257
+ if labels is not None and is_crowd is not None:
258
+ labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
259
+ return labels
260
+
261
+
262
+ class VisImage:
263
+ def __init__(self, img, scale=1.0):
264
+ """
265
+ Args:
266
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
267
+ scale (float): scale the input image
268
+ """
269
+ self.img = img
270
+ self.scale = scale
271
+ self.width, self.height = img.shape[1], img.shape[0]
272
+ self._setup_figure(img)
273
+
274
+ def _setup_figure(self, img):
275
+ """
276
+ Args:
277
+ Same as in :meth:`__init__()`.
278
+
279
+ Returns:
280
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
281
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
282
+ """
283
+ fig = mplfigure.Figure(frameon=False)
284
+ self.dpi = fig.get_dpi()
285
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
286
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
287
+ fig.set_size_inches(
288
+ (self.width * self.scale + 1e-2) / self.dpi,
289
+ (self.height * self.scale + 1e-2) / self.dpi,
290
+ )
291
+ self.canvas = FigureCanvasAgg(fig)
292
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
293
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
294
+ ax.axis("off")
295
+ self.fig = fig
296
+ self.ax = ax
297
+ self.reset_image(img)
298
+
299
+ def reset_image(self, img):
300
+ """
301
+ Args:
302
+ img: same as in __init__
303
+ """
304
+ img = img.astype("uint8")
305
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
306
+
307
+ def save(self, filepath):
308
+ """
309
+ Args:
310
+ filepath (str): a string that contains the absolute path, including the file name, where
311
+ the visualized image will be saved.
312
+ """
313
+ self.fig.savefig(filepath)
314
+
315
+ def get_image(self):
316
+ """
317
+ Returns:
318
+ ndarray:
319
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
320
+ The shape is scaled w.r.t the input image using the given `scale` argument.
321
+ """
322
+ canvas = self.canvas
323
+ s, (width, height) = canvas.print_to_buffer()
324
+ # buf = io.BytesIO() # works for cairo backend
325
+ # canvas.print_rgba(buf)
326
+ # width, height = self.width, self.height
327
+ # s = buf.getvalue()
328
+
329
+ buffer = np.frombuffer(s, dtype="uint8")
330
+
331
+ img_rgba = buffer.reshape(height, width, 4)
332
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
333
+ return rgb.astype("uint8")
334
+
335
+
336
+ class Visualizer:
337
+ """
338
+ Visualizer that draws data about detection/segmentation on images.
339
+
340
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
341
+ that draw primitive objects to images, as well as high-level wrappers like
342
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
343
+ that draw composite data in some pre-defined style.
344
+
345
+ Note that the exact visualization style for the high-level wrappers are subject to change.
346
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
347
+ of objects themselves (e.g. when the object is too small) may change according
348
+ to different heuristics, as long as the results still look visually reasonable.
349
+
350
+ To obtain a consistent style, you can implement custom drawing functions with the
351
+ abovementioned primitive methods instead. If you need more customized visualization
352
+ styles, you can process the data yourself following their format documented in
353
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
354
+ intend to satisfy everyone's preference on drawing styles.
355
+
356
+ This visualizer focuses on high rendering quality rather than performance. It is not
357
+ designed to be used for real-time applications.
358
+ """
359
+
360
+ # TODO implement a fast, rasterized version using OpenCV
361
+
362
+ def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
363
+ """
364
+ Args:
365
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
366
+ the height and width of the image respectively. C is the number of
367
+ color channels. The image is required to be in RGB format since that
368
+ is a requirement of the Matplotlib library. The image is also expected
369
+ to be in the range [0, 255].
370
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
371
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
372
+ instances on an image.
373
+ """
374
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
375
+ if metadata is None:
376
+ metadata = MetadataCatalog.get("__nonexist__")
377
+ self.metadata = metadata
378
+ self.output = VisImage(self.img, scale=scale)
379
+ self.cpu_device = torch.device("cpu")
380
+
381
+ # too small texts are useless, therefore clamp to 9
382
+ self._default_font_size = max(
383
+ np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
384
+ )
385
+ self._instance_mode = instance_mode
386
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
387
+
388
+ def draw_instance_predictions(self, predictions):
389
+ """
390
+ Draw instance-level prediction results on an image.
391
+
392
+ Args:
393
+ predictions (Instances): the output of an instance detection/segmentation
394
+ model. Following fields will be used to draw:
395
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
396
+
397
+ Returns:
398
+ output (VisImage): image object with visualizations.
399
+ """
400
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
401
+ scores = predictions.scores if predictions.has("scores") else None
402
+ classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
403
+ labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
404
+ keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
405
+
406
+ if predictions.has("pred_masks"):
407
+ masks = np.asarray(predictions.pred_masks)
408
+ masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
409
+ else:
410
+ masks = None
411
+
412
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
413
+ colors = [
414
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
415
+ ]
416
+ alpha = 0.8
417
+ else:
418
+ colors = None
419
+ alpha = 0.5
420
+
421
+ if self._instance_mode == ColorMode.IMAGE_BW:
422
+ self.output.reset_image(
423
+ self._create_grayscale_image(
424
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
425
+ if predictions.has("pred_masks")
426
+ else None
427
+ )
428
+ )
429
+ alpha = 0.3
430
+
431
+ self.overlay_instances(
432
+ masks=masks,
433
+ boxes=boxes,
434
+ labels=labels,
435
+ keypoints=keypoints,
436
+ assigned_colors=colors,
437
+ alpha=alpha,
438
+ )
439
+ return self.output
440
+
441
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
442
+ """
443
+ Draw semantic segmentation predictions/labels.
444
+
445
+ Args:
446
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
447
+ Each value is the integer label of the pixel.
448
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
449
+ alpha (float): the larger it is, the more opaque the segmentations are.
450
+
451
+ Returns:
452
+ output (VisImage): image object with visualizations.
453
+ """
454
+ if isinstance(sem_seg, torch.Tensor):
455
+ sem_seg = sem_seg.numpy()
456
+ labels, areas = np.unique(sem_seg, return_counts=True)
457
+ sorted_idxs = np.argsort(-areas).tolist()
458
+ labels = labels[sorted_idxs]
459
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
460
+ try:
461
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
462
+ except (AttributeError, IndexError):
463
+ mask_color = None
464
+
465
+ binary_mask = (sem_seg == label).astype(np.uint8)
466
+ text = self.metadata.stuff_classes[label]
467
+ self.draw_binary_mask(
468
+ binary_mask,
469
+ color=mask_color,
470
+ edge_color=_OFF_WHITE,
471
+ text=text,
472
+ alpha=alpha,
473
+ area_threshold=area_threshold,
474
+ )
475
+ return self.output
476
+
477
+ def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
478
+ """
479
+ Draw panoptic prediction annotations or results.
480
+
481
+ Args:
482
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
483
+ segment.
484
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
485
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
486
+ If None, category id of each pixel is computed by
487
+ ``pixel // metadata.label_divisor``.
488
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
489
+
490
+ Returns:
491
+ output (VisImage): image object with visualizations.
492
+ """
493
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
494
+
495
+ if self._instance_mode == ColorMode.IMAGE_BW:
496
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
497
+
498
+ # draw mask for all semantic segments first i.e. "stuff"
499
+ for mask, sinfo in pred.semantic_masks():
500
+ category_idx = sinfo["category_id"]
501
+ try:
502
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
503
+ except AttributeError:
504
+ mask_color = None
505
+
506
+ text = self.metadata.stuff_classes[category_idx]
507
+ self.draw_binary_mask(
508
+ mask,
509
+ color=mask_color,
510
+ edge_color=_OFF_WHITE,
511
+ text=text,
512
+ alpha=alpha,
513
+ area_threshold=area_threshold,
514
+ )
515
+
516
+ # draw mask for all instances second
517
+ all_instances = list(pred.instance_masks())
518
+ if len(all_instances) == 0:
519
+ return self.output
520
+ masks, sinfo = list(zip(*all_instances))
521
+ category_ids = [x["category_id"] for x in sinfo]
522
+
523
+ try:
524
+ scores = [x["score"] for x in sinfo]
525
+ except KeyError:
526
+ scores = None
527
+ labels = _create_text_labels(
528
+ category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
529
+ )
530
+
531
+ try:
532
+ colors = [
533
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
534
+ ]
535
+ except AttributeError:
536
+ colors = None
537
+ self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
538
+
539
+ return self.output
540
+
541
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
542
+
543
+ def draw_dataset_dict(self, dic):
544
+ """
545
+ Draw annotations/segmentaions in Detectron2 Dataset format.
546
+
547
+ Args:
548
+ dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
549
+
550
+ Returns:
551
+ output (VisImage): image object with visualizations.
552
+ """
553
+ annos = dic.get("annotations", None)
554
+ if annos:
555
+ if "segmentation" in annos[0]:
556
+ masks = [x["segmentation"] for x in annos]
557
+ else:
558
+ masks = None
559
+ if "keypoints" in annos[0]:
560
+ keypts = [x["keypoints"] for x in annos]
561
+ keypts = np.array(keypts).reshape(len(annos), -1, 3)
562
+ else:
563
+ keypts = None
564
+
565
+ boxes = [
566
+ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
567
+ if len(x["bbox"]) == 4
568
+ else x["bbox"]
569
+ for x in annos
570
+ ]
571
+
572
+ colors = None
573
+ category_ids = [x["category_id"] for x in annos]
574
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
575
+ colors = [
576
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
577
+ for c in category_ids
578
+ ]
579
+ names = self.metadata.get("thing_classes", None)
580
+ labels = _create_text_labels(
581
+ category_ids,
582
+ scores=None,
583
+ class_names=names,
584
+ is_crowd=[x.get("iscrowd", 0) for x in annos],
585
+ )
586
+ self.overlay_instances(
587
+ labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
588
+ )
589
+
590
+ sem_seg = dic.get("sem_seg", None)
591
+ if sem_seg is None and "sem_seg_file_name" in dic:
592
+ with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
593
+ sem_seg = Image.open(f)
594
+ sem_seg = np.asarray(sem_seg, dtype="uint8")
595
+ if sem_seg is not None:
596
+ self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
597
+
598
+ pan_seg = dic.get("pan_seg", None)
599
+ if pan_seg is None and "pan_seg_file_name" in dic:
600
+ with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
601
+ pan_seg = Image.open(f)
602
+ pan_seg = np.asarray(pan_seg)
603
+ from panopticapi.utils import rgb2id
604
+
605
+ pan_seg = rgb2id(pan_seg)
606
+ if pan_seg is not None:
607
+ segments_info = dic["segments_info"]
608
+ pan_seg = torch.tensor(pan_seg)
609
+ self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
610
+ return self.output
611
+
612
+ def overlay_instances(
613
+ self,
614
+ *,
615
+ boxes=None,
616
+ labels=None,
617
+ masks=None,
618
+ keypoints=None,
619
+ assigned_colors=None,
620
+ alpha=0.5,
621
+ ):
622
+ """
623
+ Args:
624
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
625
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
626
+ or a :class:`RotatedBoxes`,
627
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
628
+ for the N objects in a single image,
629
+ labels (list[str]): the text to be displayed for each instance.
630
+ masks (masks-like object): Supported types are:
631
+
632
+ * :class:`detectron2.structures.PolygonMasks`,
633
+ :class:`detectron2.structures.BitMasks`.
634
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
635
+ The first level of the list corresponds to individual instances. The second
636
+ level to all the polygon that compose the instance, and the third level
637
+ to the polygon coordinates. The third level should have the format of
638
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
639
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
640
+ * list[dict]: each dict is a COCO-style RLE.
641
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
642
+ where the N is the number of instances and K is the number of keypoints.
643
+ The last dimension corresponds to (x, y, visibility or score).
644
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
645
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
646
+ for full list of formats that the colors are accepted in.
647
+
648
+ Returns:
649
+ output (VisImage): image object with visualizations.
650
+ """
651
+ num_instances = 0
652
+ if boxes is not None:
653
+ boxes = self._convert_boxes(boxes)
654
+ num_instances = len(boxes)
655
+ if masks is not None:
656
+ masks = self._convert_masks(masks)
657
+ if num_instances:
658
+ assert len(masks) == num_instances
659
+ else:
660
+ num_instances = len(masks)
661
+ if keypoints is not None:
662
+ if num_instances:
663
+ assert len(keypoints) == num_instances
664
+ else:
665
+ num_instances = len(keypoints)
666
+ keypoints = self._convert_keypoints(keypoints)
667
+ if labels is not None:
668
+ assert len(labels) == num_instances
669
+ if assigned_colors is None:
670
+ assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
671
+ if num_instances == 0:
672
+ return self.output
673
+ if boxes is not None and boxes.shape[1] == 5:
674
+ return self.overlay_rotated_instances(
675
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
676
+ )
677
+
678
+ # Display in largest to smallest order to reduce occlusion.
679
+ areas = None
680
+ if boxes is not None:
681
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
682
+ elif masks is not None:
683
+ areas = np.asarray([x.area() for x in masks])
684
+
685
+ if areas is not None:
686
+ sorted_idxs = np.argsort(-areas).tolist()
687
+ # Re-order overlapped instances in descending order.
688
+ boxes = boxes[sorted_idxs] if boxes is not None else None
689
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
690
+ masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
691
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
692
+ keypoints = keypoints[sorted_idxs] if keypoints is not None else None
693
+
694
+ for i in range(num_instances):
695
+ color = assigned_colors[i]
696
+ if boxes is not None:
697
+ self.draw_box(boxes[i], edge_color=color)
698
+
699
+ if masks is not None:
700
+ for segment in masks[i].polygons:
701
+ self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
702
+
703
+ if labels is not None:
704
+ # first get a box
705
+ if boxes is not None:
706
+ x0, y0, x1, y1 = boxes[i]
707
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
708
+ horiz_align = "left"
709
+ elif masks is not None:
710
+ # skip small mask without polygon
711
+ if len(masks[i].polygons) == 0:
712
+ continue
713
+
714
+ x0, y0, x1, y1 = masks[i].bbox()
715
+
716
+ # draw text in the center (defined by median) when box is not drawn
717
+ # median is less sensitive to outliers.
718
+ text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
719
+ horiz_align = "center"
720
+ else:
721
+ continue # drawing the box confidence for keypoints isn't very useful.
722
+ # for small objects, draw text at the side to avoid occlusion
723
+ instance_area = (y1 - y0) * (x1 - x0)
724
+ if (
725
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
726
+ or y1 - y0 < 40 * self.output.scale
727
+ ):
728
+ if y1 >= self.output.height - 5:
729
+ text_pos = (x1, y0)
730
+ else:
731
+ text_pos = (x0, y1)
732
+
733
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
734
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
735
+ font_size = (
736
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
737
+ * 0.5
738
+ * self._default_font_size
739
+ )
740
+ self.draw_text(
741
+ labels[i],
742
+ text_pos,
743
+ color=lighter_color,
744
+ horizontal_alignment=horiz_align,
745
+ font_size=font_size,
746
+ )
747
+
748
+ # draw keypoints
749
+ if keypoints is not None:
750
+ for keypoints_per_instance in keypoints:
751
+ self.draw_and_connect_keypoints(keypoints_per_instance)
752
+
753
+ return self.output
754
+
755
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
756
+ """
757
+ Args:
758
+ boxes (ndarray): an Nx5 numpy array of
759
+ (x_center, y_center, width, height, angle_degrees) format
760
+ for the N objects in a single image.
761
+ labels (list[str]): the text to be displayed for each instance.
762
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
763
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
764
+ for full list of formats that the colors are accepted in.
765
+
766
+ Returns:
767
+ output (VisImage): image object with visualizations.
768
+ """
769
+ num_instances = len(boxes)
770
+
771
+ if assigned_colors is None:
772
+ assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
773
+ if num_instances == 0:
774
+ return self.output
775
+
776
+ # Display in largest to smallest order to reduce occlusion.
777
+ if boxes is not None:
778
+ areas = boxes[:, 2] * boxes[:, 3]
779
+
780
+ sorted_idxs = np.argsort(-areas).tolist()
781
+ # Re-order overlapped instances in descending order.
782
+ boxes = boxes[sorted_idxs]
783
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
784
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
785
+
786
+ for i in range(num_instances):
787
+ self.draw_rotated_box_with_label(
788
+ boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
789
+ )
790
+
791
+ return self.output
792
+
793
+ def draw_and_connect_keypoints(self, keypoints):
794
+ """
795
+ Draws keypoints of an instance and follows the rules for keypoint connections
796
+ to draw lines between appropriate keypoints. This follows color heuristics for
797
+ line color.
798
+
799
+ Args:
800
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
801
+ and the last dimension corresponds to (x, y, probability).
802
+
803
+ Returns:
804
+ output (VisImage): image object with visualizations.
805
+ """
806
+ visible = {}
807
+ keypoint_names = self.metadata.get("keypoint_names")
808
+ for idx, keypoint in enumerate(keypoints):
809
+ # draw keypoint
810
+ x, y, prob = keypoint
811
+ if prob > self.keypoint_threshold:
812
+ self.draw_circle((x, y), color=_RED)
813
+ if keypoint_names:
814
+ keypoint_name = keypoint_names[idx]
815
+ visible[keypoint_name] = (x, y)
816
+
817
+ if self.metadata.get("keypoint_connection_rules"):
818
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
819
+ if kp0 in visible and kp1 in visible:
820
+ x0, y0 = visible[kp0]
821
+ x1, y1 = visible[kp1]
822
+ color = tuple(x / 255.0 for x in color)
823
+ self.draw_line([x0, x1], [y0, y1], color=color)
824
+
825
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
826
+ # Note that this strategy is specific to person keypoints.
827
+ # For other keypoints, it should just do nothing
828
+ try:
829
+ ls_x, ls_y = visible["left_shoulder"]
830
+ rs_x, rs_y = visible["right_shoulder"]
831
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
832
+ except KeyError:
833
+ pass
834
+ else:
835
+ # draw line from nose to mid-shoulder
836
+ nose_x, nose_y = visible.get("nose", (None, None))
837
+ if nose_x is not None:
838
+ self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
839
+
840
+ try:
841
+ # draw line from mid-shoulder to mid-hip
842
+ lh_x, lh_y = visible["left_hip"]
843
+ rh_x, rh_y = visible["right_hip"]
844
+ except KeyError:
845
+ pass
846
+ else:
847
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
848
+ self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
849
+ return self.output
850
+
851
+ """
852
+ Primitive drawing functions:
853
+ """
854
+
855
+ def draw_text(
856
+ self,
857
+ text,
858
+ position,
859
+ *,
860
+ font_size=None,
861
+ color="g",
862
+ horizontal_alignment="center",
863
+ rotation=0,
864
+ ):
865
+ """
866
+ Args:
867
+ text (str): class label
868
+ position (tuple): a tuple of the x and y coordinates to place text on image.
869
+ font_size (int, optional): font of the text. If not provided, a font size
870
+ proportional to the image width is calculated and used.
871
+ color: color of the text. Refer to `matplotlib.colors` for full list
872
+ of formats that are accepted.
873
+ horizontal_alignment (str): see `matplotlib.text.Text`
874
+ rotation: rotation angle in degrees CCW
875
+
876
+ Returns:
877
+ output (VisImage): image object with text drawn.
878
+ """
879
+ if not font_size:
880
+ font_size = self._default_font_size
881
+
882
+ # since the text background is dark, we don't want the text to be dark
883
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
884
+ color[np.argmax(color)] = max(0.8, np.max(color))
885
+
886
+ x, y = position
887
+ self.output.ax.text(
888
+ x,
889
+ y,
890
+ text,
891
+ size=font_size * self.output.scale,
892
+ family="sans-serif",
893
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
894
+ verticalalignment="top",
895
+ horizontalalignment=horizontal_alignment,
896
+ color=color,
897
+ zorder=10,
898
+ rotation=rotation,
899
+ )
900
+ return self.output
901
+
902
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
903
+ """
904
+ Args:
905
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
906
+ are the coordinates of the image's top left corner. x1 and y1 are the
907
+ coordinates of the image's bottom right corner.
908
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
909
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
910
+ for full list of formats that are accepted.
911
+ line_style (string): the string to use to create the outline of the boxes.
912
+
913
+ Returns:
914
+ output (VisImage): image object with box drawn.
915
+ """
916
+ x0, y0, x1, y1 = box_coord
917
+ width = x1 - x0
918
+ height = y1 - y0
919
+
920
+ linewidth = max(self._default_font_size / 4, 1)
921
+
922
+ self.output.ax.add_patch(
923
+ mpl.patches.Rectangle(
924
+ (x0, y0),
925
+ width,
926
+ height,
927
+ fill=False,
928
+ edgecolor=edge_color,
929
+ linewidth=linewidth * self.output.scale,
930
+ alpha=alpha,
931
+ linestyle=line_style,
932
+ )
933
+ )
934
+ return self.output
935
+
936
+ def draw_rotated_box_with_label(
937
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
938
+ ):
939
+ """
940
+ Draw a rotated box with label on its top-left corner.
941
+
942
+ Args:
943
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
944
+ where cnt_x and cnt_y are the center coordinates of the box.
945
+ w and h are the width and height of the box. angle represents how
946
+ many degrees the box is rotated CCW with regard to the 0-degree box.
947
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
948
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
949
+ for full list of formats that are accepted.
950
+ line_style (string): the string to use to create the outline of the boxes.
951
+ label (string): label for rotated box. It will not be rendered when set to None.
952
+
953
+ Returns:
954
+ output (VisImage): image object with box drawn.
955
+ """
956
+ cnt_x, cnt_y, w, h, angle = rotated_box
957
+ area = w * h
958
+ # use thinner lines when the box is small
959
+ linewidth = self._default_font_size / (
960
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
961
+ )
962
+
963
+ theta = angle * math.pi / 180.0
964
+ c = math.cos(theta)
965
+ s = math.sin(theta)
966
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
967
+ # x: left->right ; y: top->down
968
+ rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
969
+ for k in range(4):
970
+ j = (k + 1) % 4
971
+ self.draw_line(
972
+ [rotated_rect[k][0], rotated_rect[j][0]],
973
+ [rotated_rect[k][1], rotated_rect[j][1]],
974
+ color=edge_color,
975
+ linestyle="--" if k == 1 else line_style,
976
+ linewidth=linewidth,
977
+ )
978
+
979
+ if label is not None:
980
+ text_pos = rotated_rect[1] # topleft corner
981
+
982
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
983
+ label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
984
+ font_size = (
985
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
986
+ )
987
+ self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
988
+
989
+ return self.output
990
+
991
+ def draw_circle(self, circle_coord, color, radius=3):
992
+ """
993
+ Args:
994
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
995
+ of the center of the circle.
996
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
997
+ formats that are accepted.
998
+ radius (int): radius of the circle.
999
+
1000
+ Returns:
1001
+ output (VisImage): image object with box drawn.
1002
+ """
1003
+ x, y = circle_coord
1004
+ self.output.ax.add_patch(
1005
+ mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
1006
+ )
1007
+ return self.output
1008
+
1009
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1010
+ """
1011
+ Args:
1012
+ x_data (list[int]): a list containing x values of all the points being drawn.
1013
+ Length of list should match the length of y_data.
1014
+ y_data (list[int]): a list containing y values of all the points being drawn.
1015
+ Length of list should match the length of x_data.
1016
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
1017
+ formats that are accepted.
1018
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1019
+ for a full list of formats that are accepted.
1020
+ linewidth (float or None): width of the line. When it's None,
1021
+ a default value will be computed and used.
1022
+
1023
+ Returns:
1024
+ output (VisImage): image object with line drawn.
1025
+ """
1026
+ if linewidth is None:
1027
+ linewidth = self._default_font_size / 3
1028
+ linewidth = max(linewidth, 1)
1029
+ self.output.ax.add_line(
1030
+ mpl.lines.Line2D(
1031
+ x_data,
1032
+ y_data,
1033
+ linewidth=linewidth * self.output.scale,
1034
+ color=color,
1035
+ linestyle=linestyle,
1036
+ )
1037
+ )
1038
+ return self.output
1039
+
1040
+ def draw_binary_mask(
1041
+ self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0
1042
+ ):
1043
+ """
1044
+ Args:
1045
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1046
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1047
+ type.
1048
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1049
+ formats that are accepted. If None, will pick a random color.
1050
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1051
+ full list of formats that are accepted.
1052
+ text (str): if None, will be drawn in the object's center of mass.
1053
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1054
+ area_threshold (float): a connected component small than this will not be shown.
1055
+
1056
+ Returns:
1057
+ output (VisImage): image object with mask drawn.
1058
+ """
1059
+ if color is None:
1060
+ color = random_color(rgb=True, maximum=1)
1061
+ color = mplc.to_rgb(color)
1062
+
1063
+ has_valid_segment = False
1064
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1065
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1066
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1067
+
1068
+ if not mask.has_holes:
1069
+ # draw polygons for regular masks
1070
+ for segment in mask.polygons:
1071
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
1072
+ if area < (area_threshold or 0):
1073
+ continue
1074
+ has_valid_segment = True
1075
+ segment = segment.reshape(-1, 2)
1076
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
1077
+ else:
1078
+ # TODO: Use Path/PathPatch to draw vector graphics:
1079
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1080
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1081
+ rgba[:, :, :3] = color
1082
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1083
+ has_valid_segment = True
1084
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1085
+
1086
+ if text is not None and has_valid_segment:
1087
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1088
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1089
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1090
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
1091
+
1092
+ # draw text on the largest component, as well as other very large components.
1093
+ for cid in range(1, _num_cc):
1094
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1095
+ # median is more stable than centroid
1096
+ # center = centroids[largest_component_id]
1097
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1098
+ self.draw_text(text, center, color=lighter_color)
1099
+ return self.output
1100
+
1101
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1102
+ """
1103
+ Args:
1104
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
1105
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1106
+ formats that are accepted.
1107
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1108
+ full list of formats that are accepted. If not provided, a darker shade
1109
+ of the polygon color will be used instead.
1110
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1111
+
1112
+ Returns:
1113
+ output (VisImage): image object with polygon drawn.
1114
+ """
1115
+ if edge_color is None:
1116
+ # make edge color darker than the polygon color
1117
+ if alpha > 0.8:
1118
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
1119
+ else:
1120
+ edge_color = color
1121
+ edge_color = mplc.to_rgb(edge_color) + (1,)
1122
+
1123
+ polygon = mpl.patches.Polygon(
1124
+ segment,
1125
+ fill=True,
1126
+ facecolor=mplc.to_rgb(color) + (alpha,),
1127
+ edgecolor=edge_color,
1128
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1129
+ )
1130
+ self.output.ax.add_patch(polygon)
1131
+ return self.output
1132
+
1133
+ """
1134
+ Internal methods:
1135
+ """
1136
+
1137
+ def _jitter(self, color):
1138
+ """
1139
+ Randomly modifies given color to produce a slightly different color than the color given.
1140
+
1141
+ Args:
1142
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1143
+ picked. The values in the list are in the [0.0, 1.0] range.
1144
+
1145
+ Returns:
1146
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1147
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
1148
+ """
1149
+ color = mplc.to_rgb(color)
1150
+ vec = np.random.rand(3)
1151
+ # better to do it in another color space
1152
+ vec = vec / np.linalg.norm(vec) * 0.5
1153
+ res = np.clip(vec + color, 0, 1)
1154
+ return tuple(res)
1155
+
1156
+ def _create_grayscale_image(self, mask=None):
1157
+ """
1158
+ Create a grayscale version of the original image.
1159
+ The colors in masked area, if given, will be kept.
1160
+ """
1161
+ img_bw = self.img.astype("f4").mean(axis=2)
1162
+ img_bw = np.stack([img_bw] * 3, axis=2)
1163
+ if mask is not None:
1164
+ img_bw[mask] = self.img[mask]
1165
+ return img_bw
1166
+
1167
+ def _change_color_brightness(self, color, brightness_factor):
1168
+ """
1169
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1170
+ less or more saturation than the original color.
1171
+
1172
+ Args:
1173
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1174
+ formats that are accepted.
1175
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1176
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1177
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
1178
+
1179
+ Returns:
1180
+ modified_color (tuple[double]): a tuple containing the RGB values of the
1181
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
1182
+ """
1183
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1184
+ color = mplc.to_rgb(color)
1185
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1186
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1187
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1188
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1189
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
1190
+ return modified_color
1191
+
1192
+ def _convert_boxes(self, boxes):
1193
+ """
1194
+ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
1195
+ """
1196
+ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
1197
+ return boxes.tensor.detach().numpy()
1198
+ else:
1199
+ return np.asarray(boxes)
1200
+
1201
+ def _convert_masks(self, masks_or_polygons):
1202
+ """
1203
+ Convert different format of masks or polygons to a tuple of masks and polygons.
1204
+
1205
+ Returns:
1206
+ list[GenericMask]:
1207
+ """
1208
+
1209
+ m = masks_or_polygons
1210
+ if isinstance(m, PolygonMasks):
1211
+ m = m.polygons
1212
+ if isinstance(m, BitMasks):
1213
+ m = m.tensor.numpy()
1214
+ if isinstance(m, torch.Tensor):
1215
+ m = m.numpy()
1216
+ ret = []
1217
+ for x in m:
1218
+ if isinstance(x, GenericMask):
1219
+ ret.append(x)
1220
+ else:
1221
+ ret.append(GenericMask(x, self.output.height, self.output.width))
1222
+ return ret
1223
+
1224
+ def _convert_keypoints(self, keypoints):
1225
+ if isinstance(keypoints, Keypoints):
1226
+ keypoints = keypoints.tensor
1227
+ keypoints = np.asarray(keypoints)
1228
+ return keypoints
1229
+
1230
+ def get_output(self):
1231
+ """
1232
+ Returns:
1233
+ output (VisImage): the image output containing the visualizations added
1234
+ to the image.
1235
+ """
1236
+ return self.output