python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,956 @@
1
+ #coding=utf-8
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+ import wml.semantic.visualization_utils as smv
6
+ from PIL import Image
7
+ from wml.basic_data_def import COCO_JOINTS_PAIR,colors_tableau ,colors_tableau_large, PSEUDOCOLOR
8
+ from wml.basic_data_def import DEFAULT_COLOR_MAP as _DEFAULT_COLOR_MAP
9
+ import wml.object_detection2.bboxes as odb
10
+ from wml.wstructures import WPolygonMasks,WBitmapMasks, WMCKeypoints, WMCKeypointsItem
11
+ import math
12
+ import wml.basic_img_utils as bwmli
13
+ from .basic_visualization import *
14
+
15
+ DEFAULT_COLOR_MAP = _DEFAULT_COLOR_MAP
16
+
17
+ def draw_text_on_image(img,text,font_scale=1.2,color=(0.,255.,0.),pos=None,thickness=1):
18
+ if isinstance(text,bytes):
19
+ text = str(text,encoding="utf-8")
20
+ if not isinstance(text,str):
21
+ text = str(text)
22
+ thickness = 2
23
+ size = cv2.getTextSize(text,fontFace=cv2.FONT_HERSHEY_COMPLEX,fontScale=font_scale,thickness=thickness)
24
+ if pos is None:
25
+ pos = (0,(img.shape[0]+size[0][1])//2)
26
+ elif isinstance(pos,str) and pos.lower() == "tl":
27
+ text_size,_ = cv2.getTextSize(text,cv2.FONT_HERSHEY_DUPLEX,fontScale=font_scale,thickness=thickness)
28
+ tw,th = text_size
29
+ pos = (0,th+5)
30
+ elif isinstance(pos,str) and pos.lower() == "bl":
31
+ text_size,_ = cv2.getTextSize(text,cv2.FONT_HERSHEY_DUPLEX,fontScale=font_scale,thickness=thickness)
32
+ tw,th = text_size
33
+ pos = (0,img.shape[0]-th-5)
34
+
35
+ cv2.putText(img, text, pos, cv2.FONT_HERSHEY_COMPLEX, fontScale=font_scale, color=color, thickness=thickness)
36
+ return img
37
+
38
+ def draw_lines(img, lines, color=[255, 0, 0], thickness=2):
39
+ for line in lines:
40
+ for x1, y1, x2, y2 in line:
41
+ cv2.line(img, (x1, y1), (x2, y2), color, thickness)
42
+
43
+ def draw_rectangle(img, p1, p2, color=[255, 0, 0], thickness=2):
44
+ cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
45
+
46
+ def draw_bbox(img, bbox, shape=None, label=None, color=[255, 0, 0], thickness=2,is_relative_bbox=False,xy_order=True):
47
+ '''
48
+ bbox: [y0,x0,y1,x1] if xy_order = False else [x0,y0,x1,y1]
49
+ '''
50
+ if is_relative_bbox:
51
+ p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
52
+ p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
53
+ else:
54
+ p1 = (int(bbox[0]), int(bbox[1]))
55
+ p2 = (int(bbox[2]), int(bbox[3]))
56
+ if xy_order:
57
+ p1 = p1[::-1]
58
+ p2 = p2[::-1]
59
+ cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
60
+ p1 = (p1[0]+15, p1[1])
61
+ if label is not None:
62
+ cv2.putText(img, str(label), p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1)
63
+ return img
64
+
65
+ '''
66
+ pmin: (y0,x0)
67
+ pmax: (y1,x1)
68
+ '''
69
+ def get_text_pos_fn(pmin,pmax,bbox,label,text_size):
70
+ text_width,text_height = text_size
71
+ return (pmin[0]+text_height+5,pmin[1]+5)
72
+ '''if bbox[0]<text_height:
73
+ p1 = (pmax[0],pmin[1])
74
+ else:
75
+ p1 = pmin
76
+ return (p1[0]-5,p1[1])'''
77
+
78
+ def get_text_pos_tr(pmin,pmax,bbox,label):
79
+ p1 = (pmax[0],pmin[1])
80
+ return (p1[0]-5,p1[1])
81
+
82
+ def get_text_pos_tm(pmin,pmax,bbox,label):
83
+ p1 = ((pmin[0]+pmax[0])//2,pmin[1])
84
+ return p1
85
+ def get_text_pos_br(pmin,pmax,bbox,label):
86
+ p1 = (pmax[0],pmax[1])
87
+ return (p1[0]-5,p1[1])
88
+
89
+ def random_color_fn(label,probs=None):
90
+ del label
91
+ nr = len(colors_tableau)
92
+ return colors_tableau[random.randint(0,nr-1)]
93
+
94
+ def fixed_color_fn(label,probs=None):
95
+ color_nr = len(colors_tableau)
96
+ return colors_tableau[label%color_nr]
97
+
98
+ def fixed_color_large_fn(label,probs=None):
99
+ if isinstance(label,(str,bytes)):
100
+ return colors_tableau_large[len(label)]
101
+ color_nr = len(colors_tableau_large)
102
+ return colors_tableau_large[label%color_nr]
103
+
104
+ def pesudo_color_fn(label,probs):
105
+ color_nr = len(PSEUDOCOLOR)
106
+ idx = int(probs*color_nr)
107
+ return PSEUDOCOLOR[idx%color_nr]
108
+
109
+ def red_color_fn(label):
110
+ del label
111
+ return (255,0,0)
112
+
113
+ def blue_color_fn(label):
114
+ del label
115
+ return (0,0,255)
116
+
117
+ def green_color_fn(label):
118
+ del label
119
+ return (0,255,0)
120
+
121
+ def default_text_fn(label,score=None):
122
+ return str(label)
123
+
124
+ '''
125
+ bboxes: [N,4] (y0,x0,y1,x1)
126
+ color_fn: tuple(3) (*f)(label)
127
+ text_fn: str (*f)(label,score)
128
+ get_text_pos_fn: tuple(2) (*f)(lt_corner,br_corner,bboxes,label)
129
+ '''
130
+ def draw_bboxes(img, classes=None, scores=None, bboxes=None,
131
+ color_fn=fixed_color_large_fn,
132
+ text_fn=default_text_fn,
133
+ get_text_pos_fn=get_text_pos_fn,
134
+ thickness=2,show_text=True,font_scale=1.2,text_color=(0.,255.,0.),
135
+ is_relative_coordinate=True,
136
+ is_show_text=None,
137
+ fill_bboxes=False,
138
+ is_crowd=None):
139
+ if bboxes is None:
140
+ return img
141
+
142
+ bboxes = np.array(bboxes)
143
+ if len(bboxes) == 0:
144
+ return img
145
+ if classes is None:
146
+ classes = np.zeros([bboxes.shape[0]],dtype=np.int32)
147
+ if is_relative_coordinate and np.any(bboxes>1.1):
148
+ print(f"Use relative coordinate and max bboxes value is {np.max(bboxes)}")
149
+ elif not is_relative_coordinate and np.all(bboxes<1.1):
150
+ print(f"Use absolute coordinate and max bboxes value is {np.max(bboxes)}")
151
+
152
+ bboxes_thickness = thickness if not fill_bboxes else -1
153
+ if is_relative_coordinate:
154
+ shape = img.shape
155
+ else:
156
+ shape = [1.0,1.0]
157
+ if len(img.shape)<2:
158
+ print(f"Error img size {img.shape}.")
159
+ return img
160
+ img = np.array(img)
161
+ if scores is None:
162
+ scores = np.ones([len(classes)],dtype=np.float32)
163
+ if not isinstance(bboxes,np.ndarray):
164
+ bboxes = np.array(bboxes)
165
+ for i in range(bboxes.shape[0]):
166
+ try:
167
+ bbox = bboxes[i]
168
+ if color_fn is not None:
169
+ color = color_fn(classes[i],scores[i])
170
+ else:
171
+ color = (int(random.random()*255), int(random.random()*255), int(random.random()*255))
172
+ p10 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
173
+ p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
174
+ cur_is_crowd = False if is_crowd is None else is_crowd[i]
175
+ if not cur_is_crowd:
176
+ cv2.rectangle(img, p10[::-1], p2[::-1], color, bboxes_thickness)
177
+ else:
178
+ cv2.rectangle(img, p10[::-1], p2[::-1], color, int(max(bboxes_thickness//2,1)))
179
+ t_r = min(min(math.fabs(p10[0]-p2[0]),math.fabs(p10[1]-p2[1]))/10,5)
180
+ t_r = int(max(t_r,2))
181
+ cv2.circle(img,p10[::-1],t_r,color=color,thickness=-1)
182
+ if show_text and text_fn is not None:
183
+ f_show_text = True
184
+ if is_show_text is not None:
185
+ f_show_text = is_show_text(p10,p2)
186
+
187
+ if f_show_text:
188
+ text_thickness = 1
189
+ s = text_fn(classes[i], scores[i])
190
+ text_size,_ = cv2.getTextSize(s,cv2.FONT_HERSHEY_DUPLEX,fontScale=font_scale,thickness=text_thickness)
191
+ p = get_text_pos_fn(p10,p2,bbox,classes[i],text_size)
192
+ cv2.putText(img, s, p[::-1], cv2.FONT_HERSHEY_DUPLEX,
193
+ fontScale=font_scale,
194
+ color=text_color if not cur_is_crowd else (110,160,110),
195
+ thickness=text_thickness)
196
+ except Exception as e:
197
+ bbox = bboxes[i]
198
+ p10 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
199
+ p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
200
+ if color_fn is not None:
201
+ color = color_fn(classes[i])
202
+ else:
203
+ color = (random.random()*255, random.random()*255, random.random()*255)
204
+ print("ERROR: object_detection2.visualization ",img.shape,shape,bboxes[i],classes[i],p10,p2,color,thickness,e)
205
+
206
+
207
+ return img
208
+
209
+ def draw_bboxes_xy(img, classes=None, scores=None, bboxes=None,
210
+ color_fn=fixed_color_large_fn,
211
+ text_fn=default_text_fn,
212
+ get_text_pos_fn=get_text_pos_fn,
213
+ thickness=2,show_text=True,font_scale=1.2,text_color=(0.,255.,0.),
214
+ is_relative_coordinate=False,
215
+ is_show_text=None,
216
+ fill_bboxes=False,
217
+ is_crowd=None):
218
+ if bboxes is not None:
219
+ bboxes = odb.npchangexyorder(bboxes)
220
+ return draw_bboxes(img,classes,scores=scores,bboxes=bboxes,color_fn=color_fn,
221
+ text_fn=text_fn,get_text_pos_fn=get_text_pos_fn,thickness=thickness,
222
+ show_text=show_text,font_scale=font_scale,text_color=text_color,
223
+ is_relative_coordinate=is_relative_coordinate,
224
+ is_show_text=is_show_text,
225
+ fill_bboxes=fill_bboxes,
226
+ is_crowd=is_crowd)
227
+
228
+ def draw_legend(labels,text_fn,img_size,color_fn,thickness=4,font_scale=1.2,text_color=(0.,255.,0.),fill_bboxes=True):
229
+ '''
230
+ Generate a legend image
231
+ Args:
232
+ labels: list[int] labels
233
+ text_fn: str fn(label) trans label to text
234
+ img_size: (H,W) the legend image size, the legend is drawed in veritical direction
235
+ color_fn: tuple(3) fn(label): trans label to RGB color
236
+ thickness: text thickness
237
+ font_scale: font size
238
+ text_color: text color
239
+ Returns:
240
+
241
+ '''
242
+ boxes_width = max(img_size[1]//3,20)
243
+ boxes_height = img_size[0]/(2*len(labels))
244
+ def lget_text_pos_fn(pmin, pmax, bbox, label):
245
+ p1 = (pmax[0]+5, pmax[1]+5)
246
+ return p1
247
+
248
+ bboxes = []
249
+ for i,l in enumerate(labels):
250
+ xmin = 5
251
+ xmax = xmin+boxes_width
252
+ ymin = int((2*i+0.5)*boxes_height)
253
+ ymax = ymin + boxes_height
254
+ bboxes.append([ymin,xmin,ymax,xmax])
255
+ img = np.ones([img_size[0],img_size[1],3],dtype=np.uint8)
256
+ def _text_fn(x,_):
257
+ return text_fn(x)
258
+ return draw_bboxes(img,labels,bboxes=bboxes,color_fn=color_fn,text_fn=_text_fn,
259
+ get_text_pos_fn=lget_text_pos_fn,
260
+ thickness=thickness,
261
+ show_text=True,
262
+ font_scale=font_scale,
263
+ text_color=text_color,
264
+ is_relative_coordinate=False,
265
+ fill_bboxes=fill_bboxes)
266
+
267
+
268
+
269
+ '''
270
+ img: [H,W,C]
271
+ mask only include the area within bbox
272
+ bboxes: [N,4](y0,x0,y1,x1)
273
+ mask: [N,h,w]
274
+ '''
275
+ def draw_mask(img,classes,bboxes,masks,
276
+ color_fn=fixed_color_large_fn,
277
+ is_relative_coordinate=True):
278
+ masks = masks.astype(np.uint8)
279
+ if is_relative_coordinate:
280
+ scales = np.array([[img.shape[1],img.shape[0],img.shape[1],img.shape[0]]],dtype=np.float32)
281
+ bboxes = bboxes*scales
282
+ for i,bbox in enumerate(bboxes):
283
+ if color_fn is not None:
284
+ color = list(color_fn(classes[i]))
285
+ else:
286
+ color = [random.random()*255, random.random()*255, random.random()*255]
287
+ color = np.reshape(np.array(color,dtype=np.float32),[1,1,-1])
288
+ x = int(bbox[1])
289
+ y = int(bbox[0])
290
+ w = int((bbox[3]-bbox[1]))
291
+ h = int((bbox[2]-bbox[0]))
292
+ if w<=0 or h<=0:
293
+ continue
294
+ mask = masks[i]
295
+ mask = cv2.resize(mask,(w,h))
296
+ mask = np.expand_dims(mask,axis=-1)
297
+ try:
298
+ img[y:y+h,x:x+w,:] = (img[y:y+h,x:x+w,:]*(np.array([[[1]]],dtype=np.float32)-mask*0.4)).astype(np.uint8)+(mask*color*0.4).astype(np.uint8)
299
+ except:
300
+ pass
301
+
302
+ return img
303
+
304
+ '''
305
+ mask only include the area within bbox
306
+ '''
307
+ def draw_mask_xy(img,classes,bboxes,masks,
308
+ color_fn=fixed_color_large_fn,
309
+ is_relative_coordinate=False):
310
+ bboxes = odb.npchangexyorder(bboxes)
311
+ img = draw_mask(img=img,
312
+ classes=classes,bboxes=bboxes,
313
+ masks=masks,color_fn=color_fn,
314
+ is_relative_coordinate=is_relative_coordinate)
315
+ return img
316
+ '''
317
+ mask only include the area within bbox
318
+ '''
319
+ def draw_bboxes_and_mask(img,classes,scores,bboxes,masks,
320
+ color_fn=fixed_color_large_fn,
321
+ text_fn=default_text_fn,
322
+ thickness=4,
323
+ show_text=False,
324
+ font_scale=0.8,
325
+ is_relative_coordinate=False):
326
+ masks = masks.astype(np.uint8)
327
+ img = draw_mask(img=img,
328
+ classes=classes,bboxes=bboxes,
329
+ masks=masks,color_fn=color_fn,
330
+ is_relative_coordinate=is_relative_coordinate)
331
+ img = draw_bboxes(img,classes,scores,bboxes,
332
+ color_fn=color_fn,
333
+ text_fn=text_fn,
334
+ thickness=thickness,
335
+ show_text=show_text,
336
+ fontScale=font_scale)
337
+ return img
338
+
339
+ '''
340
+ img: [H,W,C]
341
+ mask: [N,H,W], include the area of whole image
342
+ bboxes: [N,4], [y0,x0,y1,x1]
343
+ '''
344
+ def draw_maskv2_bitmap(img,classes,bboxes=None,masks=None,
345
+ color_fn=fixed_color_large_fn,
346
+ is_relative_coordinate=True,
347
+ alpha=0.4,
348
+ ):
349
+ if not isinstance(masks,np.ndarray):
350
+ masks = np.array(masks)
351
+ if is_relative_coordinate and bboxes is not None:
352
+ scales = np.array([[img.shape[1],img.shape[0],img.shape[1],img.shape[0]]],dtype=np.float32)
353
+ bboxes = bboxes*scales
354
+ masks = masks.astype(np.uint8)
355
+ if masks.shape[1] < img.shape[0] or masks.shape[2]<img.shape[1]:
356
+ masks = np.pad(masks,[[0,0],[0,img.shape[0]-masks.shape[1]],[0,img.shape[1]-masks.shape[2]]])
357
+ for i in range(masks.shape[0]):
358
+ if color_fn is not None:
359
+ color = list(color_fn(classes[i]))
360
+ else:
361
+ color = [random.random()*255, random.random()*255, random.random()*255]
362
+ if bboxes is not None:
363
+ bbox = bboxes[i]
364
+ w = bbox[3]-bbox[1]
365
+ h = bbox[2]-bbox[0]
366
+ if w<=0 or h<=0:
367
+ continue
368
+ mask = masks[i]
369
+ img = smv.draw_mask_on_image_array(img,mask,color=color,alpha=alpha)
370
+
371
+ return img
372
+
373
+ def draw_maskv2_polygon(img,classes,bboxes=None,masks=None,
374
+ color_fn=fixed_color_large_fn,
375
+ is_relative_coordinate=True,
376
+ alpha=0.4,
377
+ fill=False,
378
+ thickness=1,
379
+ ):
380
+ if fill:
381
+ masks = masks.bitmap()
382
+ img = draw_maskv2_bitmap(img,
383
+ classes=classes,
384
+ bboxes=bboxes,
385
+ masks=masks,
386
+ color_fn=color_fn,
387
+ is_relative_coordinate=is_relative_coordinate,
388
+ alpha=alpha)
389
+ return img
390
+ if is_relative_coordinate and bboxes is not None:
391
+ scales = np.array([[img.shape[1],img.shape[0],img.shape[1],img.shape[0]]],dtype=np.float32)
392
+ bboxes = bboxes*scales
393
+ for i in range(masks.shape[0]):
394
+ if color_fn is not None:
395
+ color = list(color_fn(classes[i]))
396
+ else:
397
+ color = [random.random()*255, random.random()*255, random.random()*255]
398
+ mask = masks[i]
399
+ img = smv.draw_polygon_mask_on_image_array(img, mask.points, color=color, thickness=thickness)
400
+
401
+ return img
402
+
403
+ def draw_maskv2(img,classes,bboxes=None,masks=None,
404
+ color_fn=fixed_color_large_fn,
405
+ is_relative_coordinate=True,
406
+ alpha=0.4,
407
+ fill=False,
408
+ thickness=1,
409
+ ):
410
+ '''
411
+ bboxes: [N,4] (y0,x0,y1,x1)
412
+ mask:
413
+ [N,H,W], mask include the area of whole image
414
+ or WPolygonMasks
415
+ '''
416
+ if isinstance(masks,WPolygonMasks):
417
+ img = draw_maskv2_polygon(img,
418
+ classes=classes,
419
+ bboxes=bboxes,
420
+ masks=masks,
421
+ color_fn=color_fn,
422
+ is_relative_coordinate=is_relative_coordinate,
423
+ alpha=alpha,
424
+ fill=fill,
425
+ thickness=thickness)
426
+ return img
427
+ elif isinstance(masks,WBitmapMasks):
428
+ img = draw_maskv2_bitmap(img,
429
+ classes=classes,
430
+ bboxes=bboxes,
431
+ masks=masks.to_ndarray(),
432
+ color_fn=color_fn,
433
+ is_relative_coordinate=is_relative_coordinate,
434
+ alpha=alpha)
435
+ return img
436
+ elif isinstance(masks,WMCKeypoints):
437
+ img = draw_mckeypoints(img,
438
+ labels=classes,
439
+ keypoints=masks,
440
+ color_fn=color_fn)
441
+ return img
442
+
443
+ try:
444
+ if not isinstance(masks,np.ndarray):
445
+ masks = np.array(masks)
446
+ masks = masks.astype(np.uint8)
447
+ except:
448
+ pass
449
+
450
+ if isinstance(masks,np.ndarray):
451
+ img = draw_maskv2_bitmap(img,
452
+ classes=classes,
453
+ bboxes=bboxes,
454
+ masks=masks,
455
+ color_fn=color_fn,
456
+ is_relative_coordinate=is_relative_coordinate,
457
+ alpha=alpha)
458
+ else:
459
+ info = f"Unknow mask type {type(masks).__name__}"
460
+ print(f"WARNING: {info}")
461
+
462
+ return img
463
+
464
+ '''
465
+ bboxes: [N,4] (x0,y0,x1,y1)
466
+ mask:
467
+ [N,H,W], mask include the area of whole image
468
+ or WPolygonMasks
469
+ '''
470
+ def draw_maskv2_xy(img,classes,bboxes=None,masks=None,
471
+ color_fn=fixed_color_large_fn,
472
+ is_relative_coordinate=False,
473
+ alpha=0.4,
474
+ fill=False,
475
+ thickness=1,
476
+ ):
477
+ if bboxes is not None:
478
+ bboxes = odb.npchangexyorder(bboxes)
479
+ img = draw_maskv2(img=img,
480
+ classes=classes,bboxes=bboxes,
481
+ masks=masks,color_fn=color_fn,
482
+ is_relative_coordinate=is_relative_coordinate,
483
+ alpha=alpha,
484
+ fill=fill,
485
+ thickness=thickness)
486
+ return img
487
+ '''
488
+ bboxes: [N,4] (x0,y0,x1,y1)
489
+ mask:
490
+ [N,H,W], mask include the area of whole image
491
+ or WPolygonMasks
492
+ '''
493
+ def draw_bboxes_and_maskv2(img,classes,scores=None,bboxes=None,masks=None,
494
+ color_fn=fixed_color_large_fn,
495
+ text_fn=default_text_fn,
496
+ thickness=4,
497
+ show_text=False,
498
+ is_relative_coordinate=True,
499
+ font_scale=0.8):
500
+ img = draw_maskv2(img=img,
501
+ classes=classes,bboxes=bboxes,
502
+ masks=masks,color_fn=color_fn,
503
+ is_relative_coordinate=is_relative_coordinate)
504
+
505
+ img = draw_bboxes(img,classes,scores,bboxes,
506
+ color_fn=color_fn,
507
+ text_fn=text_fn,
508
+ thickness=thickness,
509
+ show_text=show_text,
510
+ is_relative_coordinate=is_relative_coordinate,
511
+ font_scale=font_scale)
512
+ return img
513
+
514
+
515
+
516
+ def draw_heatmap_on_image(image,scores,color_pos=(255,0,0),color_neg=(0,0,0),alpha=0.4):
517
+ '''
518
+ draw semantic on image
519
+ Args:
520
+ image:
521
+ scores: [H,W] scores value
522
+ color_map: list[int], [r,g,b]
523
+ alpha: mask percent
524
+ ignored_label:
525
+ Returns:
526
+ return image*(1-alpha)+semantic+alpha
527
+ '''
528
+
529
+ color_pos = np.reshape(np.array(color_pos),[1,1,3])
530
+ color_neg = np.reshape(np.array(color_neg),[1,1,3])
531
+ color_pos = color_pos*np.ones_like(image).astype(np.float32)
532
+ color_neg = color_neg*np.ones_like(image).astype(np.float32)
533
+ scores = np.expand_dims(scores,axis=-1)
534
+ color = color_pos*scores+color_neg*(1-scores)
535
+ color = np.clip(color,0,255)
536
+ new_img = image.astype(np.float32)*(1-alpha)+color*alpha
537
+ new_img = np.clip(new_img,0,255).astype(np.uint8)
538
+ return new_img
539
+
540
+ def draw_heatmap_on_imagev2(image,scores,palette=[(0,(0,0,255)),(0.5,(255,255,255)),(1.0,(255,0,0))],alpha=0.4):
541
+ '''
542
+ 使用更复杂的伪彩色
543
+ draw semantic on image
544
+ Args:
545
+ image:
546
+ scores: [H,W] scores value
547
+ color_map: list[int], [r,g,b]
548
+ alpha: mask percent
549
+ ignored_label:
550
+ Returns:
551
+ return image*(1-alpha)+semantic+alpha
552
+ '''
553
+
554
+ color = bwmli.pseudocolor_img(img=scores,palette=palette,auto_norm=False)
555
+ color = np.clip(color,0,255)
556
+ new_img = image.astype(np.float32)*(1-alpha)+color*alpha
557
+ new_img = np.clip(new_img,0,255).astype(np.uint8)
558
+ return new_img
559
+
560
+ def try_draw_rgb_heatmap_on_image(image,scores,color_pos=(255,0,0),color_neg=(0,0,0),alpha=0.4):
561
+ '''
562
+ draw semantic on image
563
+ Args:
564
+ image: [H,W,3/1]
565
+ scores: [C,H,W] scores value, in (0~1)
566
+ color_map: list[int], [r,g,b]
567
+ alpha: mask percent
568
+ ignored_label:
569
+ Returns:
570
+ return image*(1-alpha)+semantic+alpha
571
+ '''
572
+ if scores.shape[0]>3:
573
+ scores = np.sum(scores,axis=0,keepdims=False)
574
+ return draw_heatmap_on_image(image=image,
575
+ scores=scores,
576
+ color_pos=color_pos,color_neg=color_neg,alpha=alpha)
577
+ if scores.shape[0]<3:
578
+ scores = np.concatenate([scores,np.zeros([3-scores.shape[0],scores.shape[1],scores.shape[2]],dtype=scores.dtype)],axis=0)
579
+ color_pos = np.reshape(np.array(color_pos),[1,1,3])
580
+ color_neg = np.reshape(np.array(color_neg),[1,1,3])
581
+ color_pos = color_pos*np.ones_like(image).astype(np.float32)
582
+ color_neg = color_neg*np.ones_like(image).astype(np.float32)
583
+ scores = np.transpose(scores,[1,2,0])
584
+ scores = scores*alpha
585
+ color = color_pos*scores+color_neg*(1-scores)
586
+ new_img = image.astype(np.float32)*(1-alpha)+color*alpha
587
+ new_img = np.clip(new_img,0,255).astype(np.uint8)
588
+ return new_img
589
+
590
+ def try_draw_rgb_heatmap_on_imagev2(image,scores,palette=[(0,(0,0,255)),(0.5,(255,255,255)),(1.0,(255,0,0))],alpha=0.4):
591
+ '''
592
+ 使用更复杂的伪彩色
593
+ draw semantic on image
594
+ Args:
595
+ image: [H,W,3/1]
596
+ scores: [C,H,W] scores value, in (0~1)
597
+ color_map: list[int], [r,g,b]
598
+ alpha: mask percent
599
+ ignored_label:
600
+ Returns:
601
+ return image*(1-alpha)+semantic+alpha
602
+ '''
603
+ if scores.shape[0]>3:
604
+ scores = np.sum(scores,axis=0,keepdims=False)
605
+ return draw_heatmap_on_imagev2(image=image,
606
+ scores=scores,
607
+ palette=palette,
608
+ alpha=alpha)
609
+ if scores.shape[0]<3:
610
+ scores = np.concatenate([scores,np.zeros([3-scores.shape[0],scores.shape[1],scores.shape[2]],dtype=scores.dtype)],axis=0)
611
+ color_pos=(255,0,0)
612
+ color_neg=(0,0,0)
613
+ color_pos = np.reshape(np.array(color_pos),[1,1,3])
614
+ color_neg = np.reshape(np.array(color_neg),[1,1,3])
615
+ color_pos = color_pos*np.ones_like(image).astype(np.float32)
616
+ color_neg = color_neg*np.ones_like(image).astype(np.float32)
617
+ scores = np.transpose(scores,[1,2,0])
618
+ scores = scores*alpha
619
+ color = color_pos*scores+color_neg*(1-scores)
620
+ new_img = image.astype(np.float32)*(1-alpha)+color*alpha
621
+ new_img = np.clip(new_img,0,255).astype(np.uint8)
622
+ return new_img
623
+
624
+ def draw_mckeypoints(image,labels,keypoints,r=2,
625
+ color_fn=fixed_color_large_fn,
626
+ text_fn=default_text_fn,
627
+ show_text=False,
628
+ font_scale=0.8,
629
+ text_thickness=1,
630
+ text_color=(0,0,255)):
631
+ '''
632
+ gt_labels: [N]
633
+ keypoints: WMCKeypoints or list (size is N) of [M,2]
634
+ '''
635
+ for i, points in enumerate(keypoints):
636
+ color = color_fn(labels[i])
637
+ if isinstance(points,WMCKeypointsItem):
638
+ points = points.points
639
+ for p in points:
640
+ cv2.circle(image, (int(p[0]), int(p[1])), r, color, -1)
641
+ if show_text:
642
+ text = text_fn(labels[i])
643
+ cv2.putText(image, text, (int(p[0]), int(p[1])), cv2.FONT_HERSHEY_DUPLEX,
644
+ fontScale=font_scale,
645
+ color=text_color,
646
+ thickness=text_thickness)
647
+
648
+
649
+ return image
650
+
651
+ def draw_npmckeypoints(image,labels,keypoints,r=2,
652
+ color_fn=fixed_color_large_fn,
653
+ text_fn=default_text_fn,
654
+ show_text=False,
655
+ font_scale=0.8,
656
+ text_thickness=1,
657
+ text_color=(0,0,255)):
658
+ '''
659
+ gt_labels: [N]
660
+ keypoints: [N,2]
661
+ '''
662
+ for l,p in zip(labels,keypoints):
663
+ color = color_fn(l)
664
+ cv2.circle(image, (int(p[0]+0.5), int(p[1]+0.5)), r, color, -1)
665
+ if show_text:
666
+ text = text_fn(l)
667
+ cv2.putText(image, text, (int(p[0]), int(p[1])), cv2.FONT_HERSHEY_DUPLEX,
668
+ fontScale=font_scale,
669
+ color=text_color,
670
+ thickness=text_thickness)
671
+
672
+
673
+ return image
674
+
675
+ def add_jointsv1(image, joints, color, r=5,line_thickness=2,no_line=False,joints_pair=None,left_node=None):
676
+
677
+ def link(a, b, color):
678
+ jointa = joints[a]
679
+ jointb = joints[b]
680
+ cv2.line(
681
+ image,
682
+ (int(jointa[0]), int(jointa[1])),
683
+ (int(jointb[0]), int(jointb[1])),
684
+ color, line_thickness )
685
+
686
+ # add link
687
+ if not no_line and joints_pair is not None:
688
+ for pair in joints_pair:
689
+ link(pair[0], pair[1], color)
690
+
691
+ # add joints
692
+ node_color = None
693
+ for i, joint in enumerate(joints):
694
+ if left_node is None:
695
+ node_color = colors_tableau[i]
696
+ elif i in left_node:
697
+ node_color = (0,255,0)
698
+ else:
699
+ node_color = (0,0,255)
700
+ cv2.circle(image, (int(joint[0]), int(joint[1])), r, node_color, -1)
701
+
702
+ return image
703
+
704
+ def add_jointsv2(image, joints, color, r=5,line_thickness=2,no_line=False,joints_pair=None,left_node=None):
705
+
706
+ def link(a, b, color):
707
+ jointa = joints[a]
708
+ jointb = joints[b]
709
+ if jointa[2] > 0.01 and jointb[2] > 0.01:
710
+ cv2.line(
711
+ image,
712
+ (int(jointa[0]), int(jointa[1])),
713
+ (int(jointb[0]), int(jointb[1])),
714
+ color, line_thickness )
715
+
716
+ # add link
717
+ if not no_line and joints_pair is not None:
718
+ for pair in joints_pair:
719
+ link(pair[0], pair[1], color)
720
+
721
+ # add joints
722
+ for i, joint in enumerate(joints):
723
+ if joint[2] > 0.05 and joint[0] > 1 and joint[1] > 1:
724
+ if left_node is None:
725
+ node_color = colors_tableau[i]
726
+ elif i in left_node:
727
+ node_color = (0,255,0)
728
+ else:
729
+ node_color = (0,0,255)
730
+ cv2.circle(image, (int(joint[0]), int(joint[1])), r, node_color, -1)
731
+
732
+ return image
733
+
734
+ def draw_keypoints(image, joints, color=[0,255,0],no_line=False,joints_pair=COCO_JOINTS_PAIR,left_node=list(range(1,17,2)),r=5,line_thickness=2):
735
+ '''
736
+
737
+ Args:
738
+ image: [H,W,3]
739
+ joints: [N,kps_nr,2] or [kps_nr,2]
740
+ color:
741
+ no_line:
742
+ joints_pair: [[first idx,second idx],...]
743
+ Returns:
744
+
745
+ '''
746
+ image = np.ascontiguousarray(image)
747
+ joints = np.array(joints)
748
+ if color is None:
749
+ use_random_color=True
750
+ else:
751
+ use_random_color = False
752
+ if len(joints.shape)==2:
753
+ joints = [joints]
754
+ else:
755
+ assert len(joints.shape)==3,"keypoints need to be 3-dimensional."
756
+
757
+ for person in joints:
758
+ if use_random_color:
759
+ color = np.random.randint(0, 255, size=3)
760
+ color = [int(i) for i in color]
761
+
762
+ if person.shape[-1] == 3:
763
+ add_jointsv2(image, person, color=color,no_line=no_line,joints_pair=joints_pair,left_node=left_node,r=r,
764
+ line_thickness=line_thickness)
765
+ else:
766
+ add_jointsv1(image, person, color=color,no_line=no_line,joints_pair=joints_pair,left_node=left_node,r=r,
767
+ line_thickness=line_thickness)
768
+
769
+ return image
770
+
771
+
772
+ def draw_keypoints_diff(image, joints0, joints1,color=[0,255,0]):
773
+ image = np.ascontiguousarray(image)
774
+ joints0 = np.array(joints0)
775
+ joints1 = np.array(joints1)
776
+ if color is None:
777
+ use_random_color=True
778
+ else:
779
+ use_random_color = False
780
+ if len(joints0.shape)==2:
781
+ points_nr = joints0.shape[0]
782
+ joints0 = [joints0]
783
+ joints1 = [joints1]
784
+ else:
785
+ points_nr = joints0.shape[1]
786
+ assert len(joints0.shape)==3,"keypoints need to be 3-dimensional."
787
+
788
+ for person0,person1 in zip(joints0,joints1):
789
+ if use_random_color:
790
+ color = np.random.randint(0, 255, size=3)
791
+ color = [int(i) for i in color]
792
+ for i in range(points_nr):
793
+ jointa = person0[i]
794
+ jointb = person1[i]
795
+ if person0.shape[-1] == 3:
796
+ if person0[i][-1]>0.015 and person1[i][-1]>0.015:
797
+ cv2.line(
798
+ image,
799
+ (int(jointa[0]), int(jointa[1])),
800
+ (int(jointb[0]), int(jointb[1])),
801
+ color, 2 )
802
+ else:
803
+ cv2.line(
804
+ image,
805
+ (int(jointa[0]), int(jointa[1])),
806
+ (int(jointb[0]), int(jointb[1])),
807
+ color, 2 )
808
+
809
+ return image
810
+
811
+
812
+ def draw_points(img,points,classes=None,show_text=False,r=2,
813
+ color_fn=fixed_color_large_fn,
814
+ text_fn=default_text_fn,
815
+ font_scale=0.8,
816
+ thickness=2):
817
+ '''
818
+ img: [H,W,3]
819
+ points: [N,2]/[N,3]
820
+ classes:[N]/None
821
+ color_fn: tuple(3) (*f)(label)
822
+ text_fn: str (*f)(label,score)
823
+ '''
824
+ img = np.ascontiguousarray(img)
825
+ nr = points.shape[0]
826
+ if classes is None:
827
+ classes = np.ones([nr],dtype=np.int32)
828
+ if points.shape[1]>=3:
829
+ scores = points[:,-1]
830
+ else:
831
+ scores = np.ones([nr],dtype=np.float32)
832
+ for i,joint in enumerate(points):
833
+ color = color_fn(classes[i])
834
+ pos = (int(joint[0]), int(joint[1]))
835
+ cv2.circle(img, pos, r, color, -1)
836
+ if show_text:
837
+ text = text_fn(classes[i],scores[i])
838
+ cv2.putText(img, text, pos, cv2.FONT_HERSHEY_COMPLEX, fontScale=font_scale, color=color, thickness=thickness)
839
+
840
+ return img
841
+
842
+ '''
843
+ bboxes:[(ymin,xmin,ymax,xmax),....] value in range[0,1]
844
+ mask:[X,h,w]
845
+ size:[H,W]
846
+ '''
847
+ def get_fullsize_mask(boxes,masks,size,mask_bg_value=0):
848
+ dtype = masks.dtype
849
+
850
+ res_masks = []
851
+ boxes = np.clip(boxes,0.0,1.0)
852
+ for i,bbox in enumerate(boxes):
853
+ x = int(bbox[1]*size[1])
854
+ y = int(bbox[0]*size[0])
855
+ w = int((bbox[3]-bbox[1])*size[1])
856
+ h = int((bbox[2]-bbox[0])*size[0])
857
+ res_mask = np.ones(size,dtype=dtype)*mask_bg_value
858
+ if w>1 and h>1:
859
+ mask = masks[i]
860
+ mask = cv2.resize(mask,(w,h))
861
+ sys.stdout.flush()
862
+ res_mask[y:y+h,x:x+w] = mask
863
+ res_masks.append(res_mask)
864
+
865
+ if len(res_masks)==0:
866
+ return np.zeros([0,size[0],size[1]],dtype=dtype)
867
+ return np.stack(res_masks,axis=0)
868
+
869
+ def generate_mask_by_boxes(boxes,masks,mask_value=1):
870
+ '''
871
+ boxes:[N,4],[x0,y0,x1,y1]
872
+ masks:[N,H,W]/[H,W]
873
+ '''
874
+ if len(masks.shape)==3:
875
+ shape = masks.shape[1:]
876
+ else:
877
+ shape = masks.shape
878
+
879
+ boxes[:,0:4:2] = np.clip(boxes[:,0:4:2],0.0,shape[1])
880
+ boxes[:,1:4:2] = np.clip(boxes[:,1:4:2],0.0,shape[0])
881
+ boxes = boxes.astype(np.int32)
882
+ for i,bbox in enumerate(boxes):
883
+ x0 = bbox[0]
884
+ y0 = bbox[1]
885
+ x1 = bbox[2]
886
+ y1 = bbox[3]
887
+ if len(masks.shape)==3:
888
+ masks[i,y0:y1,x0:x1] = mask_value
889
+ else:
890
+ masks[y0:y1,x0:x1] = mask_value
891
+
892
+ return masks
893
+
894
+ def draw_polygon(img,polygon,color=(255,255,255),is_line=True,isClosed=True):
895
+ if is_line:
896
+ return cv2.polylines(img, [polygon], color=color,isClosed=isClosed)
897
+ else:
898
+ return cv2.fillPoly(img,[polygon],color=color)
899
+
900
+ def colorize_semantic_seg(seg,color_mapping):
901
+ '''
902
+ seg:[H,W], value in [0,classes_num-1]
903
+ color_mapping: list of color size is color_nr*3
904
+ '''
905
+ seg = Image.fromarray(seg.astype(np.uint8)).convert('P')
906
+ seg.putpalette(color_mapping)
907
+ seg = seg.convert('RGB')
908
+ return np.array(seg)
909
+
910
+ def colorize_semantic_seg_by_label(seg,label,color_mapping):
911
+ '''
912
+ seg:[H,W], value in set([0,1])
913
+ labels: value in range [0,calsses_num-1]
914
+ color_mapping: list of color size is color_nr*3
915
+ '''
916
+ res = np.ones([seg.shape[0],seg.shape[1],3],dtype=np.int32)
917
+ color = np.array(color_mapping[label*3:label*3+3])
918
+ for i in range(3):
919
+ res[...,i] = color[i]
920
+ seg = np.expand_dims(seg,axis=-1)
921
+ res = res*seg
922
+ return res.astype(np.int32)
923
+
924
+ def draw_seg_on_img(img,seg,color_mapping=DEFAULT_COLOR_MAP,alpha=0.4,ignore_idx=255):
925
+ '''
926
+ img:[H,W,3/1]
927
+ seg:[num_classes,H,W]
928
+ color_mapping: list of color size is color_nr*3
929
+ '''
930
+ if seg.size == 0:
931
+ return img
932
+ seg = np.where(seg==ignore_idx,np.zeros_like(seg),seg)
933
+
934
+ sum_seg = np.sum(seg,axis=0,keepdims=False)
935
+ sum_seg = np.clip(sum_seg,a_min=1,a_max=10000)
936
+ inv_seg = 1.0/sum_seg.astype(np.float32)
937
+ inv_seg = np.expand_dims(inv_seg,axis=-1)
938
+
939
+ res = []
940
+ for i in range(seg.shape[0]):
941
+ c_seg = colorize_semantic_seg_by_label(seg[i],i,color_mapping=color_mapping)
942
+ res.append(c_seg*inv_seg)
943
+ res = np.stack(res,axis=0)
944
+ res = np.sum(res,axis=0,keepdims=False)
945
+
946
+ valid_mask = (sum_seg>0).astype(np.float32)
947
+ alpha = valid_mask*alpha
948
+ img_scale = 1.0-alpha
949
+
950
+ img = img*np.expand_dims(img_scale,axis=-1)+res*np.expand_dims(alpha,axis=-1)
951
+ img = np.clip(img,a_min=0,a_max=255).astype(np.uint8)
952
+
953
+ return img
954
+
955
+
956
+