dataset-toolkit 0.1.2__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dataset_toolkit-0.1.2/dataset_toolkit.egg-info → dataset_toolkit-0.2.0}/PKG-INFO +39 -1
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/README.md +38 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/__init__.py +9 -2
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/yolo_exporter.py +4 -2
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/loaders/local_loader.py +145 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/models.py +3 -1
- dataset_toolkit-0.2.0/dataset_toolkit/processors/__init__.py +9 -0
- dataset_toolkit-0.2.0/dataset_toolkit/processors/evaluator.py +535 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0/dataset_toolkit.egg-info}/PKG-INFO +39 -1
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/SOURCES.txt +2 -0
- dataset_toolkit-0.2.0/examples/evaluation_example.py +250 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/pyproject.toml +1 -1
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/setup.py +1 -1
- dataset_toolkit-0.1.2/dataset_toolkit/utils/__init__.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/LICENSE +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/MANIFEST.in +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/__init__.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/coco_exporter.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/txt_exporter.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/loaders/__init__.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/pipeline.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/processors/merger.py +0 -0
- {dataset_toolkit-0.1.2/dataset_toolkit/processors → dataset_toolkit-0.2.0/dataset_toolkit/utils}/__init__.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/utils/coords.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/dependency_links.txt +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/requires.txt +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/top_level.txt +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/examples/basic_usage.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/requirements.txt +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/setup.cfg +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/__init__.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/conftest.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_exporters.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_loaders.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_processors.py +0 -0
- {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_pypi_test.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dataset-toolkit
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: 一个用于加载、处理和导出计算机视觉数据集的工具包
|
5
5
|
Home-page: https://github.com/yourusername/dataset-toolkit
|
6
6
|
Author: wenxiang.han
|
@@ -42,6 +42,7 @@ Dynamic: requires-python
|
|
42
42
|
- 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
|
43
43
|
- 🛠️ **工具函数**:提供坐标转换等实用工具
|
44
44
|
- 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
|
45
|
+
- 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
|
45
46
|
|
46
47
|
## 📦 安装
|
47
48
|
|
@@ -121,6 +122,43 @@ result = (pipeline
|
|
121
122
|
.execute())
|
122
123
|
```
|
123
124
|
|
125
|
+
### 模型评估(v0.2.0+)
|
126
|
+
|
127
|
+
```python
|
128
|
+
from dataset_toolkit import (
|
129
|
+
load_yolo_from_local,
|
130
|
+
load_predictions_from_streamlined,
|
131
|
+
Evaluator
|
132
|
+
)
|
133
|
+
|
134
|
+
# 1. 加载GT和预测结果
|
135
|
+
gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
|
136
|
+
pred_dataset = load_predictions_from_streamlined(
|
137
|
+
"/results/predictions",
|
138
|
+
categories={0: 'parcel'},
|
139
|
+
image_dir="/data/test/images"
|
140
|
+
)
|
141
|
+
|
142
|
+
# 2. 创建评估器
|
143
|
+
evaluator = Evaluator(
|
144
|
+
positive_gt=gt_dataset,
|
145
|
+
positive_pred=pred_dataset,
|
146
|
+
iou_threshold=0.5
|
147
|
+
)
|
148
|
+
|
149
|
+
# 3. 计算指标
|
150
|
+
metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
|
151
|
+
print(f"Precision: {metrics['precision']:.4f}")
|
152
|
+
print(f"Recall: {metrics['recall']:.4f}")
|
153
|
+
print(f"F1-Score: {metrics['f1']:.4f}")
|
154
|
+
|
155
|
+
# 4. 寻找最优阈值
|
156
|
+
optimal = evaluator.find_optimal_threshold(metric='f1')
|
157
|
+
print(f"最优阈值: {optimal['optimal_threshold']}")
|
158
|
+
```
|
159
|
+
|
160
|
+
详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
|
161
|
+
|
124
162
|
## 📚 API 文档
|
125
163
|
|
126
164
|
### 数据加载器
|
@@ -9,6 +9,7 @@
|
|
9
9
|
- 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
|
10
10
|
- 🛠️ **工具函数**:提供坐标转换等实用工具
|
11
11
|
- 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
|
12
|
+
- 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
|
12
13
|
|
13
14
|
## 📦 安装
|
14
15
|
|
@@ -88,6 +89,43 @@ result = (pipeline
|
|
88
89
|
.execute())
|
89
90
|
```
|
90
91
|
|
92
|
+
### 模型评估(v0.2.0+)
|
93
|
+
|
94
|
+
```python
|
95
|
+
from dataset_toolkit import (
|
96
|
+
load_yolo_from_local,
|
97
|
+
load_predictions_from_streamlined,
|
98
|
+
Evaluator
|
99
|
+
)
|
100
|
+
|
101
|
+
# 1. 加载GT和预测结果
|
102
|
+
gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
|
103
|
+
pred_dataset = load_predictions_from_streamlined(
|
104
|
+
"/results/predictions",
|
105
|
+
categories={0: 'parcel'},
|
106
|
+
image_dir="/data/test/images"
|
107
|
+
)
|
108
|
+
|
109
|
+
# 2. 创建评估器
|
110
|
+
evaluator = Evaluator(
|
111
|
+
positive_gt=gt_dataset,
|
112
|
+
positive_pred=pred_dataset,
|
113
|
+
iou_threshold=0.5
|
114
|
+
)
|
115
|
+
|
116
|
+
# 3. 计算指标
|
117
|
+
metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
|
118
|
+
print(f"Precision: {metrics['precision']:.4f}")
|
119
|
+
print(f"Recall: {metrics['recall']:.4f}")
|
120
|
+
print(f"F1-Score: {metrics['f1']:.4f}")
|
121
|
+
|
122
|
+
# 4. 寻找最优阈值
|
123
|
+
optimal = evaluator.find_optimal_threshold(metric='f1')
|
124
|
+
print(f"最优阈值: {optimal['optimal_threshold']}")
|
125
|
+
```
|
126
|
+
|
127
|
+
详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
|
128
|
+
|
91
129
|
## 📚 API 文档
|
92
130
|
|
93
131
|
### 数据加载器
|
@@ -15,7 +15,7 @@ Dataset Toolkit - 计算机视觉数据集处理工具包
|
|
15
15
|
>>> export_to_coco(dataset, "output.json")
|
16
16
|
"""
|
17
17
|
|
18
|
-
__version__ = "0.
|
18
|
+
__version__ = "0.2.0"
|
19
19
|
__author__ = "wenxiang.han"
|
20
20
|
__email__ = "wenxiang.han@anker-in.com"
|
21
21
|
|
@@ -28,13 +28,18 @@ from dataset_toolkit.models import (
|
|
28
28
|
|
29
29
|
from dataset_toolkit.loaders.local_loader import (
|
30
30
|
load_yolo_from_local,
|
31
|
-
load_csv_result_from_local
|
31
|
+
load_csv_result_from_local,
|
32
|
+
load_predictions_from_streamlined
|
32
33
|
)
|
33
34
|
|
34
35
|
from dataset_toolkit.processors.merger import (
|
35
36
|
merge_datasets
|
36
37
|
)
|
37
38
|
|
39
|
+
from dataset_toolkit.processors.evaluator import (
|
40
|
+
Evaluator
|
41
|
+
)
|
42
|
+
|
38
43
|
from dataset_toolkit.exporters.coco_exporter import (
|
39
44
|
export_to_coco
|
40
45
|
)
|
@@ -69,9 +74,11 @@ __all__ = [
|
|
69
74
|
# 加载器
|
70
75
|
"load_yolo_from_local",
|
71
76
|
"load_csv_result_from_local",
|
77
|
+
"load_predictions_from_streamlined",
|
72
78
|
|
73
79
|
# 处理器
|
74
80
|
"merge_datasets",
|
81
|
+
"Evaluator",
|
75
82
|
|
76
83
|
# 导出器
|
77
84
|
"export_to_coco",
|
@@ -146,8 +146,10 @@ def export_to_yolo_and_txt(
|
|
146
146
|
rel_path = os.path.relpath(img_in_yolo, txt_path.parent)
|
147
147
|
f.write(f"{rel_path}\n")
|
148
148
|
else:
|
149
|
-
#
|
150
|
-
|
149
|
+
# 绝对路径(规范化但不解析软链接)
|
150
|
+
# 使用 os.path.normpath 规范化路径,去除 .. 等
|
151
|
+
normalized_path = os.path.normpath(str(img_in_yolo.absolute()))
|
152
|
+
f.write(f"{normalized_path}\n")
|
151
153
|
|
152
154
|
print(f"✓ txt 列表已生成: {len(dataset.images)} 行")
|
153
155
|
|
@@ -186,4 +186,149 @@ def load_csv_result_from_local(dataset_path: str, categories: Dict[int, str] = N
|
|
186
186
|
|
187
187
|
print(f"加载完成. 共找到 {image_count} 张图片, {len(dataset.categories)} 个类别.")
|
188
188
|
print(f"类别映射: {dataset.categories}")
|
189
|
+
return dataset
|
190
|
+
|
191
|
+
|
192
|
+
def load_predictions_from_streamlined(
|
193
|
+
predictions_dir: str,
|
194
|
+
categories: Dict[int, str],
|
195
|
+
image_dir: str = None
|
196
|
+
) -> Dataset:
|
197
|
+
"""
|
198
|
+
从streamlined推理结果目录加载预测数据集。
|
199
|
+
|
200
|
+
预测文件格式(每行一个检测):
|
201
|
+
class_id,confidence,center_x,center_y,width,height
|
202
|
+
例如: 0,0.934679,354.00,388.00,274.00,102.00
|
203
|
+
|
204
|
+
参数:
|
205
|
+
predictions_dir: 预测结果txt文件所在目录
|
206
|
+
categories: 类别映射字典 {class_id: class_name}
|
207
|
+
image_dir: 图像目录(可选,用于读取图像尺寸)
|
208
|
+
如果不提供,将尝试从预测文件同级目录查找
|
209
|
+
|
210
|
+
返回:
|
211
|
+
Dataset: 预测数据集对象,dataset_type='pred'
|
212
|
+
"""
|
213
|
+
pred_path = Path(predictions_dir)
|
214
|
+
|
215
|
+
if not pred_path.is_dir():
|
216
|
+
raise FileNotFoundError(f"预测结果目录不存在: {pred_path}")
|
217
|
+
|
218
|
+
# 尝试自动查找图像目录
|
219
|
+
if image_dir is None:
|
220
|
+
# 尝试常见的图像目录位置
|
221
|
+
possible_image_dirs = [
|
222
|
+
pred_path.parent / 'images',
|
223
|
+
pred_path.parent.parent / 'images',
|
224
|
+
]
|
225
|
+
for possible_dir in possible_image_dirs:
|
226
|
+
if possible_dir.is_dir():
|
227
|
+
image_dir = str(possible_dir)
|
228
|
+
print(f"自动找到图像目录: {image_dir}")
|
229
|
+
break
|
230
|
+
|
231
|
+
dataset = Dataset(
|
232
|
+
name=pred_path.name,
|
233
|
+
categories=categories,
|
234
|
+
dataset_type="pred"
|
235
|
+
)
|
236
|
+
|
237
|
+
supported_extensions = ['.jpg', '.jpeg', '.png']
|
238
|
+
txt_files = list(pred_path.glob('*.txt'))
|
239
|
+
|
240
|
+
print(f"开始加载预测结果: {pred_path.name}...")
|
241
|
+
print(f"找到 {len(txt_files)} 个预测文件")
|
242
|
+
|
243
|
+
loaded_count = 0
|
244
|
+
skipped_count = 0
|
245
|
+
|
246
|
+
for txt_file in txt_files:
|
247
|
+
# 预测文件名对应的图像文件名(假设同名)
|
248
|
+
image_base_name = txt_file.stem
|
249
|
+
|
250
|
+
# 尝试查找对应的图像文件
|
251
|
+
image_path = None
|
252
|
+
img_width, img_height = None, None
|
253
|
+
|
254
|
+
if image_dir:
|
255
|
+
image_dir_path = Path(image_dir)
|
256
|
+
for ext in supported_extensions:
|
257
|
+
potential_image = image_dir_path / (image_base_name + ext)
|
258
|
+
if potential_image.exists():
|
259
|
+
image_path = str(potential_image.resolve())
|
260
|
+
try:
|
261
|
+
with Image.open(potential_image) as img:
|
262
|
+
img_width, img_height = img.size
|
263
|
+
except IOError:
|
264
|
+
print(f"警告: 无法打开图片 {potential_image}")
|
265
|
+
break
|
266
|
+
|
267
|
+
# 如果没有找到图像,使用默认值
|
268
|
+
if image_path is None:
|
269
|
+
# 假设一个默认的图像路径和尺寸
|
270
|
+
image_path = f"unknown/{image_base_name}.jpg"
|
271
|
+
img_width, img_height = 640, 640 # 默认尺寸
|
272
|
+
if image_dir:
|
273
|
+
skipped_count += 1
|
274
|
+
|
275
|
+
# 创建图像标注对象
|
276
|
+
image_annotation = ImageAnnotation(
|
277
|
+
image_id=image_base_name + '.jpg',
|
278
|
+
path=image_path,
|
279
|
+
width=img_width,
|
280
|
+
height=img_height
|
281
|
+
)
|
282
|
+
|
283
|
+
# 读取预测结果
|
284
|
+
try:
|
285
|
+
with open(txt_file, 'r') as f:
|
286
|
+
for line in f:
|
287
|
+
line = line.strip()
|
288
|
+
if not line:
|
289
|
+
continue
|
290
|
+
|
291
|
+
# 解析格式: class_id,confidence,center_x,center_y,width,height
|
292
|
+
parts = line.split(',')
|
293
|
+
if len(parts) != 6:
|
294
|
+
print(f"警告: 格式错误,已跳过: {txt_file} -> '{line}'")
|
295
|
+
continue
|
296
|
+
|
297
|
+
try:
|
298
|
+
class_id = int(parts[0])
|
299
|
+
confidence = float(parts[1])
|
300
|
+
center_x = float(parts[2])
|
301
|
+
center_y = float(parts[3])
|
302
|
+
width = float(parts[4])
|
303
|
+
height = float(parts[5])
|
304
|
+
|
305
|
+
# 转换为 [x_min, y_min, width, height] 格式
|
306
|
+
x_min = center_x - width / 2
|
307
|
+
y_min = center_y - height / 2
|
308
|
+
|
309
|
+
annotation = Annotation(
|
310
|
+
category_id=class_id,
|
311
|
+
bbox=[x_min, y_min, width, height],
|
312
|
+
confidence=confidence
|
313
|
+
)
|
314
|
+
image_annotation.annotations.append(annotation)
|
315
|
+
|
316
|
+
except (ValueError, IndexError) as e:
|
317
|
+
print(f"警告: 解析错误,已跳过: {txt_file} -> '{line}' ({e})")
|
318
|
+
continue
|
319
|
+
|
320
|
+
except Exception as e:
|
321
|
+
print(f"警告: 读取文件失败,已跳过: {txt_file} ({e})")
|
322
|
+
continue
|
323
|
+
|
324
|
+
dataset.images.append(image_annotation)
|
325
|
+
loaded_count += 1
|
326
|
+
|
327
|
+
print(f"加载完成. 成功加载 {loaded_count} 个预测文件")
|
328
|
+
if skipped_count > 0:
|
329
|
+
print(f"警告: {skipped_count} 个文件未找到对应图像,使用默认尺寸")
|
330
|
+
|
331
|
+
total_detections = sum(len(img.annotations) for img in dataset.images)
|
332
|
+
print(f"总检测数: {total_detections}")
|
333
|
+
|
189
334
|
return dataset
|
@@ -24,4 +24,6 @@ class Dataset:
|
|
24
24
|
"""代表一个完整的数据集对象,作为系统内部的标准化表示."""
|
25
25
|
name: str
|
26
26
|
images: List[ImageAnnotation] = field(default_factory=list)
|
27
|
-
categories: Dict[int, str] = field(default_factory=dict)
|
27
|
+
categories: Dict[int, str] = field(default_factory=dict)
|
28
|
+
dataset_type: str = "train" # 'train', 'gt', 'pred'
|
29
|
+
metadata: Dict = field(default_factory=dict) # 存储描述性信息,不包含处理参数
|
@@ -0,0 +1,535 @@
|
|
1
|
+
# dataset_toolkit/processors/evaluator.py
|
2
|
+
from typing import Dict, List, Optional, Tuple
|
3
|
+
from dataset_toolkit.models import Dataset, Annotation, ImageAnnotation
|
4
|
+
|
5
|
+
|
6
|
+
class Evaluator:
|
7
|
+
"""
|
8
|
+
评估器:支持正检集和误检集分离评估
|
9
|
+
|
10
|
+
用于比较GT和Pred数据集,计算Precision、Recall、F1、FPPI等指标。
|
11
|
+
支持在不同置信度阈值下动态评估,无需重新加载数据。
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
positive_gt: Dataset,
|
17
|
+
positive_pred: Dataset,
|
18
|
+
negative_gt: Optional[Dataset] = None,
|
19
|
+
negative_pred: Optional[Dataset] = None,
|
20
|
+
iou_threshold: float = 0.5
|
21
|
+
):
|
22
|
+
"""
|
23
|
+
初始化评估器
|
24
|
+
|
25
|
+
Args:
|
26
|
+
positive_gt: 正检集GT(必需)- 包含目标物体的测试集
|
27
|
+
positive_pred: 正检集预测(必需)- 对正检集的预测结果
|
28
|
+
negative_gt: 误检集GT(可选)- 不包含目标的背景图像GT
|
29
|
+
negative_pred: 误检集预测(可选)- 对误检集的预测结果
|
30
|
+
iou_threshold: IoU阈值,用于判断检测是否匹配GT(默认0.5)
|
31
|
+
"""
|
32
|
+
self.positive_gt = positive_gt
|
33
|
+
self.positive_pred = positive_pred
|
34
|
+
self.negative_gt = negative_gt
|
35
|
+
self.negative_pred = negative_pred
|
36
|
+
self.iou_threshold = iou_threshold
|
37
|
+
|
38
|
+
# 验证数据集
|
39
|
+
self._validate_datasets()
|
40
|
+
|
41
|
+
def _validate_datasets(self):
|
42
|
+
"""验证数据集的有效性"""
|
43
|
+
if self.positive_gt is None or self.positive_pred is None:
|
44
|
+
raise ValueError("正检集的GT和Pred是必需的")
|
45
|
+
|
46
|
+
if len(self.positive_gt.images) == 0:
|
47
|
+
raise ValueError("正检集GT为空")
|
48
|
+
|
49
|
+
if len(self.positive_pred.images) == 0:
|
50
|
+
raise ValueError("正检集Pred为空")
|
51
|
+
|
52
|
+
# 检查类别是否一致
|
53
|
+
if self.positive_gt.categories != self.positive_pred.categories:
|
54
|
+
print("警告: GT和Pred的类别映射不一致")
|
55
|
+
print(f" GT categories: {self.positive_gt.categories}")
|
56
|
+
print(f" Pred categories: {self.positive_pred.categories}")
|
57
|
+
|
58
|
+
def calculate_metrics(
|
59
|
+
self,
|
60
|
+
confidence_threshold: float = 0.5,
|
61
|
+
class_id: Optional[int] = None,
|
62
|
+
calculate_fppi: bool = True
|
63
|
+
) -> Dict:
|
64
|
+
"""
|
65
|
+
计算综合评估指标
|
66
|
+
|
67
|
+
Args:
|
68
|
+
confidence_threshold: 置信度阈值(动态传入,不存储在数据集中)
|
69
|
+
class_id: 指定类别ID,None表示所有类别
|
70
|
+
calculate_fppi: 是否计算FPPI(需要negative_pred)
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
包含所有指标的字典:
|
74
|
+
- tp, fp, fn: True/False Positives/Negatives
|
75
|
+
- precision, recall, f1: 精确率、召回率、F1分数
|
76
|
+
- fppi: False Positives Per Image(如果计算)
|
77
|
+
- confidence_threshold, iou_threshold: 使用的阈值
|
78
|
+
- positive_set_size, negative_set_size: 数据集大小
|
79
|
+
"""
|
80
|
+
metrics = {}
|
81
|
+
|
82
|
+
# 1. 从正检集计算 Precision, Recall, F1
|
83
|
+
positive_metrics = self._calculate_positive_metrics(
|
84
|
+
confidence_threshold, class_id
|
85
|
+
)
|
86
|
+
metrics.update(positive_metrics)
|
87
|
+
|
88
|
+
# 2. 从误检集计算 FPPI
|
89
|
+
if calculate_fppi and self.negative_pred is not None:
|
90
|
+
fppi_metrics = self._calculate_fppi_metrics(
|
91
|
+
confidence_threshold, class_id
|
92
|
+
)
|
93
|
+
metrics.update(fppi_metrics)
|
94
|
+
else:
|
95
|
+
metrics['fppi'] = None
|
96
|
+
metrics['fppi_note'] = "未提供误检集" if self.negative_pred is None else "未计算FPPI"
|
97
|
+
|
98
|
+
# 3. 添加配置信息
|
99
|
+
metrics['confidence_threshold'] = confidence_threshold
|
100
|
+
metrics['iou_threshold'] = self.iou_threshold
|
101
|
+
metrics['positive_set_size'] = len(self.positive_gt.images)
|
102
|
+
metrics['negative_set_size'] = (
|
103
|
+
len(self.negative_pred.images)
|
104
|
+
if self.negative_pred else 0
|
105
|
+
)
|
106
|
+
|
107
|
+
return metrics
|
108
|
+
|
109
|
+
def _calculate_positive_metrics(
|
110
|
+
self,
|
111
|
+
confidence_threshold: float,
|
112
|
+
class_id: Optional[int] = None
|
113
|
+
) -> Dict:
|
114
|
+
"""
|
115
|
+
从正检集计算 Precision, Recall, F1
|
116
|
+
|
117
|
+
Args:
|
118
|
+
confidence_threshold: 置信度阈值
|
119
|
+
class_id: 指定类别ID
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
包含TP, FP, FN, Precision, Recall, F1的字典
|
123
|
+
"""
|
124
|
+
# 1. 过滤预测结果
|
125
|
+
filtered_preds = self._filter_predictions(
|
126
|
+
self.positive_pred,
|
127
|
+
confidence_threshold,
|
128
|
+
class_id
|
129
|
+
)
|
130
|
+
|
131
|
+
# 2. 匹配GT和Pred
|
132
|
+
tp = 0 # True Positives
|
133
|
+
fp = 0 # False Positives
|
134
|
+
fn = 0 # False Negatives
|
135
|
+
|
136
|
+
matched_gt = set() # 记录已匹配的GT,格式: (image_id, gt_index)
|
137
|
+
|
138
|
+
# 遍历每张图像
|
139
|
+
for img_gt in self.positive_gt.images:
|
140
|
+
# 获取该图像的GT标注
|
141
|
+
gt_anns = [
|
142
|
+
ann for ann in img_gt.annotations
|
143
|
+
if class_id is None or ann.category_id == class_id
|
144
|
+
]
|
145
|
+
|
146
|
+
# 获取该图像的预测结果
|
147
|
+
img_preds = filtered_preds.get(img_gt.image_id, [])
|
148
|
+
|
149
|
+
# 匹配预测和GT
|
150
|
+
for pred in img_preds:
|
151
|
+
best_iou = 0
|
152
|
+
best_gt_idx = -1
|
153
|
+
|
154
|
+
for i, gt_ann in enumerate(gt_anns):
|
155
|
+
if (img_gt.image_id, i) not in matched_gt:
|
156
|
+
iou = self._calculate_iou(pred.bbox, gt_ann.bbox)
|
157
|
+
if iou > best_iou:
|
158
|
+
best_iou = iou
|
159
|
+
best_gt_idx = i
|
160
|
+
|
161
|
+
if best_iou >= self.iou_threshold:
|
162
|
+
tp += 1
|
163
|
+
matched_gt.add((img_gt.image_id, best_gt_idx))
|
164
|
+
else:
|
165
|
+
fp += 1
|
166
|
+
|
167
|
+
# 统计未匹配的GT(False Negatives)
|
168
|
+
fn += len([
|
169
|
+
i for i in range(len(gt_anns))
|
170
|
+
if (img_gt.image_id, i) not in matched_gt
|
171
|
+
])
|
172
|
+
|
173
|
+
# 3. 计算指标
|
174
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
175
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
176
|
+
f1 = (2 * precision * recall / (precision + recall)
|
177
|
+
if (precision + recall) > 0 else 0.0)
|
178
|
+
|
179
|
+
return {
|
180
|
+
'tp': tp,
|
181
|
+
'fp': fp,
|
182
|
+
'fn': fn,
|
183
|
+
'precision': precision,
|
184
|
+
'recall': recall,
|
185
|
+
'f1': f1,
|
186
|
+
'positive_set_note': f"基于{len(self.positive_gt.images)}张正检图像"
|
187
|
+
}
|
188
|
+
|
189
|
+
def _calculate_fppi_metrics(
|
190
|
+
self,
|
191
|
+
confidence_threshold: float,
|
192
|
+
class_id: Optional[int] = None
|
193
|
+
) -> Dict:
|
194
|
+
"""
|
195
|
+
从误检集计算 FPPI (False Positives Per Image)
|
196
|
+
|
197
|
+
Args:
|
198
|
+
confidence_threshold: 置信度阈值
|
199
|
+
class_id: 指定类别ID
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
包含FPPI相关指标的字典
|
203
|
+
"""
|
204
|
+
# 1. 过滤预测结果
|
205
|
+
filtered_preds = self._filter_predictions(
|
206
|
+
self.negative_pred,
|
207
|
+
confidence_threshold,
|
208
|
+
class_id
|
209
|
+
)
|
210
|
+
|
211
|
+
# 2. 统计误检数
|
212
|
+
total_fp = 0
|
213
|
+
fp_per_image = []
|
214
|
+
|
215
|
+
for img_pred in self.negative_pred.images:
|
216
|
+
img_preds = filtered_preds.get(img_pred.image_id, [])
|
217
|
+
|
218
|
+
# 在误检集中,所有检测都是False Positive
|
219
|
+
# 但如果提供了negative_gt,可以更精确地判断
|
220
|
+
if self.negative_gt is not None:
|
221
|
+
# 找到对应的GT图像
|
222
|
+
gt_img = next(
|
223
|
+
(img for img in self.negative_gt.images
|
224
|
+
if img.image_id == img_pred.image_id),
|
225
|
+
None
|
226
|
+
)
|
227
|
+
|
228
|
+
if gt_img is not None:
|
229
|
+
# 获取GT标注
|
230
|
+
gt_anns = [
|
231
|
+
ann for ann in gt_img.annotations
|
232
|
+
if class_id is None or ann.category_id == class_id
|
233
|
+
]
|
234
|
+
|
235
|
+
# 匹配预测和GT,未匹配的是FP
|
236
|
+
matched = set()
|
237
|
+
fp_count = 0
|
238
|
+
for pred in img_preds:
|
239
|
+
is_matched = False
|
240
|
+
for i, gt_ann in enumerate(gt_anns):
|
241
|
+
if i not in matched:
|
242
|
+
iou = self._calculate_iou(pred.bbox, gt_ann.bbox)
|
243
|
+
if iou >= self.iou_threshold:
|
244
|
+
matched.add(i)
|
245
|
+
is_matched = True
|
246
|
+
break
|
247
|
+
if not is_matched:
|
248
|
+
fp_count += 1
|
249
|
+
|
250
|
+
total_fp += fp_count
|
251
|
+
fp_per_image.append(fp_count)
|
252
|
+
else:
|
253
|
+
# 没有对应的GT,所有检测都是FP
|
254
|
+
fp_count = len(img_preds)
|
255
|
+
total_fp += fp_count
|
256
|
+
fp_per_image.append(fp_count)
|
257
|
+
else:
|
258
|
+
# 没有提供negative_gt,假设所有检测都是FP
|
259
|
+
fp_count = len(img_preds)
|
260
|
+
total_fp += fp_count
|
261
|
+
fp_per_image.append(fp_count)
|
262
|
+
|
263
|
+
# 3. 计算FPPI
|
264
|
+
num_images = len(self.negative_pred.images)
|
265
|
+
fppi = total_fp / num_images if num_images > 0 else 0.0
|
266
|
+
|
267
|
+
return {
|
268
|
+
'fppi': fppi,
|
269
|
+
'total_false_positives': total_fp,
|
270
|
+
'negative_set_size': num_images,
|
271
|
+
'fppi_note': f"基于{num_images}张误检图像",
|
272
|
+
'max_fp_per_image': max(fp_per_image) if fp_per_image else 0,
|
273
|
+
'min_fp_per_image': min(fp_per_image) if fp_per_image else 0,
|
274
|
+
'avg_fp_per_image': total_fp / num_images if num_images > 0 else 0.0
|
275
|
+
}
|
276
|
+
|
277
|
+
def _filter_predictions(
|
278
|
+
self,
|
279
|
+
pred_dataset: Dataset,
|
280
|
+
confidence_threshold: float,
|
281
|
+
class_id: Optional[int] = None
|
282
|
+
) -> Dict[str, List[Annotation]]:
|
283
|
+
"""
|
284
|
+
根据置信度阈值和类别过滤预测结果
|
285
|
+
|
286
|
+
Args:
|
287
|
+
pred_dataset: 预测数据集
|
288
|
+
confidence_threshold: 置信度阈值
|
289
|
+
class_id: 指定类别ID
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
{image_id: [annotations]} 字典
|
293
|
+
"""
|
294
|
+
filtered = {}
|
295
|
+
for img in pred_dataset.images:
|
296
|
+
img_preds = [
|
297
|
+
ann for ann in img.annotations
|
298
|
+
if ann.confidence >= confidence_threshold
|
299
|
+
and (class_id is None or ann.category_id == class_id)
|
300
|
+
]
|
301
|
+
if img_preds:
|
302
|
+
filtered[img.image_id] = img_preds
|
303
|
+
return filtered
|
304
|
+
|
305
|
+
def _calculate_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
|
306
|
+
"""
|
307
|
+
计算两个边界框的IoU (Intersection over Union)
|
308
|
+
|
309
|
+
Args:
|
310
|
+
bbox1, bbox2: [x_min, y_min, width, height] 格式的边界框
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
IoU值 (0.0 到 1.0)
|
314
|
+
"""
|
315
|
+
# bbox格式: [x_min, y_min, width, height]
|
316
|
+
x1_min, y1_min, w1, h1 = bbox1
|
317
|
+
x2_min, y2_min, w2, h2 = bbox2
|
318
|
+
|
319
|
+
x1_max = x1_min + w1
|
320
|
+
y1_max = y1_min + h1
|
321
|
+
x2_max = x2_min + w2
|
322
|
+
y2_max = y2_min + h2
|
323
|
+
|
324
|
+
# 计算交集
|
325
|
+
inter_x_min = max(x1_min, x2_min)
|
326
|
+
inter_y_min = max(y1_min, y2_min)
|
327
|
+
inter_x_max = min(x1_max, x2_max)
|
328
|
+
inter_y_max = min(y1_max, y2_max)
|
329
|
+
|
330
|
+
if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
|
331
|
+
return 0.0
|
332
|
+
|
333
|
+
inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
|
334
|
+
|
335
|
+
# 计算并集
|
336
|
+
area1 = w1 * h1
|
337
|
+
area2 = w2 * h2
|
338
|
+
union_area = area1 + area2 - inter_area
|
339
|
+
|
340
|
+
return inter_area / union_area if union_area > 0 else 0.0
|
341
|
+
|
342
|
+
def calculate_pr_curve(
|
343
|
+
self,
|
344
|
+
thresholds: Optional[List[float]] = None,
|
345
|
+
class_id: Optional[int] = None
|
346
|
+
) -> List[Dict]:
|
347
|
+
"""
|
348
|
+
计算PR曲线(不同置信度阈值下的Precision-Recall)
|
349
|
+
|
350
|
+
Args:
|
351
|
+
thresholds: 要测试的置信度阈值列表,None则使用默认值
|
352
|
+
class_id: 指定类别ID
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
每个阈值对应的指标列表
|
356
|
+
"""
|
357
|
+
if thresholds is None:
|
358
|
+
thresholds = [i/10 for i in range(1, 10)] # 0.1 到 0.9
|
359
|
+
|
360
|
+
pr_points = []
|
361
|
+
for threshold in thresholds:
|
362
|
+
# 对同一份数据,使用不同阈值计算指标
|
363
|
+
metrics = self.calculate_metrics(
|
364
|
+
confidence_threshold=threshold,
|
365
|
+
class_id=class_id,
|
366
|
+
calculate_fppi=False # PR曲线不需要FPPI
|
367
|
+
)
|
368
|
+
pr_points.append({
|
369
|
+
'threshold': threshold,
|
370
|
+
'precision': metrics['precision'],
|
371
|
+
'recall': metrics['recall'],
|
372
|
+
'f1': metrics['f1'],
|
373
|
+
'tp': metrics['tp'],
|
374
|
+
'fp': metrics['fp'],
|
375
|
+
'fn': metrics['fn']
|
376
|
+
})
|
377
|
+
|
378
|
+
return pr_points
|
379
|
+
|
380
|
+
def find_optimal_threshold(
|
381
|
+
self,
|
382
|
+
metric: str = 'f1',
|
383
|
+
class_id: Optional[int] = None,
|
384
|
+
thresholds: Optional[List[float]] = None
|
385
|
+
) -> Dict:
|
386
|
+
"""
|
387
|
+
找到使指定指标最优的置信度阈值
|
388
|
+
|
389
|
+
Args:
|
390
|
+
metric: 优化目标指标 ('precision', 'recall', 'f1')
|
391
|
+
class_id: 指定类别ID
|
392
|
+
thresholds: 要测试的阈值列表
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
最优阈值及对应的所有指标
|
396
|
+
"""
|
397
|
+
if metric not in ['precision', 'recall', 'f1']:
|
398
|
+
raise ValueError(f"不支持的指标: {metric}")
|
399
|
+
|
400
|
+
pr_curve = self.calculate_pr_curve(thresholds, class_id)
|
401
|
+
|
402
|
+
# 找到指定指标最大的点
|
403
|
+
best_point = max(pr_curve, key=lambda x: x[metric])
|
404
|
+
|
405
|
+
return {
|
406
|
+
'optimal_threshold': best_point['threshold'],
|
407
|
+
'optimized_metric': metric,
|
408
|
+
'metrics': best_point
|
409
|
+
}
|
410
|
+
|
411
|
+
def find_threshold_with_constraint(
|
412
|
+
self,
|
413
|
+
target_metric: str,
|
414
|
+
constraint_metric: str,
|
415
|
+
constraint_value: float,
|
416
|
+
class_id: Optional[int] = None,
|
417
|
+
thresholds: Optional[List[float]] = None
|
418
|
+
) -> Optional[Dict]:
|
419
|
+
"""
|
420
|
+
在约束条件下找到最优阈值
|
421
|
+
|
422
|
+
例如:在FPPI < 0.01的约束下,找到Recall最高的阈值
|
423
|
+
|
424
|
+
Args:
|
425
|
+
target_metric: 要优化的目标指标 ('precision', 'recall', 'f1')
|
426
|
+
constraint_metric: 约束指标 ('fppi', 'precision', 'recall')
|
427
|
+
constraint_value: 约束值(如 fppi < 0.01)
|
428
|
+
class_id: 指定类别ID
|
429
|
+
thresholds: 要测试的阈值列表
|
430
|
+
|
431
|
+
Returns:
|
432
|
+
最优阈值及对应的指标,如果无满足约束的阈值则返回None
|
433
|
+
"""
|
434
|
+
if thresholds is None:
|
435
|
+
thresholds = [i/100 for i in range(1, 100)] # 0.01 到 0.99
|
436
|
+
|
437
|
+
best_threshold = None
|
438
|
+
best_value = 0
|
439
|
+
best_metrics = None
|
440
|
+
|
441
|
+
for threshold in thresholds:
|
442
|
+
metrics = self.calculate_metrics(
|
443
|
+
confidence_threshold=threshold,
|
444
|
+
class_id=class_id,
|
445
|
+
calculate_fppi=(constraint_metric == 'fppi')
|
446
|
+
)
|
447
|
+
|
448
|
+
# 检查约束条件
|
449
|
+
constraint_satisfied = False
|
450
|
+
if constraint_metric == 'fppi':
|
451
|
+
if metrics['fppi'] is not None and metrics['fppi'] <= constraint_value:
|
452
|
+
constraint_satisfied = True
|
453
|
+
elif constraint_metric in metrics:
|
454
|
+
if metrics[constraint_metric] >= constraint_value:
|
455
|
+
constraint_satisfied = True
|
456
|
+
|
457
|
+
# 如果满足约束,检查是否是最优的
|
458
|
+
if constraint_satisfied:
|
459
|
+
if metrics[target_metric] > best_value:
|
460
|
+
best_value = metrics[target_metric]
|
461
|
+
best_threshold = threshold
|
462
|
+
best_metrics = metrics
|
463
|
+
|
464
|
+
if best_threshold is None:
|
465
|
+
return None
|
466
|
+
|
467
|
+
return {
|
468
|
+
'optimal_threshold': best_threshold,
|
469
|
+
'target_metric': target_metric,
|
470
|
+
'target_value': best_value,
|
471
|
+
'constraint': f"{constraint_metric} <= {constraint_value}",
|
472
|
+
'metrics': best_metrics
|
473
|
+
}
|
474
|
+
|
475
|
+
def generate_report(
|
476
|
+
self,
|
477
|
+
confidence_threshold: float = 0.5,
|
478
|
+
class_id: Optional[int] = None
|
479
|
+
) -> str:
|
480
|
+
"""
|
481
|
+
生成评估报告
|
482
|
+
|
483
|
+
Args:
|
484
|
+
confidence_threshold: 置信度阈值
|
485
|
+
class_id: 指定类别ID
|
486
|
+
|
487
|
+
Returns:
|
488
|
+
格式化的评估报告字符串
|
489
|
+
"""
|
490
|
+
metrics = self.calculate_metrics(confidence_threshold, class_id)
|
491
|
+
|
492
|
+
class_name = "所有类别"
|
493
|
+
if class_id is not None and class_id in self.positive_gt.categories:
|
494
|
+
class_name = f"类别 {class_id} ({self.positive_gt.categories[class_id]})"
|
495
|
+
|
496
|
+
report = f"""
|
497
|
+
{'='*60}
|
498
|
+
评估报告 - {class_name}
|
499
|
+
{'='*60}
|
500
|
+
|
501
|
+
配置信息:
|
502
|
+
置信度阈值: {metrics['confidence_threshold']}
|
503
|
+
IoU阈值: {metrics['iou_threshold']}
|
504
|
+
正检集大小: {metrics['positive_set_size']} 张图像
|
505
|
+
误检集大小: {metrics['negative_set_size']} 张图像
|
506
|
+
|
507
|
+
正检集指标 (Precision & Recall):
|
508
|
+
True Positives (TP): {metrics['tp']}
|
509
|
+
False Positives (FP): {metrics['fp']}
|
510
|
+
False Negatives (FN): {metrics['fn']}
|
511
|
+
|
512
|
+
Precision: {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)
|
513
|
+
Recall: {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)
|
514
|
+
F1-Score: {metrics['f1']:.4f}
|
515
|
+
"""
|
516
|
+
|
517
|
+
if metrics['fppi'] is not None:
|
518
|
+
report += f"""
|
519
|
+
误检集指标 (FPPI):
|
520
|
+
FPPI (False Positives Per Image): {metrics['fppi']:.6f}
|
521
|
+
总误检数: {metrics.get('total_false_positives', 'N/A')}
|
522
|
+
平均每图误检数: {metrics.get('avg_fp_per_image', 'N/A'):.2f}
|
523
|
+
单图最大误检数: {metrics.get('max_fp_per_image', 'N/A')}
|
524
|
+
单图最小误检数: {metrics.get('min_fp_per_image', 'N/A')}
|
525
|
+
"""
|
526
|
+
else:
|
527
|
+
report += f"""
|
528
|
+
误检集指标 (FPPI):
|
529
|
+
{metrics['fppi_note']}
|
530
|
+
"""
|
531
|
+
|
532
|
+
report += f"\n{'='*60}\n"
|
533
|
+
|
534
|
+
return report
|
535
|
+
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dataset-toolkit
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: 一个用于加载、处理和导出计算机视觉数据集的工具包
|
5
5
|
Home-page: https://github.com/yourusername/dataset-toolkit
|
6
6
|
Author: wenxiang.han
|
@@ -42,6 +42,7 @@ Dynamic: requires-python
|
|
42
42
|
- 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
|
43
43
|
- 🛠️ **工具函数**:提供坐标转换等实用工具
|
44
44
|
- 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
|
45
|
+
- 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
|
45
46
|
|
46
47
|
## 📦 安装
|
47
48
|
|
@@ -121,6 +122,43 @@ result = (pipeline
|
|
121
122
|
.execute())
|
122
123
|
```
|
123
124
|
|
125
|
+
### 模型评估(v0.2.0+)
|
126
|
+
|
127
|
+
```python
|
128
|
+
from dataset_toolkit import (
|
129
|
+
load_yolo_from_local,
|
130
|
+
load_predictions_from_streamlined,
|
131
|
+
Evaluator
|
132
|
+
)
|
133
|
+
|
134
|
+
# 1. 加载GT和预测结果
|
135
|
+
gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
|
136
|
+
pred_dataset = load_predictions_from_streamlined(
|
137
|
+
"/results/predictions",
|
138
|
+
categories={0: 'parcel'},
|
139
|
+
image_dir="/data/test/images"
|
140
|
+
)
|
141
|
+
|
142
|
+
# 2. 创建评估器
|
143
|
+
evaluator = Evaluator(
|
144
|
+
positive_gt=gt_dataset,
|
145
|
+
positive_pred=pred_dataset,
|
146
|
+
iou_threshold=0.5
|
147
|
+
)
|
148
|
+
|
149
|
+
# 3. 计算指标
|
150
|
+
metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
|
151
|
+
print(f"Precision: {metrics['precision']:.4f}")
|
152
|
+
print(f"Recall: {metrics['recall']:.4f}")
|
153
|
+
print(f"F1-Score: {metrics['f1']:.4f}")
|
154
|
+
|
155
|
+
# 4. 寻找最优阈值
|
156
|
+
optimal = evaluator.find_optimal_threshold(metric='f1')
|
157
|
+
print(f"最优阈值: {optimal['optimal_threshold']}")
|
158
|
+
```
|
159
|
+
|
160
|
+
详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
|
161
|
+
|
124
162
|
## 📚 API 文档
|
125
163
|
|
126
164
|
### 数据加载器
|
@@ -19,10 +19,12 @@ dataset_toolkit/exporters/yolo_exporter.py
|
|
19
19
|
dataset_toolkit/loaders/__init__.py
|
20
20
|
dataset_toolkit/loaders/local_loader.py
|
21
21
|
dataset_toolkit/processors/__init__.py
|
22
|
+
dataset_toolkit/processors/evaluator.py
|
22
23
|
dataset_toolkit/processors/merger.py
|
23
24
|
dataset_toolkit/utils/__init__.py
|
24
25
|
dataset_toolkit/utils/coords.py
|
25
26
|
examples/basic_usage.py
|
27
|
+
examples/evaluation_example.py
|
26
28
|
tests/__init__.py
|
27
29
|
tests/conftest.py
|
28
30
|
tests/test_exporters.py
|
@@ -0,0 +1,250 @@
|
|
1
|
+
"""
|
2
|
+
评估系统使用示例
|
3
|
+
|
4
|
+
演示如何使用dataset_toolkit进行模型评估:
|
5
|
+
1. 加载GT数据集(正检集和误检集)
|
6
|
+
2. 加载预测结果
|
7
|
+
3. 计算评估指标(Precision, Recall, F1, FPPI)
|
8
|
+
4. 测试不同置信度阈值
|
9
|
+
5. 找到最优阈值
|
10
|
+
"""
|
11
|
+
|
12
|
+
import sys
|
13
|
+
from pathlib import Path
|
14
|
+
|
15
|
+
# 添加项目路径
|
16
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
17
|
+
|
18
|
+
from dataset_toolkit import (
|
19
|
+
load_yolo_from_local,
|
20
|
+
load_predictions_from_streamlined,
|
21
|
+
Evaluator
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
def main():
|
26
|
+
print("="*60)
|
27
|
+
print("评估系统使用示例")
|
28
|
+
print("="*60)
|
29
|
+
|
30
|
+
# ============================================================
|
31
|
+
# 1. 配置路径(请根据实际情况修改)
|
32
|
+
# ============================================================
|
33
|
+
|
34
|
+
# 正检集路径
|
35
|
+
positive_gt_path = "/opt/dlami/nvme/workspace_wenxiang/parcel/test_val/labels"
|
36
|
+
positive_pred_path = "/opt/dlami/nvme/workspace_wenxiang/ai_train/onnx_infer/detections/results/streamlined_test"
|
37
|
+
positive_image_path = "/opt/dlami/nvme/workspace_wenxiang/parcel/test_val/images"
|
38
|
+
|
39
|
+
# 误检集路径(可选)
|
40
|
+
negative_pred_path = None # 如果有误检集,设置路径
|
41
|
+
negative_image_path = None
|
42
|
+
|
43
|
+
# 类别映射
|
44
|
+
categories = {0: 'parcel'}
|
45
|
+
|
46
|
+
# ============================================================
|
47
|
+
# 2. 加载数据集
|
48
|
+
# ============================================================
|
49
|
+
|
50
|
+
print("\n步骤1: 加载数据集...")
|
51
|
+
print("-" * 60)
|
52
|
+
|
53
|
+
# 加载正检集GT
|
54
|
+
print("\n加载正检集GT...")
|
55
|
+
gt_positive = load_yolo_from_local(
|
56
|
+
positive_gt_path,
|
57
|
+
categories=categories
|
58
|
+
)
|
59
|
+
gt_positive.dataset_type = "gt"
|
60
|
+
gt_positive.metadata = {
|
61
|
+
"test_purpose": "positive",
|
62
|
+
"description": "包含目标物体的测试集"
|
63
|
+
}
|
64
|
+
print(f"✓ 正检集GT: {len(gt_positive.images)} 张图像")
|
65
|
+
|
66
|
+
# 加载正检集预测
|
67
|
+
print("\n加载正检集预测结果...")
|
68
|
+
pred_positive = load_predictions_from_streamlined(
|
69
|
+
positive_pred_path,
|
70
|
+
categories=categories,
|
71
|
+
image_dir=positive_image_path
|
72
|
+
)
|
73
|
+
pred_positive.dataset_type = "pred"
|
74
|
+
pred_positive.metadata = {
|
75
|
+
"test_purpose": "positive",
|
76
|
+
"model_name": "yolov8_parcel"
|
77
|
+
}
|
78
|
+
print(f"✓ 正检集Pred: {len(pred_positive.images)} 张图像")
|
79
|
+
|
80
|
+
# 加载误检集预测(如果有)
|
81
|
+
pred_negative = None
|
82
|
+
if negative_pred_path:
|
83
|
+
print("\n加载误检集预测结果...")
|
84
|
+
pred_negative = load_predictions_from_streamlined(
|
85
|
+
negative_pred_path,
|
86
|
+
categories=categories,
|
87
|
+
image_dir=negative_image_path
|
88
|
+
)
|
89
|
+
pred_negative.dataset_type = "pred"
|
90
|
+
pred_negative.metadata = {
|
91
|
+
"test_purpose": "negative",
|
92
|
+
"model_name": "yolov8_parcel"
|
93
|
+
}
|
94
|
+
print(f"✓ 误检集Pred: {len(pred_negative.images)} 张图像")
|
95
|
+
else:
|
96
|
+
print("\n未提供误检集,将只计算Precision/Recall/F1")
|
97
|
+
|
98
|
+
# ============================================================
|
99
|
+
# 3. 创建评估器
|
100
|
+
# ============================================================
|
101
|
+
|
102
|
+
print("\n步骤2: 创建评估器...")
|
103
|
+
print("-" * 60)
|
104
|
+
|
105
|
+
evaluator = Evaluator(
|
106
|
+
positive_gt=gt_positive,
|
107
|
+
positive_pred=pred_positive,
|
108
|
+
negative_pred=pred_negative,
|
109
|
+
iou_threshold=0.5
|
110
|
+
)
|
111
|
+
print("✓ 评估器创建成功")
|
112
|
+
|
113
|
+
# ============================================================
|
114
|
+
# 4. 计算单个阈值的指标
|
115
|
+
# ============================================================
|
116
|
+
|
117
|
+
print("\n步骤3: 计算评估指标(置信度阈值=0.5)...")
|
118
|
+
print("-" * 60)
|
119
|
+
|
120
|
+
metrics = evaluator.calculate_metrics(
|
121
|
+
confidence_threshold=0.5,
|
122
|
+
class_id=0 # 只评估parcel类别
|
123
|
+
)
|
124
|
+
|
125
|
+
print(f"\n正检集指标:")
|
126
|
+
print(f" TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']}")
|
127
|
+
print(f" Precision: {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)")
|
128
|
+
print(f" Recall: {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)")
|
129
|
+
print(f" F1-Score: {metrics['f1']:.4f}")
|
130
|
+
|
131
|
+
if metrics['fppi'] is not None:
|
132
|
+
print(f"\n误检集指标:")
|
133
|
+
print(f" FPPI: {metrics['fppi']:.6f}")
|
134
|
+
print(f" 总误检数: {metrics['total_false_positives']}")
|
135
|
+
|
136
|
+
# ============================================================
|
137
|
+
# 5. 测试多个阈值
|
138
|
+
# ============================================================
|
139
|
+
|
140
|
+
print("\n步骤4: 测试多个置信度阈值...")
|
141
|
+
print("-" * 60)
|
142
|
+
|
143
|
+
test_thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
|
144
|
+
|
145
|
+
print(f"\n{'阈值':<10} {'Precision':<12} {'Recall':<12} {'F1':<12} {'FPPI':<12}")
|
146
|
+
print("-" * 60)
|
147
|
+
|
148
|
+
for threshold in test_thresholds:
|
149
|
+
m = evaluator.calculate_metrics(
|
150
|
+
confidence_threshold=threshold,
|
151
|
+
class_id=0
|
152
|
+
)
|
153
|
+
fppi_str = f"{m['fppi']:.6f}" if m['fppi'] is not None else "N/A"
|
154
|
+
print(f"{threshold:<10.2f} {m['precision']:<12.4f} {m['recall']:<12.4f} "
|
155
|
+
f"{m['f1']:<12.4f} {fppi_str:<12}")
|
156
|
+
|
157
|
+
# ============================================================
|
158
|
+
# 6. 找到最优阈值
|
159
|
+
# ============================================================
|
160
|
+
|
161
|
+
print("\n步骤5: 寻找最优阈值...")
|
162
|
+
print("-" * 60)
|
163
|
+
|
164
|
+
# 找到F1最高的阈值
|
165
|
+
optimal = evaluator.find_optimal_threshold(
|
166
|
+
metric='f1',
|
167
|
+
class_id=0
|
168
|
+
)
|
169
|
+
|
170
|
+
print(f"\n最优阈值(F1最大):")
|
171
|
+
print(f" 阈值: {optimal['optimal_threshold']}")
|
172
|
+
print(f" Precision: {optimal['metrics']['precision']:.4f}")
|
173
|
+
print(f" Recall: {optimal['metrics']['recall']:.4f}")
|
174
|
+
print(f" F1-Score: {optimal['metrics']['f1']:.4f}")
|
175
|
+
|
176
|
+
# 如果有误检集,找到FPPI约束下的最优阈值
|
177
|
+
if pred_negative:
|
178
|
+
constrained = evaluator.find_threshold_with_constraint(
|
179
|
+
target_metric='recall',
|
180
|
+
constraint_metric='fppi',
|
181
|
+
constraint_value=0.01, # FPPI < 0.01
|
182
|
+
class_id=0
|
183
|
+
)
|
184
|
+
|
185
|
+
if constrained:
|
186
|
+
print(f"\n最优阈值(FPPI < 0.01约束下,Recall最大):")
|
187
|
+
print(f" 阈值: {constrained['optimal_threshold']}")
|
188
|
+
print(f" Recall: {constrained['target_value']:.4f}")
|
189
|
+
print(f" FPPI: {constrained['metrics']['fppi']:.6f}")
|
190
|
+
else:
|
191
|
+
print(f"\n警告: 无法找到满足 FPPI < 0.01 约束的阈值")
|
192
|
+
|
193
|
+
# ============================================================
|
194
|
+
# 7. 生成完整报告
|
195
|
+
# ============================================================
|
196
|
+
|
197
|
+
print("\n步骤6: 生成完整评估报告...")
|
198
|
+
print("-" * 60)
|
199
|
+
|
200
|
+
report = evaluator.generate_report(
|
201
|
+
confidence_threshold=0.5,
|
202
|
+
class_id=0
|
203
|
+
)
|
204
|
+
print(report)
|
205
|
+
|
206
|
+
# ============================================================
|
207
|
+
# 8. 计算PR曲线数据(可用于绘图)
|
208
|
+
# ============================================================
|
209
|
+
|
210
|
+
print("\n步骤7: 计算PR曲线数据...")
|
211
|
+
print("-" * 60)
|
212
|
+
|
213
|
+
pr_curve = evaluator.calculate_pr_curve(
|
214
|
+
thresholds=[i/10 for i in range(1, 10)],
|
215
|
+
class_id=0
|
216
|
+
)
|
217
|
+
|
218
|
+
print(f"\nPR曲线数据点: {len(pr_curve)} 个")
|
219
|
+
print(f"{'阈值':<10} {'Precision':<12} {'Recall':<12}")
|
220
|
+
print("-" * 40)
|
221
|
+
for point in pr_curve[:5]: # 只显示前5个
|
222
|
+
print(f"{point['threshold']:<10.2f} {point['precision']:<12.4f} {point['recall']:<12.4f}")
|
223
|
+
print("...")
|
224
|
+
|
225
|
+
print("\n" + "="*60)
|
226
|
+
print("评估完成!")
|
227
|
+
print("="*60)
|
228
|
+
|
229
|
+
# 可以将PR曲线数据保存或绘图
|
230
|
+
# import matplotlib.pyplot as plt
|
231
|
+
# precisions = [p['precision'] for p in pr_curve]
|
232
|
+
# recalls = [p['recall'] for p in pr_curve]
|
233
|
+
# plt.plot(recalls, precisions)
|
234
|
+
# plt.xlabel('Recall')
|
235
|
+
# plt.ylabel('Precision')
|
236
|
+
# plt.title('PR Curve')
|
237
|
+
# plt.savefig('pr_curve.png')
|
238
|
+
|
239
|
+
|
240
|
+
if __name__ == '__main__':
|
241
|
+
try:
|
242
|
+
main()
|
243
|
+
except FileNotFoundError as e:
|
244
|
+
print(f"\n错误: {e}")
|
245
|
+
print("\n请修改脚本中的路径配置,指向实际的数据集位置。")
|
246
|
+
except Exception as e:
|
247
|
+
print(f"\n发生错误: {e}")
|
248
|
+
import traceback
|
249
|
+
traceback.print_exc()
|
250
|
+
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|