bplusplus 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +5 -3
- bplusplus/{collect_images.py → collect.py} +3 -3
- bplusplus/prepare.py +573 -0
- bplusplus/train_validate.py +8 -64
- bplusplus/yolov5detect/__init__.py +1 -0
- bplusplus/yolov5detect/detect.py +444 -0
- bplusplus/yolov5detect/export.py +1530 -0
- bplusplus/yolov5detect/insect.yaml +8 -0
- bplusplus/yolov5detect/models/__init__.py +0 -0
- bplusplus/yolov5detect/models/common.py +1109 -0
- bplusplus/yolov5detect/models/experimental.py +130 -0
- bplusplus/yolov5detect/models/hub/anchors.yaml +56 -0
- bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +52 -0
- bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +42 -0
- bplusplus/yolov5detect/models/hub/yolov3.yaml +52 -0
- bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +49 -0
- bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +43 -0
- bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +55 -0
- bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +42 -0
- bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +57 -0
- bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +68 -0
- bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +49 -0
- bplusplus/yolov5detect/models/hub/yolov5l6.yaml +61 -0
- bplusplus/yolov5detect/models/hub/yolov5m6.yaml +61 -0
- bplusplus/yolov5detect/models/hub/yolov5n6.yaml +61 -0
- bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +50 -0
- bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +49 -0
- bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +49 -0
- bplusplus/yolov5detect/models/hub/yolov5s6.yaml +61 -0
- bplusplus/yolov5detect/models/hub/yolov5x6.yaml +61 -0
- bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +49 -0
- bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +49 -0
- bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +49 -0
- bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +49 -0
- bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +49 -0
- bplusplus/yolov5detect/models/tf.py +797 -0
- bplusplus/yolov5detect/models/yolo.py +495 -0
- bplusplus/yolov5detect/models/yolov5l.yaml +49 -0
- bplusplus/yolov5detect/models/yolov5m.yaml +49 -0
- bplusplus/yolov5detect/models/yolov5n.yaml +49 -0
- bplusplus/yolov5detect/models/yolov5s.yaml +49 -0
- bplusplus/yolov5detect/models/yolov5x.yaml +49 -0
- bplusplus/yolov5detect/utils/__init__.py +97 -0
- bplusplus/yolov5detect/utils/activations.py +134 -0
- bplusplus/yolov5detect/utils/augmentations.py +448 -0
- bplusplus/yolov5detect/utils/autoanchor.py +175 -0
- bplusplus/yolov5detect/utils/autobatch.py +70 -0
- bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
- bplusplus/yolov5detect/utils/aws/mime.sh +26 -0
- bplusplus/yolov5detect/utils/aws/resume.py +41 -0
- bplusplus/yolov5detect/utils/aws/userdata.sh +27 -0
- bplusplus/yolov5detect/utils/callbacks.py +72 -0
- bplusplus/yolov5detect/utils/dataloaders.py +1385 -0
- bplusplus/yolov5detect/utils/docker/Dockerfile +73 -0
- bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +40 -0
- bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +42 -0
- bplusplus/yolov5detect/utils/downloads.py +136 -0
- bplusplus/yolov5detect/utils/flask_rest_api/README.md +70 -0
- bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +17 -0
- bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +49 -0
- bplusplus/yolov5detect/utils/general.py +1294 -0
- bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +25 -0
- bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +6 -0
- bplusplus/yolov5detect/utils/google_app_engine/app.yaml +16 -0
- bplusplus/yolov5detect/utils/loggers/__init__.py +476 -0
- bplusplus/yolov5detect/utils/loggers/clearml/README.md +222 -0
- bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
- bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +230 -0
- bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +90 -0
- bplusplus/yolov5detect/utils/loggers/comet/README.md +250 -0
- bplusplus/yolov5detect/utils/loggers/comet/__init__.py +551 -0
- bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +151 -0
- bplusplus/yolov5detect/utils/loggers/comet/hpo.py +126 -0
- bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +135 -0
- bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
- bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +210 -0
- bplusplus/yolov5detect/utils/loss.py +259 -0
- bplusplus/yolov5detect/utils/metrics.py +381 -0
- bplusplus/yolov5detect/utils/plots.py +517 -0
- bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
- bplusplus/yolov5detect/utils/segment/augmentations.py +100 -0
- bplusplus/yolov5detect/utils/segment/dataloaders.py +366 -0
- bplusplus/yolov5detect/utils/segment/general.py +160 -0
- bplusplus/yolov5detect/utils/segment/loss.py +198 -0
- bplusplus/yolov5detect/utils/segment/metrics.py +225 -0
- bplusplus/yolov5detect/utils/segment/plots.py +152 -0
- bplusplus/yolov5detect/utils/torch_utils.py +482 -0
- bplusplus/yolov5detect/utils/triton.py +90 -0
- bplusplus-1.1.0.dist-info/METADATA +179 -0
- bplusplus-1.1.0.dist-info/RECORD +92 -0
- bplusplus/build_model.py +0 -38
- bplusplus-0.1.0.dist-info/METADATA +0 -91
- bplusplus-0.1.0.dist-info/RECORD +0 -8
- {bplusplus-0.1.0.dist-info → bplusplus-1.1.0.dist-info}/LICENSE +0 -0
- {bplusplus-0.1.0.dist-info → bplusplus-1.1.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
# Ultralytics YOLOv5 🚀, AGPL-3.0 license
|
|
2
|
+
"""Loss functions."""
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from utils.metrics import bbox_iou
|
|
8
|
+
from utils.torch_utils import de_parallel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def smooth_BCE(eps=0.1):
|
|
12
|
+
"""Returns label smoothing BCE targets for reducing overfitting; pos: `1.0 - 0.5*eps`, neg: `0.5*eps`. For details see https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441."""
|
|
13
|
+
return 1.0 - 0.5 * eps, 0.5 * eps
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BCEBlurWithLogitsLoss(nn.Module):
|
|
17
|
+
"""Modified BCEWithLogitsLoss to reduce missing label effects in YOLOv5 training with optional alpha smoothing."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, alpha=0.05):
|
|
20
|
+
"""Initializes a modified BCEWithLogitsLoss with reduced missing label effects, taking optional alpha smoothing
|
|
21
|
+
parameter.
|
|
22
|
+
"""
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.loss_fcn = nn.BCEWithLogitsLoss(reduction="none") # must be nn.BCEWithLogitsLoss()
|
|
25
|
+
self.alpha = alpha
|
|
26
|
+
|
|
27
|
+
def forward(self, pred, true):
|
|
28
|
+
"""Computes modified BCE loss for YOLOv5 with reduced missing label effects, taking pred and true tensors,
|
|
29
|
+
returns mean loss.
|
|
30
|
+
"""
|
|
31
|
+
loss = self.loss_fcn(pred, true)
|
|
32
|
+
pred = torch.sigmoid(pred) # prob from logits
|
|
33
|
+
dx = pred - true # reduce only missing label effects
|
|
34
|
+
# dx = (pred - true).abs() # reduce missing label and false label effects
|
|
35
|
+
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
|
|
36
|
+
loss *= alpha_factor
|
|
37
|
+
return loss.mean()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FocalLoss(nn.Module):
|
|
41
|
+
"""Applies focal loss to address class imbalance by modifying BCEWithLogitsLoss with gamma and alpha parameters."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
|
44
|
+
"""Initializes FocalLoss with specified loss function, gamma, and alpha values; modifies loss reduction to
|
|
45
|
+
'none'.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
|
49
|
+
self.gamma = gamma
|
|
50
|
+
self.alpha = alpha
|
|
51
|
+
self.reduction = loss_fcn.reduction
|
|
52
|
+
self.loss_fcn.reduction = "none" # required to apply FL to each element
|
|
53
|
+
|
|
54
|
+
def forward(self, pred, true):
|
|
55
|
+
"""Calculates the focal loss between predicted and true labels using a modified BCEWithLogitsLoss."""
|
|
56
|
+
loss = self.loss_fcn(pred, true)
|
|
57
|
+
# p_t = torch.exp(-loss)
|
|
58
|
+
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
|
59
|
+
|
|
60
|
+
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
|
61
|
+
pred_prob = torch.sigmoid(pred) # prob from logits
|
|
62
|
+
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
|
63
|
+
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
|
64
|
+
modulating_factor = (1.0 - p_t) ** self.gamma
|
|
65
|
+
loss *= alpha_factor * modulating_factor
|
|
66
|
+
|
|
67
|
+
if self.reduction == "mean":
|
|
68
|
+
return loss.mean()
|
|
69
|
+
elif self.reduction == "sum":
|
|
70
|
+
return loss.sum()
|
|
71
|
+
else: # 'none'
|
|
72
|
+
return loss
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class QFocalLoss(nn.Module):
|
|
76
|
+
"""Implements Quality Focal Loss to address class imbalance by modulating loss based on prediction confidence."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
|
79
|
+
"""Initializes Quality Focal Loss with given loss function, gamma, alpha; modifies reduction to 'none'."""
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
|
82
|
+
self.gamma = gamma
|
|
83
|
+
self.alpha = alpha
|
|
84
|
+
self.reduction = loss_fcn.reduction
|
|
85
|
+
self.loss_fcn.reduction = "none" # required to apply FL to each element
|
|
86
|
+
|
|
87
|
+
def forward(self, pred, true):
|
|
88
|
+
"""Computes the focal loss between `pred` and `true` using BCEWithLogitsLoss, adjusting for imbalance with
|
|
89
|
+
`gamma` and `alpha`.
|
|
90
|
+
"""
|
|
91
|
+
loss = self.loss_fcn(pred, true)
|
|
92
|
+
|
|
93
|
+
pred_prob = torch.sigmoid(pred) # prob from logits
|
|
94
|
+
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
|
95
|
+
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
|
|
96
|
+
loss *= alpha_factor * modulating_factor
|
|
97
|
+
|
|
98
|
+
if self.reduction == "mean":
|
|
99
|
+
return loss.mean()
|
|
100
|
+
elif self.reduction == "sum":
|
|
101
|
+
return loss.sum()
|
|
102
|
+
else: # 'none'
|
|
103
|
+
return loss
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ComputeLoss:
|
|
107
|
+
"""Computes the total loss for YOLOv5 model predictions, including classification, box, and objectness losses."""
|
|
108
|
+
|
|
109
|
+
sort_obj_iou = False
|
|
110
|
+
|
|
111
|
+
# Compute losses
|
|
112
|
+
def __init__(self, model, autobalance=False):
|
|
113
|
+
"""Initializes ComputeLoss with model and autobalance option, autobalances losses if True."""
|
|
114
|
+
device = next(model.parameters()).device # get model device
|
|
115
|
+
h = model.hyp # hyperparameters
|
|
116
|
+
|
|
117
|
+
# Define criteria
|
|
118
|
+
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device))
|
|
119
|
+
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["obj_pw"]], device=device))
|
|
120
|
+
|
|
121
|
+
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
|
122
|
+
self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets
|
|
123
|
+
|
|
124
|
+
# Focal loss
|
|
125
|
+
g = h["fl_gamma"] # focal loss gamma
|
|
126
|
+
if g > 0:
|
|
127
|
+
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
|
128
|
+
|
|
129
|
+
m = de_parallel(model).model[-1] # Detect() module
|
|
130
|
+
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
|
|
131
|
+
self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
|
|
132
|
+
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
|
|
133
|
+
self.na = m.na # number of anchors
|
|
134
|
+
self.nc = m.nc # number of classes
|
|
135
|
+
self.nl = m.nl # number of layers
|
|
136
|
+
self.anchors = m.anchors
|
|
137
|
+
self.device = device
|
|
138
|
+
|
|
139
|
+
def __call__(self, p, targets): # predictions, targets
|
|
140
|
+
"""Performs forward pass, calculating class, box, and object loss for given predictions and targets."""
|
|
141
|
+
lcls = torch.zeros(1, device=self.device) # class loss
|
|
142
|
+
lbox = torch.zeros(1, device=self.device) # box loss
|
|
143
|
+
lobj = torch.zeros(1, device=self.device) # object loss
|
|
144
|
+
tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
|
|
145
|
+
|
|
146
|
+
# Losses
|
|
147
|
+
for i, pi in enumerate(p): # layer index, layer predictions
|
|
148
|
+
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
|
149
|
+
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
|
|
150
|
+
|
|
151
|
+
n = b.shape[0] # number of targets
|
|
152
|
+
if n:
|
|
153
|
+
# pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
|
|
154
|
+
pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions
|
|
155
|
+
|
|
156
|
+
# Regression
|
|
157
|
+
pxy = pxy.sigmoid() * 2 - 0.5
|
|
158
|
+
pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
|
|
159
|
+
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
|
160
|
+
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
|
|
161
|
+
lbox += (1.0 - iou).mean() # iou loss
|
|
162
|
+
|
|
163
|
+
# Objectness
|
|
164
|
+
iou = iou.detach().clamp(0).type(tobj.dtype)
|
|
165
|
+
if self.sort_obj_iou:
|
|
166
|
+
j = iou.argsort()
|
|
167
|
+
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
|
|
168
|
+
if self.gr < 1:
|
|
169
|
+
iou = (1.0 - self.gr) + self.gr * iou
|
|
170
|
+
tobj[b, a, gj, gi] = iou # iou ratio
|
|
171
|
+
|
|
172
|
+
# Classification
|
|
173
|
+
if self.nc > 1: # cls loss (only if multiple classes)
|
|
174
|
+
t = torch.full_like(pcls, self.cn, device=self.device) # targets
|
|
175
|
+
t[range(n), tcls[i]] = self.cp
|
|
176
|
+
lcls += self.BCEcls(pcls, t) # BCE
|
|
177
|
+
|
|
178
|
+
# Append targets to text file
|
|
179
|
+
# with open('targets.txt', 'a') as file:
|
|
180
|
+
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
|
181
|
+
|
|
182
|
+
obji = self.BCEobj(pi[..., 4], tobj)
|
|
183
|
+
lobj += obji * self.balance[i] # obj loss
|
|
184
|
+
if self.autobalance:
|
|
185
|
+
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
|
|
186
|
+
|
|
187
|
+
if self.autobalance:
|
|
188
|
+
self.balance = [x / self.balance[self.ssi] for x in self.balance]
|
|
189
|
+
lbox *= self.hyp["box"]
|
|
190
|
+
lobj *= self.hyp["obj"]
|
|
191
|
+
lcls *= self.hyp["cls"]
|
|
192
|
+
bs = tobj.shape[0] # batch size
|
|
193
|
+
|
|
194
|
+
return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
|
|
195
|
+
|
|
196
|
+
def build_targets(self, p, targets):
|
|
197
|
+
"""Prepares model targets from input targets (image,class,x,y,w,h) for loss computation, returning class, box,
|
|
198
|
+
indices, and anchors.
|
|
199
|
+
"""
|
|
200
|
+
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
|
201
|
+
tcls, tbox, indices, anch = [], [], [], []
|
|
202
|
+
gain = torch.ones(7, device=self.device) # normalized to gridspace gain
|
|
203
|
+
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
|
204
|
+
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices
|
|
205
|
+
|
|
206
|
+
g = 0.5 # bias
|
|
207
|
+
off = (
|
|
208
|
+
torch.tensor(
|
|
209
|
+
[
|
|
210
|
+
[0, 0],
|
|
211
|
+
[1, 0],
|
|
212
|
+
[0, 1],
|
|
213
|
+
[-1, 0],
|
|
214
|
+
[0, -1], # j,k,l,m
|
|
215
|
+
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
|
|
216
|
+
],
|
|
217
|
+
device=self.device,
|
|
218
|
+
).float()
|
|
219
|
+
* g
|
|
220
|
+
) # offsets
|
|
221
|
+
|
|
222
|
+
for i in range(self.nl):
|
|
223
|
+
anchors, shape = self.anchors[i], p[i].shape
|
|
224
|
+
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
|
|
225
|
+
|
|
226
|
+
# Match targets to anchors
|
|
227
|
+
t = targets * gain # shape(3,n,7)
|
|
228
|
+
if nt:
|
|
229
|
+
# Matches
|
|
230
|
+
r = t[..., 4:6] / anchors[:, None] # wh ratio
|
|
231
|
+
j = torch.max(r, 1 / r).max(2)[0] < self.hyp["anchor_t"] # compare
|
|
232
|
+
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
|
|
233
|
+
t = t[j] # filter
|
|
234
|
+
|
|
235
|
+
# Offsets
|
|
236
|
+
gxy = t[:, 2:4] # grid xy
|
|
237
|
+
gxi = gain[[2, 3]] - gxy # inverse
|
|
238
|
+
j, k = ((gxy % 1 < g) & (gxy > 1)).T
|
|
239
|
+
l, m = ((gxi % 1 < g) & (gxi > 1)).T
|
|
240
|
+
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
|
241
|
+
t = t.repeat((5, 1, 1))[j]
|
|
242
|
+
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
|
243
|
+
else:
|
|
244
|
+
t = targets[0]
|
|
245
|
+
offsets = 0
|
|
246
|
+
|
|
247
|
+
# Define
|
|
248
|
+
bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
|
|
249
|
+
a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class
|
|
250
|
+
gij = (gxy - offsets).long()
|
|
251
|
+
gi, gj = gij.T # grid indices
|
|
252
|
+
|
|
253
|
+
# Append
|
|
254
|
+
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
|
|
255
|
+
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
|
256
|
+
anch.append(anchors[a]) # anchors
|
|
257
|
+
tcls.append(c) # class
|
|
258
|
+
|
|
259
|
+
return tcls, tbox, indices, anch
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
# Ultralytics YOLOv5 🚀, AGPL-3.0 license
|
|
2
|
+
"""Model validation metrics."""
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
import warnings
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from utils import TryExcept, threaded
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fitness(x):
|
|
16
|
+
"""Calculates fitness of a model using weighted sum of metrics P, R, mAP@0.5, mAP@0.5:0.95."""
|
|
17
|
+
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
|
18
|
+
return (x[:, :4] * w).sum(1)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def smooth(y, f=0.05):
|
|
22
|
+
"""Applies box filter smoothing to array `y` with fraction `f`, yielding a smoothed array."""
|
|
23
|
+
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
|
24
|
+
p = np.ones(nf // 2) # ones padding
|
|
25
|
+
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
|
26
|
+
return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=".", names=(), eps=1e-16, prefix=""):
|
|
30
|
+
"""
|
|
31
|
+
Compute the average precision, given the recall and precision curves.
|
|
32
|
+
|
|
33
|
+
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
|
34
|
+
# Arguments
|
|
35
|
+
tp: True positives (nparray, nx1 or nx10).
|
|
36
|
+
conf: Objectness value from 0-1 (nparray).
|
|
37
|
+
pred_cls: Predicted object classes (nparray).
|
|
38
|
+
target_cls: True object classes (nparray).
|
|
39
|
+
plot: Plot precision-recall curve at mAP@0.5
|
|
40
|
+
save_dir: Plot save directory
|
|
41
|
+
# Returns
|
|
42
|
+
The average precision as computed in py-faster-rcnn.
|
|
43
|
+
"""
|
|
44
|
+
# Sort by objectness
|
|
45
|
+
i = np.argsort(-conf)
|
|
46
|
+
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
|
47
|
+
|
|
48
|
+
# Find unique classes
|
|
49
|
+
unique_classes, nt = np.unique(target_cls, return_counts=True)
|
|
50
|
+
nc = unique_classes.shape[0] # number of classes, number of detections
|
|
51
|
+
|
|
52
|
+
# Create Precision-Recall curve and compute AP for each class
|
|
53
|
+
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
|
54
|
+
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
|
55
|
+
for ci, c in enumerate(unique_classes):
|
|
56
|
+
i = pred_cls == c
|
|
57
|
+
n_l = nt[ci] # number of labels
|
|
58
|
+
n_p = i.sum() # number of predictions
|
|
59
|
+
if n_p == 0 or n_l == 0:
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
# Accumulate FPs and TPs
|
|
63
|
+
fpc = (1 - tp[i]).cumsum(0)
|
|
64
|
+
tpc = tp[i].cumsum(0)
|
|
65
|
+
|
|
66
|
+
# Recall
|
|
67
|
+
recall = tpc / (n_l + eps) # recall curve
|
|
68
|
+
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
|
69
|
+
|
|
70
|
+
# Precision
|
|
71
|
+
precision = tpc / (tpc + fpc) # precision curve
|
|
72
|
+
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
|
73
|
+
|
|
74
|
+
# AP from recall-precision curve
|
|
75
|
+
for j in range(tp.shape[1]):
|
|
76
|
+
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
|
77
|
+
if plot and j == 0:
|
|
78
|
+
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
|
79
|
+
|
|
80
|
+
# Compute F1 (harmonic mean of precision and recall)
|
|
81
|
+
f1 = 2 * p * r / (p + r + eps)
|
|
82
|
+
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
|
83
|
+
names = dict(enumerate(names)) # to dict
|
|
84
|
+
if plot:
|
|
85
|
+
plot_pr_curve(px, py, ap, Path(save_dir) / f"{prefix}PR_curve.png", names)
|
|
86
|
+
plot_mc_curve(px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1")
|
|
87
|
+
plot_mc_curve(px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision")
|
|
88
|
+
plot_mc_curve(px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall")
|
|
89
|
+
|
|
90
|
+
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
|
91
|
+
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
|
92
|
+
tp = (r * nt).round() # true positives
|
|
93
|
+
fp = (tp / (p + eps) - tp).round() # false positives
|
|
94
|
+
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def compute_ap(recall, precision):
|
|
98
|
+
"""Compute the average precision, given the recall and precision curves
|
|
99
|
+
# Arguments
|
|
100
|
+
recall: The recall curve (list)
|
|
101
|
+
precision: The precision curve (list)
|
|
102
|
+
# Returns
|
|
103
|
+
Average precision, precision curve, recall curve.
|
|
104
|
+
"""
|
|
105
|
+
# Append sentinel values to beginning and end
|
|
106
|
+
mrec = np.concatenate(([0.0], recall, [1.0]))
|
|
107
|
+
mpre = np.concatenate(([1.0], precision, [0.0]))
|
|
108
|
+
|
|
109
|
+
# Compute the precision envelope
|
|
110
|
+
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
|
111
|
+
|
|
112
|
+
# Integrate area under curve
|
|
113
|
+
method = "interp" # methods: 'continuous', 'interp'
|
|
114
|
+
if method == "interp":
|
|
115
|
+
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
|
116
|
+
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
|
117
|
+
else: # 'continuous'
|
|
118
|
+
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
|
119
|
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
|
120
|
+
|
|
121
|
+
return ap, mpre, mrec
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ConfusionMatrix:
|
|
125
|
+
"""Generates and visualizes a confusion matrix for evaluating object detection classification performance."""
|
|
126
|
+
|
|
127
|
+
def __init__(self, nc, conf=0.25, iou_thres=0.45):
|
|
128
|
+
"""Initializes ConfusionMatrix with given number of classes, confidence, and IoU threshold."""
|
|
129
|
+
self.matrix = np.zeros((nc + 1, nc + 1))
|
|
130
|
+
self.nc = nc # number of classes
|
|
131
|
+
self.conf = conf
|
|
132
|
+
self.iou_thres = iou_thres
|
|
133
|
+
|
|
134
|
+
def process_batch(self, detections, labels):
|
|
135
|
+
"""
|
|
136
|
+
Return intersection-over-union (Jaccard index) of boxes.
|
|
137
|
+
|
|
138
|
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
139
|
+
|
|
140
|
+
Arguments:
|
|
141
|
+
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
|
|
142
|
+
labels (Array[M, 5]), class, x1, y1, x2, y2
|
|
143
|
+
Returns:
|
|
144
|
+
None, updates confusion matrix accordingly
|
|
145
|
+
"""
|
|
146
|
+
if detections is None:
|
|
147
|
+
gt_classes = labels.int()
|
|
148
|
+
for gc in gt_classes:
|
|
149
|
+
self.matrix[self.nc, gc] += 1 # background FN
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
detections = detections[detections[:, 4] > self.conf]
|
|
153
|
+
gt_classes = labels[:, 0].int()
|
|
154
|
+
detection_classes = detections[:, 5].int()
|
|
155
|
+
iou = box_iou(labels[:, 1:], detections[:, :4])
|
|
156
|
+
|
|
157
|
+
x = torch.where(iou > self.iou_thres)
|
|
158
|
+
if x[0].shape[0]:
|
|
159
|
+
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
160
|
+
if x[0].shape[0] > 1:
|
|
161
|
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
162
|
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
163
|
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
164
|
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
165
|
+
else:
|
|
166
|
+
matches = np.zeros((0, 3))
|
|
167
|
+
|
|
168
|
+
n = matches.shape[0] > 0
|
|
169
|
+
m0, m1, _ = matches.transpose().astype(int)
|
|
170
|
+
for i, gc in enumerate(gt_classes):
|
|
171
|
+
j = m0 == i
|
|
172
|
+
if n and sum(j) == 1:
|
|
173
|
+
self.matrix[detection_classes[m1[j]], gc] += 1 # correct
|
|
174
|
+
else:
|
|
175
|
+
self.matrix[self.nc, gc] += 1 # true background
|
|
176
|
+
|
|
177
|
+
if n:
|
|
178
|
+
for i, dc in enumerate(detection_classes):
|
|
179
|
+
if not any(m1 == i):
|
|
180
|
+
self.matrix[dc, self.nc] += 1 # predicted background
|
|
181
|
+
|
|
182
|
+
def tp_fp(self):
|
|
183
|
+
"""Calculates true positives (tp) and false positives (fp) excluding the background class from the confusion
|
|
184
|
+
matrix.
|
|
185
|
+
"""
|
|
186
|
+
tp = self.matrix.diagonal() # true positives
|
|
187
|
+
fp = self.matrix.sum(1) - tp # false positives
|
|
188
|
+
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
|
189
|
+
return tp[:-1], fp[:-1] # remove background class
|
|
190
|
+
|
|
191
|
+
@TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")
|
|
192
|
+
def plot(self, normalize=True, save_dir="", names=()):
|
|
193
|
+
"""Plots confusion matrix using seaborn, optional normalization; can save plot to specified directory."""
|
|
194
|
+
import seaborn as sn
|
|
195
|
+
|
|
196
|
+
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
|
197
|
+
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
198
|
+
|
|
199
|
+
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
|
200
|
+
nc, nn = self.nc, len(names) # number of classes, names
|
|
201
|
+
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
|
202
|
+
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
|
203
|
+
ticklabels = (names + ["background"]) if labels else "auto"
|
|
204
|
+
with warnings.catch_warnings():
|
|
205
|
+
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
|
206
|
+
sn.heatmap(
|
|
207
|
+
array,
|
|
208
|
+
ax=ax,
|
|
209
|
+
annot=nc < 30,
|
|
210
|
+
annot_kws={"size": 8},
|
|
211
|
+
cmap="Blues",
|
|
212
|
+
fmt=".2f",
|
|
213
|
+
square=True,
|
|
214
|
+
vmin=0.0,
|
|
215
|
+
xticklabels=ticklabels,
|
|
216
|
+
yticklabels=ticklabels,
|
|
217
|
+
).set_facecolor((1, 1, 1))
|
|
218
|
+
ax.set_xlabel("True")
|
|
219
|
+
ax.set_ylabel("Predicted")
|
|
220
|
+
ax.set_title("Confusion Matrix")
|
|
221
|
+
fig.savefig(Path(save_dir) / "confusion_matrix.png", dpi=250)
|
|
222
|
+
plt.close(fig)
|
|
223
|
+
|
|
224
|
+
def print(self):
|
|
225
|
+
"""Prints the confusion matrix row-wise, with each class and its predictions separated by spaces."""
|
|
226
|
+
for i in range(self.nc + 1):
|
|
227
|
+
print(" ".join(map(str, self.matrix[i])))
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
|
231
|
+
"""
|
|
232
|
+
Calculates IoU, GIoU, DIoU, or CIoU between two boxes, supporting xywh/xyxy formats.
|
|
233
|
+
|
|
234
|
+
Input shapes are box1(1,4) to box2(n,4).
|
|
235
|
+
"""
|
|
236
|
+
# Get the coordinates of bounding boxes
|
|
237
|
+
if xywh: # transform from xywh to xyxy
|
|
238
|
+
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
|
239
|
+
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
|
|
240
|
+
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
|
|
241
|
+
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
|
|
242
|
+
else: # x1, y1, x2, y2 = box1
|
|
243
|
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
|
244
|
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
|
245
|
+
w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
|
|
246
|
+
w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
|
|
247
|
+
|
|
248
|
+
# Intersection area
|
|
249
|
+
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
|
|
250
|
+
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
|
|
251
|
+
).clamp(0)
|
|
252
|
+
|
|
253
|
+
# Union Area
|
|
254
|
+
union = w1 * h1 + w2 * h2 - inter + eps
|
|
255
|
+
|
|
256
|
+
# IoU
|
|
257
|
+
iou = inter / union
|
|
258
|
+
if CIoU or DIoU or GIoU:
|
|
259
|
+
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
|
260
|
+
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
|
261
|
+
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
|
262
|
+
c2 = cw**2 + ch**2 + eps # convex diagonal squared
|
|
263
|
+
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
|
264
|
+
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
|
265
|
+
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
|
266
|
+
with torch.no_grad():
|
|
267
|
+
alpha = v / (v - iou + (1 + eps))
|
|
268
|
+
return iou - (rho2 / c2 + v * alpha) # CIoU
|
|
269
|
+
return iou - rho2 / c2 # DIoU
|
|
270
|
+
c_area = cw * ch + eps # convex area
|
|
271
|
+
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
|
272
|
+
return iou # IoU
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def box_iou(box1, box2, eps=1e-7):
|
|
276
|
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
|
277
|
+
"""
|
|
278
|
+
Return intersection-over-union (Jaccard index) of boxes.
|
|
279
|
+
|
|
280
|
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
281
|
+
|
|
282
|
+
Arguments:
|
|
283
|
+
box1 (Tensor[N, 4])
|
|
284
|
+
box2 (Tensor[M, 4])
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
|
288
|
+
IoU values for every element in boxes1 and boxes2
|
|
289
|
+
"""
|
|
290
|
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
|
291
|
+
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
|
292
|
+
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
|
293
|
+
|
|
294
|
+
# IoU = inter / (area1 + area2 - inter)
|
|
295
|
+
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def bbox_ioa(box1, box2, eps=1e-7):
|
|
299
|
+
"""
|
|
300
|
+
Returns the intersection over box2 area given box1, box2.
|
|
301
|
+
|
|
302
|
+
Boxes are x1y1x2y2
|
|
303
|
+
box1: np.array of shape(4)
|
|
304
|
+
box2: np.array of shape(nx4)
|
|
305
|
+
returns: np.array of shape(n)
|
|
306
|
+
"""
|
|
307
|
+
# Get the coordinates of bounding boxes
|
|
308
|
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1
|
|
309
|
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
|
310
|
+
|
|
311
|
+
# Intersection area
|
|
312
|
+
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * (
|
|
313
|
+
np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)
|
|
314
|
+
).clip(0)
|
|
315
|
+
|
|
316
|
+
# box2 area
|
|
317
|
+
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
|
318
|
+
|
|
319
|
+
# Intersection over box2 area
|
|
320
|
+
return inter_area / box2_area
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def wh_iou(wh1, wh2, eps=1e-7):
|
|
324
|
+
"""Calculates the Intersection over Union (IoU) for two sets of widths and heights; `wh1` and `wh2` should be nx2
|
|
325
|
+
and mx2 tensors.
|
|
326
|
+
"""
|
|
327
|
+
wh1 = wh1[:, None] # [N,1,2]
|
|
328
|
+
wh2 = wh2[None] # [1,M,2]
|
|
329
|
+
inter = torch.min(wh1, wh2).prod(2) # [N,M]
|
|
330
|
+
return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# Plots ----------------------------------------------------------------------------------------------------------------
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@threaded
|
|
337
|
+
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()):
|
|
338
|
+
"""Plots precision-recall curve, optionally per class, saving to `save_dir`; `px`, `py` are lists, `ap` is Nx2
|
|
339
|
+
array, `names` optional.
|
|
340
|
+
"""
|
|
341
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
342
|
+
py = np.stack(py, axis=1)
|
|
343
|
+
|
|
344
|
+
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
345
|
+
for i, y in enumerate(py.T):
|
|
346
|
+
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
|
347
|
+
else:
|
|
348
|
+
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
|
349
|
+
|
|
350
|
+
ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
|
|
351
|
+
ax.set_xlabel("Recall")
|
|
352
|
+
ax.set_ylabel("Precision")
|
|
353
|
+
ax.set_xlim(0, 1)
|
|
354
|
+
ax.set_ylim(0, 1)
|
|
355
|
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
356
|
+
ax.set_title("Precision-Recall Curve")
|
|
357
|
+
fig.savefig(save_dir, dpi=250)
|
|
358
|
+
plt.close(fig)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@threaded
|
|
362
|
+
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric"):
|
|
363
|
+
"""Plots a metric-confidence curve for model predictions, supporting per-class visualization and smoothing."""
|
|
364
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
365
|
+
|
|
366
|
+
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
367
|
+
for i, y in enumerate(py):
|
|
368
|
+
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
|
369
|
+
else:
|
|
370
|
+
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
|
371
|
+
|
|
372
|
+
y = smooth(py.mean(0), 0.05)
|
|
373
|
+
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
|
374
|
+
ax.set_xlabel(xlabel)
|
|
375
|
+
ax.set_ylabel(ylabel)
|
|
376
|
+
ax.set_xlim(0, 1)
|
|
377
|
+
ax.set_ylim(0, 1)
|
|
378
|
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
379
|
+
ax.set_title(f"{ylabel}-Confidence Curve")
|
|
380
|
+
fig.savefig(save_dir, dpi=250)
|
|
381
|
+
plt.close(fig)
|