megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- api/batch_processing/data_preparation/manage_local_batch.py +302 -263
- api/batch_processing/data_preparation/manage_video_batch.py +81 -2
- api/batch_processing/postprocessing/add_max_conf.py +1 -0
- api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
- api/batch_processing/postprocessing/compare_batch_results.py +110 -60
- api/batch_processing/postprocessing/load_api_results.py +56 -70
- api/batch_processing/postprocessing/md_to_coco.py +1 -1
- api/batch_processing/postprocessing/md_to_labelme.py +2 -1
- api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
- api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
- api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
- api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
- api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
- api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
- api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
- api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
- classification/prepare_classification_script.py +191 -191
- data_management/coco_to_yolo.py +68 -45
- data_management/databases/integrity_check_json_db.py +7 -5
- data_management/generate_crops_from_cct.py +3 -3
- data_management/get_image_sizes.py +8 -6
- data_management/importers/add_timestamps_to_icct.py +79 -0
- data_management/importers/animl_results_to_md_results.py +160 -0
- data_management/importers/auckland_doc_test_to_json.py +4 -4
- data_management/importers/auckland_doc_to_json.py +1 -1
- data_management/importers/awc_to_json.py +5 -5
- data_management/importers/bellevue_to_json.py +5 -5
- data_management/importers/carrizo_shrubfree_2018.py +5 -5
- data_management/importers/carrizo_trail_cam_2017.py +5 -5
- data_management/importers/cct_field_adjustments.py +2 -3
- data_management/importers/channel_islands_to_cct.py +4 -4
- data_management/importers/ena24_to_json.py +5 -5
- data_management/importers/helena_to_cct.py +10 -10
- data_management/importers/idaho-camera-traps.py +12 -12
- data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
- data_management/importers/jb_csv_to_json.py +4 -4
- data_management/importers/missouri_to_json.py +1 -1
- data_management/importers/noaa_seals_2019.py +1 -1
- data_management/importers/pc_to_json.py +5 -5
- data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
- data_management/importers/prepare_zsl_imerit.py +5 -5
- data_management/importers/rspb_to_json.py +4 -4
- data_management/importers/save_the_elephants_survey_A.py +5 -5
- data_management/importers/save_the_elephants_survey_B.py +6 -6
- data_management/importers/snapshot_safari_importer.py +9 -9
- data_management/importers/snapshot_serengeti_lila.py +9 -9
- data_management/importers/timelapse_csv_set_to_json.py +5 -7
- data_management/importers/ubc_to_json.py +4 -4
- data_management/importers/umn_to_json.py +4 -4
- data_management/importers/wellington_to_json.py +1 -1
- data_management/importers/wi_to_json.py +2 -2
- data_management/importers/zamba_results_to_md_results.py +181 -0
- data_management/labelme_to_coco.py +35 -7
- data_management/labelme_to_yolo.py +229 -0
- data_management/lila/add_locations_to_island_camera_traps.py +1 -1
- data_management/lila/add_locations_to_nacti.py +147 -0
- data_management/lila/create_lila_blank_set.py +474 -0
- data_management/lila/create_lila_test_set.py +2 -1
- data_management/lila/create_links_to_md_results_files.py +106 -0
- data_management/lila/download_lila_subset.py +46 -21
- data_management/lila/generate_lila_per_image_labels.py +23 -14
- data_management/lila/get_lila_annotation_counts.py +17 -11
- data_management/lila/lila_common.py +14 -11
- data_management/lila/test_lila_metadata_urls.py +116 -0
- data_management/ocr_tools.py +829 -0
- data_management/resize_coco_dataset.py +13 -11
- data_management/yolo_output_to_md_output.py +84 -12
- data_management/yolo_to_coco.py +38 -20
- detection/process_video.py +36 -14
- detection/pytorch_detector.py +23 -8
- detection/run_detector.py +76 -19
- detection/run_detector_batch.py +178 -63
- detection/run_inference_with_yolov5_val.py +326 -57
- detection/run_tiled_inference.py +153 -43
- detection/video_utils.py +34 -8
- md_utils/ct_utils.py +172 -1
- md_utils/md_tests.py +372 -51
- md_utils/path_utils.py +167 -39
- md_utils/process_utils.py +26 -7
- md_utils/split_locations_into_train_val.py +215 -0
- md_utils/string_utils.py +10 -0
- md_utils/url_utils.py +0 -2
- md_utils/write_html_image_list.py +9 -26
- md_visualization/plot_utils.py +12 -8
- md_visualization/visualization_utils.py +106 -7
- md_visualization/visualize_db.py +16 -8
- md_visualization/visualize_detector_output.py +208 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
- taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
- taxonomy_mapping/map_new_lila_datasets.py +43 -39
- taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
- taxonomy_mapping/preview_lila_taxonomy.py +27 -27
- taxonomy_mapping/species_lookup.py +33 -13
- taxonomy_mapping/taxonomy_csv_checker.py +7 -5
- api/synchronous/api_core/yolov5/detect.py +0 -252
- api/synchronous/api_core/yolov5/export.py +0 -607
- api/synchronous/api_core/yolov5/hubconf.py +0 -146
- api/synchronous/api_core/yolov5/models/__init__.py +0 -0
- api/synchronous/api_core/yolov5/models/common.py +0 -738
- api/synchronous/api_core/yolov5/models/experimental.py +0 -104
- api/synchronous/api_core/yolov5/models/tf.py +0 -574
- api/synchronous/api_core/yolov5/models/yolo.py +0 -338
- api/synchronous/api_core/yolov5/train.py +0 -670
- api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
- api/synchronous/api_core/yolov5/utils/activations.py +0 -103
- api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
- api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
- api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
- api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
- api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
- api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
- api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
- api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
- api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
- api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
- api/synchronous/api_core/yolov5/utils/general.py +0 -1018
- api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
- api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
- api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
- api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
- api/synchronous/api_core/yolov5/utils/loss.py +0 -234
- api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
- api/synchronous/api_core/yolov5/utils/plots.py +0 -489
- api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
- api/synchronous/api_core/yolov5/val.py +0 -394
- md_utils/matlab_porting_tools.py +0 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
|
@@ -1,338 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
YOLO-specific modules
|
|
4
|
-
|
|
5
|
-
Usage:
|
|
6
|
-
$ python path/to/models/yolo.py --cfg yolov5s.yaml
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import argparse
|
|
10
|
-
import os
|
|
11
|
-
import platform
|
|
12
|
-
import sys
|
|
13
|
-
from copy import deepcopy
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
|
|
16
|
-
FILE = Path(__file__).resolve()
|
|
17
|
-
ROOT = FILE.parents[1] # YOLOv5 root directory
|
|
18
|
-
if str(ROOT) not in sys.path:
|
|
19
|
-
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
20
|
-
if platform.system() != 'Windows':
|
|
21
|
-
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
22
|
-
|
|
23
|
-
from models.common import *
|
|
24
|
-
from models.experimental import *
|
|
25
|
-
from utils.autoanchor import check_anchor_order
|
|
26
|
-
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
|
|
27
|
-
from utils.plots import feature_visualization
|
|
28
|
-
from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
|
|
29
|
-
time_sync)
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
import thop # for FLOPs computation
|
|
33
|
-
except ImportError:
|
|
34
|
-
thop = None
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class Detect(nn.Module):
|
|
38
|
-
stride = None # strides computed during build
|
|
39
|
-
onnx_dynamic = False # ONNX export parameter
|
|
40
|
-
export = False # export mode
|
|
41
|
-
|
|
42
|
-
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
|
43
|
-
super().__init__()
|
|
44
|
-
self.nc = nc # number of classes
|
|
45
|
-
self.no = nc + 5 # number of outputs per anchor
|
|
46
|
-
self.nl = len(anchors) # number of detection layers
|
|
47
|
-
self.na = len(anchors[0]) // 2 # number of anchors
|
|
48
|
-
self.grid = [torch.zeros(1)] * self.nl # init grid
|
|
49
|
-
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
|
|
50
|
-
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
|
51
|
-
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
|
52
|
-
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
|
53
|
-
|
|
54
|
-
def forward(self, x):
|
|
55
|
-
z = [] # inference output
|
|
56
|
-
for i in range(self.nl):
|
|
57
|
-
x[i] = self.m[i](x[i]) # conv
|
|
58
|
-
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
|
59
|
-
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
60
|
-
|
|
61
|
-
if not self.training: # inference
|
|
62
|
-
if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
|
63
|
-
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
|
64
|
-
|
|
65
|
-
y = x[i].sigmoid()
|
|
66
|
-
if self.inplace:
|
|
67
|
-
y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
|
|
68
|
-
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
|
69
|
-
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
|
70
|
-
xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
|
|
71
|
-
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
|
|
72
|
-
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
|
|
73
|
-
y = torch.cat((xy, wh, conf), 4)
|
|
74
|
-
z.append(y.view(bs, -1, self.no))
|
|
75
|
-
|
|
76
|
-
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
|
77
|
-
|
|
78
|
-
def _make_grid(self, nx=20, ny=20, i=0):
|
|
79
|
-
d = self.anchors[i].device
|
|
80
|
-
t = self.anchors[i].dtype
|
|
81
|
-
shape = 1, self.na, ny, nx, 2 # grid shape
|
|
82
|
-
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
|
83
|
-
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
|
84
|
-
yv, xv = torch.meshgrid(y, x, indexing='ij')
|
|
85
|
-
else:
|
|
86
|
-
yv, xv = torch.meshgrid(y, x)
|
|
87
|
-
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
|
88
|
-
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
|
89
|
-
return grid, anchor_grid
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class Model(nn.Module):
|
|
93
|
-
# YOLOv5 model
|
|
94
|
-
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
|
95
|
-
super().__init__()
|
|
96
|
-
if isinstance(cfg, dict):
|
|
97
|
-
self.yaml = cfg # model dict
|
|
98
|
-
else: # is *.yaml
|
|
99
|
-
import yaml # for torch hub
|
|
100
|
-
self.yaml_file = Path(cfg).name
|
|
101
|
-
with open(cfg, encoding='ascii', errors='ignore') as f:
|
|
102
|
-
self.yaml = yaml.safe_load(f) # model dict
|
|
103
|
-
|
|
104
|
-
# Define model
|
|
105
|
-
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
|
106
|
-
if nc and nc != self.yaml['nc']:
|
|
107
|
-
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
|
108
|
-
self.yaml['nc'] = nc # override yaml value
|
|
109
|
-
if anchors:
|
|
110
|
-
LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
|
|
111
|
-
self.yaml['anchors'] = round(anchors) # override yaml value
|
|
112
|
-
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
|
113
|
-
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
|
114
|
-
self.inplace = self.yaml.get('inplace', True)
|
|
115
|
-
|
|
116
|
-
# Build strides, anchors
|
|
117
|
-
m = self.model[-1] # Detect()
|
|
118
|
-
if isinstance(m, Detect):
|
|
119
|
-
s = 256 # 2x min stride
|
|
120
|
-
m.inplace = self.inplace
|
|
121
|
-
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
|
122
|
-
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
|
123
|
-
m.anchors /= m.stride.view(-1, 1, 1)
|
|
124
|
-
self.stride = m.stride
|
|
125
|
-
self._initialize_biases() # only run once
|
|
126
|
-
|
|
127
|
-
# Init weights, biases
|
|
128
|
-
initialize_weights(self)
|
|
129
|
-
self.info()
|
|
130
|
-
LOGGER.info('')
|
|
131
|
-
|
|
132
|
-
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
133
|
-
if augment:
|
|
134
|
-
return self._forward_augment(x) # augmented inference, None
|
|
135
|
-
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
|
136
|
-
|
|
137
|
-
def _forward_augment(self, x):
|
|
138
|
-
img_size = x.shape[-2:] # height, width
|
|
139
|
-
s = [1, 0.83, 0.67] # scales
|
|
140
|
-
f = [None, 3, None] # flips (2-ud, 3-lr)
|
|
141
|
-
y = [] # outputs
|
|
142
|
-
for si, fi in zip(s, f):
|
|
143
|
-
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
|
144
|
-
yi = self._forward_once(xi)[0] # forward
|
|
145
|
-
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
|
146
|
-
yi = self._descale_pred(yi, fi, si, img_size)
|
|
147
|
-
y.append(yi)
|
|
148
|
-
y = self._clip_augmented(y) # clip augmented tails
|
|
149
|
-
return torch.cat(y, 1), None # augmented inference, train
|
|
150
|
-
|
|
151
|
-
def _forward_once(self, x, profile=False, visualize=False):
|
|
152
|
-
y, dt = [], [] # outputs
|
|
153
|
-
for m in self.model:
|
|
154
|
-
if m.f != -1: # if not from previous layer
|
|
155
|
-
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
|
156
|
-
if profile:
|
|
157
|
-
self._profile_one_layer(m, x, dt)
|
|
158
|
-
x = m(x) # run
|
|
159
|
-
y.append(x if m.i in self.save else None) # save output
|
|
160
|
-
if visualize:
|
|
161
|
-
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
|
162
|
-
return x
|
|
163
|
-
|
|
164
|
-
def _descale_pred(self, p, flips, scale, img_size):
|
|
165
|
-
# de-scale predictions following augmented inference (inverse operation)
|
|
166
|
-
if self.inplace:
|
|
167
|
-
p[..., :4] /= scale # de-scale
|
|
168
|
-
if flips == 2:
|
|
169
|
-
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
|
170
|
-
elif flips == 3:
|
|
171
|
-
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
|
172
|
-
else:
|
|
173
|
-
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
|
174
|
-
if flips == 2:
|
|
175
|
-
y = img_size[0] - y # de-flip ud
|
|
176
|
-
elif flips == 3:
|
|
177
|
-
x = img_size[1] - x # de-flip lr
|
|
178
|
-
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
|
179
|
-
return p
|
|
180
|
-
|
|
181
|
-
def _clip_augmented(self, y):
|
|
182
|
-
# Clip YOLOv5 augmented inference tails
|
|
183
|
-
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
|
184
|
-
g = sum(4 ** x for x in range(nl)) # grid points
|
|
185
|
-
e = 1 # exclude layer count
|
|
186
|
-
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
|
|
187
|
-
y[0] = y[0][:, :-i] # large
|
|
188
|
-
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
|
189
|
-
y[-1] = y[-1][:, i:] # small
|
|
190
|
-
return y
|
|
191
|
-
|
|
192
|
-
def _profile_one_layer(self, m, x, dt):
|
|
193
|
-
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
|
|
194
|
-
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
|
195
|
-
t = time_sync()
|
|
196
|
-
for _ in range(10):
|
|
197
|
-
m(x.copy() if c else x)
|
|
198
|
-
dt.append((time_sync() - t) * 100)
|
|
199
|
-
if m == self.model[0]:
|
|
200
|
-
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
|
201
|
-
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
|
202
|
-
if c:
|
|
203
|
-
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
204
|
-
|
|
205
|
-
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
|
206
|
-
# https://arxiv.org/abs/1708.02002 section 3.3
|
|
207
|
-
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
|
208
|
-
m = self.model[-1] # Detect() module
|
|
209
|
-
for mi, s in zip(m.m, m.stride): # from
|
|
210
|
-
b = mi.bias.view(m.na, -1).detach() # conv.bias(255) to (3,85)
|
|
211
|
-
b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
|
|
212
|
-
b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
|
213
|
-
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
|
214
|
-
|
|
215
|
-
def _print_biases(self):
|
|
216
|
-
m = self.model[-1] # Detect() module
|
|
217
|
-
for mi in m.m: # from
|
|
218
|
-
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
|
|
219
|
-
LOGGER.info(
|
|
220
|
-
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
|
|
221
|
-
|
|
222
|
-
# def _print_weights(self):
|
|
223
|
-
# for m in self.model.modules():
|
|
224
|
-
# if type(m) is Bottleneck:
|
|
225
|
-
# LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
|
|
226
|
-
|
|
227
|
-
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
|
228
|
-
LOGGER.info('Fusing layers... ')
|
|
229
|
-
for m in self.model.modules():
|
|
230
|
-
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
|
231
|
-
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
232
|
-
delattr(m, 'bn') # remove batchnorm
|
|
233
|
-
m.forward = m.forward_fuse # update forward
|
|
234
|
-
self.info()
|
|
235
|
-
return self
|
|
236
|
-
|
|
237
|
-
def info(self, verbose=False, img_size=640): # print model information
|
|
238
|
-
model_info(self, verbose, img_size)
|
|
239
|
-
|
|
240
|
-
def _apply(self, fn):
|
|
241
|
-
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
|
242
|
-
self = super()._apply(fn)
|
|
243
|
-
m = self.model[-1] # Detect()
|
|
244
|
-
if isinstance(m, Detect):
|
|
245
|
-
m.stride = fn(m.stride)
|
|
246
|
-
m.grid = list(map(fn, m.grid))
|
|
247
|
-
if isinstance(m.anchor_grid, list):
|
|
248
|
-
m.anchor_grid = list(map(fn, m.anchor_grid))
|
|
249
|
-
return self
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def parse_model(d, ch): # model_dict, input_channels(3)
|
|
253
|
-
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
|
254
|
-
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
|
|
255
|
-
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
|
256
|
-
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
|
257
|
-
|
|
258
|
-
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
|
259
|
-
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
|
260
|
-
m = eval(m) if isinstance(m, str) else m # eval strings
|
|
261
|
-
for j, a in enumerate(args):
|
|
262
|
-
try:
|
|
263
|
-
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
|
264
|
-
except NameError:
|
|
265
|
-
pass
|
|
266
|
-
|
|
267
|
-
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
|
268
|
-
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
|
|
269
|
-
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
|
|
270
|
-
c1, c2 = ch[f], args[0]
|
|
271
|
-
if c2 != no: # if not output
|
|
272
|
-
c2 = make_divisible(c2 * gw, 8)
|
|
273
|
-
|
|
274
|
-
args = [c1, c2, *args[1:]]
|
|
275
|
-
if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]:
|
|
276
|
-
args.insert(2, n) # number of repeats
|
|
277
|
-
n = 1
|
|
278
|
-
elif m is nn.BatchNorm2d:
|
|
279
|
-
args = [ch[f]]
|
|
280
|
-
elif m is Concat:
|
|
281
|
-
c2 = sum(ch[x] for x in f)
|
|
282
|
-
elif m is Detect:
|
|
283
|
-
args.append([ch[x] for x in f])
|
|
284
|
-
if isinstance(args[1], int): # number of anchors
|
|
285
|
-
args[1] = [list(range(args[1] * 2))] * len(f)
|
|
286
|
-
elif m is Contract:
|
|
287
|
-
c2 = ch[f] * args[0] ** 2
|
|
288
|
-
elif m is Expand:
|
|
289
|
-
c2 = ch[f] // args[0] ** 2
|
|
290
|
-
else:
|
|
291
|
-
c2 = ch[f]
|
|
292
|
-
|
|
293
|
-
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
|
294
|
-
t = str(m)[8:-2].replace('__main__.', '') # module type
|
|
295
|
-
np = sum(x.numel() for x in m_.parameters()) # number params
|
|
296
|
-
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
|
297
|
-
LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
|
|
298
|
-
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
299
|
-
layers.append(m_)
|
|
300
|
-
if i == 0:
|
|
301
|
-
ch = []
|
|
302
|
-
ch.append(c2)
|
|
303
|
-
return nn.Sequential(*layers), sorted(save)
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
if __name__ == '__main__':
|
|
307
|
-
parser = argparse.ArgumentParser()
|
|
308
|
-
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
|
|
309
|
-
parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
|
|
310
|
-
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
311
|
-
parser.add_argument('--profile', action='store_true', help='profile model speed')
|
|
312
|
-
parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
|
|
313
|
-
parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
|
|
314
|
-
opt = parser.parse_args()
|
|
315
|
-
opt.cfg = check_yaml(opt.cfg) # check YAML
|
|
316
|
-
print_args(vars(opt))
|
|
317
|
-
device = select_device(opt.device)
|
|
318
|
-
|
|
319
|
-
# Create model
|
|
320
|
-
im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
|
|
321
|
-
model = Model(opt.cfg).to(device)
|
|
322
|
-
|
|
323
|
-
# Options
|
|
324
|
-
if opt.line_profile: # profile layer by layer
|
|
325
|
-
_ = model(im, profile=True)
|
|
326
|
-
|
|
327
|
-
elif opt.profile: # profile forward-backward
|
|
328
|
-
results = profile(input=im, ops=[model], n=3)
|
|
329
|
-
|
|
330
|
-
elif opt.test: # test all models
|
|
331
|
-
for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
|
|
332
|
-
try:
|
|
333
|
-
_ = Model(cfg)
|
|
334
|
-
except Exception as e:
|
|
335
|
-
print(f'Error in {cfg}: {e}')
|
|
336
|
-
|
|
337
|
-
else: # report fused model summary
|
|
338
|
-
model.fuse()
|