EasyMetrics 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- easyMetrics/__init__.py +4 -0
- easyMetrics/core/__init__.py +1 -0
- easyMetrics/core/base.py +37 -0
- easyMetrics/tasks/__init__.py +1 -0
- easyMetrics/tasks/detection/__init__.py +2 -0
- easyMetrics/tasks/detection/format_converter.py +277 -0
- easyMetrics/tasks/detection/interface.py +67 -0
- easyMetrics/tasks/detection/map.py +413 -0
- easyMetrics/tasks/detection/matcher.py +73 -0
- easyMetrics/tasks/detection/utils.py +77 -0
- easymetrics-0.1.0.dist-info/METADATA +136 -0
- easymetrics-0.1.0.dist-info/RECORD +15 -0
- easymetrics-0.1.0.dist-info/WHEEL +5 -0
- easymetrics-0.1.0.dist-info/licenses/LICENSE +21 -0
- easymetrics-0.1.0.dist-info/top_level.txt +1 -0
easyMetrics/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import Metric
|
easyMetrics/core/base.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
class Metric(ABC):
|
|
5
|
+
"""
|
|
6
|
+
所有指标的抽象基类。
|
|
7
|
+
"""
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.reset()
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def reset(self):
|
|
13
|
+
"""
|
|
14
|
+
重置指标的内部状态。
|
|
15
|
+
"""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def update(self, preds: Any, target: Any):
|
|
20
|
+
"""
|
|
21
|
+
使用新的预测值和真实值更新指标状态。
|
|
22
|
+
|
|
23
|
+
参数:
|
|
24
|
+
preds: 模型的预测结果。
|
|
25
|
+
target: 真实标签 (Ground Truth)。
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def compute(self) -> Any:
|
|
31
|
+
"""
|
|
32
|
+
计算最终的指标值。
|
|
33
|
+
|
|
34
|
+
返回:
|
|
35
|
+
计算出的指标值。
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .detection import MeanAveragePrecision
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
class DetectionFormatConverter:
|
|
5
|
+
"""
|
|
6
|
+
检测格式转换器,支持多种常见的目标检测格式转换为内部统一格式。
|
|
7
|
+
|
|
8
|
+
支持的格式:
|
|
9
|
+
- coco: COCO 格式,{"boxes": [[x1, y1, x2, y2]], "scores": [0.9], "labels": [0]}
|
|
10
|
+
- voc: VOC 格式,支持列表形式 [[x1, y1, x2, y2, class_id]]
|
|
11
|
+
- yolo: YOLO 格式,[[class_id, x, y, w, h, confidence]],其中 x,y,w,h 是归一化值
|
|
12
|
+
- custom: 自定义格式,需要提供转换函数
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def convert(
|
|
17
|
+
preds: List[Any],
|
|
18
|
+
targets: List[Any],
|
|
19
|
+
format: str = "coco",
|
|
20
|
+
pred_format: Optional[str] = None,
|
|
21
|
+
target_format: Optional[str] = None,
|
|
22
|
+
**kwargs
|
|
23
|
+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
24
|
+
"""
|
|
25
|
+
将输入数据转换为内部统一格式。
|
|
26
|
+
|
|
27
|
+
参数:
|
|
28
|
+
preds: 预测结果,可以是不同格式的数据。
|
|
29
|
+
targets: 真实标签,可以是不同格式的数据。
|
|
30
|
+
format: 输入数据的格式,当 pred_format 和 target_format 未指定时使用。
|
|
31
|
+
支持 "coco", "voc", "yolo", "custom"。
|
|
32
|
+
pred_format: 预测结果的格式,优先级高于 format。
|
|
33
|
+
target_format: 真实标签的格式,优先级高于 format。
|
|
34
|
+
**kwargs: 额外的转换参数。
|
|
35
|
+
- image_size: YOLO 格式需要的图像尺寸 (width, height)
|
|
36
|
+
- custom_converter: 自定义格式的转换函数
|
|
37
|
+
|
|
38
|
+
返回:
|
|
39
|
+
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: 转换后的预测结果和真实标签。
|
|
40
|
+
"""
|
|
41
|
+
# 确定最终使用的格式
|
|
42
|
+
final_pred_format = pred_format or format
|
|
43
|
+
final_target_format = target_format or format
|
|
44
|
+
|
|
45
|
+
# 转换预测结果
|
|
46
|
+
if final_pred_format == "coco":
|
|
47
|
+
converted_preds = DetectionFormatConverter._convert_coco_preds(preds)
|
|
48
|
+
elif final_pred_format == "voc":
|
|
49
|
+
converted_preds = DetectionFormatConverter._convert_voc_preds(preds)
|
|
50
|
+
elif final_pred_format == "yolo":
|
|
51
|
+
image_size = kwargs.get("image_size", (640, 640))
|
|
52
|
+
converted_preds = DetectionFormatConverter._convert_yolo_preds(preds, image_size)
|
|
53
|
+
elif final_pred_format == "custom":
|
|
54
|
+
custom_converter = kwargs.get("custom_converter")
|
|
55
|
+
if not custom_converter:
|
|
56
|
+
raise ValueError("custom format requires custom_converter function")
|
|
57
|
+
# 对于自定义格式,先转换整体,然后提取预测部分
|
|
58
|
+
temp_preds, _ = custom_converter(preds, [], **kwargs)
|
|
59
|
+
converted_preds = temp_preds
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError(f"Unsupported pred_format: {final_pred_format}")
|
|
62
|
+
|
|
63
|
+
# 转换真实标签
|
|
64
|
+
if final_target_format == "coco":
|
|
65
|
+
converted_targets = DetectionFormatConverter._convert_coco_targets(targets)
|
|
66
|
+
elif final_target_format == "voc":
|
|
67
|
+
converted_targets = DetectionFormatConverter._convert_voc_targets(targets)
|
|
68
|
+
elif final_target_format == "yolo":
|
|
69
|
+
image_size = kwargs.get("image_size", (640, 640))
|
|
70
|
+
converted_targets = DetectionFormatConverter._convert_yolo_targets(targets, image_size)
|
|
71
|
+
elif final_target_format == "custom":
|
|
72
|
+
custom_converter = kwargs.get("custom_converter")
|
|
73
|
+
if not custom_converter:
|
|
74
|
+
raise ValueError("custom format requires custom_converter function")
|
|
75
|
+
# 对于自定义格式,先转换整体,然后提取标签部分
|
|
76
|
+
_, temp_targets = custom_converter([], targets, **kwargs)
|
|
77
|
+
converted_targets = temp_targets
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(f"Unsupported target_format: {final_target_format}")
|
|
80
|
+
|
|
81
|
+
return converted_preds, converted_targets
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _convert_coco_preds(coco_preds: List[Any]) -> List[Dict[str, Any]]:
|
|
85
|
+
"""
|
|
86
|
+
转换 COCO 格式的预测结果。
|
|
87
|
+
COCO 格式示例: [{"image_id": 0, "category_id": 0, "bbox": [0, 0, 100, 100], "score": 0.9}]
|
|
88
|
+
"""
|
|
89
|
+
# 按 image_id 分组
|
|
90
|
+
grouped_preds = {}
|
|
91
|
+
for pred in coco_preds:
|
|
92
|
+
image_id = pred.get('image_id', 0)
|
|
93
|
+
if image_id not in grouped_preds:
|
|
94
|
+
grouped_preds[image_id] = {
|
|
95
|
+
"boxes": [],
|
|
96
|
+
"scores": [],
|
|
97
|
+
"labels": []
|
|
98
|
+
}
|
|
99
|
+
# COCO bbox 格式: [x, y, width, height],转换为 [x1, y1, x2, y2]
|
|
100
|
+
x, y, width, height = pred['bbox']
|
|
101
|
+
boxes = [x, y, x + width, y + height]
|
|
102
|
+
grouped_preds[image_id]["boxes"].append(boxes)
|
|
103
|
+
grouped_preds[image_id]["scores"].append(pred['score'])
|
|
104
|
+
grouped_preds[image_id]["labels"].append(pred['category_id'])
|
|
105
|
+
|
|
106
|
+
# 转换为列表格式
|
|
107
|
+
converted = list(grouped_preds.values())
|
|
108
|
+
return converted
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _convert_coco_targets(coco_targets: List[Any]) -> List[Dict[str, Any]]:
|
|
112
|
+
"""
|
|
113
|
+
转换 COCO 格式的真实标签。
|
|
114
|
+
COCO 格式示例: [{"image_id": 0, "category_id": 0, "bbox": [0, 0, 100, 100], "area": 10000, "iscrowd": 0}]
|
|
115
|
+
"""
|
|
116
|
+
# 按 image_id 分组
|
|
117
|
+
grouped_targets = {}
|
|
118
|
+
for target in coco_targets:
|
|
119
|
+
image_id = target.get('image_id', 0)
|
|
120
|
+
if image_id not in grouped_targets:
|
|
121
|
+
grouped_targets[image_id] = {
|
|
122
|
+
"boxes": [],
|
|
123
|
+
"labels": []
|
|
124
|
+
}
|
|
125
|
+
# COCO bbox 格式: [x, y, width, height],转换为 [x1, y1, x2, y2]
|
|
126
|
+
x, y, width, height = target['bbox']
|
|
127
|
+
boxes = [x, y, x + width, y + height]
|
|
128
|
+
grouped_targets[image_id]["boxes"].append(boxes)
|
|
129
|
+
grouped_targets[image_id]["labels"].append(target['category_id'])
|
|
130
|
+
|
|
131
|
+
# 转换为列表格式
|
|
132
|
+
converted = list(grouped_targets.values())
|
|
133
|
+
return converted
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def _convert_voc_preds(voc_preds: List[Any]) -> List[Dict[str, Any]]:
|
|
137
|
+
"""
|
|
138
|
+
转换 VOC 格式的预测结果。
|
|
139
|
+
VOC 格式示例: [[x1, y1, x2, y2, class_id, confidence]]
|
|
140
|
+
"""
|
|
141
|
+
converted = []
|
|
142
|
+
for pred in voc_preds:
|
|
143
|
+
if isinstance(pred, list):
|
|
144
|
+
# 单张图片的预测
|
|
145
|
+
boxes = []
|
|
146
|
+
scores = []
|
|
147
|
+
labels = []
|
|
148
|
+
for item in pred:
|
|
149
|
+
if len(item) >= 6:
|
|
150
|
+
# [x1, y1, x2, y2, class_id, confidence]
|
|
151
|
+
boxes.append(item[:4])
|
|
152
|
+
scores.append(item[5])
|
|
153
|
+
labels.append(int(item[4]))
|
|
154
|
+
elif len(item) == 5:
|
|
155
|
+
# [x1, y1, x2, y2, class_id]
|
|
156
|
+
boxes.append(item[:4])
|
|
157
|
+
scores.append(1.0) # 默认置信度
|
|
158
|
+
labels.append(int(item[4]))
|
|
159
|
+
converted.append({
|
|
160
|
+
"boxes": boxes,
|
|
161
|
+
"scores": scores,
|
|
162
|
+
"labels": labels
|
|
163
|
+
})
|
|
164
|
+
elif isinstance(pred, dict):
|
|
165
|
+
# 已经是 COCO 格式
|
|
166
|
+
converted.append(pred)
|
|
167
|
+
return converted
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _convert_voc_targets(voc_targets: List[Any]) -> List[Dict[str, Any]]:
|
|
171
|
+
"""
|
|
172
|
+
转换 VOC 格式的真实标签。
|
|
173
|
+
VOC 格式示例: [[x1, y1, x2, y2, class_id]]
|
|
174
|
+
"""
|
|
175
|
+
converted = []
|
|
176
|
+
for target in voc_targets:
|
|
177
|
+
if isinstance(target, list):
|
|
178
|
+
# 单张图片的标签
|
|
179
|
+
boxes = []
|
|
180
|
+
labels = []
|
|
181
|
+
for item in target:
|
|
182
|
+
if len(item) >= 5:
|
|
183
|
+
# [x1, y1, x2, y2, class_id]
|
|
184
|
+
boxes.append(item[:4])
|
|
185
|
+
labels.append(int(item[4]))
|
|
186
|
+
converted.append({
|
|
187
|
+
"boxes": boxes,
|
|
188
|
+
"labels": labels
|
|
189
|
+
})
|
|
190
|
+
elif isinstance(target, dict):
|
|
191
|
+
# 已经是 COCO 格式
|
|
192
|
+
converted.append(target)
|
|
193
|
+
return converted
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def _convert_yolo_preds(yolo_preds: List[Any], image_size: Tuple[int, int]) -> List[Dict[str, Any]]:
|
|
197
|
+
"""
|
|
198
|
+
转换 YOLO 格式的预测结果。
|
|
199
|
+
YOLO 格式示例: [[class_id, x, y, w, h, confidence]]
|
|
200
|
+
其中 x, y, w, h 是归一化值 (0-1)
|
|
201
|
+
"""
|
|
202
|
+
converted = []
|
|
203
|
+
width, height = image_size
|
|
204
|
+
|
|
205
|
+
for pred in yolo_preds:
|
|
206
|
+
if isinstance(pred, list):
|
|
207
|
+
# 单张图片的预测
|
|
208
|
+
boxes = []
|
|
209
|
+
scores = []
|
|
210
|
+
labels = []
|
|
211
|
+
for item in pred:
|
|
212
|
+
if len(item) >= 6:
|
|
213
|
+
# [class_id, x, y, w, h, confidence]
|
|
214
|
+
class_id, x, y, w, h, conf = item
|
|
215
|
+
# 转换为 [x1, y1, x2, y2]
|
|
216
|
+
x1 = (x - w/2) * width
|
|
217
|
+
y1 = (y - h/2) * height
|
|
218
|
+
x2 = (x + w/2) * width
|
|
219
|
+
y2 = (y + h/2) * height
|
|
220
|
+
boxes.append([x1, y1, x2, y2])
|
|
221
|
+
scores.append(conf)
|
|
222
|
+
labels.append(int(class_id))
|
|
223
|
+
elif len(item) == 5:
|
|
224
|
+
# [class_id, x, y, w, h]
|
|
225
|
+
class_id, x, y, w, h = item
|
|
226
|
+
# 转换为 [x1, y1, x2, y2]
|
|
227
|
+
x1 = (x - w/2) * width
|
|
228
|
+
y1 = (y - h/2) * height
|
|
229
|
+
x2 = (x + w/2) * width
|
|
230
|
+
y2 = (y + h/2) * height
|
|
231
|
+
boxes.append([x1, y1, x2, y2])
|
|
232
|
+
scores.append(1.0) # 默认置信度
|
|
233
|
+
labels.append(int(class_id))
|
|
234
|
+
converted.append({
|
|
235
|
+
"boxes": boxes,
|
|
236
|
+
"scores": scores,
|
|
237
|
+
"labels": labels
|
|
238
|
+
})
|
|
239
|
+
elif isinstance(pred, dict):
|
|
240
|
+
# 已经是 COCO 格式
|
|
241
|
+
converted.append(pred)
|
|
242
|
+
return converted
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def _convert_yolo_targets(yolo_targets: List[Any], image_size: Tuple[int, int]) -> List[Dict[str, Any]]:
|
|
246
|
+
"""
|
|
247
|
+
转换 YOLO 格式的真实标签。
|
|
248
|
+
YOLO 格式示例: [[class_id, x, y, w, h]]
|
|
249
|
+
其中 x, y, w, h 是归一化值 (0-1)
|
|
250
|
+
"""
|
|
251
|
+
converted = []
|
|
252
|
+
width, height = image_size
|
|
253
|
+
|
|
254
|
+
for target in yolo_targets:
|
|
255
|
+
if isinstance(target, list):
|
|
256
|
+
# 单张图片的标签
|
|
257
|
+
boxes = []
|
|
258
|
+
labels = []
|
|
259
|
+
for item in target:
|
|
260
|
+
if len(item) >= 5:
|
|
261
|
+
# [class_id, x, y, w, h]
|
|
262
|
+
class_id, x, y, w, h = item[:5]
|
|
263
|
+
# 转换为 [x1, y1, x2, y2]
|
|
264
|
+
x1 = (x - w/2) * width
|
|
265
|
+
y1 = (y - h/2) * height
|
|
266
|
+
x2 = (x + w/2) * width
|
|
267
|
+
y2 = (y + h/2) * height
|
|
268
|
+
boxes.append([x1, y1, x2, y2])
|
|
269
|
+
labels.append(int(class_id))
|
|
270
|
+
converted.append({
|
|
271
|
+
"boxes": boxes,
|
|
272
|
+
"labels": labels
|
|
273
|
+
})
|
|
274
|
+
elif isinstance(target, dict):
|
|
275
|
+
# 已经是 COCO 格式
|
|
276
|
+
converted.append(target)
|
|
277
|
+
return converted
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
from .map import MeanAveragePrecision
|
|
3
|
+
from .format_converter import DetectionFormatConverter
|
|
4
|
+
|
|
5
|
+
def evaluate_detection(
|
|
6
|
+
preds: List[Any],
|
|
7
|
+
targets: List[Any],
|
|
8
|
+
metrics: Optional[List[str]] = None,
|
|
9
|
+
n_jobs: int = 1,
|
|
10
|
+
score_criteria: Optional[List[Tuple[float, float]]] = None,
|
|
11
|
+
format: str = "coco",
|
|
12
|
+
pred_format: Optional[str] = None,
|
|
13
|
+
target_format: Optional[str] = None,
|
|
14
|
+
progress: bool = True,
|
|
15
|
+
**kwargs
|
|
16
|
+
) -> Dict[str, float]:
|
|
17
|
+
"""
|
|
18
|
+
使用 COCO 风格的指标评估目标检测结果。
|
|
19
|
+
|
|
20
|
+
参数:
|
|
21
|
+
preds (List[Any]): 每张图片的预测结果列表。
|
|
22
|
+
targets (List[Any]): 每张图片的真实标签列表。
|
|
23
|
+
metrics (Optional[List[str]]): 需要返回的特定指标列表。
|
|
24
|
+
n_jobs (int): 并行计算线程数。默认为 1。-1 表示使用所有核心。
|
|
25
|
+
score_criteria (Optional[List[Tuple[float, float]]]):
|
|
26
|
+
计算指定 IoU 和 精度下的最佳置信度阈值。
|
|
27
|
+
格式: [(iou_thresh, min_precision), ...]
|
|
28
|
+
例如: [(0.5, 0.9)]
|
|
29
|
+
format (str): 输入数据的格式,当 pred_format 和 target_format 未指定时使用。
|
|
30
|
+
支持 "coco", "voc", "yolo", "custom"。
|
|
31
|
+
pred_format (Optional[str]): 预测结果的格式,优先级高于 format。
|
|
32
|
+
target_format (Optional[str]): 真实标签的格式,优先级高于 format。
|
|
33
|
+
progress (bool): 是否显示进度条。默认为 True。
|
|
34
|
+
**kwargs: 额外的转换参数。
|
|
35
|
+
- image_size: YOLO 格式需要的图像尺寸 (width, height)
|
|
36
|
+
- custom_converter: 自定义格式的转换函数
|
|
37
|
+
|
|
38
|
+
返回:
|
|
39
|
+
Dict[str, float]: 包含计算指标的字典。
|
|
40
|
+
"""
|
|
41
|
+
# 转换输入格式
|
|
42
|
+
converted_preds, converted_targets = DetectionFormatConverter.convert(
|
|
43
|
+
preds, targets,
|
|
44
|
+
format=format,
|
|
45
|
+
pred_format=pred_format,
|
|
46
|
+
target_format=target_format,
|
|
47
|
+
**kwargs
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# 初始化指标计算器
|
|
51
|
+
metric = MeanAveragePrecision()
|
|
52
|
+
|
|
53
|
+
# 更新数据
|
|
54
|
+
metric.update(converted_preds, converted_targets)
|
|
55
|
+
|
|
56
|
+
# 计算所有指标
|
|
57
|
+
all_results = metric.compute(n_jobs=n_jobs, score_criteria=score_criteria, progress=progress)
|
|
58
|
+
|
|
59
|
+
# 如果请求了特定指标,进行筛选
|
|
60
|
+
if metrics:
|
|
61
|
+
filtered_results = {}
|
|
62
|
+
for k in metrics:
|
|
63
|
+
if k in all_results:
|
|
64
|
+
filtered_results[k] = all_results[k]
|
|
65
|
+
return filtered_results
|
|
66
|
+
|
|
67
|
+
return all_results
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
import numpy as np
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
4
|
+
from easyMetrics.core.base import Metric
|
|
5
|
+
from .utils import calculate_iou, compute_ap_coco
|
|
6
|
+
from .matcher import BaseMatcher, GreedyIoUMatcher
|
|
7
|
+
|
|
8
|
+
class MeanAveragePrecision(Metric):
|
|
9
|
+
"""
|
|
10
|
+
计算目标检测的平均精度 (mAP) 和平均召回率 (AR)。
|
|
11
|
+
与 COCO 评估指标对齐。
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self,
|
|
14
|
+
iou_thresholds: Optional[List[float]] = None,
|
|
15
|
+
rec_thresholds: Optional[List[float]] = None,
|
|
16
|
+
max_detection_thresholds: Optional[List[int]] = None,
|
|
17
|
+
matcher: Optional[BaseMatcher] = None):
|
|
18
|
+
super().__init__()
|
|
19
|
+
# COCO 默认 IoU 阈值: 0.50 到 0.95,步长 0.05
|
|
20
|
+
self.iou_thresholds = iou_thresholds or np.linspace(0.5, 0.95, 10)
|
|
21
|
+
# COCO 默认召回率阈值: 0 到 1,共 101 个点
|
|
22
|
+
self.rec_thresholds = rec_thresholds or np.linspace(0.0, 1.00, 101)
|
|
23
|
+
# COCO 默认 AR 计算的最大检测数: 1, 10, 100
|
|
24
|
+
self.max_detection_thresholds = max_detection_thresholds or [1, 10, 100]
|
|
25
|
+
# 匹配策略
|
|
26
|
+
self.matcher = matcher or GreedyIoUMatcher()
|
|
27
|
+
|
|
28
|
+
self.preds = []
|
|
29
|
+
self.targets = []
|
|
30
|
+
|
|
31
|
+
def reset(self):
|
|
32
|
+
self.preds = []
|
|
33
|
+
self.targets = []
|
|
34
|
+
|
|
35
|
+
def _ensure_numpy_array(self, data, dtype=None):
|
|
36
|
+
"""
|
|
37
|
+
确保输入数据是 numpy 数组,如果不是则转换为 numpy 数组。
|
|
38
|
+
|
|
39
|
+
参数:
|
|
40
|
+
data: 输入数据,可以是列表、numpy 数组或其他可转换类型。
|
|
41
|
+
dtype: 目标数据类型,默认为 None(自动推断)。
|
|
42
|
+
|
|
43
|
+
返回:
|
|
44
|
+
numpy.ndarray: 转换后的 numpy 数组。
|
|
45
|
+
"""
|
|
46
|
+
import numpy as np
|
|
47
|
+
if isinstance(data, np.ndarray):
|
|
48
|
+
if dtype is not None and data.dtype != dtype:
|
|
49
|
+
return data.astype(dtype)
|
|
50
|
+
return data
|
|
51
|
+
elif data is None:
|
|
52
|
+
return np.array([], dtype=dtype)
|
|
53
|
+
else:
|
|
54
|
+
return np.array(data, dtype=dtype)
|
|
55
|
+
|
|
56
|
+
def update(self, preds: List[Dict[str, Any]], target: List[Dict[str, Any]]):
|
|
57
|
+
"""
|
|
58
|
+
参数:
|
|
59
|
+
preds: 字典列表。每个字典包含 'boxes', 'scores', 'labels'。
|
|
60
|
+
target: 字典列表。每个字典包含 'boxes', 'labels'。
|
|
61
|
+
"""
|
|
62
|
+
# 自动转换 preds 中的数据类型
|
|
63
|
+
converted_preds = []
|
|
64
|
+
for pred in preds:
|
|
65
|
+
converted_pred = {
|
|
66
|
+
'boxes': self._ensure_numpy_array(pred['boxes'], dtype=float),
|
|
67
|
+
'scores': self._ensure_numpy_array(pred['scores'], dtype=float),
|
|
68
|
+
'labels': self._ensure_numpy_array(pred['labels'], dtype=int)
|
|
69
|
+
}
|
|
70
|
+
converted_preds.append(converted_pred)
|
|
71
|
+
|
|
72
|
+
# 自动转换 targets 中的数据类型
|
|
73
|
+
converted_targets = []
|
|
74
|
+
for tgt in target:
|
|
75
|
+
converted_target = {
|
|
76
|
+
'boxes': self._ensure_numpy_array(tgt['boxes'], dtype=float),
|
|
77
|
+
'labels': self._ensure_numpy_array(tgt['labels'], dtype=int)
|
|
78
|
+
}
|
|
79
|
+
converted_targets.append(converted_target)
|
|
80
|
+
|
|
81
|
+
self.preds.extend(converted_preds)
|
|
82
|
+
self.targets.extend(converted_targets)
|
|
83
|
+
|
|
84
|
+
def compute(self, n_jobs: int = 1, score_criteria: Optional[List[Tuple[float, float]]] = None, progress: bool = True) -> Dict[str, float]:
|
|
85
|
+
"""
|
|
86
|
+
计算 mAP 和 AR 指标。
|
|
87
|
+
|
|
88
|
+
参数:
|
|
89
|
+
n_jobs: 并行计算的线程数。默认为 1 (串行)。设置为 -1 表示使用所有可用 CPU。
|
|
90
|
+
score_criteria: 可选。计算指定 IoU 和 精度下的最佳置信度阈值。
|
|
91
|
+
格式为列表: [(iou_thresh, min_precision), ...]
|
|
92
|
+
例如 [(0.5, 0.9)] 表示寻找 IoU=0.5 时精度至少为 0.9 的最低置信度。
|
|
93
|
+
progress: 是否显示进度条。默认为 True。
|
|
94
|
+
|
|
95
|
+
返回包含以下键的字典:
|
|
96
|
+
- mAP, mAP_50, mAP_75...
|
|
97
|
+
- BestScore_IoU{iou}_P{prec}_{cls_id}: 满足条件的最佳置信度
|
|
98
|
+
"""
|
|
99
|
+
# 1. 识别所有唯一类别
|
|
100
|
+
unique_classes = set()
|
|
101
|
+
for t in self.targets:
|
|
102
|
+
unique_classes.update(t['labels'].tolist())
|
|
103
|
+
for p in self.preds:
|
|
104
|
+
unique_classes.update(p['labels'].tolist())
|
|
105
|
+
|
|
106
|
+
sorted_classes = sorted(list(unique_classes))
|
|
107
|
+
|
|
108
|
+
# 定义尺度范围 (COCO 标准)
|
|
109
|
+
area_rngs = {
|
|
110
|
+
'all': (0, 1e10),
|
|
111
|
+
'small': (0, 32 ** 2),
|
|
112
|
+
'medium': (32 ** 2, 96 ** 2),
|
|
113
|
+
'large': (96 ** 2, 1e10)
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
results = {}
|
|
117
|
+
cls_id_to_idx = {cls_id: i for i, cls_id in enumerate(sorted_classes)}
|
|
118
|
+
|
|
119
|
+
# 内部函数:计算单个类别的统计信息
|
|
120
|
+
def _process_class(cls_id, area_rng):
|
|
121
|
+
# score_criteria 仅在 'all' 尺度下计算,避免混淆
|
|
122
|
+
criteria = score_criteria if area_name == 'all' else None
|
|
123
|
+
return cls_id, self._compute_class_stats(cls_id, area_rng, criteria)
|
|
124
|
+
|
|
125
|
+
# 对每个尺度分别进行评估
|
|
126
|
+
for area_name, area_rng in area_rngs.items():
|
|
127
|
+
# 初始化聚合器
|
|
128
|
+
aps = np.zeros((len(sorted_classes), len(self.iou_thresholds)))
|
|
129
|
+
ars = np.zeros((len(sorted_classes), len(self.iou_thresholds), len(self.max_detection_thresholds)))
|
|
130
|
+
|
|
131
|
+
# 收集额外的分数指标
|
|
132
|
+
class_score_results = {}
|
|
133
|
+
|
|
134
|
+
# 初始化进度条
|
|
135
|
+
progress_bar = None
|
|
136
|
+
if progress and len(sorted_classes) > 0:
|
|
137
|
+
try:
|
|
138
|
+
from tqdm import tqdm
|
|
139
|
+
progress_bar = tqdm(total=len(sorted_classes), desc=f"评估 {area_name}")
|
|
140
|
+
except ImportError:
|
|
141
|
+
# 如果没有 tqdm,使用简单的打印
|
|
142
|
+
print(f"开始评估 {area_name}...")
|
|
143
|
+
|
|
144
|
+
# 并行或串行处理
|
|
145
|
+
if n_jobs == 1:
|
|
146
|
+
for cls_id in sorted_classes:
|
|
147
|
+
idx = cls_id_to_idx[cls_id]
|
|
148
|
+
# score_criteria 仅在 'all' 尺度下计算
|
|
149
|
+
criteria = score_criteria if area_name == 'all' else None
|
|
150
|
+
cls_stats = self._compute_class_stats(cls_id, area_rng, criteria)
|
|
151
|
+
|
|
152
|
+
# 解包结果
|
|
153
|
+
if criteria:
|
|
154
|
+
cls_aps, cls_ars, cls_scores = cls_stats
|
|
155
|
+
class_score_results[cls_id] = cls_scores
|
|
156
|
+
else:
|
|
157
|
+
cls_aps, cls_ars = cls_stats
|
|
158
|
+
|
|
159
|
+
aps[idx, :] = cls_aps
|
|
160
|
+
ars[idx, :, :] = cls_ars
|
|
161
|
+
|
|
162
|
+
# 更新进度条
|
|
163
|
+
if progress_bar:
|
|
164
|
+
progress_bar.update(1)
|
|
165
|
+
else:
|
|
166
|
+
max_workers = None if n_jobs == -1 else n_jobs
|
|
167
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
168
|
+
futures = [executor.submit(_process_class, cls_id, area_rng) for cls_id in sorted_classes]
|
|
169
|
+
for future in futures:
|
|
170
|
+
cls_id, cls_stats = future.result()
|
|
171
|
+
idx = cls_id_to_idx[cls_id]
|
|
172
|
+
|
|
173
|
+
# score_criteria 仅在 'all' 尺度下计算
|
|
174
|
+
criteria = score_criteria if area_name == 'all' else None
|
|
175
|
+
if criteria:
|
|
176
|
+
cls_aps, cls_ars, cls_scores = cls_stats
|
|
177
|
+
class_score_results[cls_id] = cls_scores
|
|
178
|
+
else:
|
|
179
|
+
cls_aps, cls_ars = cls_stats
|
|
180
|
+
|
|
181
|
+
aps[idx, :] = cls_aps
|
|
182
|
+
ars[idx, :, :] = cls_ars
|
|
183
|
+
|
|
184
|
+
# 更新进度条
|
|
185
|
+
if progress_bar:
|
|
186
|
+
progress_bar.update(1)
|
|
187
|
+
|
|
188
|
+
# 关闭进度条
|
|
189
|
+
if progress_bar:
|
|
190
|
+
progress_bar.close()
|
|
191
|
+
|
|
192
|
+
# 计算聚合指标
|
|
193
|
+
if aps.size > 0:
|
|
194
|
+
mean_ap = float(np.mean(aps))
|
|
195
|
+
|
|
196
|
+
if area_name == 'all':
|
|
197
|
+
results['mAP'] = mean_ap
|
|
198
|
+
|
|
199
|
+
# mAP @ 0.5
|
|
200
|
+
idx_50 = np.where(np.isclose(self.iou_thresholds, 0.5))[0]
|
|
201
|
+
if len(idx_50) > 0:
|
|
202
|
+
results['mAP_50'] = float(np.mean(aps[:, idx_50[0]]))
|
|
203
|
+
|
|
204
|
+
# mAP @ 0.75
|
|
205
|
+
idx_75 = np.where(np.isclose(self.iou_thresholds, 0.75))[0]
|
|
206
|
+
if len(idx_75) > 0:
|
|
207
|
+
results['mAP_75'] = float(np.mean(aps[:, idx_75[0]]))
|
|
208
|
+
|
|
209
|
+
# 特定类别的 AP (仅在 'all' 尺度下报告)
|
|
210
|
+
for cls_id in sorted_classes:
|
|
211
|
+
idx = cls_id_to_idx[cls_id]
|
|
212
|
+
results[f'AP_{cls_id}'] = float(np.mean(aps[idx, :]))
|
|
213
|
+
|
|
214
|
+
# 添加 AP@50 和 AP@75 的类别细分
|
|
215
|
+
if len(idx_50) > 0:
|
|
216
|
+
results[f'AP_50_{cls_id}'] = float(aps[idx, idx_50[0]])
|
|
217
|
+
if len(idx_75) > 0:
|
|
218
|
+
results[f'AP_75_{cls_id}'] = float(aps[idx, idx_75[0]])
|
|
219
|
+
|
|
220
|
+
# 添加最佳置信度指标
|
|
221
|
+
if cls_id in class_score_results:
|
|
222
|
+
for k, v in class_score_results[cls_id].items():
|
|
223
|
+
results[f'{k}_{cls_id}'] = v
|
|
224
|
+
|
|
225
|
+
# 计算 AR (仅在 'all' 尺度下报告 AR_1, AR_10, AR_100)
|
|
226
|
+
if ars.size > 0:
|
|
227
|
+
mean_ars = np.mean(ars, axis=(0, 1)) # [NumMaxDets]
|
|
228
|
+
for i, max_det in enumerate(self.max_detection_thresholds):
|
|
229
|
+
results[f'AR_{max_det}'] = float(mean_ars[i])
|
|
230
|
+
|
|
231
|
+
elif area_name == 'small':
|
|
232
|
+
results['mAP_s'] = mean_ap
|
|
233
|
+
elif area_name == 'medium':
|
|
234
|
+
results['mAP_m'] = mean_ap
|
|
235
|
+
elif area_name == 'large':
|
|
236
|
+
results['mAP_l'] = mean_ap
|
|
237
|
+
else:
|
|
238
|
+
if area_name == 'all':
|
|
239
|
+
results['mAP'] = 0.0
|
|
240
|
+
|
|
241
|
+
return results
|
|
242
|
+
|
|
243
|
+
def _prepare_data(self, cls_id: int, area_rng: Tuple[float, float]) -> Tuple[List[Tuple], Dict, int]:
|
|
244
|
+
"""
|
|
245
|
+
准备特定类别和面积范围的数据。
|
|
246
|
+
"""
|
|
247
|
+
min_area, max_area = area_rng
|
|
248
|
+
class_preds = []
|
|
249
|
+
class_gt = {}
|
|
250
|
+
n_pos = 0
|
|
251
|
+
|
|
252
|
+
# 收集 GT
|
|
253
|
+
for img_idx, target in enumerate(self.targets):
|
|
254
|
+
mask = target['labels'] == cls_id
|
|
255
|
+
boxes = target['boxes'][mask]
|
|
256
|
+
|
|
257
|
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
258
|
+
valid_area_mask = (areas >= min_area) & (areas < max_area)
|
|
259
|
+
valid_boxes = boxes[valid_area_mask]
|
|
260
|
+
|
|
261
|
+
n_pos += len(valid_boxes)
|
|
262
|
+
class_gt[img_idx] = {
|
|
263
|
+
'boxes': valid_boxes,
|
|
264
|
+
'used': np.zeros(len(valid_boxes), dtype=bool)
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
# 收集预测值
|
|
268
|
+
limit_dets = max(self.max_detection_thresholds) if self.max_detection_thresholds else 100
|
|
269
|
+
|
|
270
|
+
for img_idx, pred in enumerate(self.preds):
|
|
271
|
+
mask = pred['labels'] == cls_id
|
|
272
|
+
scores = pred['scores'][mask]
|
|
273
|
+
boxes = pred['boxes'][mask]
|
|
274
|
+
|
|
275
|
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
276
|
+
valid_area_mask = (areas >= min_area) & (areas < max_area)
|
|
277
|
+
|
|
278
|
+
scores = scores[valid_area_mask]
|
|
279
|
+
boxes = boxes[valid_area_mask]
|
|
280
|
+
|
|
281
|
+
if len(scores) > 0:
|
|
282
|
+
order = np.argsort(-scores)
|
|
283
|
+
scores = scores[order][:limit_dets]
|
|
284
|
+
boxes = boxes[order][:limit_dets]
|
|
285
|
+
ranks = np.arange(len(scores))
|
|
286
|
+
else:
|
|
287
|
+
scores, boxes, ranks = [], [], []
|
|
288
|
+
|
|
289
|
+
for s, b, r in zip(scores, boxes, ranks):
|
|
290
|
+
class_preds.append((s, img_idx, b, r))
|
|
291
|
+
|
|
292
|
+
# 全局按分数排序
|
|
293
|
+
class_preds.sort(key=lambda x: x[0], reverse=True)
|
|
294
|
+
|
|
295
|
+
return class_preds, class_gt, n_pos
|
|
296
|
+
|
|
297
|
+
def _match_predictions(self, class_preds: List[Tuple], class_gt: Dict, pred_gt_ious: List[np.ndarray], iou_thresh: float) -> Tuple[np.ndarray, np.ndarray]:
|
|
298
|
+
"""
|
|
299
|
+
在特定 IoU 阈值下进行贪婪匹配。
|
|
300
|
+
"""
|
|
301
|
+
tp = np.zeros(len(class_preds))
|
|
302
|
+
fp = np.zeros(len(class_preds))
|
|
303
|
+
|
|
304
|
+
# 重置 GT 使用状态
|
|
305
|
+
for img_data in class_gt.values():
|
|
306
|
+
img_data['used'][:] = False
|
|
307
|
+
|
|
308
|
+
for i, ious in enumerate(pred_gt_ious):
|
|
309
|
+
if len(ious) == 0:
|
|
310
|
+
fp[i] = 1
|
|
311
|
+
continue
|
|
312
|
+
|
|
313
|
+
# 找到最佳匹配
|
|
314
|
+
best_gt_idx = np.argmax(ious)
|
|
315
|
+
best_iou = ious[best_gt_idx]
|
|
316
|
+
|
|
317
|
+
img_idx = class_preds[i][1]
|
|
318
|
+
|
|
319
|
+
if best_iou >= iou_thresh:
|
|
320
|
+
if not class_gt[img_idx]['used'][best_gt_idx]:
|
|
321
|
+
tp[i] = 1
|
|
322
|
+
class_gt[img_idx]['used'][best_gt_idx] = True
|
|
323
|
+
else:
|
|
324
|
+
fp[i] = 1 # 重复检测
|
|
325
|
+
else:
|
|
326
|
+
fp[i] = 1
|
|
327
|
+
|
|
328
|
+
return tp, fp
|
|
329
|
+
|
|
330
|
+
def _compute_class_stats(self, cls_id: int, area_rng: Tuple[float, float],
|
|
331
|
+
score_criteria: Optional[List[Tuple[float, float]]] = None) -> Any:
|
|
332
|
+
"""
|
|
333
|
+
计算特定类别在所有阈值下的 AP 和 Recall。
|
|
334
|
+
如果提供了 score_criteria,则返回 (aps, recalls, scores_dict)
|
|
335
|
+
否则返回 (aps, recalls)
|
|
336
|
+
"""
|
|
337
|
+
# 1. 准备数据
|
|
338
|
+
class_preds, class_gt, n_pos = self._prepare_data(cls_id, area_rng)
|
|
339
|
+
|
|
340
|
+
num_iou = len(self.iou_thresholds)
|
|
341
|
+
num_dets = len(self.max_detection_thresholds)
|
|
342
|
+
|
|
343
|
+
# 结果初始化
|
|
344
|
+
aps = np.zeros(num_iou)
|
|
345
|
+
recs = np.zeros((num_iou, num_dets))
|
|
346
|
+
scores_result = {} # key: "BestScore_IoU{}_P{}"
|
|
347
|
+
|
|
348
|
+
if n_pos == 0 or len(class_preds) == 0:
|
|
349
|
+
if score_criteria:
|
|
350
|
+
return aps, recs, scores_result
|
|
351
|
+
return aps, recs
|
|
352
|
+
|
|
353
|
+
# 提取排名以供快速过滤
|
|
354
|
+
pred_ranks = np.array([p[3] for p in class_preds])
|
|
355
|
+
# 提取分数 (class_preds 已经是按分数降序排列)
|
|
356
|
+
pred_scores = np.array([p[0] for p in class_preds])
|
|
357
|
+
|
|
358
|
+
# 2. 预计算 IoU
|
|
359
|
+
pred_gt_ious = []
|
|
360
|
+
for _, img_idx, pred_box, _ in class_preds:
|
|
361
|
+
gt_boxes = class_gt[img_idx]['boxes']
|
|
362
|
+
if len(gt_boxes) == 0:
|
|
363
|
+
pred_gt_ious.append(np.array([]))
|
|
364
|
+
else:
|
|
365
|
+
ious = calculate_iou(pred_box[None, :], gt_boxes)[0]
|
|
366
|
+
pred_gt_ious.append(ious)
|
|
367
|
+
|
|
368
|
+
# 3. 遍历阈值
|
|
369
|
+
for t_idx, iou_thresh in enumerate(self.iou_thresholds):
|
|
370
|
+
tp, fp = self.matcher.match(class_preds, class_gt, pred_gt_ious, iou_thresh)
|
|
371
|
+
|
|
372
|
+
# 计算 AP
|
|
373
|
+
cum_tp = np.cumsum(tp)
|
|
374
|
+
cum_fp = np.cumsum(fp)
|
|
375
|
+
|
|
376
|
+
recall = cum_tp / n_pos
|
|
377
|
+
precision = cum_tp / (cum_tp + cum_fp)
|
|
378
|
+
|
|
379
|
+
aps[t_idx] = compute_ap_coco(recall, precision)
|
|
380
|
+
|
|
381
|
+
# 计算最佳置信度 (如果需要)
|
|
382
|
+
if score_criteria:
|
|
383
|
+
for (crit_iou, crit_prec) in score_criteria:
|
|
384
|
+
# 使用 isclose 比较浮点数
|
|
385
|
+
if np.isclose(iou_thresh, crit_iou):
|
|
386
|
+
# 找到满足 precision >= crit_prec 的所有索引
|
|
387
|
+
valid_mask = precision >= crit_prec
|
|
388
|
+
if np.any(valid_mask):
|
|
389
|
+
# 我们想要最大的 recall,即 valid_mask 中最后一个 True 的位置
|
|
390
|
+
# 因为 class_preds 是按 score 降序 (recall 升序) 排列的
|
|
391
|
+
# 最后一个满足条件的点对应最低的 score (但在满足精度的前提下 recall 最大)
|
|
392
|
+
best_idx = np.where(valid_mask)[0][-1]
|
|
393
|
+
best_score = pred_scores[best_idx]
|
|
394
|
+
|
|
395
|
+
key = f"BestScore_IoU{crit_iou:.2f}_P{crit_prec:.2f}"
|
|
396
|
+
scores_result[key] = float(best_score)
|
|
397
|
+
|
|
398
|
+
# 计算 AR (针对每个 max_det 阈值)
|
|
399
|
+
for d_idx, max_det in enumerate(self.max_detection_thresholds):
|
|
400
|
+
valid_mask = pred_ranks < max_det
|
|
401
|
+
tp_sum = np.sum(tp[valid_mask])
|
|
402
|
+
recs[t_idx, d_idx] = tp_sum / n_pos
|
|
403
|
+
|
|
404
|
+
if score_criteria:
|
|
405
|
+
return aps, recs, scores_result
|
|
406
|
+
return aps, recs
|
|
407
|
+
|
|
408
|
+
def _match_predictions(self, *args, **kwargs):
|
|
409
|
+
"""
|
|
410
|
+
[已弃用] 请使用 self.matcher.match
|
|
411
|
+
为了兼容性暂时保留,但会抛出错误
|
|
412
|
+
"""
|
|
413
|
+
raise DeprecationWarning("Please use self.matcher.match instead of _match_predictions")
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Dict, Tuple, Any
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
class BaseMatcher(ABC):
|
|
6
|
+
"""
|
|
7
|
+
匹配策略的抽象基类。
|
|
8
|
+
定义如何将预测框与真实框进行匹配。
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def match(self,
|
|
13
|
+
class_preds: List[Tuple],
|
|
14
|
+
class_gt: Dict,
|
|
15
|
+
pred_gt_ious: List[np.ndarray],
|
|
16
|
+
iou_thresh: float) -> Tuple[np.ndarray, np.ndarray]:
|
|
17
|
+
"""
|
|
18
|
+
执行匹配逻辑。
|
|
19
|
+
|
|
20
|
+
参数:
|
|
21
|
+
class_preds: 预测列表,每个元素为 (score, img_idx, box, rank)
|
|
22
|
+
class_gt: GT 字典,key 为 img_idx,value 为 {'boxes': [], 'used': []}
|
|
23
|
+
pred_gt_ious: 预计算的 IoU 列表,对应 class_preds 中的每个预测
|
|
24
|
+
iou_thresh: 当前的 IoU 阈值
|
|
25
|
+
|
|
26
|
+
返回:
|
|
27
|
+
tp: (N,) 数组,1 表示 True Positive
|
|
28
|
+
fp: (N,) 数组,1 表示 False Positive
|
|
29
|
+
"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
class GreedyIoUMatcher(BaseMatcher):
|
|
33
|
+
"""
|
|
34
|
+
标准的贪婪 IoU 匹配策略 (COCO/VOC 标准)。
|
|
35
|
+
按分数从高到低,优先匹配 IoU 最大的 GT。
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def match(self,
|
|
39
|
+
class_preds: List[Tuple],
|
|
40
|
+
class_gt: Dict,
|
|
41
|
+
pred_gt_ious: List[np.ndarray],
|
|
42
|
+
iou_thresh: float) -> Tuple[np.ndarray, np.ndarray]:
|
|
43
|
+
|
|
44
|
+
tp = np.zeros(len(class_preds))
|
|
45
|
+
fp = np.zeros(len(class_preds))
|
|
46
|
+
|
|
47
|
+
# 重置 GT 使用状态
|
|
48
|
+
# 注意:这里会修改 class_gt 的内部状态,这是预期的副作用
|
|
49
|
+
for img_data in class_gt.values():
|
|
50
|
+
img_data['used'][:] = False
|
|
51
|
+
|
|
52
|
+
for i, ious in enumerate(pred_gt_ious):
|
|
53
|
+
if len(ious) == 0:
|
|
54
|
+
fp[i] = 1
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
# 找到最佳匹配
|
|
58
|
+
best_gt_idx = np.argmax(ious)
|
|
59
|
+
best_iou = ious[best_gt_idx]
|
|
60
|
+
|
|
61
|
+
img_idx = class_preds[i][1]
|
|
62
|
+
|
|
63
|
+
if best_iou >= iou_thresh:
|
|
64
|
+
# 检查该 GT 是否已被匹配
|
|
65
|
+
if not class_gt[img_idx]['used'][best_gt_idx]:
|
|
66
|
+
tp[i] = 1
|
|
67
|
+
class_gt[img_idx]['used'][best_gt_idx] = True
|
|
68
|
+
else:
|
|
69
|
+
fp[i] = 1 # 重复检测 (Duplicate)
|
|
70
|
+
else:
|
|
71
|
+
fp[i] = 1 # IoU 不足
|
|
72
|
+
|
|
73
|
+
return tp, fp
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def calculate_iou(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
|
|
4
|
+
"""
|
|
5
|
+
计算两组边界框之间的交并比 (IoU)。
|
|
6
|
+
|
|
7
|
+
参数:
|
|
8
|
+
boxes1: (N, 4) ndarray, 格式为 [x1, y1, x2, y2]
|
|
9
|
+
boxes2: (M, 4) ndarray, 格式为 [x1, y1, x2, y2]
|
|
10
|
+
|
|
11
|
+
返回:
|
|
12
|
+
iou: (N, M) ndarray, 表示 boxes1 和 boxes2 之间的重叠程度
|
|
13
|
+
"""
|
|
14
|
+
if boxes1.size == 0 or boxes2.size == 0:
|
|
15
|
+
return np.zeros((boxes1.shape[0], boxes2.shape[0]))
|
|
16
|
+
|
|
17
|
+
# 确保输入是 2D 的
|
|
18
|
+
if boxes1.ndim == 1:
|
|
19
|
+
boxes1 = boxes1[np.newaxis, :]
|
|
20
|
+
if boxes2.ndim == 1:
|
|
21
|
+
boxes2 = boxes2[np.newaxis, :]
|
|
22
|
+
|
|
23
|
+
# 计算面积
|
|
24
|
+
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # (N,)
|
|
25
|
+
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # (M,)
|
|
26
|
+
|
|
27
|
+
# 广播计算交集区域的左上角和右下角坐标
|
|
28
|
+
lt = np.maximum(boxes1[:, None, :2], boxes2[None, :, :2]) # (N, M, 2) [x1, y1]
|
|
29
|
+
rb = np.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:]) # (N, M, 2) [x2, y2]
|
|
30
|
+
|
|
31
|
+
# 计算交集宽高,clip(0) 确保无重叠时为 0
|
|
32
|
+
wh = np.clip(rb - lt, 0, None) # (N, M, 2) [w, h]
|
|
33
|
+
inter = wh[:, :, 0] * wh[:, :, 1] # (N, M)
|
|
34
|
+
|
|
35
|
+
# 计算并集面积
|
|
36
|
+
union = area1[:, None] + area2[None, :] - inter
|
|
37
|
+
|
|
38
|
+
# 避免除以零
|
|
39
|
+
union = np.maximum(union, 1e-6)
|
|
40
|
+
|
|
41
|
+
iou = inter / union
|
|
42
|
+
return iou
|
|
43
|
+
|
|
44
|
+
def compute_ap_coco(recall: np.ndarray, precision: np.ndarray) -> float:
|
|
45
|
+
"""
|
|
46
|
+
使用 COCO 风格的 101 点插值法计算平均精度 (Average Precision)。
|
|
47
|
+
|
|
48
|
+
参数:
|
|
49
|
+
recall: (N,) ndarray, 召回率数组,需单调递增
|
|
50
|
+
precision: (N,) ndarray, 对应的精度数组
|
|
51
|
+
|
|
52
|
+
返回:
|
|
53
|
+
ap: float, 计算得到的 AP 值
|
|
54
|
+
"""
|
|
55
|
+
# 在开头和结尾添加哨兵值
|
|
56
|
+
# 注意: COCO 不严格要求开头为 0,但要求覆盖 [0, 1] 区间
|
|
57
|
+
# 我们采用类似的包络线方法,但进行固定采样
|
|
58
|
+
mrec = np.concatenate(([0.0], recall, [1.0]))
|
|
59
|
+
mpre = np.concatenate(([0.0], precision, [0.0]))
|
|
60
|
+
|
|
61
|
+
# 计算精度包络线 (单调递减)
|
|
62
|
+
# 对于每个 recall 值,取其右侧最大的 precision
|
|
63
|
+
for i in range(mpre.size - 1, 0, -1):
|
|
64
|
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
|
65
|
+
|
|
66
|
+
# 生成从 0.0 到 1.00 的 101 个召回率阈值
|
|
67
|
+
rec_thresholds = np.linspace(0.0, 1.00, 101)
|
|
68
|
+
|
|
69
|
+
# 对于每个阈值 t,我们需要找到 recall >= t 时的最大 precision
|
|
70
|
+
# 由于 mpre[i] 已经是 recall >= mrec[i] 时的最大 precision,
|
|
71
|
+
# 我们只需要找到 mrec 中第一个大于等于 t 的位置。
|
|
72
|
+
inds = np.searchsorted(mrec, rec_thresholds, side='left')
|
|
73
|
+
|
|
74
|
+
# 获取包络线上的值
|
|
75
|
+
q = mpre[inds]
|
|
76
|
+
|
|
77
|
+
return float(np.mean(q))
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: EasyMetrics
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: 轻量级、零依赖的机器学习指标评估平台
|
|
5
|
+
Author-email: EasyMetrics Team <team@easymetrics.example>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2026 EasyMetrics Team
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Project-URL: homepage, https://github.com/easymetrics/easymetrics
|
|
29
|
+
Project-URL: documentation, https://easymetrics.readthedocs.io
|
|
30
|
+
Project-URL: repository, https://github.com/easymetrics/easymetrics
|
|
31
|
+
Project-URL: issues, https://github.com/easymetrics/easymetrics/issues
|
|
32
|
+
Classifier: Programming Language :: Python :: 3
|
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
34
|
+
Classifier: Operating System :: OS Independent
|
|
35
|
+
Classifier: Intended Audience :: Developers
|
|
36
|
+
Classifier: Intended Audience :: Science/Research
|
|
37
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
38
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
39
|
+
Requires-Python: >=3.7
|
|
40
|
+
Description-Content-Type: text/markdown
|
|
41
|
+
License-File: LICENSE
|
|
42
|
+
Requires-Dist: numpy
|
|
43
|
+
Provides-Extra: progress
|
|
44
|
+
Requires-Dist: tqdm; extra == "progress"
|
|
45
|
+
Dynamic: license-file
|
|
46
|
+
|
|
47
|
+
# EasyMetrics
|
|
48
|
+
|
|
49
|
+
一个轻量级、零依赖的机器学习指标评估平台,基于 `numpy` 从零构建,专注于提供简单易用且准确的模型评估工具。
|
|
50
|
+
|
|
51
|
+
## ✨ 核心特性
|
|
52
|
+
- **零依赖**: 仅需 Python 和 Numpy,无需安装大型深度学习框架
|
|
53
|
+
- **易于扩展**: 模块化设计,通过继承 `Metric` 基类即可添加新任务
|
|
54
|
+
- **功能强大**: 完美支持目标检测任务的全方位评估
|
|
55
|
+
- 标准 COCO 指标: mAP、mAP_50、mAP_75、mAP_s/m/l
|
|
56
|
+
- 平均召回率 (AR) 指标
|
|
57
|
+
- 每类别独立评估
|
|
58
|
+
- **独家功能**: 自动计算满足特定精度要求的最佳置信度阈值
|
|
59
|
+
|
|
60
|
+
## � 目录结构
|
|
61
|
+
```
|
|
62
|
+
easyMetrics/
|
|
63
|
+
├── easyMetrics/ # 核心代码
|
|
64
|
+
│ ├── core/ # 抽象基类
|
|
65
|
+
│ │ └── base.py
|
|
66
|
+
│ └── tasks/ # 任务实现
|
|
67
|
+
│ └── detection/ # 目标检测
|
|
68
|
+
│ ├── interface.py # 对外接口
|
|
69
|
+
│ ├── map.py # mAP 核心逻辑
|
|
70
|
+
│ ├── matcher.py # 匹配策略
|
|
71
|
+
│ ├── utils.py # 辅助函数
|
|
72
|
+
│ └── format_converter.py # 格式转换器
|
|
73
|
+
├── docs/ # 文档
|
|
74
|
+
│ ├── 使用指南.md
|
|
75
|
+
│ └── 指标详解.md
|
|
76
|
+
├── demo.py # 使用示例
|
|
77
|
+
└── README.md
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## 🚀 快速上手
|
|
81
|
+
|
|
82
|
+
### 目标检测评估
|
|
83
|
+
|
|
84
|
+
使用 `evaluate_detection` 函数,一行代码完成评估:
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
import numpy as np
|
|
88
|
+
from easyMetrics.tasks.detection import evaluate_detection
|
|
89
|
+
|
|
90
|
+
# 准备数据 - 每张图片一个字典
|
|
91
|
+
preds = [{
|
|
92
|
+
'boxes': np.array([[10, 10, 50, 50]]), # [x1, y1, x2, y2] 格式
|
|
93
|
+
'scores': np.array([0.9]), # 置信度分数
|
|
94
|
+
'labels': np.array([0]) # 类别索引
|
|
95
|
+
}]
|
|
96
|
+
targets = [{
|
|
97
|
+
'boxes': np.array([[10, 10, 50, 50]]), # 真实边界框
|
|
98
|
+
'labels': np.array([0]) # 真实类别
|
|
99
|
+
}]
|
|
100
|
+
|
|
101
|
+
# 1. 计算标准 COCO 指标
|
|
102
|
+
results = evaluate_detection(preds, targets)
|
|
103
|
+
print(f"mAP: {results['mAP']:.4f}")
|
|
104
|
+
print(f"mAP_50: {results['mAP_50']:.4f}")
|
|
105
|
+
|
|
106
|
+
# 2. 寻找最佳置信度阈值
|
|
107
|
+
# 场景: IoU=0.5 时精度至少达到 90%
|
|
108
|
+
results = evaluate_detection(
|
|
109
|
+
preds, targets,
|
|
110
|
+
score_criteria=[(0.5, 0.9)]
|
|
111
|
+
)
|
|
112
|
+
print(f"推荐阈值: {results.get('BestScore_IoU0.50_P0.90_0')}")
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
### 并行加速
|
|
116
|
+
|
|
117
|
+
对于大规模数据集,启用多核并行计算:
|
|
118
|
+
|
|
119
|
+
```python
|
|
120
|
+
# 使用 4 个核心
|
|
121
|
+
results = evaluate_detection(preds, targets, n_jobs=4)
|
|
122
|
+
|
|
123
|
+
# 使用所有可用核心
|
|
124
|
+
results = evaluate_detection(preds, targets, n_jobs=-1)
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
## 🔧 扩展新任务
|
|
128
|
+
|
|
129
|
+
添加新指标(例如分类任务的准确率):
|
|
130
|
+
|
|
131
|
+
1. 在 `easyMetrics/tasks/` 下创建新目录(如 `classification`)
|
|
132
|
+
2. 继承 `easyMetrics.core.Metric` 基类
|
|
133
|
+
3. 实现 `reset()`, `update()` 和 `compute()` 方法
|
|
134
|
+
|
|
135
|
+
---
|
|
136
|
+
*Created with ❤️ by EasyMetrics Team*
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
easyMetrics/__init__.py,sha256=QRWXyUWj4TDlsR-fmXUNs2Nu7VTuaU64OKk9AbPiCAA,100
|
|
2
|
+
easyMetrics/core/__init__.py,sha256=L-JsvnYEbmJZPzf8oBbldRA-B6Oni3AtJaZZJEuDOwA,25
|
|
3
|
+
easyMetrics/core/base.py,sha256=pMRZCl_SJtfCo-_fzuGijkYMSq1DDIZegI96EWxUX6k,789
|
|
4
|
+
easyMetrics/tasks/__init__.py,sha256=Cwc3WM7ZgblWq1jfIjrHt_bBrJRobx4q_WbHn3LaN68,44
|
|
5
|
+
easyMetrics/tasks/detection/__init__.py,sha256=74SsA-fJbamX7WGIWJJmeTXjjHnogYU1jwYse65d45U,80
|
|
6
|
+
easyMetrics/tasks/detection/format_converter.py,sha256=EhGz-8j46bf5higaKyJmtqwILp8V2NtpiK0i-gOAgdE,11700
|
|
7
|
+
easyMetrics/tasks/detection/interface.py,sha256=ZbR9aJVXsRiSNGrhcVZAEGkfH2SFX0kJi0XY_rUEbMA,2568
|
|
8
|
+
easyMetrics/tasks/detection/map.py,sha256=F2qVDTQR1qr1SHEf3w3Xg7yvUytzfQtxXdF6p25CcwA,17471
|
|
9
|
+
easyMetrics/tasks/detection/matcher.py,sha256=C-4hEVBJseQKgcHupBLlaDNFGhKxOxN-YdPNnpuit9I,2443
|
|
10
|
+
easyMetrics/tasks/detection/utils.py,sha256=jNDPHxWzGo7478QDI1f8uVGK6HUlwvHlPbfwzEk-H0A,2770
|
|
11
|
+
easymetrics-0.1.0.dist-info/licenses/LICENSE,sha256=Z-thSEoGCfxHrpC8rpgrqY-LSoJuyHbgPj5Gt8BuG_0,1073
|
|
12
|
+
easymetrics-0.1.0.dist-info/METADATA,sha256=dHEOCgWpY4ECOtByCGk-0fJYz4lFYSq18oYnMCte4Eg,5265
|
|
13
|
+
easymetrics-0.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
14
|
+
easymetrics-0.1.0.dist-info/top_level.txt,sha256=2ytZxZ9LMTdKpeZunCE1rpvkpDAhWeQLiiXTrFy9niw,12
|
|
15
|
+
easymetrics-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 EasyMetrics Team
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
easyMetrics
|