python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import sys
|
|
3
|
+
import copy
|
|
4
|
+
from .common import BaseClassifierMetrics
|
|
5
|
+
from .build import CLASSIFIER_METRICS_REGISTRY
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _safe_persent(v0,v1):
|
|
9
|
+
if v1==0:
|
|
10
|
+
return 100.
|
|
11
|
+
else:
|
|
12
|
+
return v0*100./v1
|
|
13
|
+
|
|
14
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
15
|
+
class Accuracy(BaseClassifierMetrics):
|
|
16
|
+
def __init__(self,topk=1,**kwargs):
|
|
17
|
+
super().__init__(**kwargs)
|
|
18
|
+
self.topk = topk
|
|
19
|
+
self.all_correct = []
|
|
20
|
+
self.accuracy = 100.0
|
|
21
|
+
|
|
22
|
+
def __call__(self,output,target):
|
|
23
|
+
'''
|
|
24
|
+
output: [N0,...,Nn,num_classes] or [N0,...,Nn]
|
|
25
|
+
target: [N0,...,Nn]
|
|
26
|
+
'''
|
|
27
|
+
if output.ndim==target.ndim:
|
|
28
|
+
idx = output
|
|
29
|
+
output = np.reshape(output,[-1])
|
|
30
|
+
target = np.reshape(target,[-1])
|
|
31
|
+
else:
|
|
32
|
+
idx = np.argsort(output,axis=-1)
|
|
33
|
+
idx = idx[...,-self.topk:]
|
|
34
|
+
target = np.repeat(np.expand_dims(target,axis=-1),self.topk,axis=-1)
|
|
35
|
+
correct = target==idx
|
|
36
|
+
correct = np.reshape(correct,[-1])
|
|
37
|
+
self.all_correct.append(correct)
|
|
38
|
+
|
|
39
|
+
def num_examples(self):
|
|
40
|
+
if len(self.all_correct)==0:
|
|
41
|
+
return
|
|
42
|
+
all_correct = np.concatenate(self.all_correct,axis=0)
|
|
43
|
+
return all_correct.size
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def evaluate(self):
|
|
47
|
+
self.accuracy = 100
|
|
48
|
+
if len(self.all_correct)==0:
|
|
49
|
+
return self.accuracy
|
|
50
|
+
all_correct = np.concatenate(self.all_correct,axis=0)
|
|
51
|
+
if all_correct.size == 0:
|
|
52
|
+
return self.accuracy
|
|
53
|
+
|
|
54
|
+
print(f"Total {all_correct.size} samples")
|
|
55
|
+
correct = float(np.sum(all_correct))
|
|
56
|
+
|
|
57
|
+
self.accuracy = _safe_persent(correct,all_correct.size)
|
|
58
|
+
|
|
59
|
+
return self.accuracy
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def show(self,name=""):
|
|
63
|
+
sys.stdout.flush()
|
|
64
|
+
self.evaluate()
|
|
65
|
+
print(f"Test size {self.num_examples()}")
|
|
66
|
+
print(f"accuracy={self.accuracy}")
|
|
67
|
+
return self.accuracy
|
|
68
|
+
|
|
69
|
+
def value(self):
|
|
70
|
+
return self.accuracy
|
|
71
|
+
|
|
72
|
+
def to_string(self):
|
|
73
|
+
return f"{self.accuracy:.2f}"
|
|
74
|
+
|
|
75
|
+
def __repr__(self):
|
|
76
|
+
return self.to_string()
|
|
77
|
+
|
|
78
|
+
def value(self):
|
|
79
|
+
return self.accuracy
|
|
80
|
+
|
|
81
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
82
|
+
class BAccuracy(Accuracy):
|
|
83
|
+
def __init__(self,num_classes,**kwargs):
|
|
84
|
+
'''
|
|
85
|
+
二分类的正确率,最后一个类别为背景,其它类别为前景,只需要将背景或前景分正确即可
|
|
86
|
+
'''
|
|
87
|
+
self.bk_classes = num_classes-1
|
|
88
|
+
super().__init__(**kwargs)
|
|
89
|
+
|
|
90
|
+
def __call__(self,output,target):
|
|
91
|
+
'''
|
|
92
|
+
output: [N0,...,Nn,num_classes]
|
|
93
|
+
target: [N0,...,Nn]
|
|
94
|
+
'''
|
|
95
|
+
idx = np.argsort(output,axis=-1)
|
|
96
|
+
idx = idx[...,-1]
|
|
97
|
+
labels = idx!=self.bk_classes
|
|
98
|
+
target = target!=self.bk_classes
|
|
99
|
+
labels = np.reshape(labels,[-1])
|
|
100
|
+
target = np.reshape(target,[-1])
|
|
101
|
+
return super().__call__(labels,target)
|
|
102
|
+
|
|
103
|
+
def to_string(self):
|
|
104
|
+
return f"{self.accuracy:.2f}"
|
|
105
|
+
|
|
106
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
107
|
+
class PrecisionAndRecall(BaseClassifierMetrics):
|
|
108
|
+
def __init__(self,**kwargs):
|
|
109
|
+
super().__init__(**kwargs)
|
|
110
|
+
self.all_output = []
|
|
111
|
+
self.all_target = []
|
|
112
|
+
self.recall = 100.0
|
|
113
|
+
self.precision = 100.0
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def __call__(self,output,target):
|
|
117
|
+
'''
|
|
118
|
+
output: [N0,...,Nn]
|
|
119
|
+
target: [N0,...,Nn]
|
|
120
|
+
'''
|
|
121
|
+
self.all_output.append(np.reshape(output,[-1]))
|
|
122
|
+
self.all_target.append(np.reshape(target,[-1]))
|
|
123
|
+
|
|
124
|
+
def num_examples(self):
|
|
125
|
+
if len(self.all_output)==0:
|
|
126
|
+
return
|
|
127
|
+
all_output = np.concatenate(self.all_output,axis=0)
|
|
128
|
+
return all_output.size
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def evaluate(self):
|
|
132
|
+
self.recall = 100
|
|
133
|
+
self.precision = 100
|
|
134
|
+
if len(self.all_output)==0:
|
|
135
|
+
return self.precision,self.recall
|
|
136
|
+
all_output = np.concatenate(self.all_output,axis=0)
|
|
137
|
+
if all_output.size == 0:
|
|
138
|
+
return self.precision,self.recall
|
|
139
|
+
|
|
140
|
+
all_target = np.concatenate(self.all_target,axis=0)
|
|
141
|
+
tmp_mask = all_output==all_target
|
|
142
|
+
correct = np.sum(all_output[tmp_mask].astype(np.float32))
|
|
143
|
+
|
|
144
|
+
tp_fp = np.sum(all_output.astype(np.float32))
|
|
145
|
+
tp_fn = np.sum(all_target.astype(np.float32))
|
|
146
|
+
|
|
147
|
+
self.precision = _safe_persent(correct,tp_fp)
|
|
148
|
+
self.recall = _safe_persent(correct,tp_fn)
|
|
149
|
+
|
|
150
|
+
return self.precision,self.recall
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def show(self,name=""):
|
|
154
|
+
sys.stdout.flush()
|
|
155
|
+
self.evaluate()
|
|
156
|
+
print(f"Test size {self.num_examples()}")
|
|
157
|
+
print(self.to_string())
|
|
158
|
+
|
|
159
|
+
def value(self):
|
|
160
|
+
return f"P={self.precision:.2f}/R={self.recall:.2f}"
|
|
161
|
+
|
|
162
|
+
def to_string(self):
|
|
163
|
+
return f"P={self.precision:.2f}, R={self.recall:.2f}"
|
|
164
|
+
|
|
165
|
+
def __repr__(self):
|
|
166
|
+
return self.to_string()
|
|
167
|
+
|
|
168
|
+
def value(self):
|
|
169
|
+
return _safe_persent(self.precision*self.recall,self.precision+self.recall) #F1
|
|
170
|
+
|
|
171
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
172
|
+
class BPrecisionAndRecall(PrecisionAndRecall):
|
|
173
|
+
def __init__(self,num_classes,**kwargs):
|
|
174
|
+
'''
|
|
175
|
+
二分类的精度与召回,最后一个类别为背景,其它类别为前景,只需要将背景或前景分正确即可
|
|
176
|
+
'''
|
|
177
|
+
self.bk_classes = num_classes-1
|
|
178
|
+
super().__init__(**kwargs)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def __call__(self,output,target):
|
|
182
|
+
'''
|
|
183
|
+
output: [N0,...,Nn,num_classes]
|
|
184
|
+
target: [N0,...,Nn]
|
|
185
|
+
'''
|
|
186
|
+
idx = np.argsort(output,axis=-1)
|
|
187
|
+
idx = idx[...,-1]
|
|
188
|
+
labels = idx!=self.bk_classes
|
|
189
|
+
target = target!=self.bk_classes
|
|
190
|
+
labels = np.reshape(labels,[-1])
|
|
191
|
+
target = np.reshape(target,[-1])
|
|
192
|
+
return super().__call__(labels,target)
|
|
193
|
+
|
|
194
|
+
def to_string(self):
|
|
195
|
+
return f"BP={self.precision:.2f}, BR={self.recall:.2f}"
|
|
196
|
+
|
|
197
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
198
|
+
class ConfusionMatrix(BaseClassifierMetrics):
|
|
199
|
+
def __init__(self,num_classes=-1,**kwargs):
|
|
200
|
+
super().__init__(**kwargs)
|
|
201
|
+
self.all_target = []
|
|
202
|
+
self.all_pred = []
|
|
203
|
+
self.accuracy = 100.0
|
|
204
|
+
self.num_classes = num_classes
|
|
205
|
+
self.cm = []
|
|
206
|
+
|
|
207
|
+
def __call__(self,output,target):
|
|
208
|
+
'''
|
|
209
|
+
output: [N0,...,Nn,num_classes]
|
|
210
|
+
target: [N0,...,Nn]
|
|
211
|
+
'''
|
|
212
|
+
if len(output.shape)>len(target.shape):
|
|
213
|
+
if self.num_classes<=0:
|
|
214
|
+
self.num_classes = output.shape[-1]
|
|
215
|
+
idx = np.argsort(output,axis=-1)
|
|
216
|
+
pred = idx[...,-1]
|
|
217
|
+
else:
|
|
218
|
+
pred = output
|
|
219
|
+
self.all_pred.append(copy.deepcopy(np.reshape(pred,[-1])))
|
|
220
|
+
self.all_target.append(copy.deepcopy(np.reshape(target,[-1])))
|
|
221
|
+
|
|
222
|
+
def num_examples(self):
|
|
223
|
+
if len(self.all_pred)==0:
|
|
224
|
+
return
|
|
225
|
+
all_pred= np.concatenate(self.all_pred,axis=0)
|
|
226
|
+
return all_pred.size
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def evaluate(self):
|
|
230
|
+
if len(self.all_pred)==0:
|
|
231
|
+
return ""
|
|
232
|
+
|
|
233
|
+
cm = np.zeros([self.num_classes,self.num_classes],dtype=np.int32)
|
|
234
|
+
all_pred= np.concatenate(self.all_pred,axis=0)
|
|
235
|
+
all_target = np.concatenate(self.all_target,axis=0)
|
|
236
|
+
|
|
237
|
+
for p,t in zip(all_pred,all_target):
|
|
238
|
+
cm[t,p] = cm[t,p]+1
|
|
239
|
+
|
|
240
|
+
self.cm = cm #cm[i,j] 表示gt类别i被分为类别j的数量
|
|
241
|
+
|
|
242
|
+
return cm
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def show(self,name=""):
|
|
247
|
+
sys.stdout.flush()
|
|
248
|
+
self.evaluate()
|
|
249
|
+
print(self.to_string())
|
|
250
|
+
return self.accuracy
|
|
251
|
+
|
|
252
|
+
def value(self,blod=True):
|
|
253
|
+
res = "\n"
|
|
254
|
+
for i in range(self.num_classes):
|
|
255
|
+
line = ""
|
|
256
|
+
for j in range(self.num_classes):
|
|
257
|
+
if blod and i==j:
|
|
258
|
+
#line += f"\033[1m{self.cm[i,j]:<5}\033[0m, "
|
|
259
|
+
line += f"{self.cm[i,j]:<4}*, "
|
|
260
|
+
else:
|
|
261
|
+
line += f"{self.cm[i,j]:<5}, "
|
|
262
|
+
res += line+"\n"
|
|
263
|
+
return res
|
|
264
|
+
|
|
265
|
+
def to_string(self,blod=True):
|
|
266
|
+
res = "\n"
|
|
267
|
+
for i in range(self.num_classes):
|
|
268
|
+
line = ""
|
|
269
|
+
for j in range(self.num_classes):
|
|
270
|
+
if blod and i==j:
|
|
271
|
+
#line += f"\033[1m{self.cm[i,j]:<5}\033[0m, "
|
|
272
|
+
line += f"{self.cm[i,j]:<4}*, "
|
|
273
|
+
else:
|
|
274
|
+
line += f"{self.cm[i,j]:<5}, "
|
|
275
|
+
res += line+"\n"
|
|
276
|
+
return res
|
|
277
|
+
|
|
278
|
+
def __repr__(self):
|
|
279
|
+
return self.to_string()
|
|
280
|
+
|
|
281
|
+
def value(self):
|
|
282
|
+
return self.cm
|
|
283
|
+
|
|
284
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
285
|
+
class ClassesWiseModelPerformace(BaseClassifierMetrics):
|
|
286
|
+
def __init__(self,num_classes,classes_begin_value=0,model_type=PrecisionAndRecall,model_args={},label_trans=None,
|
|
287
|
+
name=None,
|
|
288
|
+
use_gt_and_pred_select=False,
|
|
289
|
+
classes=None,
|
|
290
|
+
**kwargs):
|
|
291
|
+
|
|
292
|
+
super().__init__(**kwargs)
|
|
293
|
+
|
|
294
|
+
self.num_classes = num_classes
|
|
295
|
+
self.clases_begin_value = classes_begin_value
|
|
296
|
+
model_args['classes_begin_value'] = classes_begin_value
|
|
297
|
+
|
|
298
|
+
if isinstance(model_type,(str,bytes)):
|
|
299
|
+
model_type = CLASSIFIER_METRICS_REGISTRY.get(model_type)
|
|
300
|
+
|
|
301
|
+
if classes is None:
|
|
302
|
+
classes = [f"C{i+1}" for i in range(num_classes)]
|
|
303
|
+
|
|
304
|
+
self.classes = classes
|
|
305
|
+
|
|
306
|
+
self.data = []
|
|
307
|
+
for i in range(self.num_classes):
|
|
308
|
+
self.data.append(model_type(num_classes=num_classes,**model_args))
|
|
309
|
+
self.label_trans = label_trans
|
|
310
|
+
self.have_data = np.zeros([num_classes],dtype=np.bool)
|
|
311
|
+
self.accuracy = Accuracy(topk=1)
|
|
312
|
+
self.name = name
|
|
313
|
+
self.use_gt_and_pred_select = use_gt_and_pred_select
|
|
314
|
+
self.total_eval_samples = 0
|
|
315
|
+
|
|
316
|
+
def select_labels(self,labels,target,classes):
|
|
317
|
+
if self.use_gt_and_pred_select:
|
|
318
|
+
return self.select_labels_by_gt_and_pred(labels,target,classes)
|
|
319
|
+
else:
|
|
320
|
+
return self.select_labels_by_gt(labels,target,classes)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
@staticmethod
|
|
324
|
+
def select_labels_by_gt_and_pred(labels,target,classes):
|
|
325
|
+
if len(labels) == 0:
|
|
326
|
+
return np.array([],dtype=np.int32),np.array([],dtype=np.int32)
|
|
327
|
+
if not isinstance(labels,np.ndarray):
|
|
328
|
+
labels = np.array(labels)
|
|
329
|
+
mask0 = np.equal(labels,classes)
|
|
330
|
+
mask1 = np.equal(target,classes)
|
|
331
|
+
mask = np.logical_or(mask0,mask1)
|
|
332
|
+
nlabels = (labels[mask]==classes).astype(np.int32)
|
|
333
|
+
ntarget = (target[mask]==classes).astype(np.int32)
|
|
334
|
+
return nlabels,ntarget
|
|
335
|
+
|
|
336
|
+
@staticmethod
|
|
337
|
+
def select_labels_by_gt(labels,target,classes):
|
|
338
|
+
if len(labels) == 0:
|
|
339
|
+
return np.array([],dtype=np.int32),np.array([],dtype=np.int32)
|
|
340
|
+
if not isinstance(labels,np.ndarray):
|
|
341
|
+
labels = np.array(labels)
|
|
342
|
+
mask = np.equal(target,classes)
|
|
343
|
+
nlabels = (labels[mask]==classes).astype(np.int32)
|
|
344
|
+
ntarget = (target[mask]==classes).astype(np.int32)
|
|
345
|
+
return nlabels,ntarget
|
|
346
|
+
|
|
347
|
+
def __call__(self,output,target):
|
|
348
|
+
'''
|
|
349
|
+
output: [N0,...,Nn,num_classes] or [N0,...,Nn]
|
|
350
|
+
target: [N0,...,Nn]
|
|
351
|
+
'''
|
|
352
|
+
self.accuracy(output,target)
|
|
353
|
+
if len(output.shape) > len(target.shape):
|
|
354
|
+
idx = np.argsort(output,axis=-1)
|
|
355
|
+
labels = np.reshape(idx[...,-1:],[-1])
|
|
356
|
+
else:
|
|
357
|
+
labels = output
|
|
358
|
+
target = np.reshape(target,[-1])
|
|
359
|
+
self.total_eval_samples += target.size
|
|
360
|
+
|
|
361
|
+
if self.label_trans is not None:
|
|
362
|
+
labels = self.label_trans(labels)
|
|
363
|
+
target = self.label_trans(target)
|
|
364
|
+
|
|
365
|
+
for i in range(self.num_classes):
|
|
366
|
+
classes = i+self.clases_begin_value
|
|
367
|
+
clabels,ctarget = self.select_labels(labels,target,classes)
|
|
368
|
+
self.have_data[i] = True
|
|
369
|
+
self.data[i](clabels,ctarget)
|
|
370
|
+
|
|
371
|
+
self._current_info = ""
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def show(self,name=""):
|
|
375
|
+
sys.stdout.flush()
|
|
376
|
+
for i in range(self.num_classes):
|
|
377
|
+
if not self.have_data[i]:
|
|
378
|
+
continue
|
|
379
|
+
classes = i+self.clases_begin_value
|
|
380
|
+
print(f"Classes:{classes}")
|
|
381
|
+
try:
|
|
382
|
+
self.data[i].show(name=name)
|
|
383
|
+
except:
|
|
384
|
+
print("N.A.")
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
def evaluate(self):
|
|
388
|
+
print(f"Total eval samples {self.total_eval_samples}")
|
|
389
|
+
for d in self.data:
|
|
390
|
+
d.evaluate()
|
|
391
|
+
self.accuracy.evaluate()
|
|
392
|
+
|
|
393
|
+
def to_string(self):
|
|
394
|
+
res = ";".join([str(self.classes[idx])+": "+d.to_string() for idx,d in enumerate(self.data)])
|
|
395
|
+
res += ";; ALL "+self.accuracy.to_string()
|
|
396
|
+
if self.name is not None:
|
|
397
|
+
res = f"{self.name}: "+res
|
|
398
|
+
return res
|
|
399
|
+
|
|
400
|
+
def mark_down(self,name=""):
|
|
401
|
+
str0 = "|配置|"
|
|
402
|
+
str1 = "|---|"
|
|
403
|
+
str2 = f"|CFG:{name}|"
|
|
404
|
+
str0 += f"ALL|"
|
|
405
|
+
str1 += "---|"
|
|
406
|
+
str2 += f"{str(self.accuracy.to_string())}|"
|
|
407
|
+
|
|
408
|
+
for i in range(len(self.data)):
|
|
409
|
+
str0 += f"{self.classes[i]}|"
|
|
410
|
+
str1 += "---|"
|
|
411
|
+
str2 += f"{str(self.data[i].to_string())}|"
|
|
412
|
+
print(str0)
|
|
413
|
+
print(str1)
|
|
414
|
+
print(str2)
|
|
415
|
+
|
|
416
|
+
def __repr__(self):
|
|
417
|
+
return self.to_string()
|
|
418
|
+
|
|
419
|
+
@CLASSIFIER_METRICS_REGISTRY.register()
|
|
420
|
+
class ComposeMetrics(BaseClassifierMetrics):
|
|
421
|
+
def __init__(self,*args,**kwargs):
|
|
422
|
+
self.metrics = list(args)+list(kwargs.values())
|
|
423
|
+
|
|
424
|
+
def __call__(self, *args,**kwargs):
|
|
425
|
+
[m(*args,**kwargs) for m in self.metrics]
|
|
426
|
+
|
|
427
|
+
def evaluate(self):
|
|
428
|
+
[m.evaluate() for m in self.metrics]
|
|
429
|
+
|
|
430
|
+
def show(self,name=""):
|
|
431
|
+
[m.show(name=name) for m in self.metrics]
|
|
432
|
+
|
|
433
|
+
def to_string(self):
|
|
434
|
+
return ";".join([m.to_string() for m in self.metrics])
|
|
435
|
+
|
|
436
|
+
def value(self):
|
|
437
|
+
return self.metrics[0].value()
|
|
438
|
+
|
|
439
|
+
def mark_down(self,name=""):
|
|
440
|
+
[m.mark_down(name=name) for m in self.metrics]
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
|
|
3
|
+
class BaseMetrics(metaclass=ABCMeta):
|
|
4
|
+
def __init__(self,cfg_name="",**kwargs) -> None:
|
|
5
|
+
self._current_info = ""
|
|
6
|
+
self.cfg_name = cfg_name
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
def current_info(self):
|
|
10
|
+
return self._current_info
|
|
11
|
+
|
|
12
|
+
def __repr__(self):
|
|
13
|
+
return self.to_string()
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def show(self):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def value(self):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def detail_valus(self):
|
|
23
|
+
return self.value()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def safe_persent(v0,v1):
|
|
27
|
+
if v1==0:
|
|
28
|
+
return 100.
|
|
29
|
+
else:
|
|
30
|
+
return v0*100./v1
|
|
31
|
+
|
|
32
|
+
class ComposeMetrics(BaseMetrics):
|
|
33
|
+
def __init__(self,*args,**kwargs):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.metrics = list(args)+list(kwargs.values())
|
|
36
|
+
|
|
37
|
+
def __call__(self, *args,**kwargs):
|
|
38
|
+
[m(*args,**kwargs) for m in self.metrics]
|
|
39
|
+
self._current_info = "; ".join([m.current_info() for m in self.metrics])
|
|
40
|
+
|
|
41
|
+
def evaluate(self):
|
|
42
|
+
[m.evaluate() for m in self.metrics]
|
|
43
|
+
|
|
44
|
+
def show(self,name=""):
|
|
45
|
+
[m.show(name=name) for m in self.metrics]
|
|
46
|
+
|
|
47
|
+
def to_string(self):
|
|
48
|
+
return ";".join([m.to_string() for m in self.metrics])
|
|
49
|
+
|
|
50
|
+
def value(self):
|
|
51
|
+
return self.metrics[0].value()
|
|
52
|
+
|
|
53
|
+
class BaseClassifierMetrics(metaclass=ABCMeta):
|
|
54
|
+
def __init__(self,*args,**kwargs):
|
|
55
|
+
self._current_info = ""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def value(self):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
def to_string(self):
|
|
62
|
+
return str(self.value())
|
|
63
|
+
|
|
64
|
+
def __repr__(self):
|
|
65
|
+
return self.to_string()
|
|
66
|
+
|
|
67
|
+
def mark_down(self,name=""):
|
|
68
|
+
print(name,self.to_string())
|
|
69
|
+
|
|
70
|
+
def current_info(self):
|
|
71
|
+
return self._current_info
|