python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,34 @@
1
+ import numpy as np
2
+ from numpy import seterr
3
+ from typing import Optional
4
+ import traceback
5
+ import sys
6
+
7
+ seterr(all='raise')
8
+
9
+ def npsafe_divide(numerator, denominator, name=None):
10
+ try:
11
+ return np.where(
12
+ np.greater(denominator, 0),
13
+ np.divide(numerator, denominator),
14
+ np.zeros_like(numerator))
15
+ except Exception as e:
16
+ print(e)
17
+ traceback.print_exc(file=sys.stdout)
18
+
19
+
20
+
21
+ def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
22
+ """
23
+ This function is taken from the original tf repo.
24
+ It ensures that all layers have a channel number that is divisible by 8
25
+ It can be seen here:
26
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
27
+ """
28
+ if min_value is None:
29
+ min_value = divisor
30
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
31
+ # Make sure that round down does not go down by more than 10%.
32
+ if new_v < 0.9 * v:
33
+ new_v += divisor
34
+ return new_v
File without changes
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ import cv2
3
+ import wml.basic_img_utils as bwmli
4
+ import sys
5
+
6
+ def find_contours_in_bbox(mask,bbox):
7
+ bbox = np.array(bbox).astype(np.int32)
8
+ sub_mask = bwmli.crop_img_absolute_xy(mask,bbox)
9
+ if sub_mask.shape[0]<=1 or sub_mask.shape[1]<=1:
10
+ return []
11
+ contours,hierarchy = cv2.findContours(sub_mask,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
12
+ if len(contours) == 0:
13
+ return []
14
+ offset = np.array([bbox[0],bbox[1]],dtype=np.int32)
15
+ offset = np.reshape(offset,[1,1,2])
16
+ res = []
17
+ for x in contours:
18
+ res.append(x+offset)
19
+ return res
20
+
21
+ def get_bboxes_by_contours(contours):
22
+ '''
23
+ contours:[[N,2]]
24
+ '''
25
+ if len(contours)==0:
26
+ return np.zeros([4],dtype=np.float32)
27
+ cn0 = np.reshape(contours[0],[-1,2])
28
+ x0 = np.min(cn0[:,0])
29
+ x1 = np.max(cn0[:,0])
30
+ y0 = np.min(cn0[:,1])
31
+ y1 = np.max(cn0[:,1])
32
+ for cn in contours[1:]:
33
+ cn = np.reshape(cn,[-1,2])
34
+ x0 = min(np.min(cn[:,0]),x0)
35
+ x1 = max(np.max(cn[:,0]),x1)
36
+ y0 = min(np.min(cn[:,1]),y0)
37
+ y1 = max(np.max(cn[:,1]),y1)
38
+
39
+ return np.array([x0,y0,x1,y1],dtype=np.float32)
40
+
41
+ def findContours(mask,mode=cv2.RETR_TREE,method=cv2.CHAIN_APPROX_SIMPLE):
42
+ '''
43
+ mask: [H,W] value is 0 or 1, np.uint8
44
+ return:
45
+ contours: list[[N,2]]
46
+ '''
47
+
48
+ _contours, hierarchy = cv2.findContours(mask, mode,method)
49
+ try:
50
+ hierarchy = np.reshape(hierarchy,[-1,4])
51
+ except Exception as e:
52
+ if len(_contours) != 0:
53
+ print(f"ERROR: {e}, {_contours} {hierarchy}")
54
+ hierarchy = np.zeros([0,4])
55
+ sys.stdout.flush()
56
+ contours = []
57
+ for he,cont in zip(hierarchy,_contours):
58
+ if he[-1]>=0 and cv2.contourArea(cont) < cv2.contourArea(_contours[he[-1]]):
59
+ continue
60
+ if len(cont.shape) == 3 and cont.shape[1] == 1:
61
+ contours.append(np.squeeze(cont,axis=1))
62
+ elif len(cont.shape)==2 and cont.shape[0]>2:
63
+ contours.append(cont)
64
+ return contours,hierarchy
65
+
@@ -0,0 +1,156 @@
1
+ #coding=utf-8
2
+ import numpy as np
3
+ import logging
4
+ import wml.basic_img_utils as bwmli
5
+ import cv2
6
+ from .basic_toolkit import *
7
+ import torch
8
+ import math
9
+
10
+ def np_iou(mask0,mask1):
11
+ if mask0.dtype is not np.bool:
12
+ mask0 = mask0.astype(np.bool)
13
+ if mask1.dtype is not np.bool:
14
+ mask1 = mask1.astype(np.bool)
15
+
16
+ if len(mask0.shape) != len(mask1.shape):
17
+ logging.warning("Mask not compatible with each other")
18
+ return 0.
19
+
20
+ different = np.logical_xor(mask0,mask1)
21
+ different = different.astype(np.float32)
22
+ different = np.sum(different)
23
+
24
+ union = np.logical_or(mask0,mask1)
25
+ union = union.astype(np.float32)
26
+ union = np.sum(union)
27
+ #logging.info("union={}, different={}, mask={}, gt={}".format(union,different,np.sum(mask0.astype(np.float32)),np.sum(mask1.astype(np.float32))))
28
+
29
+ if union == 0:
30
+ return 100.0
31
+
32
+ return 100.0-different*100.0/union
33
+
34
+ '''
35
+ mask:[H,W,NUM_CLASSES]
36
+ mask:[H,W]
37
+ '''
38
+ def np_mask2masklabels(mask,begin_label=1):
39
+ res = np.zeros(mask.shape[:2],np.int32)
40
+ h = mask.shape[0]
41
+ w = mask.shape[1]
42
+ num_classes = mask.shape[2]
43
+
44
+ for i in range(h):
45
+ for j in range(w):
46
+ for k in range(num_classes):
47
+ if mask[i,j,k]>0:
48
+ res[i,j] = k+begin_label
49
+ break
50
+
51
+ return res
52
+
53
+ def resize_img_and_mask(img,mask,size,img_pad_value=127,mask_pad_value=255,pad_type=1):
54
+ '''
55
+
56
+ Args:
57
+ img:
58
+ mask:
59
+ size: (w,h)
60
+ img_pad_value:
61
+ mask_pad_value:
62
+ pad_type:
63
+
64
+ Returns:
65
+
66
+ '''
67
+ img = np.array(img)
68
+ mask = np.array(mask)
69
+ img = bwmli.resize_img(img, size, keep_aspect_ratio=True)
70
+ mask = bwmli.resize_img(mask, img.shape[:2][::-1], keep_aspect_ratio=False, interpolation=cv2.INTER_NEAREST)
71
+ img, px0, px1, py0, py1 = bwmli.pad_img(img, size, pad_value=img_pad_value, pad_type=pad_type, return_pad_value=True)
72
+ mask = bwmli.pad_imgv2(mask, px0, px1, py0, py1, pad_value=mask_pad_value)
73
+ return img,mask
74
+
75
+ '''
76
+ mask:[H,W] value is 1 or 0
77
+ rect:[ymin,xmin,ymax,xmax]
78
+ output:
79
+ the new mask in sub image and correspond bbox
80
+ '''
81
+ def cut_mask(mask,rect):
82
+ max_area = np.sum(mask)
83
+ cuted_mask = bwmli.sub_image(mask,rect)
84
+ ratio = np.sum(cuted_mask)/max(1,max_area)
85
+ if ratio <= 1e-6:
86
+ return None,None,ratio
87
+ ys,xs = np.where(cuted_mask)
88
+ xmin = np.min(xs)
89
+ xmax = np.max(xs)
90
+ ymin = np.min(ys)
91
+ ymax = np.max(ys)
92
+ bbox = np.array([xmin,ymin,xmax,ymax],dtype=np.int32)
93
+ return cuted_mask,bbox,ratio
94
+
95
+ '''
96
+ mask:[N,H,W] value is 1 or 0
97
+ bboxes:[N,4] [ymin,xmin,ymax,xmax]
98
+ output:
99
+ the new mask in sub image and correspond bbox
100
+ '''
101
+ def cut_masks(masks,bboxes):
102
+ new_masks = []
103
+ new_bboxes = []
104
+ ratios = []
105
+ nr = len(masks)
106
+ for i in range(len(nr)):
107
+ n_mask,n_bbox,ratio = cut_mask(masks[i],bboxes[i])
108
+ new_masks.append(n_mask)
109
+ new_bboxes.append(n_bbox)
110
+ ratios.append(ratio)
111
+
112
+ return new_masks,new_bboxes,ratios
113
+
114
+
115
+ def resize_mask(mask,size=None,r=None,mode='nearest'):
116
+ '''
117
+ mask: [N,H,W]
118
+ size: (new_w,new_h)
119
+ mode (str): algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' | 'nearest-exact'.
120
+ Default: 'nearest'
121
+ '''
122
+ if size is None:
123
+ size = (int(mask.shape[2]*r),int(mask.shape[1]*r))
124
+ if mask.numel()==0:
125
+ return mask.new_zeros([mask.shape[0],size[1],size[0]])
126
+
127
+ mask = torch.unsqueeze(mask,dim=0)
128
+ mask = torch.nn.functional.interpolate(mask,size=(size[1],size[0]),mode=mode)
129
+ mask = torch.squeeze(mask,dim=0)
130
+ return mask
131
+
132
+ def npresize_mask(mask,size=None,r=None):
133
+ '''
134
+ mask: [N,H,W]
135
+ size: (new_w,new_h)
136
+ '''
137
+ if mask.shape[0]==0:
138
+ return np.zeros([0,size[1],size[0]],dtype=mask.dtype)
139
+ new_mask = []
140
+ for i in range(mask.shape[0]):
141
+ cur_m = cv2.resize(mask[i],dsize=(size[0],size[1]),interpolation=cv2.INTER_NEAREST)
142
+ new_mask.append(cur_m)
143
+ new_mask = np.stack(new_mask,axis=0)
144
+ return new_mask
145
+
146
+ def resize_mask_structures(mask,size):
147
+ '''
148
+ size:[W,H]
149
+ '''
150
+ if isinstance(mask,np.ndarray):
151
+ return npresize_mask(mask,size)
152
+ if torch.is_tensor(mask):
153
+ return resize_mask(mask,size)
154
+ if hasattr(mask,"resize"):
155
+ return mask.resize(size)
156
+ raise RuntimeError("Unimplement")
@@ -0,0 +1,21 @@
1
+ from wml.semantic.structures import *
2
+ from wml.iotoolkit.labelme_toolkit import LabelMeData
3
+ import wml.img_utils as wmli
4
+ import wml.object_detection2.visualization as odv
5
+
6
+ if __name__ == "__main__":
7
+ data = LabelMeData(use_polygon_mask=True)
8
+ data.read_data("~/ai/mldata1/B10CF/datasets/testv1.0/")
9
+ d = data[0]
10
+ mask = d[5]
11
+ crop_bbox = [1095,500,1235,600]
12
+ crop_bbox = [800,500,1235,600]
13
+ crop_bbox = [800,300,1235,600]
14
+ crop_bbox = [600,300,1600,600]
15
+ mask = mask.crop(crop_bbox)
16
+ img = wmli.imread(d[0])
17
+ img = wmli.crop_img_absolute_xy(img,crop_bbox)
18
+ img = odv.draw_maskv2(img,d[3],None,mask,is_relative_coordinate=False)
19
+ wmli.imwrite("tmp.jpg",img)
20
+
21
+
@@ -0,0 +1 @@
1
+ from wml.wstructures.mask_structures import WBaseMask, WBitmapMasks, WPolygonMaskItem, WPolygonMasks
@@ -0,0 +1,105 @@
1
+ #coding=utf-8
2
+ import sys
3
+ import os
4
+ from .mask_utils import np_iou
5
+ import numpy as np
6
+ from .visualization_utils import MIN_RANDOM_STANDARD_COLORS, draw_mask_on_image_array
7
+
8
+ '''
9
+ image:[height,width,3]
10
+ mask:[height,width,N]
11
+ colors:[N], string
12
+ alpha:
13
+ '''
14
+ def np_draw_masks_on_image(image, mask, colors, alpha=0.4):
15
+ if image.dtype is not np.uint8:
16
+ image = image.astype(np.uint8)
17
+ if mask.dtype is not np.uint8:
18
+ mask = mask.astype(np.uint8)
19
+ mask = np.transpose(mask, axes=[2, 0, 1])
20
+ colors_nr = len(colors)
21
+
22
+ for i,msk in enumerate(mask):
23
+ image = draw_mask_on_image_array(image,msk,colors[i%colors_nr],alpha)
24
+
25
+ return image
26
+
27
+ '''
28
+ image:[batch_size,height,width,3]
29
+ mask:[batch_size,height,width,N]
30
+ alpha:
31
+ '''
32
+ def np_draw_masks_on_images(image,mask,alpha,colors=MIN_RANDOM_STANDARD_COLORS,no_first_mask=False):
33
+ if no_first_mask:
34
+ mask = mask[:,:,:,1:]
35
+ res_images = []
36
+
37
+ for img,msk in zip(image,mask):
38
+ new_img = np_draw_masks_on_image(image=img,mask=msk,colors=colors,alpha=alpha)
39
+ res_images.append(new_img)
40
+
41
+ return np.array(res_images)
42
+
43
+
44
+
45
+ '''
46
+ masks:[X,H,W]
47
+ labels:[X]
48
+ no_background: 如果为True, 那么labels的值域为[1,num_classes], 生成时labels转换为labels-1
49
+ output:
50
+ [num_classes,H,W]/[num_classes-1,H,W](no_background=True)
51
+ '''
52
+ def merge_masks(masks,labels,num_classes,size=None,no_background=False):
53
+ if size is not None:
54
+ width = size[1]
55
+ height = size[0]
56
+ elif len(masks.shape)>=3:
57
+ width = masks.shape[2]
58
+ height = masks.shape[1]
59
+
60
+ if no_background:
61
+ get_label = lambda x:max(0,x-1)
62
+ res = np.zeros([num_classes-1,height,width],dtype=np.int32)
63
+ else:
64
+ get_label = lambda x:x
65
+ res = np.zeros([num_classes,height,width],dtype=np.int32)
66
+
67
+ for i,mask in enumerate(masks):
68
+ label = get_label(labels[i])
69
+ res[label:label+1,:,:] = np.logical_or(res[label:label+1,:,:],np.expand_dims(mask,axis=0))
70
+
71
+ return res
72
+
73
+ '''
74
+ def get_fullsize_merged_mask(masks,bboxes,labels,size,num_classes,no_background=True):
75
+ fullsize_masks = ivs.get_fullsize_mask(bboxes,masks,size)
76
+ return merge_masks(fullsize_masks,labels,num_classes,size,no_background)
77
+ '''
78
+
79
+ class ModelPerformance:
80
+ def __init__(self,no_first_class=True):
81
+ self.test_nr = 0
82
+ self.total_iou = 0.
83
+ self.no_first_class = no_first_class
84
+
85
+
86
+ def clear(self):
87
+ self.test_nr = 0
88
+ self.total_iou = 0.
89
+
90
+ '''
91
+ mask_gt: [batch_size,h,w,num_classes]
92
+ mask_pred: [batch_size,h,w,num_classes]
93
+ background is [:,:,0]
94
+ '''
95
+ def __call__(self, mask_gt,mask_pred):
96
+ if self.no_first_class:
97
+ mask_gt = mask_gt[:,:,:,1:]
98
+ mask_pred = mask_pred[:,:,:,1:]
99
+ tmp_iou = np_iou(mask_gt,mask_pred)
100
+ self.total_iou += tmp_iou
101
+ self.test_nr += 1
102
+ return tmp_iou, self.mIOU()
103
+
104
+ def mIOU(self):
105
+ return self.total_iou/self.test_nr