megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
@@ -1,338 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- YOLO-specific modules
4
-
5
- Usage:
6
- $ python path/to/models/yolo.py --cfg yolov5s.yaml
7
- """
8
-
9
- import argparse
10
- import os
11
- import platform
12
- import sys
13
- from copy import deepcopy
14
- from pathlib import Path
15
-
16
- FILE = Path(__file__).resolve()
17
- ROOT = FILE.parents[1] # YOLOv5 root directory
18
- if str(ROOT) not in sys.path:
19
- sys.path.append(str(ROOT)) # add ROOT to PATH
20
- if platform.system() != 'Windows':
21
- ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
22
-
23
- from models.common import *
24
- from models.experimental import *
25
- from utils.autoanchor import check_anchor_order
26
- from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
27
- from utils.plots import feature_visualization
28
- from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
29
- time_sync)
30
-
31
- try:
32
- import thop # for FLOPs computation
33
- except ImportError:
34
- thop = None
35
-
36
-
37
- class Detect(nn.Module):
38
- stride = None # strides computed during build
39
- onnx_dynamic = False # ONNX export parameter
40
- export = False # export mode
41
-
42
- def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
43
- super().__init__()
44
- self.nc = nc # number of classes
45
- self.no = nc + 5 # number of outputs per anchor
46
- self.nl = len(anchors) # number of detection layers
47
- self.na = len(anchors[0]) // 2 # number of anchors
48
- self.grid = [torch.zeros(1)] * self.nl # init grid
49
- self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
50
- self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
51
- self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
52
- self.inplace = inplace # use in-place ops (e.g. slice assignment)
53
-
54
- def forward(self, x):
55
- z = [] # inference output
56
- for i in range(self.nl):
57
- x[i] = self.m[i](x[i]) # conv
58
- bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
59
- x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
60
-
61
- if not self.training: # inference
62
- if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
63
- self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
64
-
65
- y = x[i].sigmoid()
66
- if self.inplace:
67
- y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
68
- y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
69
- else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
70
- xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
71
- xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
72
- wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
73
- y = torch.cat((xy, wh, conf), 4)
74
- z.append(y.view(bs, -1, self.no))
75
-
76
- return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
77
-
78
- def _make_grid(self, nx=20, ny=20, i=0):
79
- d = self.anchors[i].device
80
- t = self.anchors[i].dtype
81
- shape = 1, self.na, ny, nx, 2 # grid shape
82
- y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
83
- if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
84
- yv, xv = torch.meshgrid(y, x, indexing='ij')
85
- else:
86
- yv, xv = torch.meshgrid(y, x)
87
- grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
88
- anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
89
- return grid, anchor_grid
90
-
91
-
92
- class Model(nn.Module):
93
- # YOLOv5 model
94
- def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
95
- super().__init__()
96
- if isinstance(cfg, dict):
97
- self.yaml = cfg # model dict
98
- else: # is *.yaml
99
- import yaml # for torch hub
100
- self.yaml_file = Path(cfg).name
101
- with open(cfg, encoding='ascii', errors='ignore') as f:
102
- self.yaml = yaml.safe_load(f) # model dict
103
-
104
- # Define model
105
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
106
- if nc and nc != self.yaml['nc']:
107
- LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
108
- self.yaml['nc'] = nc # override yaml value
109
- if anchors:
110
- LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
111
- self.yaml['anchors'] = round(anchors) # override yaml value
112
- self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
113
- self.names = [str(i) for i in range(self.yaml['nc'])] # default names
114
- self.inplace = self.yaml.get('inplace', True)
115
-
116
- # Build strides, anchors
117
- m = self.model[-1] # Detect()
118
- if isinstance(m, Detect):
119
- s = 256 # 2x min stride
120
- m.inplace = self.inplace
121
- m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
122
- check_anchor_order(m) # must be in pixel-space (not grid-space)
123
- m.anchors /= m.stride.view(-1, 1, 1)
124
- self.stride = m.stride
125
- self._initialize_biases() # only run once
126
-
127
- # Init weights, biases
128
- initialize_weights(self)
129
- self.info()
130
- LOGGER.info('')
131
-
132
- def forward(self, x, augment=False, profile=False, visualize=False):
133
- if augment:
134
- return self._forward_augment(x) # augmented inference, None
135
- return self._forward_once(x, profile, visualize) # single-scale inference, train
136
-
137
- def _forward_augment(self, x):
138
- img_size = x.shape[-2:] # height, width
139
- s = [1, 0.83, 0.67] # scales
140
- f = [None, 3, None] # flips (2-ud, 3-lr)
141
- y = [] # outputs
142
- for si, fi in zip(s, f):
143
- xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
144
- yi = self._forward_once(xi)[0] # forward
145
- # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
146
- yi = self._descale_pred(yi, fi, si, img_size)
147
- y.append(yi)
148
- y = self._clip_augmented(y) # clip augmented tails
149
- return torch.cat(y, 1), None # augmented inference, train
150
-
151
- def _forward_once(self, x, profile=False, visualize=False):
152
- y, dt = [], [] # outputs
153
- for m in self.model:
154
- if m.f != -1: # if not from previous layer
155
- x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
156
- if profile:
157
- self._profile_one_layer(m, x, dt)
158
- x = m(x) # run
159
- y.append(x if m.i in self.save else None) # save output
160
- if visualize:
161
- feature_visualization(x, m.type, m.i, save_dir=visualize)
162
- return x
163
-
164
- def _descale_pred(self, p, flips, scale, img_size):
165
- # de-scale predictions following augmented inference (inverse operation)
166
- if self.inplace:
167
- p[..., :4] /= scale # de-scale
168
- if flips == 2:
169
- p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
170
- elif flips == 3:
171
- p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
172
- else:
173
- x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
174
- if flips == 2:
175
- y = img_size[0] - y # de-flip ud
176
- elif flips == 3:
177
- x = img_size[1] - x # de-flip lr
178
- p = torch.cat((x, y, wh, p[..., 4:]), -1)
179
- return p
180
-
181
- def _clip_augmented(self, y):
182
- # Clip YOLOv5 augmented inference tails
183
- nl = self.model[-1].nl # number of detection layers (P3-P5)
184
- g = sum(4 ** x for x in range(nl)) # grid points
185
- e = 1 # exclude layer count
186
- i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
187
- y[0] = y[0][:, :-i] # large
188
- i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
189
- y[-1] = y[-1][:, i:] # small
190
- return y
191
-
192
- def _profile_one_layer(self, m, x, dt):
193
- c = isinstance(m, Detect) # is final layer, copy input as inplace fix
194
- o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
195
- t = time_sync()
196
- for _ in range(10):
197
- m(x.copy() if c else x)
198
- dt.append((time_sync() - t) * 100)
199
- if m == self.model[0]:
200
- LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
201
- LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
202
- if c:
203
- LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
204
-
205
- def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
206
- # https://arxiv.org/abs/1708.02002 section 3.3
207
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
208
- m = self.model[-1] # Detect() module
209
- for mi, s in zip(m.m, m.stride): # from
210
- b = mi.bias.view(m.na, -1).detach() # conv.bias(255) to (3,85)
211
- b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
212
- b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
213
- mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
214
-
215
- def _print_biases(self):
216
- m = self.model[-1] # Detect() module
217
- for mi in m.m: # from
218
- b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
219
- LOGGER.info(
220
- ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
221
-
222
- # def _print_weights(self):
223
- # for m in self.model.modules():
224
- # if type(m) is Bottleneck:
225
- # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
226
-
227
- def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
228
- LOGGER.info('Fusing layers... ')
229
- for m in self.model.modules():
230
- if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
231
- m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
232
- delattr(m, 'bn') # remove batchnorm
233
- m.forward = m.forward_fuse # update forward
234
- self.info()
235
- return self
236
-
237
- def info(self, verbose=False, img_size=640): # print model information
238
- model_info(self, verbose, img_size)
239
-
240
- def _apply(self, fn):
241
- # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
242
- self = super()._apply(fn)
243
- m = self.model[-1] # Detect()
244
- if isinstance(m, Detect):
245
- m.stride = fn(m.stride)
246
- m.grid = list(map(fn, m.grid))
247
- if isinstance(m.anchor_grid, list):
248
- m.anchor_grid = list(map(fn, m.anchor_grid))
249
- return self
250
-
251
-
252
- def parse_model(d, ch): # model_dict, input_channels(3)
253
- LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
254
- anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
255
- na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
256
- no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
257
-
258
- layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
259
- for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
260
- m = eval(m) if isinstance(m, str) else m # eval strings
261
- for j, a in enumerate(args):
262
- try:
263
- args[j] = eval(a) if isinstance(a, str) else a # eval strings
264
- except NameError:
265
- pass
266
-
267
- n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
268
- if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
269
- BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
270
- c1, c2 = ch[f], args[0]
271
- if c2 != no: # if not output
272
- c2 = make_divisible(c2 * gw, 8)
273
-
274
- args = [c1, c2, *args[1:]]
275
- if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]:
276
- args.insert(2, n) # number of repeats
277
- n = 1
278
- elif m is nn.BatchNorm2d:
279
- args = [ch[f]]
280
- elif m is Concat:
281
- c2 = sum(ch[x] for x in f)
282
- elif m is Detect:
283
- args.append([ch[x] for x in f])
284
- if isinstance(args[1], int): # number of anchors
285
- args[1] = [list(range(args[1] * 2))] * len(f)
286
- elif m is Contract:
287
- c2 = ch[f] * args[0] ** 2
288
- elif m is Expand:
289
- c2 = ch[f] // args[0] ** 2
290
- else:
291
- c2 = ch[f]
292
-
293
- m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
294
- t = str(m)[8:-2].replace('__main__.', '') # module type
295
- np = sum(x.numel() for x in m_.parameters()) # number params
296
- m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
297
- LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
298
- save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
299
- layers.append(m_)
300
- if i == 0:
301
- ch = []
302
- ch.append(c2)
303
- return nn.Sequential(*layers), sorted(save)
304
-
305
-
306
- if __name__ == '__main__':
307
- parser = argparse.ArgumentParser()
308
- parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
309
- parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
310
- parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
311
- parser.add_argument('--profile', action='store_true', help='profile model speed')
312
- parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
313
- parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
314
- opt = parser.parse_args()
315
- opt.cfg = check_yaml(opt.cfg) # check YAML
316
- print_args(vars(opt))
317
- device = select_device(opt.device)
318
-
319
- # Create model
320
- im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
321
- model = Model(opt.cfg).to(device)
322
-
323
- # Options
324
- if opt.line_profile: # profile layer by layer
325
- _ = model(im, profile=True)
326
-
327
- elif opt.profile: # profile forward-backward
328
- results = profile(input=im, ops=[model], n=3)
329
-
330
- elif opt.test: # test all models
331
- for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
332
- try:
333
- _ = Model(cfg)
334
- except Exception as e:
335
- print(f'Error in {cfg}: {e}')
336
-
337
- else: # report fused model summary
338
- model.fuse()