EasyMetrics 0.1.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,7 @@ class MeanAveragePrecision(Metric):
37
37
  确保输入数据是 numpy 数组,如果不是则转换为 numpy 数组。
38
38
 
39
39
  参数:
40
- data: 输入数据,可以是列表、numpy 数组或其他可转换类型。
40
+ data: 输入数据,可以是列表、numpy 数组、标量或其他可转换类型。
41
41
  dtype: 目标数据类型,默认为 None(自动推断)。
42
42
 
43
43
  返回:
@@ -47,11 +47,28 @@ class MeanAveragePrecision(Metric):
47
47
  if isinstance(data, np.ndarray):
48
48
  if dtype is not None and data.dtype != dtype:
49
49
  return data.astype(dtype)
50
+ # 特殊处理:如果是单个框的坐标 [x1, y1, x2, y2],转换为 (1, 4) 形状
51
+ if data.ndim == 1 and len(data) == 4:
52
+ return data.reshape(1, 4)
53
+ # 确保返回至少一维数组
54
+ if data.ndim == 0:
55
+ return data.reshape(1)
50
56
  return data
51
57
  elif data is None:
52
58
  return np.array([], dtype=dtype)
53
59
  else:
54
- return np.array(data, dtype=dtype)
60
+ # 如果输入是标量,转换为一维数组
61
+ try:
62
+ # 尝试迭代,如果是标量会抛出 TypeError
63
+ iter(data)
64
+ arr = np.array(data, dtype=dtype)
65
+ # 特殊处理:如果是单个框的坐标 [x1, y1, x2, y2],转换为 (1, 4) 形状
66
+ if arr.ndim == 1 and len(arr) == 4:
67
+ return arr.reshape(1, 4)
68
+ return arr
69
+ except TypeError:
70
+ # 输入是标量,转换为一维数组
71
+ return np.array([data], dtype=dtype)
55
72
 
56
73
  def update(self, preds: List[Dict[str, Any]], target: List[Dict[str, Any]]):
57
74
  """
@@ -62,21 +79,44 @@ class MeanAveragePrecision(Metric):
62
79
  # 自动转换 preds 中的数据类型
63
80
  converted_preds = []
64
81
  for pred in preds:
65
- converted_pred = {
66
- 'boxes': self._ensure_numpy_array(pred['boxes'], dtype=float),
67
- 'scores': self._ensure_numpy_array(pred['scores'], dtype=float),
68
- 'labels': self._ensure_numpy_array(pred['labels'], dtype=int)
69
- }
70
- converted_preds.append(converted_pred)
82
+ # 确保 boxes 是二维数组
83
+ boxes = self._ensure_numpy_array(pred['boxes'], dtype=float)
84
+ if boxes.ndim == 1 and len(boxes) == 4:
85
+ boxes = boxes.reshape(1, 4)
86
+
87
+ # 确保 scores 是一维数组
88
+ scores = self._ensure_numpy_array(pred['scores'], dtype=float)
89
+ if scores.ndim == 0:
90
+ scores = scores.reshape(1)
91
+
92
+ # 确保 labels 是一维数组
93
+ labels = self._ensure_numpy_array(pred['labels'], dtype=int)
94
+ if labels.ndim == 0:
95
+ labels = labels.reshape(1)
96
+
97
+ converted_preds.append({
98
+ 'boxes': boxes,
99
+ 'scores': scores,
100
+ 'labels': labels
101
+ })
71
102
 
72
103
  # 自动转换 targets 中的数据类型
73
104
  converted_targets = []
74
105
  for tgt in target:
75
- converted_target = {
76
- 'boxes': self._ensure_numpy_array(tgt['boxes'], dtype=float),
77
- 'labels': self._ensure_numpy_array(tgt['labels'], dtype=int)
78
- }
79
- converted_targets.append(converted_target)
106
+ # 确保 boxes 是二维数组
107
+ boxes = self._ensure_numpy_array(tgt['boxes'], dtype=float)
108
+ if boxes.ndim == 1 and len(boxes) == 4:
109
+ boxes = boxes.reshape(1, 4)
110
+
111
+ # 确保 labels 是一维数组
112
+ labels = self._ensure_numpy_array(tgt['labels'], dtype=int)
113
+ if labels.ndim == 0:
114
+ labels = labels.reshape(1)
115
+
116
+ converted_targets.append({
117
+ 'boxes': boxes,
118
+ 'labels': labels
119
+ })
80
120
 
81
121
  self.preds.extend(converted_preds)
82
122
  self.targets.extend(converted_targets)
@@ -254,6 +294,25 @@ class MeanAveragePrecision(Metric):
254
294
  mask = target['labels'] == cls_id
255
295
  boxes = target['boxes'][mask]
256
296
 
297
+ # 检查 boxes 是否为空
298
+ if len(boxes) == 0:
299
+ n_pos += 0
300
+ class_gt[img_idx] = {
301
+ 'boxes': np.array([]),
302
+ 'used': np.array([], dtype=bool)
303
+ }
304
+ continue
305
+
306
+ # 确保 boxes 是二维数组
307
+ if boxes.ndim == 1:
308
+ # 如果是一维数组,说明是空的或者格式不对,跳过
309
+ n_pos += 0
310
+ class_gt[img_idx] = {
311
+ 'boxes': np.array([]),
312
+ 'used': np.array([], dtype=bool)
313
+ }
314
+ continue
315
+
257
316
  areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
258
317
  valid_area_mask = (areas >= min_area) & (areas < max_area)
259
318
  valid_boxes = boxes[valid_area_mask]
@@ -272,6 +331,15 @@ class MeanAveragePrecision(Metric):
272
331
  scores = pred['scores'][mask]
273
332
  boxes = pred['boxes'][mask]
274
333
 
334
+ # 检查 boxes 是否为空
335
+ if len(boxes) == 0:
336
+ continue
337
+
338
+ # 确保 boxes 是二维数组
339
+ if boxes.ndim == 1:
340
+ # 如果是一维数组,说明是空的或者格式不对,跳过
341
+ continue
342
+
275
343
  areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
276
344
  valid_area_mask = (areas >= min_area) & (areas < max_area)
277
345
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: EasyMetrics
3
- Version: 0.1.5
3
+ Version: 0.3.0
4
4
  Summary: 轻量级、零依赖的机器学习指标评估平台
5
5
  Author-email: EasyMetrics Team <team@easymetrics.example>
6
6
  License: MIT License
@@ -5,11 +5,11 @@ easyMetrics/tasks/__init__.py,sha256=Cwc3WM7ZgblWq1jfIjrHt_bBrJRobx4q_WbHn3LaN68
5
5
  easyMetrics/tasks/detection/__init__.py,sha256=74SsA-fJbamX7WGIWJJmeTXjjHnogYU1jwYse65d45U,80
6
6
  easyMetrics/tasks/detection/format_converter.py,sha256=80vM8B3exuH6EHlmYwvMrpslTDm-FMWzuccQKka42dc,12632
7
7
  easyMetrics/tasks/detection/interface.py,sha256=ZbR9aJVXsRiSNGrhcVZAEGkfH2SFX0kJi0XY_rUEbMA,2568
8
- easyMetrics/tasks/detection/map.py,sha256=F2qVDTQR1qr1SHEf3w3Xg7yvUytzfQtxXdF6p25CcwA,17471
8
+ easyMetrics/tasks/detection/map.py,sha256=UhdBkuGMCq0Bib6hvxFXidLSn3aV9lu_thPMgPSGdVQ,20052
9
9
  easyMetrics/tasks/detection/matcher.py,sha256=C-4hEVBJseQKgcHupBLlaDNFGhKxOxN-YdPNnpuit9I,2443
10
10
  easyMetrics/tasks/detection/utils.py,sha256=jNDPHxWzGo7478QDI1f8uVGK6HUlwvHlPbfwzEk-H0A,2770
11
- easymetrics-0.1.5.dist-info/licenses/LICENSE,sha256=Z-thSEoGCfxHrpC8rpgrqY-LSoJuyHbgPj5Gt8BuG_0,1073
12
- easymetrics-0.1.5.dist-info/METADATA,sha256=2iaYmKt03dFhBsxudv6nFfiOLrmQ23Vsqm9mIwaCqMs,5265
13
- easymetrics-0.1.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
14
- easymetrics-0.1.5.dist-info/top_level.txt,sha256=2ytZxZ9LMTdKpeZunCE1rpvkpDAhWeQLiiXTrFy9niw,12
15
- easymetrics-0.1.5.dist-info/RECORD,,
11
+ easymetrics-0.3.0.dist-info/licenses/LICENSE,sha256=Z-thSEoGCfxHrpC8rpgrqY-LSoJuyHbgPj5Gt8BuG_0,1073
12
+ easymetrics-0.3.0.dist-info/METADATA,sha256=dmxyFGewhHcYNhnZPCPYUN4U6QjsRHnhS7Wnqo5WRso,5265
13
+ easymetrics-0.3.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
14
+ easymetrics-0.3.0.dist-info/top_level.txt,sha256=2ytZxZ9LMTdKpeZunCE1rpvkpDAhWeQLiiXTrFy9niw,12
15
+ easymetrics-0.3.0.dist-info/RECORD,,