python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from .common import *
|
|
3
|
+
import sys
|
|
4
|
+
from .build import METRICS_REGISTRY
|
|
5
|
+
from wml.object_detection2.keypoints import mckps_distance_matrix
|
|
6
|
+
import math
|
|
7
|
+
import copy
|
|
8
|
+
'''
|
|
9
|
+
gtkps:[N,2]
|
|
10
|
+
gtlabels:[N]
|
|
11
|
+
kps: [M]
|
|
12
|
+
labels: [M]
|
|
13
|
+
scores: [M]
|
|
14
|
+
'''
|
|
15
|
+
def getMCKpsPrecision(gtkps,gtlabels,kps,labels,sigma=3,ext_info=False,is_crowd=None):
|
|
16
|
+
if not isinstance(gtkps,np.ndarray):
|
|
17
|
+
gtkps = np.array(gtkps)
|
|
18
|
+
if not isinstance(gtlabels,np.ndarray):
|
|
19
|
+
gtlabels = np.array(gtlabels)
|
|
20
|
+
if is_crowd is None:
|
|
21
|
+
is_crowd = np.zeros([gtlabels.shape[0]],dtype=np.bool)
|
|
22
|
+
if not isinstance(is_crowd,np.ndarray):
|
|
23
|
+
is_crowd = np.array(is_crowd)
|
|
24
|
+
|
|
25
|
+
if kps.size == 0:
|
|
26
|
+
if gtkps.size == 0:
|
|
27
|
+
return 100.0,100.0
|
|
28
|
+
return 100.0,0.0
|
|
29
|
+
elif gtkps.size == 0:
|
|
30
|
+
return 0.0,100.0
|
|
31
|
+
|
|
32
|
+
gt_shape = gtkps.shape
|
|
33
|
+
#indict if there have some kps match with this ground-truth kps
|
|
34
|
+
gt_mask = np.zeros([gt_shape[0]],dtype=np.int32)
|
|
35
|
+
kps_shape = kps.shape
|
|
36
|
+
#indict if there have some ground-truth kps match with this kps
|
|
37
|
+
kps_mask = np.zeros(kps_shape[0],dtype=np.int32)
|
|
38
|
+
gt_size = gtlabels.shape[0]
|
|
39
|
+
kps_size = labels.shape[0]
|
|
40
|
+
dis_m = mckps_distance_matrix(gtkps,kps)
|
|
41
|
+
for i in range(gt_size):
|
|
42
|
+
|
|
43
|
+
cur_dis = dis_m[i]
|
|
44
|
+
idxs = np.argsort(cur_dis)
|
|
45
|
+
for idx in idxs:
|
|
46
|
+
if kps_mask[idx] or gtlabels[i] != labels[idx]:
|
|
47
|
+
continue
|
|
48
|
+
cur_d = cur_dis[idx]
|
|
49
|
+
if cur_d > sigma:
|
|
50
|
+
break
|
|
51
|
+
gt_mask[i] = 1
|
|
52
|
+
kps_mask[idx] = 1
|
|
53
|
+
break
|
|
54
|
+
|
|
55
|
+
r_gt_mask = np.logical_or(gt_mask,is_crowd)
|
|
56
|
+
correct_gt_num = np.sum(r_gt_mask)
|
|
57
|
+
correct_bkps_num = np.sum(kps_mask)
|
|
58
|
+
|
|
59
|
+
recall = safe_persent(correct_gt_num,gt_size)
|
|
60
|
+
precision = safe_persent(correct_bkps_num,kps_size)
|
|
61
|
+
P_v = gt_size
|
|
62
|
+
TP_v = correct_bkps_num
|
|
63
|
+
FP_v = kps_size-correct_bkps_num
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
if ext_info:
|
|
67
|
+
gt_label_list = []
|
|
68
|
+
for i in range(gt_mask.shape[0]):
|
|
69
|
+
if gt_mask[i] != 1:
|
|
70
|
+
gt_label_list.append(gtlabels[i])
|
|
71
|
+
pred_label_list = []
|
|
72
|
+
for i in range(kps_size):
|
|
73
|
+
if kps_mask[i] != 1:
|
|
74
|
+
pred_label_list.append(labels[i])
|
|
75
|
+
return precision,recall,gt_label_list,pred_label_list,TP_v,FP_v,P_v
|
|
76
|
+
else:
|
|
77
|
+
return precision,recall
|
|
78
|
+
|
|
79
|
+
'''
|
|
80
|
+
gtkps:[N,2]
|
|
81
|
+
gtlabels:[N]
|
|
82
|
+
kps: [M]
|
|
83
|
+
labels: [M]
|
|
84
|
+
scores: [M]
|
|
85
|
+
'''
|
|
86
|
+
def getMCKpsAccuracy(gtkps,gtlabels,kps,labels,sigma=3,ext_info=False,is_crowd=None):
|
|
87
|
+
if not isinstance(gtkps,np.ndarray):
|
|
88
|
+
gtkps = np.array(gtkps)
|
|
89
|
+
if not isinstance(gtlabels,np.ndarray):
|
|
90
|
+
gtlabels = np.array(gtlabels)
|
|
91
|
+
if is_crowd is None:
|
|
92
|
+
is_crowd = np.zeros([gtlabels.shape[0]],dtype=np.bool)
|
|
93
|
+
if not isinstance(is_crowd,np.ndarray):
|
|
94
|
+
is_crowd = np.array(is_crowd)
|
|
95
|
+
|
|
96
|
+
if kps.size == 0:
|
|
97
|
+
if gtkps.size == 0:
|
|
98
|
+
return 100.0
|
|
99
|
+
return 0.0
|
|
100
|
+
elif gtkps.size == 0:
|
|
101
|
+
return 0.0
|
|
102
|
+
|
|
103
|
+
gt_shape = gtkps.shape
|
|
104
|
+
#indict if there have some kps match with this ground-truth kps
|
|
105
|
+
gt_mask = np.zeros([gt_shape[0]],dtype=np.int32)
|
|
106
|
+
kps_shape = kps.shape
|
|
107
|
+
#indict if there have some ground-truth kps match with this kps
|
|
108
|
+
kps_mask = np.zeros(kps_shape[0],dtype=np.int32)
|
|
109
|
+
gt_size = gtlabels.shape[0]
|
|
110
|
+
kps_size = labels.shape[0]
|
|
111
|
+
dis_m = mckps_distance_matrix(gtkps,kps)
|
|
112
|
+
for i in range(gt_size):
|
|
113
|
+
|
|
114
|
+
cur_dis = dis_m[i]
|
|
115
|
+
idxs = np.argsort(cur_dis)
|
|
116
|
+
for idx in idxs:
|
|
117
|
+
if kps_mask[idx] or gtlabels[i] != labels[idx]:
|
|
118
|
+
continue
|
|
119
|
+
cur_d = cur_dis[idx]
|
|
120
|
+
if cur_d > sigma:
|
|
121
|
+
break
|
|
122
|
+
gt_mask[i] = 1
|
|
123
|
+
kps_mask[idx] = 1
|
|
124
|
+
break
|
|
125
|
+
|
|
126
|
+
r_gt_mask = np.logical_or(gt_mask,is_crowd)
|
|
127
|
+
correct_gt_num = np.sum(r_gt_mask)
|
|
128
|
+
#correct_bkps_num = np.sum(kps_mask)
|
|
129
|
+
all_num = gt_size+kps_size-correct_gt_num
|
|
130
|
+
|
|
131
|
+
acc = safe_persent(correct_gt_num,all_num)
|
|
132
|
+
|
|
133
|
+
return acc
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
'''
|
|
137
|
+
gtkps:[N,2]
|
|
138
|
+
gtlabels:[N]
|
|
139
|
+
kps: [M]
|
|
140
|
+
labels: [M]
|
|
141
|
+
scores: [M]
|
|
142
|
+
return:
|
|
143
|
+
mAP:[0,100]
|
|
144
|
+
'''
|
|
145
|
+
def getKpsmAP(gtkps,gtlabels,kps,labels,scores=None,sigma=3,is_crowd=None):
|
|
146
|
+
|
|
147
|
+
if not isinstance(gtkps,np.ndarray):
|
|
148
|
+
gtkps = np.array(gtkps)
|
|
149
|
+
if not isinstance(gtlabels,np.ndarray):
|
|
150
|
+
gtlabels = np.array(gtlabels)
|
|
151
|
+
if not isinstance(kps,np.ndarray):
|
|
152
|
+
kps = np.array(kps)
|
|
153
|
+
if not isinstance(labels,np.ndarray):
|
|
154
|
+
labels = np.array(labels)
|
|
155
|
+
if is_crowd is None:
|
|
156
|
+
is_crowd = np.zeros([gtlabels.shape[0]],dtype=np.bool)
|
|
157
|
+
if not isinstance(is_crowd,np.ndarray):
|
|
158
|
+
is_crowd = np.array(is_crowd)
|
|
159
|
+
gtkps = copy.deepcopy(np.array(gtkps))
|
|
160
|
+
gtlabels = copy.deepcopy(np.array(gtlabels))
|
|
161
|
+
kps = copy.deepcopy(kps)
|
|
162
|
+
labels = copy.deepcopy(labels)
|
|
163
|
+
if scores is not None:
|
|
164
|
+
#按scores从小到大排列
|
|
165
|
+
scores = copy.deepcopy(scores)
|
|
166
|
+
index = np.argsort(scores)
|
|
167
|
+
kps = kps[index]
|
|
168
|
+
labels = labels[index]
|
|
169
|
+
|
|
170
|
+
max_nr = 20
|
|
171
|
+
data_nr = kps.shape[0]
|
|
172
|
+
|
|
173
|
+
if data_nr==0:
|
|
174
|
+
if gtkps.size == 0:
|
|
175
|
+
return 100.0
|
|
176
|
+
return 0.0
|
|
177
|
+
|
|
178
|
+
if data_nr>max_nr:
|
|
179
|
+
beg_index = range(0,data_nr,data_nr//max_nr)
|
|
180
|
+
else:
|
|
181
|
+
beg_index = range(0,data_nr)
|
|
182
|
+
|
|
183
|
+
t_res = []
|
|
184
|
+
|
|
185
|
+
for v in beg_index:
|
|
186
|
+
p,r = getMCKpsPrecision(gtkps,gtlabels,kps[v:],labels[v:],sigma,is_crowd=is_crowd)
|
|
187
|
+
t_res.append([p,r]) #r从大到小
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
t_res1 = []
|
|
191
|
+
old_v = None
|
|
192
|
+
for v in t_res:
|
|
193
|
+
if old_v is not None and v[0]<old_v[0]:
|
|
194
|
+
v[0] = old_v[0]
|
|
195
|
+
t_res1.append(v)
|
|
196
|
+
old_v = v
|
|
197
|
+
|
|
198
|
+
res = []
|
|
199
|
+
old_v = None
|
|
200
|
+
for v in reversed(t_res1):
|
|
201
|
+
if old_v is not None:
|
|
202
|
+
if v[1]<old_v[1]:
|
|
203
|
+
v[1] = old_v[1]
|
|
204
|
+
if math.fabs(v[1]-old_v[1])<1e-3 and v[0]<old_v[0]:
|
|
205
|
+
v[0] = old_v[0]
|
|
206
|
+
res.append(v)
|
|
207
|
+
old_v = v
|
|
208
|
+
|
|
209
|
+
min_r = res[0][1]
|
|
210
|
+
max_r = res[-1][1]
|
|
211
|
+
|
|
212
|
+
if min_r > 1e-2:
|
|
213
|
+
res = np.concatenate([np.array([[res[0][0],0.]]),res],axis=0)
|
|
214
|
+
if max_r <100.0-1e-2:
|
|
215
|
+
l_precisions = res[-1][0]
|
|
216
|
+
l_recall = res[-1][1]
|
|
217
|
+
t_precision = min(l_precisions*l_recall/100.0,l_precisions)
|
|
218
|
+
res = np.concatenate([res,np.array([[t_precision,100.0]])])
|
|
219
|
+
|
|
220
|
+
res = np.array(res)
|
|
221
|
+
res = res.transpose()
|
|
222
|
+
precisions = res[0]
|
|
223
|
+
recall = res[1]
|
|
224
|
+
new_r = np.arange(0.,100.01,10.).tolist()
|
|
225
|
+
new_p = []
|
|
226
|
+
for r in new_r:
|
|
227
|
+
new_p.append(np.interp(r,recall,precisions))
|
|
228
|
+
precisions = np.array(new_p)
|
|
229
|
+
return np.mean(precisions)
|
|
230
|
+
|
|
231
|
+
class BaseMCKpsMetrics(BaseMetrics):
|
|
232
|
+
def __init__(self,sigma=5,*args,**kwargs):
|
|
233
|
+
super().__init__()
|
|
234
|
+
self.sigma = sigma
|
|
235
|
+
self.all_gt_keypoints = []
|
|
236
|
+
self.all_gt_labels = []
|
|
237
|
+
self.all_keypoints = []
|
|
238
|
+
self.all_labels = []
|
|
239
|
+
self.all_scores = []
|
|
240
|
+
self.npall_gt_keypoints = None
|
|
241
|
+
self.npall_gt_labels = None
|
|
242
|
+
self.npall_keypoints = None
|
|
243
|
+
self.npall_labels = None
|
|
244
|
+
self.npall_scores = None
|
|
245
|
+
self.offset = 0
|
|
246
|
+
self.img_id = 0
|
|
247
|
+
self.value = None
|
|
248
|
+
print(f"Sigma={self.sigma}")
|
|
249
|
+
|
|
250
|
+
'''
|
|
251
|
+
gtkeypoitns:[N,2]
|
|
252
|
+
gtlabels:[N]
|
|
253
|
+
kps: [M,2]
|
|
254
|
+
labels: [M]
|
|
255
|
+
scores: [M]
|
|
256
|
+
'''
|
|
257
|
+
def __call__(self, gtkps,gtlabels,kps,labels,scores=None,area=None,iscrowd=None,probability=None):
|
|
258
|
+
if probability is not None and scores is None:
|
|
259
|
+
scores = probability
|
|
260
|
+
c_offset = max(np.max(gtkps) if gtkps.size>0 else 0,np.max(kps) if kps.size>0 else 0)
|
|
261
|
+
c_offset += self.sigma+1
|
|
262
|
+
self.all_gt_keypoints.append(gtkps+self.offset)
|
|
263
|
+
self.all_gt_labels.append(gtlabels)
|
|
264
|
+
self.all_keypoints.append(kps+self.offset)
|
|
265
|
+
self.all_labels.append(labels)
|
|
266
|
+
self.all_scores.append(scores)
|
|
267
|
+
self.offset += c_offset
|
|
268
|
+
self.img_id += 1
|
|
269
|
+
|
|
270
|
+
def evaluate(self):
|
|
271
|
+
self.npall_gt_keypoints = np.concatenate(self.all_gt_keypoints,axis=0)
|
|
272
|
+
self.npall_gt_labels = np.concatenate(self.all_gt_labels,axis=0)
|
|
273
|
+
self.npall_keypoints = np.concatenate(self.all_keypoints,axis=0)
|
|
274
|
+
self.npall_labels = np.concatenate(self.all_labels,axis=0)
|
|
275
|
+
self.npall_scores = np.concatenate(self.all_scores,axis=0)
|
|
276
|
+
|
|
277
|
+
def num_examples(self):
|
|
278
|
+
return self.img_id
|
|
279
|
+
|
|
280
|
+
def show(self,name=""):
|
|
281
|
+
if len(name)>0:
|
|
282
|
+
print(name)
|
|
283
|
+
res = self.to_string()
|
|
284
|
+
print(res)
|
|
285
|
+
return res
|
|
286
|
+
|
|
287
|
+
def to_string(self):
|
|
288
|
+
if self.value is None:
|
|
289
|
+
self.evaluate()
|
|
290
|
+
return str(self.value)
|
|
291
|
+
|
|
292
|
+
@METRICS_REGISTRY.register()
|
|
293
|
+
class MCKpsMap(BaseMCKpsMetrics):
|
|
294
|
+
def evaluate(self):
|
|
295
|
+
super().evaluate()
|
|
296
|
+
self.value = getKpsmAP(gtkps=self.npall_gt_keypoints,
|
|
297
|
+
gtlabels=self.npall_gt_labels,
|
|
298
|
+
kps=self.npall_keypoints,
|
|
299
|
+
labels=self.npall_labels,
|
|
300
|
+
scores=self.npall_scores,
|
|
301
|
+
sigma=self.sigma)
|
|
302
|
+
return self.value
|
|
303
|
+
|
|
304
|
+
@METRICS_REGISTRY.register()
|
|
305
|
+
class MCKpsPrecisionAndRecall(BaseMCKpsMetrics):
|
|
306
|
+
def __init__(self,threshold=None,*args,**kwargs):
|
|
307
|
+
super().__init__(*args,**kwargs)
|
|
308
|
+
self.threshold = threshold
|
|
309
|
+
self.acc = None
|
|
310
|
+
|
|
311
|
+
def evaluate(self):
|
|
312
|
+
super().evaluate()
|
|
313
|
+
if self.threshold is not None:
|
|
314
|
+
keep = self.npall_scores>=self.threshold
|
|
315
|
+
self.npall_scores = self.npall_scores[keep]
|
|
316
|
+
self.npall_labels = self.npall_labels[keep]
|
|
317
|
+
self.npall_keypoints = self.npall_keypoints[keep]
|
|
318
|
+
self.value = getMCKpsPrecision(gtkps=self.npall_gt_keypoints,
|
|
319
|
+
gtlabels=self.npall_gt_labels,
|
|
320
|
+
kps=self.npall_keypoints,
|
|
321
|
+
labels=self.npall_labels,
|
|
322
|
+
#scores=self.npall_scores,
|
|
323
|
+
sigma=self.sigma)
|
|
324
|
+
self.acc = getMCKpsAccuracy(gtkps=self.npall_gt_keypoints,
|
|
325
|
+
gtlabels=self.npall_gt_labels,
|
|
326
|
+
kps=self.npall_keypoints,
|
|
327
|
+
labels=self.npall_labels,
|
|
328
|
+
#scores=self.npall_scores,
|
|
329
|
+
sigma=self.sigma)
|
|
330
|
+
self.p,self.r = self.value
|
|
331
|
+
return self.value
|
|
332
|
+
|
|
333
|
+
def to_string(self):
|
|
334
|
+
if self.value is None:
|
|
335
|
+
self.evaluate()
|
|
336
|
+
p,r = self.value
|
|
337
|
+
f1 = 2*p*r/max(p+r,1e-6)
|
|
338
|
+
return f"P={p:.2f}, R={r:.2f}, f1={f1:.2f}, acc={self.acc:.2f}"
|