ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
|
@@ -11,7 +11,10 @@ import torch
|
|
|
11
11
|
|
|
12
12
|
from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
|
|
13
13
|
|
|
14
|
-
OKS_SIGMA =
|
|
14
|
+
OKS_SIGMA = (
|
|
15
|
+
np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
|
|
16
|
+
/ 10.0
|
|
17
|
+
)
|
|
15
18
|
|
|
16
19
|
|
|
17
20
|
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
@@ -33,8 +36,9 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
|
33
36
|
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
|
34
37
|
|
|
35
38
|
# Intersection area
|
|
36
|
-
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) *
|
|
37
|
-
|
|
39
|
+
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
|
|
40
|
+
np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
|
|
41
|
+
).clip(0)
|
|
38
42
|
|
|
39
43
|
# Box2 area
|
|
40
44
|
area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
|
|
@@ -99,8 +103,9 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|
|
99
103
|
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
|
100
104
|
|
|
101
105
|
# Intersection area
|
|
102
|
-
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) *
|
|
103
|
-
|
|
106
|
+
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
|
|
107
|
+
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
|
|
108
|
+
).clamp_(0)
|
|
104
109
|
|
|
105
110
|
# Union Area
|
|
106
111
|
union = w1 * h1 + w2 * h2 - inter + eps
|
|
@@ -111,10 +116,10 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|
|
111
116
|
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
|
112
117
|
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
|
113
118
|
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
|
114
|
-
c2 = cw
|
|
119
|
+
c2 = cw**2 + ch**2 + eps # convex diagonal squared
|
|
115
120
|
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
|
116
121
|
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
|
117
|
-
v = (4 / math.pi
|
|
122
|
+
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
|
118
123
|
with torch.no_grad():
|
|
119
124
|
alpha = v / (v - iou + (1 + eps))
|
|
120
125
|
return iou - (rho2 / c2 + v * alpha) # CIoU
|
|
@@ -202,12 +207,19 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
|
202
207
|
a1, b1, c1 = _get_covariance_matrix(obb1)
|
|
203
208
|
a2, b2, c2 = _get_covariance_matrix(obb2)
|
|
204
209
|
|
|
205
|
-
t1 = (
|
|
206
|
-
|
|
210
|
+
t1 = (
|
|
211
|
+
((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
|
|
212
|
+
/ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
|
|
213
|
+
) * 0.25
|
|
207
214
|
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
|
|
208
|
-
t3 =
|
|
209
|
-
|
|
210
|
-
|
|
215
|
+
t3 = (
|
|
216
|
+
torch.log(
|
|
217
|
+
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
|
|
218
|
+
/ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
|
|
219
|
+
+ eps
|
|
220
|
+
)
|
|
221
|
+
* 0.5
|
|
222
|
+
)
|
|
211
223
|
bd = t1 + t2 + t3
|
|
212
224
|
bd = torch.clamp(bd, eps, 100.0)
|
|
213
225
|
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
|
|
@@ -215,7 +227,7 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
|
215
227
|
if CIoU: # only include the wh aspect ratio part
|
|
216
228
|
w1, h1 = obb1[..., 2:4].split(1, dim=-1)
|
|
217
229
|
w2, h2 = obb2[..., 2:4].split(1, dim=-1)
|
|
218
|
-
v = (4 / math.pi
|
|
230
|
+
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
|
219
231
|
with torch.no_grad():
|
|
220
232
|
alpha = v / (v - iou + (1 + eps))
|
|
221
233
|
return iou - v * alpha # CIoU
|
|
@@ -239,12 +251,19 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
|
|
239
251
|
a1, b1, c1 = _get_covariance_matrix(obb1)
|
|
240
252
|
a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
|
|
241
253
|
|
|
242
|
-
t1 = (
|
|
243
|
-
|
|
254
|
+
t1 = (
|
|
255
|
+
((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
|
|
256
|
+
/ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
|
|
257
|
+
) * 0.25
|
|
244
258
|
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
|
|
245
|
-
t3 =
|
|
246
|
-
|
|
247
|
-
|
|
259
|
+
t3 = (
|
|
260
|
+
torch.log(
|
|
261
|
+
((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
|
|
262
|
+
/ (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
|
|
263
|
+
+ eps
|
|
264
|
+
)
|
|
265
|
+
* 0.5
|
|
266
|
+
)
|
|
248
267
|
bd = t1 + t2 + t3
|
|
249
268
|
bd = torch.clamp(bd, eps, 100.0)
|
|
250
269
|
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
|
|
@@ -279,10 +298,10 @@ class ConfusionMatrix:
|
|
|
279
298
|
iou_thres (float): The Intersection over Union threshold.
|
|
280
299
|
"""
|
|
281
300
|
|
|
282
|
-
def __init__(self, nc, conf=0.25, iou_thres=0.45, task=
|
|
301
|
+
def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
|
|
283
302
|
"""Initialize attributes for the YOLO model."""
|
|
284
303
|
self.task = task
|
|
285
|
-
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task ==
|
|
304
|
+
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
|
286
305
|
self.nc = nc # number of classes
|
|
287
306
|
self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
|
|
288
307
|
self.iou_thres = iou_thres
|
|
@@ -361,11 +380,11 @@ class ConfusionMatrix:
|
|
|
361
380
|
tp = self.matrix.diagonal() # true positives
|
|
362
381
|
fp = self.matrix.sum(1) - tp # false positives
|
|
363
382
|
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
|
364
|
-
return (tp[:-1], fp[:-1]) if self.task ==
|
|
383
|
+
return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect
|
|
365
384
|
|
|
366
|
-
@TryExcept(
|
|
385
|
+
@TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")
|
|
367
386
|
@plt_settings()
|
|
368
|
-
def plot(self, normalize=True, save_dir=
|
|
387
|
+
def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
|
|
369
388
|
"""
|
|
370
389
|
Plot the confusion matrix using seaborn and save it to a file.
|
|
371
390
|
|
|
@@ -377,30 +396,31 @@ class ConfusionMatrix:
|
|
|
377
396
|
"""
|
|
378
397
|
import seaborn as sn
|
|
379
398
|
|
|
380
|
-
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) +
|
|
399
|
+
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
|
381
400
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
382
401
|
|
|
383
402
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
|
384
403
|
nc, nn = self.nc, len(names) # number of classes, names
|
|
385
404
|
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
|
386
405
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
|
387
|
-
ticklabels = (list(names) + [
|
|
406
|
+
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
|
388
407
|
with warnings.catch_warnings():
|
|
389
|
-
warnings.simplefilter(
|
|
390
|
-
sn.heatmap(
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
ax.
|
|
408
|
+
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
|
409
|
+
sn.heatmap(
|
|
410
|
+
array,
|
|
411
|
+
ax=ax,
|
|
412
|
+
annot=nc < 30,
|
|
413
|
+
annot_kws={"size": 8},
|
|
414
|
+
cmap="Blues",
|
|
415
|
+
fmt=".2f" if normalize else ".0f",
|
|
416
|
+
square=True,
|
|
417
|
+
vmin=0.0,
|
|
418
|
+
xticklabels=ticklabels,
|
|
419
|
+
yticklabels=ticklabels,
|
|
420
|
+
).set_facecolor((1, 1, 1))
|
|
421
|
+
title = "Confusion Matrix" + " Normalized" * normalize
|
|
422
|
+
ax.set_xlabel("True")
|
|
423
|
+
ax.set_ylabel("Predicted")
|
|
404
424
|
ax.set_title(title)
|
|
405
425
|
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
|
|
406
426
|
fig.savefig(plot_fname, dpi=250)
|
|
@@ -411,7 +431,7 @@ class ConfusionMatrix:
|
|
|
411
431
|
def print(self):
|
|
412
432
|
"""Print the confusion matrix to the console."""
|
|
413
433
|
for i in range(self.nc + 1):
|
|
414
|
-
LOGGER.info(
|
|
434
|
+
LOGGER.info(" ".join(map(str, self.matrix[i])))
|
|
415
435
|
|
|
416
436
|
|
|
417
437
|
def smooth(y, f=0.05):
|
|
@@ -419,28 +439,28 @@ def smooth(y, f=0.05):
|
|
|
419
439
|
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
|
420
440
|
p = np.ones(nf // 2) # ones padding
|
|
421
441
|
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
|
422
|
-
return np.convolve(yp, np.ones(nf) / nf, mode=
|
|
442
|
+
return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
|
|
423
443
|
|
|
424
444
|
|
|
425
445
|
@plt_settings()
|
|
426
|
-
def plot_pr_curve(px, py, ap, save_dir=Path(
|
|
446
|
+
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):
|
|
427
447
|
"""Plots a precision-recall curve."""
|
|
428
448
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
429
449
|
py = np.stack(py, axis=1)
|
|
430
450
|
|
|
431
451
|
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
432
452
|
for i, y in enumerate(py.T):
|
|
433
|
-
ax.plot(px, y, linewidth=1, label=f
|
|
453
|
+
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
|
434
454
|
else:
|
|
435
|
-
ax.plot(px, py, linewidth=1, color=
|
|
455
|
+
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
|
436
456
|
|
|
437
|
-
ax.plot(px, py.mean(1), linewidth=3, color=
|
|
438
|
-
ax.set_xlabel(
|
|
439
|
-
ax.set_ylabel(
|
|
457
|
+
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
|
|
458
|
+
ax.set_xlabel("Recall")
|
|
459
|
+
ax.set_ylabel("Precision")
|
|
440
460
|
ax.set_xlim(0, 1)
|
|
441
461
|
ax.set_ylim(0, 1)
|
|
442
|
-
ax.legend(bbox_to_anchor=(1.04, 1), loc=
|
|
443
|
-
ax.set_title(
|
|
462
|
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
463
|
+
ax.set_title("Precision-Recall Curve")
|
|
444
464
|
fig.savefig(save_dir, dpi=250)
|
|
445
465
|
plt.close(fig)
|
|
446
466
|
if on_plot:
|
|
@@ -448,24 +468,24 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N
|
|
|
448
468
|
|
|
449
469
|
|
|
450
470
|
@plt_settings()
|
|
451
|
-
def plot_mc_curve(px, py, save_dir=Path(
|
|
471
|
+
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):
|
|
452
472
|
"""Plots a metric-confidence curve."""
|
|
453
473
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
454
474
|
|
|
455
475
|
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
456
476
|
for i, y in enumerate(py):
|
|
457
|
-
ax.plot(px, y, linewidth=1, label=f
|
|
477
|
+
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
|
458
478
|
else:
|
|
459
|
-
ax.plot(px, py.T, linewidth=1, color=
|
|
479
|
+
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
|
460
480
|
|
|
461
481
|
y = smooth(py.mean(0), 0.05)
|
|
462
|
-
ax.plot(px, y, linewidth=3, color=
|
|
482
|
+
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
|
463
483
|
ax.set_xlabel(xlabel)
|
|
464
484
|
ax.set_ylabel(ylabel)
|
|
465
485
|
ax.set_xlim(0, 1)
|
|
466
486
|
ax.set_ylim(0, 1)
|
|
467
|
-
ax.legend(bbox_to_anchor=(1.04, 1), loc=
|
|
468
|
-
ax.set_title(f
|
|
487
|
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
488
|
+
ax.set_title(f"{ylabel}-Confidence Curve")
|
|
469
489
|
fig.savefig(save_dir, dpi=250)
|
|
470
490
|
plt.close(fig)
|
|
471
491
|
if on_plot:
|
|
@@ -494,8 +514,8 @@ def compute_ap(recall, precision):
|
|
|
494
514
|
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
|
495
515
|
|
|
496
516
|
# Integrate area under curve
|
|
497
|
-
method =
|
|
498
|
-
if method ==
|
|
517
|
+
method = "interp" # methods: 'continuous', 'interp'
|
|
518
|
+
if method == "interp":
|
|
499
519
|
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
|
500
520
|
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
|
501
521
|
else: # 'continuous'
|
|
@@ -505,16 +525,9 @@ def compute_ap(recall, precision):
|
|
|
505
525
|
return ap, mpre, mrec
|
|
506
526
|
|
|
507
527
|
|
|
508
|
-
def ap_per_class(
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
target_cls,
|
|
512
|
-
plot=False,
|
|
513
|
-
on_plot=None,
|
|
514
|
-
save_dir=Path(),
|
|
515
|
-
names=(),
|
|
516
|
-
eps=1e-16,
|
|
517
|
-
prefix=''):
|
|
528
|
+
def ap_per_class(
|
|
529
|
+
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
|
|
530
|
+
):
|
|
518
531
|
"""
|
|
519
532
|
Computes the average precision per class for object detection evaluation.
|
|
520
533
|
|
|
@@ -591,10 +604,10 @@ def ap_per_class(tp,
|
|
|
591
604
|
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
|
592
605
|
names = dict(enumerate(names)) # to dict
|
|
593
606
|
if plot:
|
|
594
|
-
plot_pr_curve(x, prec_values, ap, save_dir / f
|
|
595
|
-
plot_mc_curve(x, f1_curve, save_dir / f
|
|
596
|
-
plot_mc_curve(x, p_curve, save_dir / f
|
|
597
|
-
plot_mc_curve(x, r_curve, save_dir / f
|
|
607
|
+
plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
|
|
608
|
+
plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
|
|
609
|
+
plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
|
|
610
|
+
plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
|
|
598
611
|
|
|
599
612
|
i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index
|
|
600
613
|
p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values
|
|
@@ -746,8 +759,18 @@ class Metric(SimpleClass):
|
|
|
746
759
|
Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
|
|
747
760
|
on the values provided in the `results` tuple.
|
|
748
761
|
"""
|
|
749
|
-
(
|
|
750
|
-
|
|
762
|
+
(
|
|
763
|
+
self.p,
|
|
764
|
+
self.r,
|
|
765
|
+
self.f1,
|
|
766
|
+
self.all_ap,
|
|
767
|
+
self.ap_class_index,
|
|
768
|
+
self.p_curve,
|
|
769
|
+
self.r_curve,
|
|
770
|
+
self.f1_curve,
|
|
771
|
+
self.px,
|
|
772
|
+
self.prec_values,
|
|
773
|
+
) = results
|
|
751
774
|
|
|
752
775
|
@property
|
|
753
776
|
def curves(self):
|
|
@@ -757,8 +780,12 @@ class Metric(SimpleClass):
|
|
|
757
780
|
@property
|
|
758
781
|
def curves_results(self):
|
|
759
782
|
"""Returns a list of curves for accessing specific metrics curves."""
|
|
760
|
-
return [
|
|
761
|
-
|
|
783
|
+
return [
|
|
784
|
+
[self.px, self.prec_values, "Recall", "Precision"],
|
|
785
|
+
[self.px, self.f1_curve, "Confidence", "F1"],
|
|
786
|
+
[self.px, self.p_curve, "Confidence", "Precision"],
|
|
787
|
+
[self.px, self.r_curve, "Confidence", "Recall"],
|
|
788
|
+
]
|
|
762
789
|
|
|
763
790
|
|
|
764
791
|
class DetMetrics(SimpleClass):
|
|
@@ -793,33 +820,35 @@ class DetMetrics(SimpleClass):
|
|
|
793
820
|
curves_results: TODO
|
|
794
821
|
"""
|
|
795
822
|
|
|
796
|
-
def __init__(self, save_dir=Path(
|
|
823
|
+
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
|
797
824
|
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
|
|
798
825
|
self.save_dir = save_dir
|
|
799
826
|
self.plot = plot
|
|
800
827
|
self.on_plot = on_plot
|
|
801
828
|
self.names = names
|
|
802
829
|
self.box = Metric()
|
|
803
|
-
self.speed = {
|
|
804
|
-
self.task =
|
|
830
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
831
|
+
self.task = "detect"
|
|
805
832
|
|
|
806
833
|
def process(self, tp, conf, pred_cls, target_cls):
|
|
807
834
|
"""Process predicted results for object detection and update metrics."""
|
|
808
|
-
results = ap_per_class(
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
835
|
+
results = ap_per_class(
|
|
836
|
+
tp,
|
|
837
|
+
conf,
|
|
838
|
+
pred_cls,
|
|
839
|
+
target_cls,
|
|
840
|
+
plot=self.plot,
|
|
841
|
+
save_dir=self.save_dir,
|
|
842
|
+
names=self.names,
|
|
843
|
+
on_plot=self.on_plot,
|
|
844
|
+
)[2:]
|
|
816
845
|
self.box.nc = len(self.names)
|
|
817
846
|
self.box.update(results)
|
|
818
847
|
|
|
819
848
|
@property
|
|
820
849
|
def keys(self):
|
|
821
850
|
"""Returns a list of keys for accessing specific metrics."""
|
|
822
|
-
return [
|
|
851
|
+
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
|
823
852
|
|
|
824
853
|
def mean_results(self):
|
|
825
854
|
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
|
@@ -847,12 +876,12 @@ class DetMetrics(SimpleClass):
|
|
|
847
876
|
@property
|
|
848
877
|
def results_dict(self):
|
|
849
878
|
"""Returns dictionary of computed performance metrics and statistics."""
|
|
850
|
-
return dict(zip(self.keys + [
|
|
879
|
+
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
851
880
|
|
|
852
881
|
@property
|
|
853
882
|
def curves(self):
|
|
854
883
|
"""Returns a list of curves for accessing specific metrics curves."""
|
|
855
|
-
return [
|
|
884
|
+
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
|
|
856
885
|
|
|
857
886
|
@property
|
|
858
887
|
def curves_results(self):
|
|
@@ -889,7 +918,7 @@ class SegmentMetrics(SimpleClass):
|
|
|
889
918
|
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
|
890
919
|
"""
|
|
891
920
|
|
|
892
|
-
def __init__(self, save_dir=Path(
|
|
921
|
+
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
|
893
922
|
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
|
|
894
923
|
self.save_dir = save_dir
|
|
895
924
|
self.plot = plot
|
|
@@ -897,8 +926,8 @@ class SegmentMetrics(SimpleClass):
|
|
|
897
926
|
self.names = names
|
|
898
927
|
self.box = Metric()
|
|
899
928
|
self.seg = Metric()
|
|
900
|
-
self.speed = {
|
|
901
|
-
self.task =
|
|
929
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
930
|
+
self.task = "segment"
|
|
902
931
|
|
|
903
932
|
def process(self, tp, tp_m, conf, pred_cls, target_cls):
|
|
904
933
|
"""
|
|
@@ -912,26 +941,30 @@ class SegmentMetrics(SimpleClass):
|
|
|
912
941
|
target_cls (list): List of target classes.
|
|
913
942
|
"""
|
|
914
943
|
|
|
915
|
-
results_mask = ap_per_class(
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
944
|
+
results_mask = ap_per_class(
|
|
945
|
+
tp_m,
|
|
946
|
+
conf,
|
|
947
|
+
pred_cls,
|
|
948
|
+
target_cls,
|
|
949
|
+
plot=self.plot,
|
|
950
|
+
on_plot=self.on_plot,
|
|
951
|
+
save_dir=self.save_dir,
|
|
952
|
+
names=self.names,
|
|
953
|
+
prefix="Mask",
|
|
954
|
+
)[2:]
|
|
924
955
|
self.seg.nc = len(self.names)
|
|
925
956
|
self.seg.update(results_mask)
|
|
926
|
-
results_box = ap_per_class(
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
957
|
+
results_box = ap_per_class(
|
|
958
|
+
tp,
|
|
959
|
+
conf,
|
|
960
|
+
pred_cls,
|
|
961
|
+
target_cls,
|
|
962
|
+
plot=self.plot,
|
|
963
|
+
on_plot=self.on_plot,
|
|
964
|
+
save_dir=self.save_dir,
|
|
965
|
+
names=self.names,
|
|
966
|
+
prefix="Box",
|
|
967
|
+
)[2:]
|
|
935
968
|
self.box.nc = len(self.names)
|
|
936
969
|
self.box.update(results_box)
|
|
937
970
|
|
|
@@ -939,8 +972,15 @@ class SegmentMetrics(SimpleClass):
|
|
|
939
972
|
def keys(self):
|
|
940
973
|
"""Returns a list of keys for accessing metrics."""
|
|
941
974
|
return [
|
|
942
|
-
|
|
943
|
-
|
|
975
|
+
"metrics/precision(B)",
|
|
976
|
+
"metrics/recall(B)",
|
|
977
|
+
"metrics/mAP50(B)",
|
|
978
|
+
"metrics/mAP50-95(B)",
|
|
979
|
+
"metrics/precision(M)",
|
|
980
|
+
"metrics/recall(M)",
|
|
981
|
+
"metrics/mAP50(M)",
|
|
982
|
+
"metrics/mAP50-95(M)",
|
|
983
|
+
]
|
|
944
984
|
|
|
945
985
|
def mean_results(self):
|
|
946
986
|
"""Return the mean metrics for bounding box and segmentation results."""
|
|
@@ -968,14 +1008,21 @@ class SegmentMetrics(SimpleClass):
|
|
|
968
1008
|
@property
|
|
969
1009
|
def results_dict(self):
|
|
970
1010
|
"""Returns results of object detection model for evaluation."""
|
|
971
|
-
return dict(zip(self.keys + [
|
|
1011
|
+
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
972
1012
|
|
|
973
1013
|
@property
|
|
974
1014
|
def curves(self):
|
|
975
1015
|
"""Returns a list of curves for accessing specific metrics curves."""
|
|
976
1016
|
return [
|
|
977
|
-
|
|
978
|
-
|
|
1017
|
+
"Precision-Recall(B)",
|
|
1018
|
+
"F1-Confidence(B)",
|
|
1019
|
+
"Precision-Confidence(B)",
|
|
1020
|
+
"Recall-Confidence(B)",
|
|
1021
|
+
"Precision-Recall(M)",
|
|
1022
|
+
"F1-Confidence(M)",
|
|
1023
|
+
"Precision-Confidence(M)",
|
|
1024
|
+
"Recall-Confidence(M)",
|
|
1025
|
+
]
|
|
979
1026
|
|
|
980
1027
|
@property
|
|
981
1028
|
def curves_results(self):
|
|
@@ -1012,7 +1059,7 @@ class PoseMetrics(SegmentMetrics):
|
|
|
1012
1059
|
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
|
1013
1060
|
"""
|
|
1014
1061
|
|
|
1015
|
-
def __init__(self, save_dir=Path(
|
|
1062
|
+
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
|
1016
1063
|
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
|
|
1017
1064
|
super().__init__(save_dir, plot, names)
|
|
1018
1065
|
self.save_dir = save_dir
|
|
@@ -1021,8 +1068,8 @@ class PoseMetrics(SegmentMetrics):
|
|
|
1021
1068
|
self.names = names
|
|
1022
1069
|
self.box = Metric()
|
|
1023
1070
|
self.pose = Metric()
|
|
1024
|
-
self.speed = {
|
|
1025
|
-
self.task =
|
|
1071
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
1072
|
+
self.task = "pose"
|
|
1026
1073
|
|
|
1027
1074
|
def process(self, tp, tp_p, conf, pred_cls, target_cls):
|
|
1028
1075
|
"""
|
|
@@ -1036,26 +1083,30 @@ class PoseMetrics(SegmentMetrics):
|
|
|
1036
1083
|
target_cls (list): List of target classes.
|
|
1037
1084
|
"""
|
|
1038
1085
|
|
|
1039
|
-
results_pose = ap_per_class(
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1086
|
+
results_pose = ap_per_class(
|
|
1087
|
+
tp_p,
|
|
1088
|
+
conf,
|
|
1089
|
+
pred_cls,
|
|
1090
|
+
target_cls,
|
|
1091
|
+
plot=self.plot,
|
|
1092
|
+
on_plot=self.on_plot,
|
|
1093
|
+
save_dir=self.save_dir,
|
|
1094
|
+
names=self.names,
|
|
1095
|
+
prefix="Pose",
|
|
1096
|
+
)[2:]
|
|
1048
1097
|
self.pose.nc = len(self.names)
|
|
1049
1098
|
self.pose.update(results_pose)
|
|
1050
|
-
results_box = ap_per_class(
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1099
|
+
results_box = ap_per_class(
|
|
1100
|
+
tp,
|
|
1101
|
+
conf,
|
|
1102
|
+
pred_cls,
|
|
1103
|
+
target_cls,
|
|
1104
|
+
plot=self.plot,
|
|
1105
|
+
on_plot=self.on_plot,
|
|
1106
|
+
save_dir=self.save_dir,
|
|
1107
|
+
names=self.names,
|
|
1108
|
+
prefix="Box",
|
|
1109
|
+
)[2:]
|
|
1059
1110
|
self.box.nc = len(self.names)
|
|
1060
1111
|
self.box.update(results_box)
|
|
1061
1112
|
|
|
@@ -1063,8 +1114,15 @@ class PoseMetrics(SegmentMetrics):
|
|
|
1063
1114
|
def keys(self):
|
|
1064
1115
|
"""Returns list of evaluation metric keys."""
|
|
1065
1116
|
return [
|
|
1066
|
-
|
|
1067
|
-
|
|
1117
|
+
"metrics/precision(B)",
|
|
1118
|
+
"metrics/recall(B)",
|
|
1119
|
+
"metrics/mAP50(B)",
|
|
1120
|
+
"metrics/mAP50-95(B)",
|
|
1121
|
+
"metrics/precision(P)",
|
|
1122
|
+
"metrics/recall(P)",
|
|
1123
|
+
"metrics/mAP50(P)",
|
|
1124
|
+
"metrics/mAP50-95(P)",
|
|
1125
|
+
]
|
|
1068
1126
|
|
|
1069
1127
|
def mean_results(self):
|
|
1070
1128
|
"""Return the mean results of box and pose."""
|
|
@@ -1088,8 +1146,15 @@ class PoseMetrics(SegmentMetrics):
|
|
|
1088
1146
|
def curves(self):
|
|
1089
1147
|
"""Returns a list of curves for accessing specific metrics curves."""
|
|
1090
1148
|
return [
|
|
1091
|
-
|
|
1092
|
-
|
|
1149
|
+
"Precision-Recall(B)",
|
|
1150
|
+
"F1-Confidence(B)",
|
|
1151
|
+
"Precision-Confidence(B)",
|
|
1152
|
+
"Recall-Confidence(B)",
|
|
1153
|
+
"Precision-Recall(P)",
|
|
1154
|
+
"F1-Confidence(P)",
|
|
1155
|
+
"Precision-Confidence(P)",
|
|
1156
|
+
"Recall-Confidence(P)",
|
|
1157
|
+
]
|
|
1093
1158
|
|
|
1094
1159
|
@property
|
|
1095
1160
|
def curves_results(self):
|
|
@@ -1119,8 +1184,8 @@ class ClassifyMetrics(SimpleClass):
|
|
|
1119
1184
|
"""Initialize a ClassifyMetrics instance."""
|
|
1120
1185
|
self.top1 = 0
|
|
1121
1186
|
self.top5 = 0
|
|
1122
|
-
self.speed = {
|
|
1123
|
-
self.task =
|
|
1187
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
1188
|
+
self.task = "classify"
|
|
1124
1189
|
|
|
1125
1190
|
def process(self, targets, pred):
|
|
1126
1191
|
"""Target classes and predicted classes."""
|
|
@@ -1137,12 +1202,12 @@ class ClassifyMetrics(SimpleClass):
|
|
|
1137
1202
|
@property
|
|
1138
1203
|
def results_dict(self):
|
|
1139
1204
|
"""Returns a dictionary with model's performance metrics and fitness score."""
|
|
1140
|
-
return dict(zip(self.keys + [
|
|
1205
|
+
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
|
|
1141
1206
|
|
|
1142
1207
|
@property
|
|
1143
1208
|
def keys(self):
|
|
1144
1209
|
"""Returns a list of keys for the results_dict property."""
|
|
1145
|
-
return [
|
|
1210
|
+
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
|
1146
1211
|
|
|
1147
1212
|
@property
|
|
1148
1213
|
def curves(self):
|
|
@@ -1156,32 +1221,33 @@ class ClassifyMetrics(SimpleClass):
|
|
|
1156
1221
|
|
|
1157
1222
|
|
|
1158
1223
|
class OBBMetrics(SimpleClass):
|
|
1159
|
-
|
|
1160
|
-
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
|
1224
|
+
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
|
1161
1225
|
self.save_dir = save_dir
|
|
1162
1226
|
self.plot = plot
|
|
1163
1227
|
self.on_plot = on_plot
|
|
1164
1228
|
self.names = names
|
|
1165
1229
|
self.box = Metric()
|
|
1166
|
-
self.speed = {
|
|
1230
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
1167
1231
|
|
|
1168
1232
|
def process(self, tp, conf, pred_cls, target_cls):
|
|
1169
1233
|
"""Process predicted results for object detection and update metrics."""
|
|
1170
|
-
results = ap_per_class(
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1234
|
+
results = ap_per_class(
|
|
1235
|
+
tp,
|
|
1236
|
+
conf,
|
|
1237
|
+
pred_cls,
|
|
1238
|
+
target_cls,
|
|
1239
|
+
plot=self.plot,
|
|
1240
|
+
save_dir=self.save_dir,
|
|
1241
|
+
names=self.names,
|
|
1242
|
+
on_plot=self.on_plot,
|
|
1243
|
+
)[2:]
|
|
1178
1244
|
self.box.nc = len(self.names)
|
|
1179
1245
|
self.box.update(results)
|
|
1180
1246
|
|
|
1181
1247
|
@property
|
|
1182
1248
|
def keys(self):
|
|
1183
1249
|
"""Returns a list of keys for accessing specific metrics."""
|
|
1184
|
-
return [
|
|
1250
|
+
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
|
1185
1251
|
|
|
1186
1252
|
def mean_results(self):
|
|
1187
1253
|
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
|
@@ -1209,7 +1275,7 @@ class OBBMetrics(SimpleClass):
|
|
|
1209
1275
|
@property
|
|
1210
1276
|
def results_dict(self):
|
|
1211
1277
|
"""Returns dictionary of computed performance metrics and statistics."""
|
|
1212
|
-
return dict(zip(self.keys + [
|
|
1278
|
+
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
1213
1279
|
|
|
1214
1280
|
@property
|
|
1215
1281
|
def curves(self):
|