dataset-toolkit 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dataset_toolkit-0.1.2/dataset_toolkit.egg-info → dataset_toolkit-0.2.0}/PKG-INFO +39 -1
  2. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/README.md +38 -0
  3. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/__init__.py +9 -2
  4. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/yolo_exporter.py +4 -2
  5. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/loaders/local_loader.py +145 -0
  6. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/models.py +3 -1
  7. dataset_toolkit-0.2.0/dataset_toolkit/processors/__init__.py +9 -0
  8. dataset_toolkit-0.2.0/dataset_toolkit/processors/evaluator.py +535 -0
  9. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0/dataset_toolkit.egg-info}/PKG-INFO +39 -1
  10. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/SOURCES.txt +2 -0
  11. dataset_toolkit-0.2.0/examples/evaluation_example.py +250 -0
  12. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/pyproject.toml +1 -1
  13. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/setup.py +1 -1
  14. dataset_toolkit-0.1.2/dataset_toolkit/utils/__init__.py +0 -0
  15. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/LICENSE +0 -0
  16. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/MANIFEST.in +0 -0
  17. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/__init__.py +0 -0
  18. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/coco_exporter.py +0 -0
  19. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/exporters/txt_exporter.py +0 -0
  20. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/loaders/__init__.py +0 -0
  21. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/pipeline.py +0 -0
  22. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/processors/merger.py +0 -0
  23. {dataset_toolkit-0.1.2/dataset_toolkit/processors → dataset_toolkit-0.2.0/dataset_toolkit/utils}/__init__.py +0 -0
  24. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit/utils/coords.py +0 -0
  25. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/dependency_links.txt +0 -0
  26. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/requires.txt +0 -0
  27. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/dataset_toolkit.egg-info/top_level.txt +0 -0
  28. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/examples/basic_usage.py +0 -0
  29. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/requirements.txt +0 -0
  30. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/setup.cfg +0 -0
  31. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/__init__.py +0 -0
  32. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/conftest.py +0 -0
  33. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_exporters.py +0 -0
  34. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_loaders.py +0 -0
  35. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_processors.py +0 -0
  36. {dataset_toolkit-0.1.2 → dataset_toolkit-0.2.0}/tests/test_pypi_test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dataset-toolkit
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: 一个用于加载、处理和导出计算机视觉数据集的工具包
5
5
  Home-page: https://github.com/yourusername/dataset-toolkit
6
6
  Author: wenxiang.han
@@ -42,6 +42,7 @@ Dynamic: requires-python
42
42
  - 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
43
43
  - 🛠️ **工具函数**:提供坐标转换等实用工具
44
44
  - 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
45
+ - 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
45
46
 
46
47
  ## 📦 安装
47
48
 
@@ -121,6 +122,43 @@ result = (pipeline
121
122
  .execute())
122
123
  ```
123
124
 
125
+ ### 模型评估(v0.2.0+)
126
+
127
+ ```python
128
+ from dataset_toolkit import (
129
+ load_yolo_from_local,
130
+ load_predictions_from_streamlined,
131
+ Evaluator
132
+ )
133
+
134
+ # 1. 加载GT和预测结果
135
+ gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
136
+ pred_dataset = load_predictions_from_streamlined(
137
+ "/results/predictions",
138
+ categories={0: 'parcel'},
139
+ image_dir="/data/test/images"
140
+ )
141
+
142
+ # 2. 创建评估器
143
+ evaluator = Evaluator(
144
+ positive_gt=gt_dataset,
145
+ positive_pred=pred_dataset,
146
+ iou_threshold=0.5
147
+ )
148
+
149
+ # 3. 计算指标
150
+ metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
151
+ print(f"Precision: {metrics['precision']:.4f}")
152
+ print(f"Recall: {metrics['recall']:.4f}")
153
+ print(f"F1-Score: {metrics['f1']:.4f}")
154
+
155
+ # 4. 寻找最优阈值
156
+ optimal = evaluator.find_optimal_threshold(metric='f1')
157
+ print(f"最优阈值: {optimal['optimal_threshold']}")
158
+ ```
159
+
160
+ 详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
161
+
124
162
  ## 📚 API 文档
125
163
 
126
164
  ### 数据加载器
@@ -9,6 +9,7 @@
9
9
  - 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
10
10
  - 🛠️ **工具函数**:提供坐标转换等实用工具
11
11
  - 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
12
+ - 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
12
13
 
13
14
  ## 📦 安装
14
15
 
@@ -88,6 +89,43 @@ result = (pipeline
88
89
  .execute())
89
90
  ```
90
91
 
92
+ ### 模型评估(v0.2.0+)
93
+
94
+ ```python
95
+ from dataset_toolkit import (
96
+ load_yolo_from_local,
97
+ load_predictions_from_streamlined,
98
+ Evaluator
99
+ )
100
+
101
+ # 1. 加载GT和预测结果
102
+ gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
103
+ pred_dataset = load_predictions_from_streamlined(
104
+ "/results/predictions",
105
+ categories={0: 'parcel'},
106
+ image_dir="/data/test/images"
107
+ )
108
+
109
+ # 2. 创建评估器
110
+ evaluator = Evaluator(
111
+ positive_gt=gt_dataset,
112
+ positive_pred=pred_dataset,
113
+ iou_threshold=0.5
114
+ )
115
+
116
+ # 3. 计算指标
117
+ metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
118
+ print(f"Precision: {metrics['precision']:.4f}")
119
+ print(f"Recall: {metrics['recall']:.4f}")
120
+ print(f"F1-Score: {metrics['f1']:.4f}")
121
+
122
+ # 4. 寻找最优阈值
123
+ optimal = evaluator.find_optimal_threshold(metric='f1')
124
+ print(f"最优阈值: {optimal['optimal_threshold']}")
125
+ ```
126
+
127
+ 详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
128
+
91
129
  ## 📚 API 文档
92
130
 
93
131
  ### 数据加载器
@@ -15,7 +15,7 @@ Dataset Toolkit - 计算机视觉数据集处理工具包
15
15
  >>> export_to_coco(dataset, "output.json")
16
16
  """
17
17
 
18
- __version__ = "0.1.2"
18
+ __version__ = "0.2.0"
19
19
  __author__ = "wenxiang.han"
20
20
  __email__ = "wenxiang.han@anker-in.com"
21
21
 
@@ -28,13 +28,18 @@ from dataset_toolkit.models import (
28
28
 
29
29
  from dataset_toolkit.loaders.local_loader import (
30
30
  load_yolo_from_local,
31
- load_csv_result_from_local
31
+ load_csv_result_from_local,
32
+ load_predictions_from_streamlined
32
33
  )
33
34
 
34
35
  from dataset_toolkit.processors.merger import (
35
36
  merge_datasets
36
37
  )
37
38
 
39
+ from dataset_toolkit.processors.evaluator import (
40
+ Evaluator
41
+ )
42
+
38
43
  from dataset_toolkit.exporters.coco_exporter import (
39
44
  export_to_coco
40
45
  )
@@ -69,9 +74,11 @@ __all__ = [
69
74
  # 加载器
70
75
  "load_yolo_from_local",
71
76
  "load_csv_result_from_local",
77
+ "load_predictions_from_streamlined",
72
78
 
73
79
  # 处理器
74
80
  "merge_datasets",
81
+ "Evaluator",
75
82
 
76
83
  # 导出器
77
84
  "export_to_coco",
@@ -146,8 +146,10 @@ def export_to_yolo_and_txt(
146
146
  rel_path = os.path.relpath(img_in_yolo, txt_path.parent)
147
147
  f.write(f"{rel_path}\n")
148
148
  else:
149
- # 绝对路径(指向 YOLO images 目录,不要 resolve,保持 YOLO 结构)
150
- f.write(f"{str(img_in_yolo.absolute())}\n")
149
+ # 绝对路径(规范化但不解析软链接)
150
+ # 使用 os.path.normpath 规范化路径,去除 .. 等
151
+ normalized_path = os.path.normpath(str(img_in_yolo.absolute()))
152
+ f.write(f"{normalized_path}\n")
151
153
 
152
154
  print(f"✓ txt 列表已生成: {len(dataset.images)} 行")
153
155
 
@@ -186,4 +186,149 @@ def load_csv_result_from_local(dataset_path: str, categories: Dict[int, str] = N
186
186
 
187
187
  print(f"加载完成. 共找到 {image_count} 张图片, {len(dataset.categories)} 个类别.")
188
188
  print(f"类别映射: {dataset.categories}")
189
+ return dataset
190
+
191
+
192
+ def load_predictions_from_streamlined(
193
+ predictions_dir: str,
194
+ categories: Dict[int, str],
195
+ image_dir: str = None
196
+ ) -> Dataset:
197
+ """
198
+ 从streamlined推理结果目录加载预测数据集。
199
+
200
+ 预测文件格式(每行一个检测):
201
+ class_id,confidence,center_x,center_y,width,height
202
+ 例如: 0,0.934679,354.00,388.00,274.00,102.00
203
+
204
+ 参数:
205
+ predictions_dir: 预测结果txt文件所在目录
206
+ categories: 类别映射字典 {class_id: class_name}
207
+ image_dir: 图像目录(可选,用于读取图像尺寸)
208
+ 如果不提供,将尝试从预测文件同级目录查找
209
+
210
+ 返回:
211
+ Dataset: 预测数据集对象,dataset_type='pred'
212
+ """
213
+ pred_path = Path(predictions_dir)
214
+
215
+ if not pred_path.is_dir():
216
+ raise FileNotFoundError(f"预测结果目录不存在: {pred_path}")
217
+
218
+ # 尝试自动查找图像目录
219
+ if image_dir is None:
220
+ # 尝试常见的图像目录位置
221
+ possible_image_dirs = [
222
+ pred_path.parent / 'images',
223
+ pred_path.parent.parent / 'images',
224
+ ]
225
+ for possible_dir in possible_image_dirs:
226
+ if possible_dir.is_dir():
227
+ image_dir = str(possible_dir)
228
+ print(f"自动找到图像目录: {image_dir}")
229
+ break
230
+
231
+ dataset = Dataset(
232
+ name=pred_path.name,
233
+ categories=categories,
234
+ dataset_type="pred"
235
+ )
236
+
237
+ supported_extensions = ['.jpg', '.jpeg', '.png']
238
+ txt_files = list(pred_path.glob('*.txt'))
239
+
240
+ print(f"开始加载预测结果: {pred_path.name}...")
241
+ print(f"找到 {len(txt_files)} 个预测文件")
242
+
243
+ loaded_count = 0
244
+ skipped_count = 0
245
+
246
+ for txt_file in txt_files:
247
+ # 预测文件名对应的图像文件名(假设同名)
248
+ image_base_name = txt_file.stem
249
+
250
+ # 尝试查找对应的图像文件
251
+ image_path = None
252
+ img_width, img_height = None, None
253
+
254
+ if image_dir:
255
+ image_dir_path = Path(image_dir)
256
+ for ext in supported_extensions:
257
+ potential_image = image_dir_path / (image_base_name + ext)
258
+ if potential_image.exists():
259
+ image_path = str(potential_image.resolve())
260
+ try:
261
+ with Image.open(potential_image) as img:
262
+ img_width, img_height = img.size
263
+ except IOError:
264
+ print(f"警告: 无法打开图片 {potential_image}")
265
+ break
266
+
267
+ # 如果没有找到图像,使用默认值
268
+ if image_path is None:
269
+ # 假设一个默认的图像路径和尺寸
270
+ image_path = f"unknown/{image_base_name}.jpg"
271
+ img_width, img_height = 640, 640 # 默认尺寸
272
+ if image_dir:
273
+ skipped_count += 1
274
+
275
+ # 创建图像标注对象
276
+ image_annotation = ImageAnnotation(
277
+ image_id=image_base_name + '.jpg',
278
+ path=image_path,
279
+ width=img_width,
280
+ height=img_height
281
+ )
282
+
283
+ # 读取预测结果
284
+ try:
285
+ with open(txt_file, 'r') as f:
286
+ for line in f:
287
+ line = line.strip()
288
+ if not line:
289
+ continue
290
+
291
+ # 解析格式: class_id,confidence,center_x,center_y,width,height
292
+ parts = line.split(',')
293
+ if len(parts) != 6:
294
+ print(f"警告: 格式错误,已跳过: {txt_file} -> '{line}'")
295
+ continue
296
+
297
+ try:
298
+ class_id = int(parts[0])
299
+ confidence = float(parts[1])
300
+ center_x = float(parts[2])
301
+ center_y = float(parts[3])
302
+ width = float(parts[4])
303
+ height = float(parts[5])
304
+
305
+ # 转换为 [x_min, y_min, width, height] 格式
306
+ x_min = center_x - width / 2
307
+ y_min = center_y - height / 2
308
+
309
+ annotation = Annotation(
310
+ category_id=class_id,
311
+ bbox=[x_min, y_min, width, height],
312
+ confidence=confidence
313
+ )
314
+ image_annotation.annotations.append(annotation)
315
+
316
+ except (ValueError, IndexError) as e:
317
+ print(f"警告: 解析错误,已跳过: {txt_file} -> '{line}' ({e})")
318
+ continue
319
+
320
+ except Exception as e:
321
+ print(f"警告: 读取文件失败,已跳过: {txt_file} ({e})")
322
+ continue
323
+
324
+ dataset.images.append(image_annotation)
325
+ loaded_count += 1
326
+
327
+ print(f"加载完成. 成功加载 {loaded_count} 个预测文件")
328
+ if skipped_count > 0:
329
+ print(f"警告: {skipped_count} 个文件未找到对应图像,使用默认尺寸")
330
+
331
+ total_detections = sum(len(img.annotations) for img in dataset.images)
332
+ print(f"总检测数: {total_detections}")
333
+
189
334
  return dataset
@@ -24,4 +24,6 @@ class Dataset:
24
24
  """代表一个完整的数据集对象,作为系统内部的标准化表示."""
25
25
  name: str
26
26
  images: List[ImageAnnotation] = field(default_factory=list)
27
- categories: Dict[int, str] = field(default_factory=dict)
27
+ categories: Dict[int, str] = field(default_factory=dict)
28
+ dataset_type: str = "train" # 'train', 'gt', 'pred'
29
+ metadata: Dict = field(default_factory=dict) # 存储描述性信息,不包含处理参数
@@ -0,0 +1,9 @@
1
+ # dataset_toolkit/processors/__init__.py
2
+ from .merger import merge_datasets
3
+ from .evaluator import Evaluator
4
+
5
+ __all__ = [
6
+ 'merge_datasets',
7
+ 'Evaluator',
8
+ ]
9
+
@@ -0,0 +1,535 @@
1
+ # dataset_toolkit/processors/evaluator.py
2
+ from typing import Dict, List, Optional, Tuple
3
+ from dataset_toolkit.models import Dataset, Annotation, ImageAnnotation
4
+
5
+
6
+ class Evaluator:
7
+ """
8
+ 评估器:支持正检集和误检集分离评估
9
+
10
+ 用于比较GT和Pred数据集,计算Precision、Recall、F1、FPPI等指标。
11
+ 支持在不同置信度阈值下动态评估,无需重新加载数据。
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ positive_gt: Dataset,
17
+ positive_pred: Dataset,
18
+ negative_gt: Optional[Dataset] = None,
19
+ negative_pred: Optional[Dataset] = None,
20
+ iou_threshold: float = 0.5
21
+ ):
22
+ """
23
+ 初始化评估器
24
+
25
+ Args:
26
+ positive_gt: 正检集GT(必需)- 包含目标物体的测试集
27
+ positive_pred: 正检集预测(必需)- 对正检集的预测结果
28
+ negative_gt: 误检集GT(可选)- 不包含目标的背景图像GT
29
+ negative_pred: 误检集预测(可选)- 对误检集的预测结果
30
+ iou_threshold: IoU阈值,用于判断检测是否匹配GT(默认0.5)
31
+ """
32
+ self.positive_gt = positive_gt
33
+ self.positive_pred = positive_pred
34
+ self.negative_gt = negative_gt
35
+ self.negative_pred = negative_pred
36
+ self.iou_threshold = iou_threshold
37
+
38
+ # 验证数据集
39
+ self._validate_datasets()
40
+
41
+ def _validate_datasets(self):
42
+ """验证数据集的有效性"""
43
+ if self.positive_gt is None or self.positive_pred is None:
44
+ raise ValueError("正检集的GT和Pred是必需的")
45
+
46
+ if len(self.positive_gt.images) == 0:
47
+ raise ValueError("正检集GT为空")
48
+
49
+ if len(self.positive_pred.images) == 0:
50
+ raise ValueError("正检集Pred为空")
51
+
52
+ # 检查类别是否一致
53
+ if self.positive_gt.categories != self.positive_pred.categories:
54
+ print("警告: GT和Pred的类别映射不一致")
55
+ print(f" GT categories: {self.positive_gt.categories}")
56
+ print(f" Pred categories: {self.positive_pred.categories}")
57
+
58
+ def calculate_metrics(
59
+ self,
60
+ confidence_threshold: float = 0.5,
61
+ class_id: Optional[int] = None,
62
+ calculate_fppi: bool = True
63
+ ) -> Dict:
64
+ """
65
+ 计算综合评估指标
66
+
67
+ Args:
68
+ confidence_threshold: 置信度阈值(动态传入,不存储在数据集中)
69
+ class_id: 指定类别ID,None表示所有类别
70
+ calculate_fppi: 是否计算FPPI(需要negative_pred)
71
+
72
+ Returns:
73
+ 包含所有指标的字典:
74
+ - tp, fp, fn: True/False Positives/Negatives
75
+ - precision, recall, f1: 精确率、召回率、F1分数
76
+ - fppi: False Positives Per Image(如果计算)
77
+ - confidence_threshold, iou_threshold: 使用的阈值
78
+ - positive_set_size, negative_set_size: 数据集大小
79
+ """
80
+ metrics = {}
81
+
82
+ # 1. 从正检集计算 Precision, Recall, F1
83
+ positive_metrics = self._calculate_positive_metrics(
84
+ confidence_threshold, class_id
85
+ )
86
+ metrics.update(positive_metrics)
87
+
88
+ # 2. 从误检集计算 FPPI
89
+ if calculate_fppi and self.negative_pred is not None:
90
+ fppi_metrics = self._calculate_fppi_metrics(
91
+ confidence_threshold, class_id
92
+ )
93
+ metrics.update(fppi_metrics)
94
+ else:
95
+ metrics['fppi'] = None
96
+ metrics['fppi_note'] = "未提供误检集" if self.negative_pred is None else "未计算FPPI"
97
+
98
+ # 3. 添加配置信息
99
+ metrics['confidence_threshold'] = confidence_threshold
100
+ metrics['iou_threshold'] = self.iou_threshold
101
+ metrics['positive_set_size'] = len(self.positive_gt.images)
102
+ metrics['negative_set_size'] = (
103
+ len(self.negative_pred.images)
104
+ if self.negative_pred else 0
105
+ )
106
+
107
+ return metrics
108
+
109
+ def _calculate_positive_metrics(
110
+ self,
111
+ confidence_threshold: float,
112
+ class_id: Optional[int] = None
113
+ ) -> Dict:
114
+ """
115
+ 从正检集计算 Precision, Recall, F1
116
+
117
+ Args:
118
+ confidence_threshold: 置信度阈值
119
+ class_id: 指定类别ID
120
+
121
+ Returns:
122
+ 包含TP, FP, FN, Precision, Recall, F1的字典
123
+ """
124
+ # 1. 过滤预测结果
125
+ filtered_preds = self._filter_predictions(
126
+ self.positive_pred,
127
+ confidence_threshold,
128
+ class_id
129
+ )
130
+
131
+ # 2. 匹配GT和Pred
132
+ tp = 0 # True Positives
133
+ fp = 0 # False Positives
134
+ fn = 0 # False Negatives
135
+
136
+ matched_gt = set() # 记录已匹配的GT,格式: (image_id, gt_index)
137
+
138
+ # 遍历每张图像
139
+ for img_gt in self.positive_gt.images:
140
+ # 获取该图像的GT标注
141
+ gt_anns = [
142
+ ann for ann in img_gt.annotations
143
+ if class_id is None or ann.category_id == class_id
144
+ ]
145
+
146
+ # 获取该图像的预测结果
147
+ img_preds = filtered_preds.get(img_gt.image_id, [])
148
+
149
+ # 匹配预测和GT
150
+ for pred in img_preds:
151
+ best_iou = 0
152
+ best_gt_idx = -1
153
+
154
+ for i, gt_ann in enumerate(gt_anns):
155
+ if (img_gt.image_id, i) not in matched_gt:
156
+ iou = self._calculate_iou(pred.bbox, gt_ann.bbox)
157
+ if iou > best_iou:
158
+ best_iou = iou
159
+ best_gt_idx = i
160
+
161
+ if best_iou >= self.iou_threshold:
162
+ tp += 1
163
+ matched_gt.add((img_gt.image_id, best_gt_idx))
164
+ else:
165
+ fp += 1
166
+
167
+ # 统计未匹配的GT(False Negatives)
168
+ fn += len([
169
+ i for i in range(len(gt_anns))
170
+ if (img_gt.image_id, i) not in matched_gt
171
+ ])
172
+
173
+ # 3. 计算指标
174
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
175
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
176
+ f1 = (2 * precision * recall / (precision + recall)
177
+ if (precision + recall) > 0 else 0.0)
178
+
179
+ return {
180
+ 'tp': tp,
181
+ 'fp': fp,
182
+ 'fn': fn,
183
+ 'precision': precision,
184
+ 'recall': recall,
185
+ 'f1': f1,
186
+ 'positive_set_note': f"基于{len(self.positive_gt.images)}张正检图像"
187
+ }
188
+
189
+ def _calculate_fppi_metrics(
190
+ self,
191
+ confidence_threshold: float,
192
+ class_id: Optional[int] = None
193
+ ) -> Dict:
194
+ """
195
+ 从误检集计算 FPPI (False Positives Per Image)
196
+
197
+ Args:
198
+ confidence_threshold: 置信度阈值
199
+ class_id: 指定类别ID
200
+
201
+ Returns:
202
+ 包含FPPI相关指标的字典
203
+ """
204
+ # 1. 过滤预测结果
205
+ filtered_preds = self._filter_predictions(
206
+ self.negative_pred,
207
+ confidence_threshold,
208
+ class_id
209
+ )
210
+
211
+ # 2. 统计误检数
212
+ total_fp = 0
213
+ fp_per_image = []
214
+
215
+ for img_pred in self.negative_pred.images:
216
+ img_preds = filtered_preds.get(img_pred.image_id, [])
217
+
218
+ # 在误检集中,所有检测都是False Positive
219
+ # 但如果提供了negative_gt,可以更精确地判断
220
+ if self.negative_gt is not None:
221
+ # 找到对应的GT图像
222
+ gt_img = next(
223
+ (img for img in self.negative_gt.images
224
+ if img.image_id == img_pred.image_id),
225
+ None
226
+ )
227
+
228
+ if gt_img is not None:
229
+ # 获取GT标注
230
+ gt_anns = [
231
+ ann for ann in gt_img.annotations
232
+ if class_id is None or ann.category_id == class_id
233
+ ]
234
+
235
+ # 匹配预测和GT,未匹配的是FP
236
+ matched = set()
237
+ fp_count = 0
238
+ for pred in img_preds:
239
+ is_matched = False
240
+ for i, gt_ann in enumerate(gt_anns):
241
+ if i not in matched:
242
+ iou = self._calculate_iou(pred.bbox, gt_ann.bbox)
243
+ if iou >= self.iou_threshold:
244
+ matched.add(i)
245
+ is_matched = True
246
+ break
247
+ if not is_matched:
248
+ fp_count += 1
249
+
250
+ total_fp += fp_count
251
+ fp_per_image.append(fp_count)
252
+ else:
253
+ # 没有对应的GT,所有检测都是FP
254
+ fp_count = len(img_preds)
255
+ total_fp += fp_count
256
+ fp_per_image.append(fp_count)
257
+ else:
258
+ # 没有提供negative_gt,假设所有检测都是FP
259
+ fp_count = len(img_preds)
260
+ total_fp += fp_count
261
+ fp_per_image.append(fp_count)
262
+
263
+ # 3. 计算FPPI
264
+ num_images = len(self.negative_pred.images)
265
+ fppi = total_fp / num_images if num_images > 0 else 0.0
266
+
267
+ return {
268
+ 'fppi': fppi,
269
+ 'total_false_positives': total_fp,
270
+ 'negative_set_size': num_images,
271
+ 'fppi_note': f"基于{num_images}张误检图像",
272
+ 'max_fp_per_image': max(fp_per_image) if fp_per_image else 0,
273
+ 'min_fp_per_image': min(fp_per_image) if fp_per_image else 0,
274
+ 'avg_fp_per_image': total_fp / num_images if num_images > 0 else 0.0
275
+ }
276
+
277
+ def _filter_predictions(
278
+ self,
279
+ pred_dataset: Dataset,
280
+ confidence_threshold: float,
281
+ class_id: Optional[int] = None
282
+ ) -> Dict[str, List[Annotation]]:
283
+ """
284
+ 根据置信度阈值和类别过滤预测结果
285
+
286
+ Args:
287
+ pred_dataset: 预测数据集
288
+ confidence_threshold: 置信度阈值
289
+ class_id: 指定类别ID
290
+
291
+ Returns:
292
+ {image_id: [annotations]} 字典
293
+ """
294
+ filtered = {}
295
+ for img in pred_dataset.images:
296
+ img_preds = [
297
+ ann for ann in img.annotations
298
+ if ann.confidence >= confidence_threshold
299
+ and (class_id is None or ann.category_id == class_id)
300
+ ]
301
+ if img_preds:
302
+ filtered[img.image_id] = img_preds
303
+ return filtered
304
+
305
+ def _calculate_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
306
+ """
307
+ 计算两个边界框的IoU (Intersection over Union)
308
+
309
+ Args:
310
+ bbox1, bbox2: [x_min, y_min, width, height] 格式的边界框
311
+
312
+ Returns:
313
+ IoU值 (0.0 到 1.0)
314
+ """
315
+ # bbox格式: [x_min, y_min, width, height]
316
+ x1_min, y1_min, w1, h1 = bbox1
317
+ x2_min, y2_min, w2, h2 = bbox2
318
+
319
+ x1_max = x1_min + w1
320
+ y1_max = y1_min + h1
321
+ x2_max = x2_min + w2
322
+ y2_max = y2_min + h2
323
+
324
+ # 计算交集
325
+ inter_x_min = max(x1_min, x2_min)
326
+ inter_y_min = max(y1_min, y2_min)
327
+ inter_x_max = min(x1_max, x2_max)
328
+ inter_y_max = min(y1_max, y2_max)
329
+
330
+ if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
331
+ return 0.0
332
+
333
+ inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
334
+
335
+ # 计算并集
336
+ area1 = w1 * h1
337
+ area2 = w2 * h2
338
+ union_area = area1 + area2 - inter_area
339
+
340
+ return inter_area / union_area if union_area > 0 else 0.0
341
+
342
+ def calculate_pr_curve(
343
+ self,
344
+ thresholds: Optional[List[float]] = None,
345
+ class_id: Optional[int] = None
346
+ ) -> List[Dict]:
347
+ """
348
+ 计算PR曲线(不同置信度阈值下的Precision-Recall)
349
+
350
+ Args:
351
+ thresholds: 要测试的置信度阈值列表,None则使用默认值
352
+ class_id: 指定类别ID
353
+
354
+ Returns:
355
+ 每个阈值对应的指标列表
356
+ """
357
+ if thresholds is None:
358
+ thresholds = [i/10 for i in range(1, 10)] # 0.1 到 0.9
359
+
360
+ pr_points = []
361
+ for threshold in thresholds:
362
+ # 对同一份数据,使用不同阈值计算指标
363
+ metrics = self.calculate_metrics(
364
+ confidence_threshold=threshold,
365
+ class_id=class_id,
366
+ calculate_fppi=False # PR曲线不需要FPPI
367
+ )
368
+ pr_points.append({
369
+ 'threshold': threshold,
370
+ 'precision': metrics['precision'],
371
+ 'recall': metrics['recall'],
372
+ 'f1': metrics['f1'],
373
+ 'tp': metrics['tp'],
374
+ 'fp': metrics['fp'],
375
+ 'fn': metrics['fn']
376
+ })
377
+
378
+ return pr_points
379
+
380
+ def find_optimal_threshold(
381
+ self,
382
+ metric: str = 'f1',
383
+ class_id: Optional[int] = None,
384
+ thresholds: Optional[List[float]] = None
385
+ ) -> Dict:
386
+ """
387
+ 找到使指定指标最优的置信度阈值
388
+
389
+ Args:
390
+ metric: 优化目标指标 ('precision', 'recall', 'f1')
391
+ class_id: 指定类别ID
392
+ thresholds: 要测试的阈值列表
393
+
394
+ Returns:
395
+ 最优阈值及对应的所有指标
396
+ """
397
+ if metric not in ['precision', 'recall', 'f1']:
398
+ raise ValueError(f"不支持的指标: {metric}")
399
+
400
+ pr_curve = self.calculate_pr_curve(thresholds, class_id)
401
+
402
+ # 找到指定指标最大的点
403
+ best_point = max(pr_curve, key=lambda x: x[metric])
404
+
405
+ return {
406
+ 'optimal_threshold': best_point['threshold'],
407
+ 'optimized_metric': metric,
408
+ 'metrics': best_point
409
+ }
410
+
411
+ def find_threshold_with_constraint(
412
+ self,
413
+ target_metric: str,
414
+ constraint_metric: str,
415
+ constraint_value: float,
416
+ class_id: Optional[int] = None,
417
+ thresholds: Optional[List[float]] = None
418
+ ) -> Optional[Dict]:
419
+ """
420
+ 在约束条件下找到最优阈值
421
+
422
+ 例如:在FPPI < 0.01的约束下,找到Recall最高的阈值
423
+
424
+ Args:
425
+ target_metric: 要优化的目标指标 ('precision', 'recall', 'f1')
426
+ constraint_metric: 约束指标 ('fppi', 'precision', 'recall')
427
+ constraint_value: 约束值(如 fppi < 0.01)
428
+ class_id: 指定类别ID
429
+ thresholds: 要测试的阈值列表
430
+
431
+ Returns:
432
+ 最优阈值及对应的指标,如果无满足约束的阈值则返回None
433
+ """
434
+ if thresholds is None:
435
+ thresholds = [i/100 for i in range(1, 100)] # 0.01 到 0.99
436
+
437
+ best_threshold = None
438
+ best_value = 0
439
+ best_metrics = None
440
+
441
+ for threshold in thresholds:
442
+ metrics = self.calculate_metrics(
443
+ confidence_threshold=threshold,
444
+ class_id=class_id,
445
+ calculate_fppi=(constraint_metric == 'fppi')
446
+ )
447
+
448
+ # 检查约束条件
449
+ constraint_satisfied = False
450
+ if constraint_metric == 'fppi':
451
+ if metrics['fppi'] is not None and metrics['fppi'] <= constraint_value:
452
+ constraint_satisfied = True
453
+ elif constraint_metric in metrics:
454
+ if metrics[constraint_metric] >= constraint_value:
455
+ constraint_satisfied = True
456
+
457
+ # 如果满足约束,检查是否是最优的
458
+ if constraint_satisfied:
459
+ if metrics[target_metric] > best_value:
460
+ best_value = metrics[target_metric]
461
+ best_threshold = threshold
462
+ best_metrics = metrics
463
+
464
+ if best_threshold is None:
465
+ return None
466
+
467
+ return {
468
+ 'optimal_threshold': best_threshold,
469
+ 'target_metric': target_metric,
470
+ 'target_value': best_value,
471
+ 'constraint': f"{constraint_metric} <= {constraint_value}",
472
+ 'metrics': best_metrics
473
+ }
474
+
475
+ def generate_report(
476
+ self,
477
+ confidence_threshold: float = 0.5,
478
+ class_id: Optional[int] = None
479
+ ) -> str:
480
+ """
481
+ 生成评估报告
482
+
483
+ Args:
484
+ confidence_threshold: 置信度阈值
485
+ class_id: 指定类别ID
486
+
487
+ Returns:
488
+ 格式化的评估报告字符串
489
+ """
490
+ metrics = self.calculate_metrics(confidence_threshold, class_id)
491
+
492
+ class_name = "所有类别"
493
+ if class_id is not None and class_id in self.positive_gt.categories:
494
+ class_name = f"类别 {class_id} ({self.positive_gt.categories[class_id]})"
495
+
496
+ report = f"""
497
+ {'='*60}
498
+ 评估报告 - {class_name}
499
+ {'='*60}
500
+
501
+ 配置信息:
502
+ 置信度阈值: {metrics['confidence_threshold']}
503
+ IoU阈值: {metrics['iou_threshold']}
504
+ 正检集大小: {metrics['positive_set_size']} 张图像
505
+ 误检集大小: {metrics['negative_set_size']} 张图像
506
+
507
+ 正检集指标 (Precision & Recall):
508
+ True Positives (TP): {metrics['tp']}
509
+ False Positives (FP): {metrics['fp']}
510
+ False Negatives (FN): {metrics['fn']}
511
+
512
+ Precision: {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)
513
+ Recall: {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)
514
+ F1-Score: {metrics['f1']:.4f}
515
+ """
516
+
517
+ if metrics['fppi'] is not None:
518
+ report += f"""
519
+ 误检集指标 (FPPI):
520
+ FPPI (False Positives Per Image): {metrics['fppi']:.6f}
521
+ 总误检数: {metrics.get('total_false_positives', 'N/A')}
522
+ 平均每图误检数: {metrics.get('avg_fp_per_image', 'N/A'):.2f}
523
+ 单图最大误检数: {metrics.get('max_fp_per_image', 'N/A')}
524
+ 单图最小误检数: {metrics.get('min_fp_per_image', 'N/A')}
525
+ """
526
+ else:
527
+ report += f"""
528
+ 误检集指标 (FPPI):
529
+ {metrics['fppi_note']}
530
+ """
531
+
532
+ report += f"\n{'='*60}\n"
533
+
534
+ return report
535
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dataset-toolkit
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: 一个用于加载、处理和导出计算机视觉数据集的工具包
5
5
  Home-page: https://github.com/yourusername/dataset-toolkit
6
6
  Author: wenxiang.han
@@ -42,6 +42,7 @@ Dynamic: requires-python
42
42
  - 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
43
43
  - 🛠️ **工具函数**:提供坐标转换等实用工具
44
44
  - 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
45
+ - 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
45
46
 
46
47
  ## 📦 安装
47
48
 
@@ -121,6 +122,43 @@ result = (pipeline
121
122
  .execute())
122
123
  ```
123
124
 
125
+ ### 模型评估(v0.2.0+)
126
+
127
+ ```python
128
+ from dataset_toolkit import (
129
+ load_yolo_from_local,
130
+ load_predictions_from_streamlined,
131
+ Evaluator
132
+ )
133
+
134
+ # 1. 加载GT和预测结果
135
+ gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
136
+ pred_dataset = load_predictions_from_streamlined(
137
+ "/results/predictions",
138
+ categories={0: 'parcel'},
139
+ image_dir="/data/test/images"
140
+ )
141
+
142
+ # 2. 创建评估器
143
+ evaluator = Evaluator(
144
+ positive_gt=gt_dataset,
145
+ positive_pred=pred_dataset,
146
+ iou_threshold=0.5
147
+ )
148
+
149
+ # 3. 计算指标
150
+ metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
151
+ print(f"Precision: {metrics['precision']:.4f}")
152
+ print(f"Recall: {metrics['recall']:.4f}")
153
+ print(f"F1-Score: {metrics['f1']:.4f}")
154
+
155
+ # 4. 寻找最优阈值
156
+ optimal = evaluator.find_optimal_threshold(metric='f1')
157
+ print(f"最优阈值: {optimal['optimal_threshold']}")
158
+ ```
159
+
160
+ 详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
161
+
124
162
  ## 📚 API 文档
125
163
 
126
164
  ### 数据加载器
@@ -19,10 +19,12 @@ dataset_toolkit/exporters/yolo_exporter.py
19
19
  dataset_toolkit/loaders/__init__.py
20
20
  dataset_toolkit/loaders/local_loader.py
21
21
  dataset_toolkit/processors/__init__.py
22
+ dataset_toolkit/processors/evaluator.py
22
23
  dataset_toolkit/processors/merger.py
23
24
  dataset_toolkit/utils/__init__.py
24
25
  dataset_toolkit/utils/coords.py
25
26
  examples/basic_usage.py
27
+ examples/evaluation_example.py
26
28
  tests/__init__.py
27
29
  tests/conftest.py
28
30
  tests/test_exporters.py
@@ -0,0 +1,250 @@
1
+ """
2
+ 评估系统使用示例
3
+
4
+ 演示如何使用dataset_toolkit进行模型评估:
5
+ 1. 加载GT数据集(正检集和误检集)
6
+ 2. 加载预测结果
7
+ 3. 计算评估指标(Precision, Recall, F1, FPPI)
8
+ 4. 测试不同置信度阈值
9
+ 5. 找到最优阈值
10
+ """
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ # 添加项目路径
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ from dataset_toolkit import (
19
+ load_yolo_from_local,
20
+ load_predictions_from_streamlined,
21
+ Evaluator
22
+ )
23
+
24
+
25
+ def main():
26
+ print("="*60)
27
+ print("评估系统使用示例")
28
+ print("="*60)
29
+
30
+ # ============================================================
31
+ # 1. 配置路径(请根据实际情况修改)
32
+ # ============================================================
33
+
34
+ # 正检集路径
35
+ positive_gt_path = "/opt/dlami/nvme/workspace_wenxiang/parcel/test_val/labels"
36
+ positive_pred_path = "/opt/dlami/nvme/workspace_wenxiang/ai_train/onnx_infer/detections/results/streamlined_test"
37
+ positive_image_path = "/opt/dlami/nvme/workspace_wenxiang/parcel/test_val/images"
38
+
39
+ # 误检集路径(可选)
40
+ negative_pred_path = None # 如果有误检集,设置路径
41
+ negative_image_path = None
42
+
43
+ # 类别映射
44
+ categories = {0: 'parcel'}
45
+
46
+ # ============================================================
47
+ # 2. 加载数据集
48
+ # ============================================================
49
+
50
+ print("\n步骤1: 加载数据集...")
51
+ print("-" * 60)
52
+
53
+ # 加载正检集GT
54
+ print("\n加载正检集GT...")
55
+ gt_positive = load_yolo_from_local(
56
+ positive_gt_path,
57
+ categories=categories
58
+ )
59
+ gt_positive.dataset_type = "gt"
60
+ gt_positive.metadata = {
61
+ "test_purpose": "positive",
62
+ "description": "包含目标物体的测试集"
63
+ }
64
+ print(f"✓ 正检集GT: {len(gt_positive.images)} 张图像")
65
+
66
+ # 加载正检集预测
67
+ print("\n加载正检集预测结果...")
68
+ pred_positive = load_predictions_from_streamlined(
69
+ positive_pred_path,
70
+ categories=categories,
71
+ image_dir=positive_image_path
72
+ )
73
+ pred_positive.dataset_type = "pred"
74
+ pred_positive.metadata = {
75
+ "test_purpose": "positive",
76
+ "model_name": "yolov8_parcel"
77
+ }
78
+ print(f"✓ 正检集Pred: {len(pred_positive.images)} 张图像")
79
+
80
+ # 加载误检集预测(如果有)
81
+ pred_negative = None
82
+ if negative_pred_path:
83
+ print("\n加载误检集预测结果...")
84
+ pred_negative = load_predictions_from_streamlined(
85
+ negative_pred_path,
86
+ categories=categories,
87
+ image_dir=negative_image_path
88
+ )
89
+ pred_negative.dataset_type = "pred"
90
+ pred_negative.metadata = {
91
+ "test_purpose": "negative",
92
+ "model_name": "yolov8_parcel"
93
+ }
94
+ print(f"✓ 误检集Pred: {len(pred_negative.images)} 张图像")
95
+ else:
96
+ print("\n未提供误检集,将只计算Precision/Recall/F1")
97
+
98
+ # ============================================================
99
+ # 3. 创建评估器
100
+ # ============================================================
101
+
102
+ print("\n步骤2: 创建评估器...")
103
+ print("-" * 60)
104
+
105
+ evaluator = Evaluator(
106
+ positive_gt=gt_positive,
107
+ positive_pred=pred_positive,
108
+ negative_pred=pred_negative,
109
+ iou_threshold=0.5
110
+ )
111
+ print("✓ 评估器创建成功")
112
+
113
+ # ============================================================
114
+ # 4. 计算单个阈值的指标
115
+ # ============================================================
116
+
117
+ print("\n步骤3: 计算评估指标(置信度阈值=0.5)...")
118
+ print("-" * 60)
119
+
120
+ metrics = evaluator.calculate_metrics(
121
+ confidence_threshold=0.5,
122
+ class_id=0 # 只评估parcel类别
123
+ )
124
+
125
+ print(f"\n正检集指标:")
126
+ print(f" TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']}")
127
+ print(f" Precision: {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)")
128
+ print(f" Recall: {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)")
129
+ print(f" F1-Score: {metrics['f1']:.4f}")
130
+
131
+ if metrics['fppi'] is not None:
132
+ print(f"\n误检集指标:")
133
+ print(f" FPPI: {metrics['fppi']:.6f}")
134
+ print(f" 总误检数: {metrics['total_false_positives']}")
135
+
136
+ # ============================================================
137
+ # 5. 测试多个阈值
138
+ # ============================================================
139
+
140
+ print("\n步骤4: 测试多个置信度阈值...")
141
+ print("-" * 60)
142
+
143
+ test_thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
144
+
145
+ print(f"\n{'阈值':<10} {'Precision':<12} {'Recall':<12} {'F1':<12} {'FPPI':<12}")
146
+ print("-" * 60)
147
+
148
+ for threshold in test_thresholds:
149
+ m = evaluator.calculate_metrics(
150
+ confidence_threshold=threshold,
151
+ class_id=0
152
+ )
153
+ fppi_str = f"{m['fppi']:.6f}" if m['fppi'] is not None else "N/A"
154
+ print(f"{threshold:<10.2f} {m['precision']:<12.4f} {m['recall']:<12.4f} "
155
+ f"{m['f1']:<12.4f} {fppi_str:<12}")
156
+
157
+ # ============================================================
158
+ # 6. 找到最优阈值
159
+ # ============================================================
160
+
161
+ print("\n步骤5: 寻找最优阈值...")
162
+ print("-" * 60)
163
+
164
+ # 找到F1最高的阈值
165
+ optimal = evaluator.find_optimal_threshold(
166
+ metric='f1',
167
+ class_id=0
168
+ )
169
+
170
+ print(f"\n最优阈值(F1最大):")
171
+ print(f" 阈值: {optimal['optimal_threshold']}")
172
+ print(f" Precision: {optimal['metrics']['precision']:.4f}")
173
+ print(f" Recall: {optimal['metrics']['recall']:.4f}")
174
+ print(f" F1-Score: {optimal['metrics']['f1']:.4f}")
175
+
176
+ # 如果有误检集,找到FPPI约束下的最优阈值
177
+ if pred_negative:
178
+ constrained = evaluator.find_threshold_with_constraint(
179
+ target_metric='recall',
180
+ constraint_metric='fppi',
181
+ constraint_value=0.01, # FPPI < 0.01
182
+ class_id=0
183
+ )
184
+
185
+ if constrained:
186
+ print(f"\n最优阈值(FPPI < 0.01约束下,Recall最大):")
187
+ print(f" 阈值: {constrained['optimal_threshold']}")
188
+ print(f" Recall: {constrained['target_value']:.4f}")
189
+ print(f" FPPI: {constrained['metrics']['fppi']:.6f}")
190
+ else:
191
+ print(f"\n警告: 无法找到满足 FPPI < 0.01 约束的阈值")
192
+
193
+ # ============================================================
194
+ # 7. 生成完整报告
195
+ # ============================================================
196
+
197
+ print("\n步骤6: 生成完整评估报告...")
198
+ print("-" * 60)
199
+
200
+ report = evaluator.generate_report(
201
+ confidence_threshold=0.5,
202
+ class_id=0
203
+ )
204
+ print(report)
205
+
206
+ # ============================================================
207
+ # 8. 计算PR曲线数据(可用于绘图)
208
+ # ============================================================
209
+
210
+ print("\n步骤7: 计算PR曲线数据...")
211
+ print("-" * 60)
212
+
213
+ pr_curve = evaluator.calculate_pr_curve(
214
+ thresholds=[i/10 for i in range(1, 10)],
215
+ class_id=0
216
+ )
217
+
218
+ print(f"\nPR曲线数据点: {len(pr_curve)} 个")
219
+ print(f"{'阈值':<10} {'Precision':<12} {'Recall':<12}")
220
+ print("-" * 40)
221
+ for point in pr_curve[:5]: # 只显示前5个
222
+ print(f"{point['threshold']:<10.2f} {point['precision']:<12.4f} {point['recall']:<12.4f}")
223
+ print("...")
224
+
225
+ print("\n" + "="*60)
226
+ print("评估完成!")
227
+ print("="*60)
228
+
229
+ # 可以将PR曲线数据保存或绘图
230
+ # import matplotlib.pyplot as plt
231
+ # precisions = [p['precision'] for p in pr_curve]
232
+ # recalls = [p['recall'] for p in pr_curve]
233
+ # plt.plot(recalls, precisions)
234
+ # plt.xlabel('Recall')
235
+ # plt.ylabel('Precision')
236
+ # plt.title('PR Curve')
237
+ # plt.savefig('pr_curve.png')
238
+
239
+
240
+ if __name__ == '__main__':
241
+ try:
242
+ main()
243
+ except FileNotFoundError as e:
244
+ print(f"\n错误: {e}")
245
+ print("\n请修改脚本中的路径配置,指向实际的数据集位置。")
246
+ except Exception as e:
247
+ print(f"\n发生错误: {e}")
248
+ import traceback
249
+ traceback.print_exc()
250
+
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "dataset-toolkit"
7
- version = "0.1.2"
7
+ version = "0.2.0"
8
8
  description = "一个用于加载、处理和导出计算机视觉数据集的工具包"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.7"
@@ -8,7 +8,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
8
8
 
9
9
  setup(
10
10
  name="dataset-toolkit",
11
- version="0.1.2",
11
+ version="0.2.0",
12
12
  author="wenxiang.han",
13
13
  author_email="wenxiang.han@anker-in.com",
14
14
  description="一个用于加载、处理和导出计算机视觉数据集的工具包",
File without changes
File without changes