dataset-toolkit 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@ Dataset Toolkit - 计算机视觉数据集处理工具包
15
15
  >>> export_to_coco(dataset, "output.json")
16
16
  """
17
17
 
18
- __version__ = "0.1.1"
18
+ __version__ = "0.2.0"
19
19
  __author__ = "wenxiang.han"
20
20
  __email__ = "wenxiang.han@anker-in.com"
21
21
 
@@ -28,13 +28,18 @@ from dataset_toolkit.models import (
28
28
 
29
29
  from dataset_toolkit.loaders.local_loader import (
30
30
  load_yolo_from_local,
31
- load_csv_result_from_local
31
+ load_csv_result_from_local,
32
+ load_predictions_from_streamlined
32
33
  )
33
34
 
34
35
  from dataset_toolkit.processors.merger import (
35
36
  merge_datasets
36
37
  )
37
38
 
39
+ from dataset_toolkit.processors.evaluator import (
40
+ Evaluator
41
+ )
42
+
38
43
  from dataset_toolkit.exporters.coco_exporter import (
39
44
  export_to_coco
40
45
  )
@@ -43,6 +48,11 @@ from dataset_toolkit.exporters.txt_exporter import (
43
48
  export_to_txt
44
49
  )
45
50
 
51
+ from dataset_toolkit.exporters.yolo_exporter import (
52
+ export_to_yolo_format,
53
+ export_to_yolo_and_txt
54
+ )
55
+
46
56
  from dataset_toolkit.utils.coords import (
47
57
  yolo_to_absolute_bbox
48
58
  )
@@ -64,13 +74,17 @@ __all__ = [
64
74
  # 加载器
65
75
  "load_yolo_from_local",
66
76
  "load_csv_result_from_local",
77
+ "load_predictions_from_streamlined",
67
78
 
68
79
  # 处理器
69
80
  "merge_datasets",
81
+ "Evaluator",
70
82
 
71
83
  # 导出器
72
84
  "export_to_coco",
73
85
  "export_to_txt",
86
+ "export_to_yolo_format",
87
+ "export_to_yolo_and_txt",
74
88
 
75
89
  # 工具函数
76
90
  "yolo_to_absolute_bbox",
@@ -0,0 +1,157 @@
1
+ # dataset_toolkit/exporters/yolo_exporter.py
2
+ """
3
+ 导出为 YOLO 格式(完整的 images/ + labels/ 目录结构)
4
+ """
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+
10
+ def export_to_yolo_format(
11
+ dataset,
12
+ output_dir: str,
13
+ use_symlinks: bool = True,
14
+ overwrite: bool = False
15
+ ):
16
+ """
17
+ 导出数据集为完整的 YOLO 格式目录结构
18
+
19
+ 参数:
20
+ dataset: Dataset 对象
21
+ output_dir: 输出目录路径
22
+ use_symlinks: 是否使用软链接(True)或复制文件(False)
23
+ overwrite: 是否覆盖已存在的文件
24
+
25
+ 输出结构:
26
+ output_dir/
27
+ ├── images/
28
+ │ ├── img1.jpg
29
+ │ └── img2.jpg
30
+ └── labels/
31
+ ├── img1.txt
32
+ └── img2.txt
33
+ """
34
+ output_path = Path(output_dir)
35
+ images_dir = output_path / 'images'
36
+ labels_dir = output_path / 'labels'
37
+
38
+ # 创建目录
39
+ images_dir.mkdir(parents=True, exist_ok=True)
40
+ labels_dir.mkdir(parents=True, exist_ok=True)
41
+
42
+ print(f"导出 YOLO 格式到: {output_path}")
43
+ print(f" 使用软链接: {use_symlinks}")
44
+
45
+ success_count = 0
46
+ error_count = 0
47
+
48
+ for img in dataset.images:
49
+ try:
50
+ # 获取图片文件名(不含扩展名)
51
+ img_path = Path(img.path)
52
+ img_name = img_path.name
53
+ stem = img_path.stem
54
+
55
+ # 1. 处理图片(软链接或复制)
56
+ target_img_path = images_dir / img_name
57
+
58
+ if target_img_path.exists() and not overwrite:
59
+ # 文件已存在,跳过
60
+ pass
61
+ else:
62
+ if use_symlinks:
63
+ # 使用软链接
64
+ if target_img_path.exists():
65
+ target_img_path.unlink()
66
+ target_img_path.symlink_to(img_path.resolve())
67
+ else:
68
+ # 复制文件
69
+ import shutil
70
+ shutil.copy2(img_path, target_img_path)
71
+
72
+ # 2. 生成标注文件
73
+ label_path = labels_dir / f"{stem}.txt"
74
+
75
+ with open(label_path, 'w') as f:
76
+ for ann in img.annotations:
77
+ # 内部格式: [x_min, y_min, width, height] (绝对像素值)
78
+ # YOLO 格式: class_id x_center y_center width height (归一化)
79
+
80
+ x_min, y_min, width, height = ann.bbox
81
+
82
+ # 转换为 YOLO 归一化格式
83
+ x_center = (x_min + width / 2) / img.width
84
+ y_center = (y_min + height / 2) / img.height
85
+ norm_width = width / img.width
86
+ norm_height = height / img.height
87
+
88
+ # 写入:class_id x_center y_center width height
89
+ f.write(f"{ann.category_id} {x_center:.6f} {y_center:.6f} {norm_width:.6f} {norm_height:.6f}\n")
90
+
91
+ success_count += 1
92
+
93
+ except Exception as e:
94
+ print(f"警告: 处理图片失败 {img.path}: {e}")
95
+ error_count += 1
96
+ continue
97
+
98
+ print(f"✓ 导出完成:")
99
+ print(f" 成功: {success_count} 张图片")
100
+ if error_count > 0:
101
+ print(f" 失败: {error_count} 张图片")
102
+ print(f" 图片目录: {images_dir}")
103
+ print(f" 标注目录: {labels_dir}")
104
+
105
+ return output_path
106
+
107
+
108
+ def export_to_yolo_and_txt(
109
+ dataset,
110
+ yolo_dir: str,
111
+ txt_file: str,
112
+ use_symlinks: bool = True,
113
+ use_relative_paths: bool = False
114
+ ):
115
+ """
116
+ 导出为 YOLO 格式并生成对应的 txt 列表文件
117
+
118
+ 参数:
119
+ dataset: Dataset 对象
120
+ yolo_dir: YOLO 格式输出目录
121
+ txt_file: txt 列表文件路径
122
+ use_symlinks: 是否使用软链接
123
+ use_relative_paths: txt 中是否使用相对路径
124
+
125
+ 返回:
126
+ yolo_dir_path: YOLO 目录路径
127
+ """
128
+ # 1. 导出为 YOLO 格式
129
+ yolo_path = export_to_yolo_format(dataset, yolo_dir, use_symlinks=use_symlinks)
130
+
131
+ # 2. 生成 txt 列表文件(指向 YOLO 目录中的 images/)
132
+ images_dir = yolo_path / 'images'
133
+ txt_path = Path(txt_file)
134
+ txt_path.parent.mkdir(parents=True, exist_ok=True)
135
+
136
+ print(f"\n生成 txt 列表: {txt_file}")
137
+
138
+ with open(txt_file, 'w') as f:
139
+ for img in dataset.images:
140
+ img_name = Path(img.path).name
141
+ # 指向 YOLO 目录中的图片(可能是软链接)
142
+ img_in_yolo = images_dir / img_name
143
+
144
+ if use_relative_paths:
145
+ # 相对于 txt 文件的路径
146
+ rel_path = os.path.relpath(img_in_yolo, txt_path.parent)
147
+ f.write(f"{rel_path}\n")
148
+ else:
149
+ # 绝对路径(规范化但不解析软链接)
150
+ # 使用 os.path.normpath 规范化路径,去除 .. 等
151
+ normalized_path = os.path.normpath(str(img_in_yolo.absolute()))
152
+ f.write(f"{normalized_path}\n")
153
+
154
+ print(f"✓ txt 列表已生成: {len(dataset.images)} 行")
155
+
156
+ return yolo_path
157
+
@@ -186,4 +186,149 @@ def load_csv_result_from_local(dataset_path: str, categories: Dict[int, str] = N
186
186
 
187
187
  print(f"加载完成. 共找到 {image_count} 张图片, {len(dataset.categories)} 个类别.")
188
188
  print(f"类别映射: {dataset.categories}")
189
+ return dataset
190
+
191
+
192
+ def load_predictions_from_streamlined(
193
+ predictions_dir: str,
194
+ categories: Dict[int, str],
195
+ image_dir: str = None
196
+ ) -> Dataset:
197
+ """
198
+ 从streamlined推理结果目录加载预测数据集。
199
+
200
+ 预测文件格式(每行一个检测):
201
+ class_id,confidence,center_x,center_y,width,height
202
+ 例如: 0,0.934679,354.00,388.00,274.00,102.00
203
+
204
+ 参数:
205
+ predictions_dir: 预测结果txt文件所在目录
206
+ categories: 类别映射字典 {class_id: class_name}
207
+ image_dir: 图像目录(可选,用于读取图像尺寸)
208
+ 如果不提供,将尝试从预测文件同级目录查找
209
+
210
+ 返回:
211
+ Dataset: 预测数据集对象,dataset_type='pred'
212
+ """
213
+ pred_path = Path(predictions_dir)
214
+
215
+ if not pred_path.is_dir():
216
+ raise FileNotFoundError(f"预测结果目录不存在: {pred_path}")
217
+
218
+ # 尝试自动查找图像目录
219
+ if image_dir is None:
220
+ # 尝试常见的图像目录位置
221
+ possible_image_dirs = [
222
+ pred_path.parent / 'images',
223
+ pred_path.parent.parent / 'images',
224
+ ]
225
+ for possible_dir in possible_image_dirs:
226
+ if possible_dir.is_dir():
227
+ image_dir = str(possible_dir)
228
+ print(f"自动找到图像目录: {image_dir}")
229
+ break
230
+
231
+ dataset = Dataset(
232
+ name=pred_path.name,
233
+ categories=categories,
234
+ dataset_type="pred"
235
+ )
236
+
237
+ supported_extensions = ['.jpg', '.jpeg', '.png']
238
+ txt_files = list(pred_path.glob('*.txt'))
239
+
240
+ print(f"开始加载预测结果: {pred_path.name}...")
241
+ print(f"找到 {len(txt_files)} 个预测文件")
242
+
243
+ loaded_count = 0
244
+ skipped_count = 0
245
+
246
+ for txt_file in txt_files:
247
+ # 预测文件名对应的图像文件名(假设同名)
248
+ image_base_name = txt_file.stem
249
+
250
+ # 尝试查找对应的图像文件
251
+ image_path = None
252
+ img_width, img_height = None, None
253
+
254
+ if image_dir:
255
+ image_dir_path = Path(image_dir)
256
+ for ext in supported_extensions:
257
+ potential_image = image_dir_path / (image_base_name + ext)
258
+ if potential_image.exists():
259
+ image_path = str(potential_image.resolve())
260
+ try:
261
+ with Image.open(potential_image) as img:
262
+ img_width, img_height = img.size
263
+ except IOError:
264
+ print(f"警告: 无法打开图片 {potential_image}")
265
+ break
266
+
267
+ # 如果没有找到图像,使用默认值
268
+ if image_path is None:
269
+ # 假设一个默认的图像路径和尺寸
270
+ image_path = f"unknown/{image_base_name}.jpg"
271
+ img_width, img_height = 640, 640 # 默认尺寸
272
+ if image_dir:
273
+ skipped_count += 1
274
+
275
+ # 创建图像标注对象
276
+ image_annotation = ImageAnnotation(
277
+ image_id=image_base_name + '.jpg',
278
+ path=image_path,
279
+ width=img_width,
280
+ height=img_height
281
+ )
282
+
283
+ # 读取预测结果
284
+ try:
285
+ with open(txt_file, 'r') as f:
286
+ for line in f:
287
+ line = line.strip()
288
+ if not line:
289
+ continue
290
+
291
+ # 解析格式: class_id,confidence,center_x,center_y,width,height
292
+ parts = line.split(',')
293
+ if len(parts) != 6:
294
+ print(f"警告: 格式错误,已跳过: {txt_file} -> '{line}'")
295
+ continue
296
+
297
+ try:
298
+ class_id = int(parts[0])
299
+ confidence = float(parts[1])
300
+ center_x = float(parts[2])
301
+ center_y = float(parts[3])
302
+ width = float(parts[4])
303
+ height = float(parts[5])
304
+
305
+ # 转换为 [x_min, y_min, width, height] 格式
306
+ x_min = center_x - width / 2
307
+ y_min = center_y - height / 2
308
+
309
+ annotation = Annotation(
310
+ category_id=class_id,
311
+ bbox=[x_min, y_min, width, height],
312
+ confidence=confidence
313
+ )
314
+ image_annotation.annotations.append(annotation)
315
+
316
+ except (ValueError, IndexError) as e:
317
+ print(f"警告: 解析错误,已跳过: {txt_file} -> '{line}' ({e})")
318
+ continue
319
+
320
+ except Exception as e:
321
+ print(f"警告: 读取文件失败,已跳过: {txt_file} ({e})")
322
+ continue
323
+
324
+ dataset.images.append(image_annotation)
325
+ loaded_count += 1
326
+
327
+ print(f"加载完成. 成功加载 {loaded_count} 个预测文件")
328
+ if skipped_count > 0:
329
+ print(f"警告: {skipped_count} 个文件未找到对应图像,使用默认尺寸")
330
+
331
+ total_detections = sum(len(img.annotations) for img in dataset.images)
332
+ print(f"总检测数: {total_detections}")
333
+
189
334
  return dataset
dataset_toolkit/models.py CHANGED
@@ -24,4 +24,6 @@ class Dataset:
24
24
  """代表一个完整的数据集对象,作为系统内部的标准化表示."""
25
25
  name: str
26
26
  images: List[ImageAnnotation] = field(default_factory=list)
27
- categories: Dict[int, str] = field(default_factory=dict)
27
+ categories: Dict[int, str] = field(default_factory=dict)
28
+ dataset_type: str = "train" # 'train', 'gt', 'pred'
29
+ metadata: Dict = field(default_factory=dict) # 存储描述性信息,不包含处理参数
@@ -0,0 +1,9 @@
1
+ # dataset_toolkit/processors/__init__.py
2
+ from .merger import merge_datasets
3
+ from .evaluator import Evaluator
4
+
5
+ __all__ = [
6
+ 'merge_datasets',
7
+ 'Evaluator',
8
+ ]
9
+
@@ -0,0 +1,535 @@
1
+ # dataset_toolkit/processors/evaluator.py
2
+ from typing import Dict, List, Optional, Tuple
3
+ from dataset_toolkit.models import Dataset, Annotation, ImageAnnotation
4
+
5
+
6
+ class Evaluator:
7
+ """
8
+ 评估器:支持正检集和误检集分离评估
9
+
10
+ 用于比较GT和Pred数据集,计算Precision、Recall、F1、FPPI等指标。
11
+ 支持在不同置信度阈值下动态评估,无需重新加载数据。
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ positive_gt: Dataset,
17
+ positive_pred: Dataset,
18
+ negative_gt: Optional[Dataset] = None,
19
+ negative_pred: Optional[Dataset] = None,
20
+ iou_threshold: float = 0.5
21
+ ):
22
+ """
23
+ 初始化评估器
24
+
25
+ Args:
26
+ positive_gt: 正检集GT(必需)- 包含目标物体的测试集
27
+ positive_pred: 正检集预测(必需)- 对正检集的预测结果
28
+ negative_gt: 误检集GT(可选)- 不包含目标的背景图像GT
29
+ negative_pred: 误检集预测(可选)- 对误检集的预测结果
30
+ iou_threshold: IoU阈值,用于判断检测是否匹配GT(默认0.5)
31
+ """
32
+ self.positive_gt = positive_gt
33
+ self.positive_pred = positive_pred
34
+ self.negative_gt = negative_gt
35
+ self.negative_pred = negative_pred
36
+ self.iou_threshold = iou_threshold
37
+
38
+ # 验证数据集
39
+ self._validate_datasets()
40
+
41
+ def _validate_datasets(self):
42
+ """验证数据集的有效性"""
43
+ if self.positive_gt is None or self.positive_pred is None:
44
+ raise ValueError("正检集的GT和Pred是必需的")
45
+
46
+ if len(self.positive_gt.images) == 0:
47
+ raise ValueError("正检集GT为空")
48
+
49
+ if len(self.positive_pred.images) == 0:
50
+ raise ValueError("正检集Pred为空")
51
+
52
+ # 检查类别是否一致
53
+ if self.positive_gt.categories != self.positive_pred.categories:
54
+ print("警告: GT和Pred的类别映射不一致")
55
+ print(f" GT categories: {self.positive_gt.categories}")
56
+ print(f" Pred categories: {self.positive_pred.categories}")
57
+
58
+ def calculate_metrics(
59
+ self,
60
+ confidence_threshold: float = 0.5,
61
+ class_id: Optional[int] = None,
62
+ calculate_fppi: bool = True
63
+ ) -> Dict:
64
+ """
65
+ 计算综合评估指标
66
+
67
+ Args:
68
+ confidence_threshold: 置信度阈值(动态传入,不存储在数据集中)
69
+ class_id: 指定类别ID,None表示所有类别
70
+ calculate_fppi: 是否计算FPPI(需要negative_pred)
71
+
72
+ Returns:
73
+ 包含所有指标的字典:
74
+ - tp, fp, fn: True/False Positives/Negatives
75
+ - precision, recall, f1: 精确率、召回率、F1分数
76
+ - fppi: False Positives Per Image(如果计算)
77
+ - confidence_threshold, iou_threshold: 使用的阈值
78
+ - positive_set_size, negative_set_size: 数据集大小
79
+ """
80
+ metrics = {}
81
+
82
+ # 1. 从正检集计算 Precision, Recall, F1
83
+ positive_metrics = self._calculate_positive_metrics(
84
+ confidence_threshold, class_id
85
+ )
86
+ metrics.update(positive_metrics)
87
+
88
+ # 2. 从误检集计算 FPPI
89
+ if calculate_fppi and self.negative_pred is not None:
90
+ fppi_metrics = self._calculate_fppi_metrics(
91
+ confidence_threshold, class_id
92
+ )
93
+ metrics.update(fppi_metrics)
94
+ else:
95
+ metrics['fppi'] = None
96
+ metrics['fppi_note'] = "未提供误检集" if self.negative_pred is None else "未计算FPPI"
97
+
98
+ # 3. 添加配置信息
99
+ metrics['confidence_threshold'] = confidence_threshold
100
+ metrics['iou_threshold'] = self.iou_threshold
101
+ metrics['positive_set_size'] = len(self.positive_gt.images)
102
+ metrics['negative_set_size'] = (
103
+ len(self.negative_pred.images)
104
+ if self.negative_pred else 0
105
+ )
106
+
107
+ return metrics
108
+
109
+ def _calculate_positive_metrics(
110
+ self,
111
+ confidence_threshold: float,
112
+ class_id: Optional[int] = None
113
+ ) -> Dict:
114
+ """
115
+ 从正检集计算 Precision, Recall, F1
116
+
117
+ Args:
118
+ confidence_threshold: 置信度阈值
119
+ class_id: 指定类别ID
120
+
121
+ Returns:
122
+ 包含TP, FP, FN, Precision, Recall, F1的字典
123
+ """
124
+ # 1. 过滤预测结果
125
+ filtered_preds = self._filter_predictions(
126
+ self.positive_pred,
127
+ confidence_threshold,
128
+ class_id
129
+ )
130
+
131
+ # 2. 匹配GT和Pred
132
+ tp = 0 # True Positives
133
+ fp = 0 # False Positives
134
+ fn = 0 # False Negatives
135
+
136
+ matched_gt = set() # 记录已匹配的GT,格式: (image_id, gt_index)
137
+
138
+ # 遍历每张图像
139
+ for img_gt in self.positive_gt.images:
140
+ # 获取该图像的GT标注
141
+ gt_anns = [
142
+ ann for ann in img_gt.annotations
143
+ if class_id is None or ann.category_id == class_id
144
+ ]
145
+
146
+ # 获取该图像的预测结果
147
+ img_preds = filtered_preds.get(img_gt.image_id, [])
148
+
149
+ # 匹配预测和GT
150
+ for pred in img_preds:
151
+ best_iou = 0
152
+ best_gt_idx = -1
153
+
154
+ for i, gt_ann in enumerate(gt_anns):
155
+ if (img_gt.image_id, i) not in matched_gt:
156
+ iou = self._calculate_iou(pred.bbox, gt_ann.bbox)
157
+ if iou > best_iou:
158
+ best_iou = iou
159
+ best_gt_idx = i
160
+
161
+ if best_iou >= self.iou_threshold:
162
+ tp += 1
163
+ matched_gt.add((img_gt.image_id, best_gt_idx))
164
+ else:
165
+ fp += 1
166
+
167
+ # 统计未匹配的GT(False Negatives)
168
+ fn += len([
169
+ i for i in range(len(gt_anns))
170
+ if (img_gt.image_id, i) not in matched_gt
171
+ ])
172
+
173
+ # 3. 计算指标
174
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
175
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
176
+ f1 = (2 * precision * recall / (precision + recall)
177
+ if (precision + recall) > 0 else 0.0)
178
+
179
+ return {
180
+ 'tp': tp,
181
+ 'fp': fp,
182
+ 'fn': fn,
183
+ 'precision': precision,
184
+ 'recall': recall,
185
+ 'f1': f1,
186
+ 'positive_set_note': f"基于{len(self.positive_gt.images)}张正检图像"
187
+ }
188
+
189
+ def _calculate_fppi_metrics(
190
+ self,
191
+ confidence_threshold: float,
192
+ class_id: Optional[int] = None
193
+ ) -> Dict:
194
+ """
195
+ 从误检集计算 FPPI (False Positives Per Image)
196
+
197
+ Args:
198
+ confidence_threshold: 置信度阈值
199
+ class_id: 指定类别ID
200
+
201
+ Returns:
202
+ 包含FPPI相关指标的字典
203
+ """
204
+ # 1. 过滤预测结果
205
+ filtered_preds = self._filter_predictions(
206
+ self.negative_pred,
207
+ confidence_threshold,
208
+ class_id
209
+ )
210
+
211
+ # 2. 统计误检数
212
+ total_fp = 0
213
+ fp_per_image = []
214
+
215
+ for img_pred in self.negative_pred.images:
216
+ img_preds = filtered_preds.get(img_pred.image_id, [])
217
+
218
+ # 在误检集中,所有检测都是False Positive
219
+ # 但如果提供了negative_gt,可以更精确地判断
220
+ if self.negative_gt is not None:
221
+ # 找到对应的GT图像
222
+ gt_img = next(
223
+ (img for img in self.negative_gt.images
224
+ if img.image_id == img_pred.image_id),
225
+ None
226
+ )
227
+
228
+ if gt_img is not None:
229
+ # 获取GT标注
230
+ gt_anns = [
231
+ ann for ann in gt_img.annotations
232
+ if class_id is None or ann.category_id == class_id
233
+ ]
234
+
235
+ # 匹配预测和GT,未匹配的是FP
236
+ matched = set()
237
+ fp_count = 0
238
+ for pred in img_preds:
239
+ is_matched = False
240
+ for i, gt_ann in enumerate(gt_anns):
241
+ if i not in matched:
242
+ iou = self._calculate_iou(pred.bbox, gt_ann.bbox)
243
+ if iou >= self.iou_threshold:
244
+ matched.add(i)
245
+ is_matched = True
246
+ break
247
+ if not is_matched:
248
+ fp_count += 1
249
+
250
+ total_fp += fp_count
251
+ fp_per_image.append(fp_count)
252
+ else:
253
+ # 没有对应的GT,所有检测都是FP
254
+ fp_count = len(img_preds)
255
+ total_fp += fp_count
256
+ fp_per_image.append(fp_count)
257
+ else:
258
+ # 没有提供negative_gt,假设所有检测都是FP
259
+ fp_count = len(img_preds)
260
+ total_fp += fp_count
261
+ fp_per_image.append(fp_count)
262
+
263
+ # 3. 计算FPPI
264
+ num_images = len(self.negative_pred.images)
265
+ fppi = total_fp / num_images if num_images > 0 else 0.0
266
+
267
+ return {
268
+ 'fppi': fppi,
269
+ 'total_false_positives': total_fp,
270
+ 'negative_set_size': num_images,
271
+ 'fppi_note': f"基于{num_images}张误检图像",
272
+ 'max_fp_per_image': max(fp_per_image) if fp_per_image else 0,
273
+ 'min_fp_per_image': min(fp_per_image) if fp_per_image else 0,
274
+ 'avg_fp_per_image': total_fp / num_images if num_images > 0 else 0.0
275
+ }
276
+
277
+ def _filter_predictions(
278
+ self,
279
+ pred_dataset: Dataset,
280
+ confidence_threshold: float,
281
+ class_id: Optional[int] = None
282
+ ) -> Dict[str, List[Annotation]]:
283
+ """
284
+ 根据置信度阈值和类别过滤预测结果
285
+
286
+ Args:
287
+ pred_dataset: 预测数据集
288
+ confidence_threshold: 置信度阈值
289
+ class_id: 指定类别ID
290
+
291
+ Returns:
292
+ {image_id: [annotations]} 字典
293
+ """
294
+ filtered = {}
295
+ for img in pred_dataset.images:
296
+ img_preds = [
297
+ ann for ann in img.annotations
298
+ if ann.confidence >= confidence_threshold
299
+ and (class_id is None or ann.category_id == class_id)
300
+ ]
301
+ if img_preds:
302
+ filtered[img.image_id] = img_preds
303
+ return filtered
304
+
305
+ def _calculate_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
306
+ """
307
+ 计算两个边界框的IoU (Intersection over Union)
308
+
309
+ Args:
310
+ bbox1, bbox2: [x_min, y_min, width, height] 格式的边界框
311
+
312
+ Returns:
313
+ IoU值 (0.0 到 1.0)
314
+ """
315
+ # bbox格式: [x_min, y_min, width, height]
316
+ x1_min, y1_min, w1, h1 = bbox1
317
+ x2_min, y2_min, w2, h2 = bbox2
318
+
319
+ x1_max = x1_min + w1
320
+ y1_max = y1_min + h1
321
+ x2_max = x2_min + w2
322
+ y2_max = y2_min + h2
323
+
324
+ # 计算交集
325
+ inter_x_min = max(x1_min, x2_min)
326
+ inter_y_min = max(y1_min, y2_min)
327
+ inter_x_max = min(x1_max, x2_max)
328
+ inter_y_max = min(y1_max, y2_max)
329
+
330
+ if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
331
+ return 0.0
332
+
333
+ inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
334
+
335
+ # 计算并集
336
+ area1 = w1 * h1
337
+ area2 = w2 * h2
338
+ union_area = area1 + area2 - inter_area
339
+
340
+ return inter_area / union_area if union_area > 0 else 0.0
341
+
342
+ def calculate_pr_curve(
343
+ self,
344
+ thresholds: Optional[List[float]] = None,
345
+ class_id: Optional[int] = None
346
+ ) -> List[Dict]:
347
+ """
348
+ 计算PR曲线(不同置信度阈值下的Precision-Recall)
349
+
350
+ Args:
351
+ thresholds: 要测试的置信度阈值列表,None则使用默认值
352
+ class_id: 指定类别ID
353
+
354
+ Returns:
355
+ 每个阈值对应的指标列表
356
+ """
357
+ if thresholds is None:
358
+ thresholds = [i/10 for i in range(1, 10)] # 0.1 到 0.9
359
+
360
+ pr_points = []
361
+ for threshold in thresholds:
362
+ # 对同一份数据,使用不同阈值计算指标
363
+ metrics = self.calculate_metrics(
364
+ confidence_threshold=threshold,
365
+ class_id=class_id,
366
+ calculate_fppi=False # PR曲线不需要FPPI
367
+ )
368
+ pr_points.append({
369
+ 'threshold': threshold,
370
+ 'precision': metrics['precision'],
371
+ 'recall': metrics['recall'],
372
+ 'f1': metrics['f1'],
373
+ 'tp': metrics['tp'],
374
+ 'fp': metrics['fp'],
375
+ 'fn': metrics['fn']
376
+ })
377
+
378
+ return pr_points
379
+
380
+ def find_optimal_threshold(
381
+ self,
382
+ metric: str = 'f1',
383
+ class_id: Optional[int] = None,
384
+ thresholds: Optional[List[float]] = None
385
+ ) -> Dict:
386
+ """
387
+ 找到使指定指标最优的置信度阈值
388
+
389
+ Args:
390
+ metric: 优化目标指标 ('precision', 'recall', 'f1')
391
+ class_id: 指定类别ID
392
+ thresholds: 要测试的阈值列表
393
+
394
+ Returns:
395
+ 最优阈值及对应的所有指标
396
+ """
397
+ if metric not in ['precision', 'recall', 'f1']:
398
+ raise ValueError(f"不支持的指标: {metric}")
399
+
400
+ pr_curve = self.calculate_pr_curve(thresholds, class_id)
401
+
402
+ # 找到指定指标最大的点
403
+ best_point = max(pr_curve, key=lambda x: x[metric])
404
+
405
+ return {
406
+ 'optimal_threshold': best_point['threshold'],
407
+ 'optimized_metric': metric,
408
+ 'metrics': best_point
409
+ }
410
+
411
+ def find_threshold_with_constraint(
412
+ self,
413
+ target_metric: str,
414
+ constraint_metric: str,
415
+ constraint_value: float,
416
+ class_id: Optional[int] = None,
417
+ thresholds: Optional[List[float]] = None
418
+ ) -> Optional[Dict]:
419
+ """
420
+ 在约束条件下找到最优阈值
421
+
422
+ 例如:在FPPI < 0.01的约束下,找到Recall最高的阈值
423
+
424
+ Args:
425
+ target_metric: 要优化的目标指标 ('precision', 'recall', 'f1')
426
+ constraint_metric: 约束指标 ('fppi', 'precision', 'recall')
427
+ constraint_value: 约束值(如 fppi < 0.01)
428
+ class_id: 指定类别ID
429
+ thresholds: 要测试的阈值列表
430
+
431
+ Returns:
432
+ 最优阈值及对应的指标,如果无满足约束的阈值则返回None
433
+ """
434
+ if thresholds is None:
435
+ thresholds = [i/100 for i in range(1, 100)] # 0.01 到 0.99
436
+
437
+ best_threshold = None
438
+ best_value = 0
439
+ best_metrics = None
440
+
441
+ for threshold in thresholds:
442
+ metrics = self.calculate_metrics(
443
+ confidence_threshold=threshold,
444
+ class_id=class_id,
445
+ calculate_fppi=(constraint_metric == 'fppi')
446
+ )
447
+
448
+ # 检查约束条件
449
+ constraint_satisfied = False
450
+ if constraint_metric == 'fppi':
451
+ if metrics['fppi'] is not None and metrics['fppi'] <= constraint_value:
452
+ constraint_satisfied = True
453
+ elif constraint_metric in metrics:
454
+ if metrics[constraint_metric] >= constraint_value:
455
+ constraint_satisfied = True
456
+
457
+ # 如果满足约束,检查是否是最优的
458
+ if constraint_satisfied:
459
+ if metrics[target_metric] > best_value:
460
+ best_value = metrics[target_metric]
461
+ best_threshold = threshold
462
+ best_metrics = metrics
463
+
464
+ if best_threshold is None:
465
+ return None
466
+
467
+ return {
468
+ 'optimal_threshold': best_threshold,
469
+ 'target_metric': target_metric,
470
+ 'target_value': best_value,
471
+ 'constraint': f"{constraint_metric} <= {constraint_value}",
472
+ 'metrics': best_metrics
473
+ }
474
+
475
+ def generate_report(
476
+ self,
477
+ confidence_threshold: float = 0.5,
478
+ class_id: Optional[int] = None
479
+ ) -> str:
480
+ """
481
+ 生成评估报告
482
+
483
+ Args:
484
+ confidence_threshold: 置信度阈值
485
+ class_id: 指定类别ID
486
+
487
+ Returns:
488
+ 格式化的评估报告字符串
489
+ """
490
+ metrics = self.calculate_metrics(confidence_threshold, class_id)
491
+
492
+ class_name = "所有类别"
493
+ if class_id is not None and class_id in self.positive_gt.categories:
494
+ class_name = f"类别 {class_id} ({self.positive_gt.categories[class_id]})"
495
+
496
+ report = f"""
497
+ {'='*60}
498
+ 评估报告 - {class_name}
499
+ {'='*60}
500
+
501
+ 配置信息:
502
+ 置信度阈值: {metrics['confidence_threshold']}
503
+ IoU阈值: {metrics['iou_threshold']}
504
+ 正检集大小: {metrics['positive_set_size']} 张图像
505
+ 误检集大小: {metrics['negative_set_size']} 张图像
506
+
507
+ 正检集指标 (Precision & Recall):
508
+ True Positives (TP): {metrics['tp']}
509
+ False Positives (FP): {metrics['fp']}
510
+ False Negatives (FN): {metrics['fn']}
511
+
512
+ Precision: {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)
513
+ Recall: {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)
514
+ F1-Score: {metrics['f1']:.4f}
515
+ """
516
+
517
+ if metrics['fppi'] is not None:
518
+ report += f"""
519
+ 误检集指标 (FPPI):
520
+ FPPI (False Positives Per Image): {metrics['fppi']:.6f}
521
+ 总误检数: {metrics.get('total_false_positives', 'N/A')}
522
+ 平均每图误检数: {metrics.get('avg_fp_per_image', 'N/A'):.2f}
523
+ 单图最大误检数: {metrics.get('max_fp_per_image', 'N/A')}
524
+ 单图最小误检数: {metrics.get('min_fp_per_image', 'N/A')}
525
+ """
526
+ else:
527
+ report += f"""
528
+ 误检集指标 (FPPI):
529
+ {metrics['fppi_note']}
530
+ """
531
+
532
+ report += f"\n{'='*60}\n"
533
+
534
+ return report
535
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dataset-toolkit
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: 一个用于加载、处理和导出计算机视觉数据集的工具包
5
5
  Home-page: https://github.com/yourusername/dataset-toolkit
6
6
  Author: wenxiang.han
@@ -42,6 +42,7 @@ Dynamic: requires-python
42
42
  - 📤 **灵活导出**:导出为 COCO JSON、TXT 等多种格式
43
43
  - 🛠️ **工具函数**:提供坐标转换等实用工具
44
44
  - 📦 **标准化数据模型**:统一的内部数据表示,方便扩展
45
+ - 📊 **模型评估**:完整的目标检测模型评估系统(v0.2.0+)
45
46
 
46
47
  ## 📦 安装
47
48
 
@@ -121,6 +122,43 @@ result = (pipeline
121
122
  .execute())
122
123
  ```
123
124
 
125
+ ### 模型评估(v0.2.0+)
126
+
127
+ ```python
128
+ from dataset_toolkit import (
129
+ load_yolo_from_local,
130
+ load_predictions_from_streamlined,
131
+ Evaluator
132
+ )
133
+
134
+ # 1. 加载GT和预测结果
135
+ gt_dataset = load_yolo_from_local("/data/test/labels", {0: 'parcel'})
136
+ pred_dataset = load_predictions_from_streamlined(
137
+ "/results/predictions",
138
+ categories={0: 'parcel'},
139
+ image_dir="/data/test/images"
140
+ )
141
+
142
+ # 2. 创建评估器
143
+ evaluator = Evaluator(
144
+ positive_gt=gt_dataset,
145
+ positive_pred=pred_dataset,
146
+ iou_threshold=0.5
147
+ )
148
+
149
+ # 3. 计算指标
150
+ metrics = evaluator.calculate_metrics(confidence_threshold=0.5)
151
+ print(f"Precision: {metrics['precision']:.4f}")
152
+ print(f"Recall: {metrics['recall']:.4f}")
153
+ print(f"F1-Score: {metrics['f1']:.4f}")
154
+
155
+ # 4. 寻找最优阈值
156
+ optimal = evaluator.find_optimal_threshold(metric='f1')
157
+ print(f"最优阈值: {optimal['optimal_threshold']}")
158
+ ```
159
+
160
+ 详细文档请参考 [EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)
161
+
124
162
  ## 📚 API 文档
125
163
 
126
164
  ### 数据加载器
@@ -0,0 +1,19 @@
1
+ dataset_toolkit/__init__.py,sha256=S9o2DdR4QEdWkItk8cHlosu60EqpDGOxXBsCsupq9UE,2011
2
+ dataset_toolkit/models.py,sha256=q9Ud8GpDM4zXlj2cOaqNvHSgSC7eXswcku5VWypqA8U,985
3
+ dataset_toolkit/pipeline.py,sha256=iBJD7SemEVFTwzHxRQrjpUIQQcVdPSZnD4sB_y56Md0,5697
4
+ dataset_toolkit/exporters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ dataset_toolkit/exporters/coco_exporter.py,sha256=l5sfj7rOcvcMC0-4LNOEJ4PeklGQORDflU_um5GGnxA,2120
6
+ dataset_toolkit/exporters/txt_exporter.py,sha256=9nTWs6M89MdKJhlODtmfzeZqWkliXac9NMWPgVUrE7c,1246
7
+ dataset_toolkit/exporters/yolo_exporter.py,sha256=xz1XxgwNwtKUaa88BNjPr6NB4NNmkOQU4fyxIWG5gNk,5253
8
+ dataset_toolkit/loaders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ dataset_toolkit/loaders/local_loader.py,sha256=fjvLtp5NSPomVaqLv_XDa7Ugdk_jz_UW-aR3rn2sPHg,12872
10
+ dataset_toolkit/processors/__init__.py,sha256=cUJhs7joukr5YiTlnxHRwxligMvkY0s07nX2vOoYU-g,164
11
+ dataset_toolkit/processors/evaluator.py,sha256=UADzgS-7vSYWtORxQkqxTeiS3RJy3L7mBrw0YdkQYaM,18782
12
+ dataset_toolkit/processors/merger.py,sha256=h8qQNgSmkPrhoQ3QiWEyIl11CmmjT5K1-8TzNb7_jbk,2834
13
+ dataset_toolkit/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ dataset_toolkit/utils/coords.py,sha256=GtTQz2gFyFQfXhKfecI8tzqWFjraJY6Xo85-kRXYAYc,614
15
+ dataset_toolkit-0.2.0.dist-info/licenses/LICENSE,sha256=8_up1FX6vk2DRcusQEZ4pWJGkgkjvEkD14xB1hdLe3c,1067
16
+ dataset_toolkit-0.2.0.dist-info/METADATA,sha256=T0jHVp0sBDu4_7sp1G9BX2-Ibnu0il8Y7E0XDvguUKo,8260
17
+ dataset_toolkit-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ dataset_toolkit-0.2.0.dist-info/top_level.txt,sha256=B4D5vMLjUNJBZDdL7Utc0FYIfYoWbzyIGBMVYaeMd3U,16
19
+ dataset_toolkit-0.2.0.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- dataset_toolkit/__init__.py,sha256=BhgTMzT5onSjI-Sd2bFSQGYyo9GwcLZUyowlyx7lMyU,1676
2
- dataset_toolkit/models.py,sha256=9HD2lAOPuEytFb1qRejODLJAD-uKHc8Ya1n9nbGhRpg,830
3
- dataset_toolkit/pipeline.py,sha256=iBJD7SemEVFTwzHxRQrjpUIQQcVdPSZnD4sB_y56Md0,5697
4
- dataset_toolkit/exporters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- dataset_toolkit/exporters/coco_exporter.py,sha256=l5sfj7rOcvcMC0-4LNOEJ4PeklGQORDflU_um5GGnxA,2120
6
- dataset_toolkit/exporters/txt_exporter.py,sha256=9nTWs6M89MdKJhlODtmfzeZqWkliXac9NMWPgVUrE7c,1246
7
- dataset_toolkit/loaders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- dataset_toolkit/loaders/local_loader.py,sha256=SCOYG5pursEIL_m3QYGcm-2skXoapiOA4yhqqa2wrDM,7468
9
- dataset_toolkit/processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- dataset_toolkit/processors/merger.py,sha256=h8qQNgSmkPrhoQ3QiWEyIl11CmmjT5K1-8TzNb7_jbk,2834
11
- dataset_toolkit/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- dataset_toolkit/utils/coords.py,sha256=GtTQz2gFyFQfXhKfecI8tzqWFjraJY6Xo85-kRXYAYc,614
13
- dataset_toolkit-0.1.1.dist-info/licenses/LICENSE,sha256=8_up1FX6vk2DRcusQEZ4pWJGkgkjvEkD14xB1hdLe3c,1067
14
- dataset_toolkit-0.1.1.dist-info/METADATA,sha256=l3COSL22yVvDDZL_c_N5uJNjAPpKE0o2BasMso_Ntss,7236
15
- dataset_toolkit-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- dataset_toolkit-0.1.1.dist-info/top_level.txt,sha256=B4D5vMLjUNJBZDdL7Utc0FYIfYoWbzyIGBMVYaeMd3U,16
17
- dataset_toolkit-0.1.1.dist-info/RECORD,,