megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
@@ -1,234 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- Loss functions
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
- from utils.metrics import bbox_iou
10
- from utils.torch_utils import de_parallel
11
-
12
-
13
- def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
14
- # return positive, negative label smoothing BCE targets
15
- return 1.0 - 0.5 * eps, 0.5 * eps
16
-
17
-
18
- class BCEBlurWithLogitsLoss(nn.Module):
19
- # BCEwithLogitLoss() with reduced missing label effects.
20
- def __init__(self, alpha=0.05):
21
- super().__init__()
22
- self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
23
- self.alpha = alpha
24
-
25
- def forward(self, pred, true):
26
- loss = self.loss_fcn(pred, true)
27
- pred = torch.sigmoid(pred) # prob from logits
28
- dx = pred - true # reduce only missing label effects
29
- # dx = (pred - true).abs() # reduce missing label and false label effects
30
- alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
31
- loss *= alpha_factor
32
- return loss.mean()
33
-
34
-
35
- class FocalLoss(nn.Module):
36
- # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
37
- def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
38
- super().__init__()
39
- self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
40
- self.gamma = gamma
41
- self.alpha = alpha
42
- self.reduction = loss_fcn.reduction
43
- self.loss_fcn.reduction = 'none' # required to apply FL to each element
44
-
45
- def forward(self, pred, true):
46
- loss = self.loss_fcn(pred, true)
47
- # p_t = torch.exp(-loss)
48
- # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
49
-
50
- # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
51
- pred_prob = torch.sigmoid(pred) # prob from logits
52
- p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
53
- alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
54
- modulating_factor = (1.0 - p_t) ** self.gamma
55
- loss *= alpha_factor * modulating_factor
56
-
57
- if self.reduction == 'mean':
58
- return loss.mean()
59
- elif self.reduction == 'sum':
60
- return loss.sum()
61
- else: # 'none'
62
- return loss
63
-
64
-
65
- class QFocalLoss(nn.Module):
66
- # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
67
- def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
68
- super().__init__()
69
- self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
70
- self.gamma = gamma
71
- self.alpha = alpha
72
- self.reduction = loss_fcn.reduction
73
- self.loss_fcn.reduction = 'none' # required to apply FL to each element
74
-
75
- def forward(self, pred, true):
76
- loss = self.loss_fcn(pred, true)
77
-
78
- pred_prob = torch.sigmoid(pred) # prob from logits
79
- alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
80
- modulating_factor = torch.abs(true - pred_prob) ** self.gamma
81
- loss *= alpha_factor * modulating_factor
82
-
83
- if self.reduction == 'mean':
84
- return loss.mean()
85
- elif self.reduction == 'sum':
86
- return loss.sum()
87
- else: # 'none'
88
- return loss
89
-
90
-
91
- class ComputeLoss:
92
- sort_obj_iou = False
93
-
94
- # Compute losses
95
- def __init__(self, model, autobalance=False):
96
- device = next(model.parameters()).device # get model device
97
- h = model.hyp # hyperparameters
98
-
99
- # Define criteria
100
- BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
101
- BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
102
-
103
- # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
104
- self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
105
-
106
- # Focal loss
107
- g = h['fl_gamma'] # focal loss gamma
108
- if g > 0:
109
- BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
110
-
111
- m = de_parallel(model).model[-1] # Detect() module
112
- self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
113
- self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
114
- self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
115
- self.na = m.na # number of anchors
116
- self.nc = m.nc # number of classes
117
- self.nl = m.nl # number of layers
118
- self.anchors = m.anchors
119
- self.device = device
120
-
121
- def __call__(self, p, targets): # predictions, targets
122
- lcls = torch.zeros(1, device=self.device) # class loss
123
- lbox = torch.zeros(1, device=self.device) # box loss
124
- lobj = torch.zeros(1, device=self.device) # object loss
125
- tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
126
-
127
- # Losses
128
- for i, pi in enumerate(p): # layer index, layer predictions
129
- b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
130
- tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
131
-
132
- n = b.shape[0] # number of targets
133
- if n:
134
- # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
135
- pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions
136
-
137
- # Regression
138
- pxy = pxy.sigmoid() * 2 - 0.5
139
- pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
140
- pbox = torch.cat((pxy, pwh), 1) # predicted box
141
- iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
142
- lbox += (1.0 - iou).mean() # iou loss
143
-
144
- # Objectness
145
- iou = iou.detach().clamp(0).type(tobj.dtype)
146
- if self.sort_obj_iou:
147
- j = iou.argsort()
148
- b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
149
- if self.gr < 1:
150
- iou = (1.0 - self.gr) + self.gr * iou
151
- tobj[b, a, gj, gi] = iou # iou ratio
152
-
153
- # Classification
154
- if self.nc > 1: # cls loss (only if multiple classes)
155
- t = torch.full_like(pcls, self.cn, device=self.device) # targets
156
- t[range(n), tcls[i]] = self.cp
157
- lcls += self.BCEcls(pcls, t) # BCE
158
-
159
- # Append targets to text file
160
- # with open('targets.txt', 'a') as file:
161
- # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
162
-
163
- obji = self.BCEobj(pi[..., 4], tobj)
164
- lobj += obji * self.balance[i] # obj loss
165
- if self.autobalance:
166
- self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
167
-
168
- if self.autobalance:
169
- self.balance = [x / self.balance[self.ssi] for x in self.balance]
170
- lbox *= self.hyp['box']
171
- lobj *= self.hyp['obj']
172
- lcls *= self.hyp['cls']
173
- bs = tobj.shape[0] # batch size
174
-
175
- return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
176
-
177
- def build_targets(self, p, targets):
178
- # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
179
- na, nt = self.na, targets.shape[0] # number of anchors, targets
180
- tcls, tbox, indices, anch = [], [], [], []
181
- gain = torch.ones(7, device=self.device) # normalized to gridspace gain
182
- ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
183
- targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices
184
-
185
- g = 0.5 # bias
186
- off = torch.tensor(
187
- [
188
- [0, 0],
189
- [1, 0],
190
- [0, 1],
191
- [-1, 0],
192
- [0, -1], # j,k,l,m
193
- # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
194
- ],
195
- device=self.device).float() * g # offsets
196
-
197
- for i in range(self.nl):
198
- anchors, shape = self.anchors[i], p[i].shape
199
- gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
200
-
201
- # Match targets to anchors
202
- t = targets * gain # shape(3,n,7)
203
- if nt:
204
- # Matches
205
- r = t[..., 4:6] / anchors[:, None] # wh ratio
206
- j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare
207
- # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
208
- t = t[j] # filter
209
-
210
- # Offsets
211
- gxy = t[:, 2:4] # grid xy
212
- gxi = gain[[2, 3]] - gxy # inverse
213
- j, k = ((gxy % 1 < g) & (gxy > 1)).T
214
- l, m = ((gxi % 1 < g) & (gxi > 1)).T
215
- j = torch.stack((torch.ones_like(j), j, k, l, m))
216
- t = t.repeat((5, 1, 1))[j]
217
- offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
218
- else:
219
- t = targets[0]
220
- offsets = 0
221
-
222
- # Define
223
- bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
224
- a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class
225
- gij = (gxy - offsets).long()
226
- gi, gj = gij.T # grid indices
227
-
228
- # Append
229
- indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
230
- tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
231
- anch.append(anchors[a]) # anchors
232
- tcls.append(c) # class
233
-
234
- return tcls, tbox, indices, anch
@@ -1,355 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- Model validation metrics
4
- """
5
-
6
- import math
7
- import warnings
8
- from pathlib import Path
9
-
10
- import matplotlib.pyplot as plt
11
- import numpy as np
12
- import torch
13
-
14
-
15
- def fitness(x):
16
- # Model fitness as a weighted combination of metrics
17
- w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
18
- return (x[:, :4] * w).sum(1)
19
-
20
-
21
- def smooth(y, f=0.05):
22
- # Box filter of fraction f
23
- nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
24
- p = np.ones(nf // 2) # ones padding
25
- yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
26
- return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
27
-
28
-
29
- def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
30
- """ Compute the average precision, given the recall and precision curves.
31
- Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
32
- # Arguments
33
- tp: True positives (nparray, nx1 or nx10).
34
- conf: Objectness value from 0-1 (nparray).
35
- pred_cls: Predicted object classes (nparray).
36
- target_cls: True object classes (nparray).
37
- plot: Plot precision-recall curve at mAP@0.5
38
- save_dir: Plot save directory
39
- # Returns
40
- The average precision as computed in py-faster-rcnn.
41
- """
42
-
43
- # Sort by objectness
44
- i = np.argsort(-conf)
45
- tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
46
-
47
- # Find unique classes
48
- unique_classes, nt = np.unique(target_cls, return_counts=True)
49
- nc = unique_classes.shape[0] # number of classes, number of detections
50
-
51
- # Create Precision-Recall curve and compute AP for each class
52
- px, py = np.linspace(0, 1, 1000), [] # for plotting
53
- ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
54
- for ci, c in enumerate(unique_classes):
55
- i = pred_cls == c
56
- n_l = nt[ci] # number of labels
57
- n_p = i.sum() # number of predictions
58
- if n_p == 0 or n_l == 0:
59
- continue
60
-
61
- # Accumulate FPs and TPs
62
- fpc = (1 - tp[i]).cumsum(0)
63
- tpc = tp[i].cumsum(0)
64
-
65
- # Recall
66
- recall = tpc / (n_l + eps) # recall curve
67
- r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
68
-
69
- # Precision
70
- precision = tpc / (tpc + fpc) # precision curve
71
- p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
72
-
73
- # AP from recall-precision curve
74
- for j in range(tp.shape[1]):
75
- ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
76
- if plot and j == 0:
77
- py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
78
-
79
- # Compute F1 (harmonic mean of precision and recall)
80
- f1 = 2 * p * r / (p + r + eps)
81
- names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
82
- names = dict(enumerate(names)) # to dict
83
- if plot:
84
- plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
85
- plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
86
- plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
87
- plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
88
-
89
- i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
90
- p, r, f1 = p[:, i], r[:, i], f1[:, i]
91
- tp = (r * nt).round() # true positives
92
- fp = (tp / (p + eps) - tp).round() # false positives
93
- return tp, fp, p, r, f1, ap, unique_classes.astype(int)
94
-
95
-
96
- def compute_ap(recall, precision):
97
- """ Compute the average precision, given the recall and precision curves
98
- # Arguments
99
- recall: The recall curve (list)
100
- precision: The precision curve (list)
101
- # Returns
102
- Average precision, precision curve, recall curve
103
- """
104
-
105
- # Append sentinel values to beginning and end
106
- mrec = np.concatenate(([0.0], recall, [1.0]))
107
- mpre = np.concatenate(([1.0], precision, [0.0]))
108
-
109
- # Compute the precision envelope
110
- mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
111
-
112
- # Integrate area under curve
113
- method = 'interp' # methods: 'continuous', 'interp'
114
- if method == 'interp':
115
- x = np.linspace(0, 1, 101) # 101-point interp (COCO)
116
- ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
117
- else: # 'continuous'
118
- i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
119
- ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
120
-
121
- return ap, mpre, mrec
122
-
123
-
124
- class ConfusionMatrix:
125
- # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
126
- def __init__(self, nc, conf=0.25, iou_thres=0.45):
127
- self.matrix = np.zeros((nc + 1, nc + 1))
128
- self.nc = nc # number of classes
129
- self.conf = conf
130
- self.iou_thres = iou_thres
131
-
132
- def process_batch(self, detections, labels):
133
- """
134
- Return intersection-over-union (Jaccard index) of boxes.
135
- Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
136
- Arguments:
137
- detections (Array[N, 6]), x1, y1, x2, y2, conf, class
138
- labels (Array[M, 5]), class, x1, y1, x2, y2
139
- Returns:
140
- None, updates confusion matrix accordingly
141
- """
142
- detections = detections[detections[:, 4] > self.conf]
143
- gt_classes = labels[:, 0].int()
144
- detection_classes = detections[:, 5].int()
145
- iou = box_iou(labels[:, 1:], detections[:, :4])
146
-
147
- x = torch.where(iou > self.iou_thres)
148
- if x[0].shape[0]:
149
- matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
150
- if x[0].shape[0] > 1:
151
- matches = matches[matches[:, 2].argsort()[::-1]]
152
- matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
153
- matches = matches[matches[:, 2].argsort()[::-1]]
154
- matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
155
- else:
156
- matches = np.zeros((0, 3))
157
-
158
- n = matches.shape[0] > 0
159
- m0, m1, _ = matches.transpose().astype(int)
160
- for i, gc in enumerate(gt_classes):
161
- j = m0 == i
162
- if n and sum(j) == 1:
163
- self.matrix[detection_classes[m1[j]], gc] += 1 # correct
164
- else:
165
- self.matrix[self.nc, gc] += 1 # background FP
166
-
167
- if n:
168
- for i, dc in enumerate(detection_classes):
169
- if not any(m1 == i):
170
- self.matrix[dc, self.nc] += 1 # background FN
171
-
172
- def matrix(self):
173
- return self.matrix
174
-
175
- def tp_fp(self):
176
- tp = self.matrix.diagonal() # true positives
177
- fp = self.matrix.sum(1) - tp # false positives
178
- # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
179
- return tp[:-1], fp[:-1] # remove background class
180
-
181
- def plot(self, normalize=True, save_dir='', names=()):
182
- try:
183
- import seaborn as sn
184
-
185
- array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
186
- array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
187
-
188
- fig = plt.figure(figsize=(12, 9), tight_layout=True)
189
- nc, nn = self.nc, len(names) # number of classes, names
190
- sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
191
- labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
192
- with warnings.catch_warnings():
193
- warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
194
- sn.heatmap(array,
195
- annot=nc < 30,
196
- annot_kws={
197
- "size": 8},
198
- cmap='Blues',
199
- fmt='.2f',
200
- square=True,
201
- vmin=0.0,
202
- xticklabels=names + ['background FP'] if labels else "auto",
203
- yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
204
- fig.axes[0].set_xlabel('True')
205
- fig.axes[0].set_ylabel('Predicted')
206
- fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
207
- plt.close()
208
- except Exception as e:
209
- print(f'WARNING: ConfusionMatrix plot failure: {e}')
210
-
211
- def print(self):
212
- for i in range(self.nc + 1):
213
- print(' '.join(map(str, self.matrix[i])))
214
-
215
-
216
- def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
217
- # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
218
-
219
- # Get the coordinates of bounding boxes
220
- if xywh: # transform from xywh to xyxy
221
- (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
222
- w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
223
- b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
224
- b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
225
- else: # x1, y1, x2, y2 = box1
226
- b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
227
- b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
228
- w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
229
- w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
230
-
231
- # Intersection area
232
- inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
233
- (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
234
-
235
- # Union Area
236
- union = w1 * h1 + w2 * h2 - inter + eps
237
-
238
- # IoU
239
- iou = inter / union
240
- if CIoU or DIoU or GIoU:
241
- cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
242
- ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
243
- if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
244
- c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
245
- rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
246
- if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
247
- v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
248
- with torch.no_grad():
249
- alpha = v / (v - iou + (1 + eps))
250
- return iou - (rho2 / c2 + v * alpha) # CIoU
251
- return iou - rho2 / c2 # DIoU
252
- c_area = cw * ch + eps # convex area
253
- return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
254
- return iou # IoU
255
-
256
-
257
- def box_area(box):
258
- # box = xyxy(4,n)
259
- return (box[2] - box[0]) * (box[3] - box[1])
260
-
261
-
262
- def box_iou(box1, box2):
263
- # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
264
- """
265
- Return intersection-over-union (Jaccard index) of boxes.
266
- Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
267
- Arguments:
268
- box1 (Tensor[N, 4])
269
- box2 (Tensor[M, 4])
270
- Returns:
271
- iou (Tensor[N, M]): the NxM matrix containing the pairwise
272
- IoU values for every element in boxes1 and boxes2
273
- """
274
-
275
- # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
276
- (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1)
277
- inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
278
-
279
- # IoU = inter / (area1 + area2 - inter)
280
- return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter)
281
-
282
-
283
- def bbox_ioa(box1, box2, eps=1E-7):
284
- """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
285
- box1: np.array of shape(4)
286
- box2: np.array of shape(nx4)
287
- returns: np.array of shape(n)
288
- """
289
-
290
- # Get the coordinates of bounding boxes
291
- b1_x1, b1_y1, b1_x2, b1_y2 = box1
292
- b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
293
-
294
- # Intersection area
295
- inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
296
- (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
297
-
298
- # box2 area
299
- box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
300
-
301
- # Intersection over box2 area
302
- return inter_area / box2_area
303
-
304
-
305
- def wh_iou(wh1, wh2):
306
- # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
307
- wh1 = wh1[:, None] # [N,1,2]
308
- wh2 = wh2[None] # [1,M,2]
309
- inter = torch.min(wh1, wh2).prod(2) # [N,M]
310
- return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
311
-
312
-
313
- # Plots ----------------------------------------------------------------------------------------------------------------
314
-
315
-
316
- def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
317
- # Precision-recall curve
318
- fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
319
- py = np.stack(py, axis=1)
320
-
321
- if 0 < len(names) < 21: # display per-class legend if < 21 classes
322
- for i, y in enumerate(py.T):
323
- ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
324
- else:
325
- ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
326
-
327
- ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
328
- ax.set_xlabel('Recall')
329
- ax.set_ylabel('Precision')
330
- ax.set_xlim(0, 1)
331
- ax.set_ylim(0, 1)
332
- plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
333
- fig.savefig(save_dir, dpi=250)
334
- plt.close()
335
-
336
-
337
- def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
338
- # Metric-confidence curve
339
- fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
340
-
341
- if 0 < len(names) < 21: # display per-class legend if < 21 classes
342
- for i, y in enumerate(py):
343
- ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
344
- else:
345
- ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
346
-
347
- y = smooth(py.mean(0), 0.05)
348
- ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
349
- ax.set_xlabel(xlabel)
350
- ax.set_ylabel(ylabel)
351
- ax.set_xlim(0, 1)
352
- ax.set_ylim(0, 1)
353
- plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
354
- fig.savefig(save_dir, dpi=250)
355
- plt.close()