megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- api/batch_processing/data_preparation/manage_local_batch.py +302 -263
- api/batch_processing/data_preparation/manage_video_batch.py +81 -2
- api/batch_processing/postprocessing/add_max_conf.py +1 -0
- api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
- api/batch_processing/postprocessing/compare_batch_results.py +110 -60
- api/batch_processing/postprocessing/load_api_results.py +56 -70
- api/batch_processing/postprocessing/md_to_coco.py +1 -1
- api/batch_processing/postprocessing/md_to_labelme.py +2 -1
- api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
- api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
- api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
- api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
- api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
- api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
- api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
- api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
- classification/prepare_classification_script.py +191 -191
- data_management/coco_to_yolo.py +68 -45
- data_management/databases/integrity_check_json_db.py +7 -5
- data_management/generate_crops_from_cct.py +3 -3
- data_management/get_image_sizes.py +8 -6
- data_management/importers/add_timestamps_to_icct.py +79 -0
- data_management/importers/animl_results_to_md_results.py +160 -0
- data_management/importers/auckland_doc_test_to_json.py +4 -4
- data_management/importers/auckland_doc_to_json.py +1 -1
- data_management/importers/awc_to_json.py +5 -5
- data_management/importers/bellevue_to_json.py +5 -5
- data_management/importers/carrizo_shrubfree_2018.py +5 -5
- data_management/importers/carrizo_trail_cam_2017.py +5 -5
- data_management/importers/cct_field_adjustments.py +2 -3
- data_management/importers/channel_islands_to_cct.py +4 -4
- data_management/importers/ena24_to_json.py +5 -5
- data_management/importers/helena_to_cct.py +10 -10
- data_management/importers/idaho-camera-traps.py +12 -12
- data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
- data_management/importers/jb_csv_to_json.py +4 -4
- data_management/importers/missouri_to_json.py +1 -1
- data_management/importers/noaa_seals_2019.py +1 -1
- data_management/importers/pc_to_json.py +5 -5
- data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
- data_management/importers/prepare_zsl_imerit.py +5 -5
- data_management/importers/rspb_to_json.py +4 -4
- data_management/importers/save_the_elephants_survey_A.py +5 -5
- data_management/importers/save_the_elephants_survey_B.py +6 -6
- data_management/importers/snapshot_safari_importer.py +9 -9
- data_management/importers/snapshot_serengeti_lila.py +9 -9
- data_management/importers/timelapse_csv_set_to_json.py +5 -7
- data_management/importers/ubc_to_json.py +4 -4
- data_management/importers/umn_to_json.py +4 -4
- data_management/importers/wellington_to_json.py +1 -1
- data_management/importers/wi_to_json.py +2 -2
- data_management/importers/zamba_results_to_md_results.py +181 -0
- data_management/labelme_to_coco.py +35 -7
- data_management/labelme_to_yolo.py +229 -0
- data_management/lila/add_locations_to_island_camera_traps.py +1 -1
- data_management/lila/add_locations_to_nacti.py +147 -0
- data_management/lila/create_lila_blank_set.py +474 -0
- data_management/lila/create_lila_test_set.py +2 -1
- data_management/lila/create_links_to_md_results_files.py +106 -0
- data_management/lila/download_lila_subset.py +46 -21
- data_management/lila/generate_lila_per_image_labels.py +23 -14
- data_management/lila/get_lila_annotation_counts.py +17 -11
- data_management/lila/lila_common.py +14 -11
- data_management/lila/test_lila_metadata_urls.py +116 -0
- data_management/ocr_tools.py +829 -0
- data_management/resize_coco_dataset.py +13 -11
- data_management/yolo_output_to_md_output.py +84 -12
- data_management/yolo_to_coco.py +38 -20
- detection/process_video.py +36 -14
- detection/pytorch_detector.py +23 -8
- detection/run_detector.py +76 -19
- detection/run_detector_batch.py +178 -63
- detection/run_inference_with_yolov5_val.py +326 -57
- detection/run_tiled_inference.py +153 -43
- detection/video_utils.py +34 -8
- md_utils/ct_utils.py +172 -1
- md_utils/md_tests.py +372 -51
- md_utils/path_utils.py +167 -39
- md_utils/process_utils.py +26 -7
- md_utils/split_locations_into_train_val.py +215 -0
- md_utils/string_utils.py +10 -0
- md_utils/url_utils.py +0 -2
- md_utils/write_html_image_list.py +9 -26
- md_visualization/plot_utils.py +12 -8
- md_visualization/visualization_utils.py +106 -7
- md_visualization/visualize_db.py +16 -8
- md_visualization/visualize_detector_output.py +208 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
- taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
- taxonomy_mapping/map_new_lila_datasets.py +43 -39
- taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
- taxonomy_mapping/preview_lila_taxonomy.py +27 -27
- taxonomy_mapping/species_lookup.py +33 -13
- taxonomy_mapping/taxonomy_csv_checker.py +7 -5
- api/synchronous/api_core/yolov5/detect.py +0 -252
- api/synchronous/api_core/yolov5/export.py +0 -607
- api/synchronous/api_core/yolov5/hubconf.py +0 -146
- api/synchronous/api_core/yolov5/models/__init__.py +0 -0
- api/synchronous/api_core/yolov5/models/common.py +0 -738
- api/synchronous/api_core/yolov5/models/experimental.py +0 -104
- api/synchronous/api_core/yolov5/models/tf.py +0 -574
- api/synchronous/api_core/yolov5/models/yolo.py +0 -338
- api/synchronous/api_core/yolov5/train.py +0 -670
- api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
- api/synchronous/api_core/yolov5/utils/activations.py +0 -103
- api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
- api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
- api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
- api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
- api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
- api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
- api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
- api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
- api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
- api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
- api/synchronous/api_core/yolov5/utils/general.py +0 -1018
- api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
- api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
- api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
- api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
- api/synchronous/api_core/yolov5/utils/loss.py +0 -234
- api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
- api/synchronous/api_core/yolov5/utils/plots.py +0 -489
- api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
- api/synchronous/api_core/yolov5/val.py +0 -394
- md_utils/matlab_porting_tools.py +0 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
utils/initialization
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def notebook_init(verbose=True):
|
|
8
|
-
# Check system software and hardware
|
|
9
|
-
print('Checking setup...')
|
|
10
|
-
|
|
11
|
-
import os
|
|
12
|
-
import shutil
|
|
13
|
-
|
|
14
|
-
from utils.general import check_requirements, emojis, is_colab
|
|
15
|
-
from utils.torch_utils import select_device # imports
|
|
16
|
-
|
|
17
|
-
check_requirements(('psutil', 'IPython'))
|
|
18
|
-
import psutil
|
|
19
|
-
from IPython import display # to display images and clear console output
|
|
20
|
-
|
|
21
|
-
if is_colab():
|
|
22
|
-
shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory
|
|
23
|
-
|
|
24
|
-
# System info
|
|
25
|
-
if verbose:
|
|
26
|
-
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
|
27
|
-
ram = psutil.virtual_memory().total
|
|
28
|
-
total, used, free = shutil.disk_usage("/")
|
|
29
|
-
display.clear_output()
|
|
30
|
-
s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)'
|
|
31
|
-
else:
|
|
32
|
-
s = ''
|
|
33
|
-
|
|
34
|
-
select_device(newline=False)
|
|
35
|
-
print(emojis(f'Setup complete ✅ {s}'))
|
|
36
|
-
return display
|
|
@@ -1,103 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
Activation functions
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torch.nn as nn
|
|
8
|
-
import torch.nn.functional as F
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class SiLU(nn.Module):
|
|
12
|
-
# SiLU activation https://arxiv.org/pdf/1606.08415.pdf
|
|
13
|
-
@staticmethod
|
|
14
|
-
def forward(x):
|
|
15
|
-
return x * torch.sigmoid(x)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class Hardswish(nn.Module):
|
|
19
|
-
# Hard-SiLU activation
|
|
20
|
-
@staticmethod
|
|
21
|
-
def forward(x):
|
|
22
|
-
# return x * F.hardsigmoid(x) # for TorchScript and CoreML
|
|
23
|
-
return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class Mish(nn.Module):
|
|
27
|
-
# Mish activation https://github.com/digantamisra98/Mish
|
|
28
|
-
@staticmethod
|
|
29
|
-
def forward(x):
|
|
30
|
-
return x * F.softplus(x).tanh()
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class MemoryEfficientMish(nn.Module):
|
|
34
|
-
# Mish activation memory-efficient
|
|
35
|
-
class F(torch.autograd.Function):
|
|
36
|
-
|
|
37
|
-
@staticmethod
|
|
38
|
-
def forward(ctx, x):
|
|
39
|
-
ctx.save_for_backward(x)
|
|
40
|
-
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def backward(ctx, grad_output):
|
|
44
|
-
x = ctx.saved_tensors[0]
|
|
45
|
-
sx = torch.sigmoid(x)
|
|
46
|
-
fx = F.softplus(x).tanh()
|
|
47
|
-
return grad_output * (fx + x * sx * (1 - fx * fx))
|
|
48
|
-
|
|
49
|
-
def forward(self, x):
|
|
50
|
-
return self.F.apply(x)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class FReLU(nn.Module):
|
|
54
|
-
# FReLU activation https://arxiv.org/abs/2007.11824
|
|
55
|
-
def __init__(self, c1, k=3): # ch_in, kernel
|
|
56
|
-
super().__init__()
|
|
57
|
-
self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
|
|
58
|
-
self.bn = nn.BatchNorm2d(c1)
|
|
59
|
-
|
|
60
|
-
def forward(self, x):
|
|
61
|
-
return torch.max(x, self.bn(self.conv(x)))
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class AconC(nn.Module):
|
|
65
|
-
r""" ACON activation (activate or not)
|
|
66
|
-
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
|
|
67
|
-
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
def __init__(self, c1):
|
|
71
|
-
super().__init__()
|
|
72
|
-
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
|
73
|
-
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
|
74
|
-
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
|
|
75
|
-
|
|
76
|
-
def forward(self, x):
|
|
77
|
-
dpx = (self.p1 - self.p2) * x
|
|
78
|
-
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class MetaAconC(nn.Module):
|
|
82
|
-
r""" ACON activation (activate or not)
|
|
83
|
-
MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
|
|
84
|
-
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
|
|
88
|
-
super().__init__()
|
|
89
|
-
c2 = max(r, c1 // r)
|
|
90
|
-
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
|
91
|
-
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
|
92
|
-
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
|
|
93
|
-
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
|
|
94
|
-
# self.bn1 = nn.BatchNorm2d(c2)
|
|
95
|
-
# self.bn2 = nn.BatchNorm2d(c1)
|
|
96
|
-
|
|
97
|
-
def forward(self, x):
|
|
98
|
-
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
|
|
99
|
-
# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
|
|
100
|
-
# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
|
|
101
|
-
beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
|
|
102
|
-
dpx = (self.p1 - self.p2) * x
|
|
103
|
-
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x
|
|
@@ -1,284 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
Image augmentation functions
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import math
|
|
7
|
-
import random
|
|
8
|
-
|
|
9
|
-
import cv2
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
|
|
13
|
-
from utils.metrics import bbox_ioa
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Albumentations:
|
|
17
|
-
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
|
18
|
-
def __init__(self):
|
|
19
|
-
self.transform = None
|
|
20
|
-
try:
|
|
21
|
-
import albumentations as A
|
|
22
|
-
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
|
23
|
-
|
|
24
|
-
T = [
|
|
25
|
-
A.Blur(p=0.01),
|
|
26
|
-
A.MedianBlur(p=0.01),
|
|
27
|
-
A.ToGray(p=0.01),
|
|
28
|
-
A.CLAHE(p=0.01),
|
|
29
|
-
A.RandomBrightnessContrast(p=0.0),
|
|
30
|
-
A.RandomGamma(p=0.0),
|
|
31
|
-
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
|
32
|
-
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
|
33
|
-
|
|
34
|
-
LOGGER.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
|
|
35
|
-
except ImportError: # package not installed, skip
|
|
36
|
-
pass
|
|
37
|
-
except Exception as e:
|
|
38
|
-
LOGGER.info(colorstr('albumentations: ') + f'{e}')
|
|
39
|
-
|
|
40
|
-
def __call__(self, im, labels, p=1.0):
|
|
41
|
-
if self.transform and random.random() < p:
|
|
42
|
-
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
|
|
43
|
-
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
|
|
44
|
-
return im, labels
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
|
48
|
-
# HSV color-space augmentation
|
|
49
|
-
if hgain or sgain or vgain:
|
|
50
|
-
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
|
51
|
-
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
|
|
52
|
-
dtype = im.dtype # uint8
|
|
53
|
-
|
|
54
|
-
x = np.arange(0, 256, dtype=r.dtype)
|
|
55
|
-
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
|
56
|
-
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
|
57
|
-
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
|
58
|
-
|
|
59
|
-
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
|
60
|
-
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def hist_equalize(im, clahe=True, bgr=False):
|
|
64
|
-
# Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
|
|
65
|
-
yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
|
|
66
|
-
if clahe:
|
|
67
|
-
c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
68
|
-
yuv[:, :, 0] = c.apply(yuv[:, :, 0])
|
|
69
|
-
else:
|
|
70
|
-
yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
|
|
71
|
-
return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def replicate(im, labels):
|
|
75
|
-
# Replicate labels
|
|
76
|
-
h, w = im.shape[:2]
|
|
77
|
-
boxes = labels[:, 1:].astype(int)
|
|
78
|
-
x1, y1, x2, y2 = boxes.T
|
|
79
|
-
s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
|
|
80
|
-
for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
|
|
81
|
-
x1b, y1b, x2b, y2b = boxes[i]
|
|
82
|
-
bh, bw = y2b - y1b, x2b - x1b
|
|
83
|
-
yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
|
|
84
|
-
x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
|
|
85
|
-
im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
|
|
86
|
-
labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
|
|
87
|
-
|
|
88
|
-
return im, labels
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
|
92
|
-
# Resize and pad image while meeting stride-multiple constraints
|
|
93
|
-
shape = im.shape[:2] # current shape [height, width]
|
|
94
|
-
if isinstance(new_shape, int):
|
|
95
|
-
new_shape = (new_shape, new_shape)
|
|
96
|
-
|
|
97
|
-
# Scale ratio (new / old)
|
|
98
|
-
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
99
|
-
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
|
100
|
-
r = min(r, 1.0)
|
|
101
|
-
|
|
102
|
-
# Compute padding
|
|
103
|
-
ratio = r, r # width, height ratios
|
|
104
|
-
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
105
|
-
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
|
106
|
-
if auto: # minimum rectangle
|
|
107
|
-
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
|
108
|
-
elif scaleFill: # stretch
|
|
109
|
-
dw, dh = 0.0, 0.0
|
|
110
|
-
new_unpad = (new_shape[1], new_shape[0])
|
|
111
|
-
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
|
112
|
-
|
|
113
|
-
dw /= 2 # divide padding into 2 sides
|
|
114
|
-
dh /= 2
|
|
115
|
-
|
|
116
|
-
if shape[::-1] != new_unpad: # resize
|
|
117
|
-
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
|
118
|
-
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
|
119
|
-
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
|
120
|
-
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
|
121
|
-
return im, ratio, (dw, dh)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def random_perspective(im,
|
|
125
|
-
targets=(),
|
|
126
|
-
segments=(),
|
|
127
|
-
degrees=10,
|
|
128
|
-
translate=.1,
|
|
129
|
-
scale=.1,
|
|
130
|
-
shear=10,
|
|
131
|
-
perspective=0.0,
|
|
132
|
-
border=(0, 0)):
|
|
133
|
-
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
|
|
134
|
-
# targets = [cls, xyxy]
|
|
135
|
-
|
|
136
|
-
height = im.shape[0] + border[0] * 2 # shape(h,w,c)
|
|
137
|
-
width = im.shape[1] + border[1] * 2
|
|
138
|
-
|
|
139
|
-
# Center
|
|
140
|
-
C = np.eye(3)
|
|
141
|
-
C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
|
|
142
|
-
C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
|
|
143
|
-
|
|
144
|
-
# Perspective
|
|
145
|
-
P = np.eye(3)
|
|
146
|
-
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
|
|
147
|
-
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
|
|
148
|
-
|
|
149
|
-
# Rotation and Scale
|
|
150
|
-
R = np.eye(3)
|
|
151
|
-
a = random.uniform(-degrees, degrees)
|
|
152
|
-
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
|
153
|
-
s = random.uniform(1 - scale, 1 + scale)
|
|
154
|
-
# s = 2 ** random.uniform(-scale, scale)
|
|
155
|
-
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
|
156
|
-
|
|
157
|
-
# Shear
|
|
158
|
-
S = np.eye(3)
|
|
159
|
-
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
|
160
|
-
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
|
161
|
-
|
|
162
|
-
# Translation
|
|
163
|
-
T = np.eye(3)
|
|
164
|
-
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
|
|
165
|
-
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
|
|
166
|
-
|
|
167
|
-
# Combined rotation matrix
|
|
168
|
-
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
|
169
|
-
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
|
170
|
-
if perspective:
|
|
171
|
-
im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
|
|
172
|
-
else: # affine
|
|
173
|
-
im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
|
174
|
-
|
|
175
|
-
# Visualize
|
|
176
|
-
# import matplotlib.pyplot as plt
|
|
177
|
-
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
|
|
178
|
-
# ax[0].imshow(im[:, :, ::-1]) # base
|
|
179
|
-
# ax[1].imshow(im2[:, :, ::-1]) # warped
|
|
180
|
-
|
|
181
|
-
# Transform label coordinates
|
|
182
|
-
n = len(targets)
|
|
183
|
-
if n:
|
|
184
|
-
use_segments = any(x.any() for x in segments)
|
|
185
|
-
new = np.zeros((n, 4))
|
|
186
|
-
if use_segments: # warp segments
|
|
187
|
-
segments = resample_segments(segments) # upsample
|
|
188
|
-
for i, segment in enumerate(segments):
|
|
189
|
-
xy = np.ones((len(segment), 3))
|
|
190
|
-
xy[:, :2] = segment
|
|
191
|
-
xy = xy @ M.T # transform
|
|
192
|
-
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
|
|
193
|
-
|
|
194
|
-
# clip
|
|
195
|
-
new[i] = segment2box(xy, width, height)
|
|
196
|
-
|
|
197
|
-
else: # warp boxes
|
|
198
|
-
xy = np.ones((n * 4, 3))
|
|
199
|
-
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
|
200
|
-
xy = xy @ M.T # transform
|
|
201
|
-
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
|
202
|
-
|
|
203
|
-
# create new boxes
|
|
204
|
-
x = xy[:, [0, 2, 4, 6]]
|
|
205
|
-
y = xy[:, [1, 3, 5, 7]]
|
|
206
|
-
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
|
207
|
-
|
|
208
|
-
# clip
|
|
209
|
-
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
|
|
210
|
-
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
|
|
211
|
-
|
|
212
|
-
# filter candidates
|
|
213
|
-
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
|
|
214
|
-
targets = targets[i]
|
|
215
|
-
targets[:, 1:5] = new[i]
|
|
216
|
-
|
|
217
|
-
return im, targets
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
def copy_paste(im, labels, segments, p=0.5):
|
|
221
|
-
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
|
222
|
-
n = len(segments)
|
|
223
|
-
if p and n:
|
|
224
|
-
h, w, c = im.shape # height, width, channels
|
|
225
|
-
im_new = np.zeros(im.shape, np.uint8)
|
|
226
|
-
for j in random.sample(range(n), k=round(p * n)):
|
|
227
|
-
l, s = labels[j], segments[j]
|
|
228
|
-
box = w - l[3], l[2], w - l[1], l[4]
|
|
229
|
-
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
|
|
230
|
-
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
|
|
231
|
-
labels = np.concatenate((labels, [[l[0], *box]]), 0)
|
|
232
|
-
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
|
|
233
|
-
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
|
|
234
|
-
|
|
235
|
-
result = cv2.bitwise_and(src1=im, src2=im_new)
|
|
236
|
-
result = cv2.flip(result, 1) # augment segments (flip left-right)
|
|
237
|
-
i = result > 0 # pixels to replace
|
|
238
|
-
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
|
|
239
|
-
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
|
240
|
-
|
|
241
|
-
return im, labels, segments
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def cutout(im, labels, p=0.5):
|
|
245
|
-
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
|
|
246
|
-
if random.random() < p:
|
|
247
|
-
h, w = im.shape[:2]
|
|
248
|
-
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
|
|
249
|
-
for s in scales:
|
|
250
|
-
mask_h = random.randint(1, int(h * s)) # create random masks
|
|
251
|
-
mask_w = random.randint(1, int(w * s))
|
|
252
|
-
|
|
253
|
-
# box
|
|
254
|
-
xmin = max(0, random.randint(0, w) - mask_w // 2)
|
|
255
|
-
ymin = max(0, random.randint(0, h) - mask_h // 2)
|
|
256
|
-
xmax = min(w, xmin + mask_w)
|
|
257
|
-
ymax = min(h, ymin + mask_h)
|
|
258
|
-
|
|
259
|
-
# apply random color mask
|
|
260
|
-
im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
|
|
261
|
-
|
|
262
|
-
# return unobscured labels
|
|
263
|
-
if len(labels) and s > 0.03:
|
|
264
|
-
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
|
|
265
|
-
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
|
|
266
|
-
labels = labels[ioa < 0.60] # remove >60% obscured labels
|
|
267
|
-
|
|
268
|
-
return labels
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def mixup(im, labels, im2, labels2):
|
|
272
|
-
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
|
273
|
-
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
|
274
|
-
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
|
275
|
-
labels = np.concatenate((labels, labels2), 0)
|
|
276
|
-
return im, labels
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
|
280
|
-
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
|
281
|
-
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
|
282
|
-
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
|
283
|
-
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
|
284
|
-
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
|
@@ -1,170 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
AutoAnchor utils
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import random
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
import yaml
|
|
11
|
-
from tqdm import tqdm
|
|
12
|
-
|
|
13
|
-
from utils.general import LOGGER, colorstr, emojis
|
|
14
|
-
|
|
15
|
-
PREFIX = colorstr('AutoAnchor: ')
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def check_anchor_order(m):
|
|
19
|
-
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
|
20
|
-
a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
|
|
21
|
-
da = a[-1] - a[0] # delta a
|
|
22
|
-
ds = m.stride[-1] - m.stride[0] # delta s
|
|
23
|
-
if da and (da.sign() != ds.sign()): # same order
|
|
24
|
-
LOGGER.info(f'{PREFIX}Reversing anchor order')
|
|
25
|
-
m.anchors[:] = m.anchors.flip(0)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
29
|
-
# Check anchor fit to data, recompute if necessary
|
|
30
|
-
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
|
31
|
-
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
|
32
|
-
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
|
|
33
|
-
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
|
|
34
|
-
|
|
35
|
-
def metric(k): # compute metric
|
|
36
|
-
r = wh[:, None] / k[None]
|
|
37
|
-
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
|
|
38
|
-
best = x.max(1)[0] # best_x
|
|
39
|
-
aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
|
|
40
|
-
bpr = (best > 1 / thr).float().mean() # best possible recall
|
|
41
|
-
return bpr, aat
|
|
42
|
-
|
|
43
|
-
stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
|
|
44
|
-
anchors = m.anchors.clone() * stride # current anchors
|
|
45
|
-
bpr, aat = metric(anchors.cpu().view(-1, 2))
|
|
46
|
-
s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
|
|
47
|
-
if bpr > 0.98: # threshold to recompute
|
|
48
|
-
LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset ✅'))
|
|
49
|
-
else:
|
|
50
|
-
LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...'))
|
|
51
|
-
na = m.anchors.numel() // 2 # number of anchors
|
|
52
|
-
try:
|
|
53
|
-
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
|
54
|
-
except Exception as e:
|
|
55
|
-
LOGGER.info(f'{PREFIX}ERROR: {e}')
|
|
56
|
-
new_bpr = metric(anchors)[0]
|
|
57
|
-
if new_bpr > bpr: # replace anchors
|
|
58
|
-
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
|
59
|
-
m.anchors[:] = anchors.clone().view_as(m.anchors)
|
|
60
|
-
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
|
61
|
-
m.anchors /= stride
|
|
62
|
-
s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
|
|
63
|
-
else:
|
|
64
|
-
s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
|
|
65
|
-
LOGGER.info(emojis(s))
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
|
|
69
|
-
""" Creates kmeans-evolved anchors from training dataset
|
|
70
|
-
|
|
71
|
-
Arguments:
|
|
72
|
-
dataset: path to data.yaml, or a loaded dataset
|
|
73
|
-
n: number of anchors
|
|
74
|
-
img_size: image size used for training
|
|
75
|
-
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
|
|
76
|
-
gen: generations to evolve anchors using genetic algorithm
|
|
77
|
-
verbose: print all results
|
|
78
|
-
|
|
79
|
-
Return:
|
|
80
|
-
k: kmeans evolved anchors
|
|
81
|
-
|
|
82
|
-
Usage:
|
|
83
|
-
from utils.autoanchor import *; _ = kmean_anchors()
|
|
84
|
-
"""
|
|
85
|
-
from scipy.cluster.vq import kmeans
|
|
86
|
-
|
|
87
|
-
npr = np.random
|
|
88
|
-
thr = 1 / thr
|
|
89
|
-
|
|
90
|
-
def metric(k, wh): # compute metrics
|
|
91
|
-
r = wh[:, None] / k[None]
|
|
92
|
-
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
|
|
93
|
-
# x = wh_iou(wh, torch.tensor(k)) # iou metric
|
|
94
|
-
return x, x.max(1)[0] # x, best_x
|
|
95
|
-
|
|
96
|
-
def anchor_fitness(k): # mutation fitness
|
|
97
|
-
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
|
|
98
|
-
return (best * (best > thr).float()).mean() # fitness
|
|
99
|
-
|
|
100
|
-
def print_results(k, verbose=True):
|
|
101
|
-
k = k[np.argsort(k.prod(1))] # sort small to large
|
|
102
|
-
x, best = metric(k, wh0)
|
|
103
|
-
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
|
|
104
|
-
s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
|
|
105
|
-
f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
|
|
106
|
-
f'past_thr={x[x > thr].mean():.3f}-mean: '
|
|
107
|
-
for x in k:
|
|
108
|
-
s += '%i,%i, ' % (round(x[0]), round(x[1]))
|
|
109
|
-
if verbose:
|
|
110
|
-
LOGGER.info(s[:-2])
|
|
111
|
-
return k
|
|
112
|
-
|
|
113
|
-
if isinstance(dataset, str): # *.yaml file
|
|
114
|
-
with open(dataset, errors='ignore') as f:
|
|
115
|
-
data_dict = yaml.safe_load(f) # model dict
|
|
116
|
-
from utils.dataloaders import LoadImagesAndLabels
|
|
117
|
-
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
|
|
118
|
-
|
|
119
|
-
# Get label wh
|
|
120
|
-
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
|
121
|
-
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
|
|
122
|
-
|
|
123
|
-
# Filter
|
|
124
|
-
i = (wh0 < 3.0).any(1).sum()
|
|
125
|
-
if i:
|
|
126
|
-
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
|
|
127
|
-
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
|
128
|
-
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
|
129
|
-
|
|
130
|
-
# Kmeans init
|
|
131
|
-
try:
|
|
132
|
-
LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
|
|
133
|
-
assert n <= len(wh) # apply overdetermined constraint
|
|
134
|
-
s = wh.std(0) # sigmas for whitening
|
|
135
|
-
k = kmeans(wh / s, n, iter=30)[0] * s # points
|
|
136
|
-
assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
|
|
137
|
-
except Exception:
|
|
138
|
-
LOGGER.warning(f'{PREFIX}WARNING: switching strategies from kmeans to random init')
|
|
139
|
-
k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
|
|
140
|
-
wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
|
|
141
|
-
k = print_results(k, verbose=False)
|
|
142
|
-
|
|
143
|
-
# Plot
|
|
144
|
-
# k, d = [None] * 20, [None] * 20
|
|
145
|
-
# for i in tqdm(range(1, 21)):
|
|
146
|
-
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
|
|
147
|
-
# fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
|
|
148
|
-
# ax = ax.ravel()
|
|
149
|
-
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
|
|
150
|
-
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
|
|
151
|
-
# ax[0].hist(wh[wh[:, 0]<100, 0],400)
|
|
152
|
-
# ax[1].hist(wh[wh[:, 1]<100, 1],400)
|
|
153
|
-
# fig.savefig('wh.png', dpi=200)
|
|
154
|
-
|
|
155
|
-
# Evolve
|
|
156
|
-
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
|
|
157
|
-
pbar = tqdm(range(gen), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
|
158
|
-
for _ in pbar:
|
|
159
|
-
v = np.ones(sh)
|
|
160
|
-
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
|
161
|
-
v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
|
|
162
|
-
kg = (k.copy() * v).clip(min=2.0)
|
|
163
|
-
fg = anchor_fitness(kg)
|
|
164
|
-
if fg > f:
|
|
165
|
-
f, k = fg, kg.copy()
|
|
166
|
-
pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
|
|
167
|
-
if verbose:
|
|
168
|
-
print_results(k, verbose)
|
|
169
|
-
|
|
170
|
-
return print_results(k)
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
Auto-batch utils
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
from utils.general import LOGGER, colorstr, emojis
|
|
12
|
-
from utils.torch_utils import profile
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def check_train_batch_size(model, imgsz=640, amp=True):
|
|
16
|
-
# Check YOLOv5 training batch size
|
|
17
|
-
with torch.cuda.amp.autocast(amp):
|
|
18
|
-
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
|
22
|
-
# Automatically estimate best batch size to use `fraction` of available CUDA memory
|
|
23
|
-
# Usage:
|
|
24
|
-
# import torch
|
|
25
|
-
# from utils.autobatch import autobatch
|
|
26
|
-
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
|
|
27
|
-
# print(autobatch(model))
|
|
28
|
-
|
|
29
|
-
# Check device
|
|
30
|
-
prefix = colorstr('AutoBatch: ')
|
|
31
|
-
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
|
|
32
|
-
device = next(model.parameters()).device # get model device
|
|
33
|
-
if device.type == 'cpu':
|
|
34
|
-
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
|
35
|
-
return batch_size
|
|
36
|
-
|
|
37
|
-
# Inspect CUDA memory
|
|
38
|
-
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
|
39
|
-
d = str(device).upper() # 'CUDA:0'
|
|
40
|
-
properties = torch.cuda.get_device_properties(device) # device properties
|
|
41
|
-
t = properties.total_memory / gb # GiB total
|
|
42
|
-
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
|
43
|
-
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
|
44
|
-
f = t - (r + a) # GiB free
|
|
45
|
-
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
|
|
46
|
-
|
|
47
|
-
# Profile batch sizes
|
|
48
|
-
batch_sizes = [1, 2, 4, 8, 16]
|
|
49
|
-
try:
|
|
50
|
-
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
|
|
51
|
-
results = profile(img, model, n=3, device=device)
|
|
52
|
-
except Exception as e:
|
|
53
|
-
LOGGER.warning(f'{prefix}{e}')
|
|
54
|
-
|
|
55
|
-
# Fit a solution
|
|
56
|
-
y = [x[2] for x in results if x] # memory [2]
|
|
57
|
-
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
|
|
58
|
-
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
|
59
|
-
if None in results: # some sizes failed
|
|
60
|
-
i = results.index(None) # first fail index
|
|
61
|
-
if b >= batch_sizes[i]: # y intercept above failure point
|
|
62
|
-
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
|
63
|
-
|
|
64
|
-
fraction = np.polyval(p, b) / t # actual fraction predicted
|
|
65
|
-
LOGGER.info(emojis(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅'))
|
|
66
|
-
return b
|
|
File without changes
|