megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
@@ -1,738 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- Common modules
4
- """
5
-
6
- import json
7
- import math
8
- import platform
9
- import warnings
10
- from collections import OrderedDict, namedtuple
11
- from copy import copy
12
- from pathlib import Path
13
-
14
- import cv2
15
- import numpy as np
16
- import pandas as pd
17
- import requests
18
- import torch
19
- import torch.nn as nn
20
- import yaml
21
- from PIL import Image
22
- from torch.cuda import amp
23
-
24
- from utils.dataloaders import exif_transpose, letterbox
25
- from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
26
- make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
27
- from utils.plots import Annotator, colors, save_one_box
28
- from utils.torch_utils import copy_attr, time_sync
29
-
30
-
31
- def autopad(k, p=None): # kernel, padding
32
- # Pad to 'same'
33
- if p is None:
34
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
35
- return p
36
-
37
-
38
- class Conv(nn.Module):
39
- # Standard convolution
40
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
41
- super().__init__()
42
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
43
- self.bn = nn.BatchNorm2d(c2)
44
- self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
45
-
46
- def forward(self, x):
47
- return self.act(self.bn(self.conv(x)))
48
-
49
- def forward_fuse(self, x):
50
- return self.act(self.conv(x))
51
-
52
-
53
- class DWConv(Conv):
54
- # Depth-wise convolution class
55
- def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
56
- super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
57
-
58
-
59
- class DWConvTranspose2d(nn.ConvTranspose2d):
60
- # Depth-wise transpose convolution class
61
- def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
62
- super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
63
-
64
-
65
- class TransformerLayer(nn.Module):
66
- # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
67
- def __init__(self, c, num_heads):
68
- super().__init__()
69
- self.q = nn.Linear(c, c, bias=False)
70
- self.k = nn.Linear(c, c, bias=False)
71
- self.v = nn.Linear(c, c, bias=False)
72
- self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
73
- self.fc1 = nn.Linear(c, c, bias=False)
74
- self.fc2 = nn.Linear(c, c, bias=False)
75
-
76
- def forward(self, x):
77
- x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
78
- x = self.fc2(self.fc1(x)) + x
79
- return x
80
-
81
-
82
- class TransformerBlock(nn.Module):
83
- # Vision Transformer https://arxiv.org/abs/2010.11929
84
- def __init__(self, c1, c2, num_heads, num_layers):
85
- super().__init__()
86
- self.conv = None
87
- if c1 != c2:
88
- self.conv = Conv(c1, c2)
89
- self.linear = nn.Linear(c2, c2) # learnable position embedding
90
- self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
91
- self.c2 = c2
92
-
93
- def forward(self, x):
94
- if self.conv is not None:
95
- x = self.conv(x)
96
- b, _, w, h = x.shape
97
- p = x.flatten(2).permute(2, 0, 1)
98
- return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
99
-
100
-
101
- class Bottleneck(nn.Module):
102
- # Standard bottleneck
103
- def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
104
- super().__init__()
105
- c_ = int(c2 * e) # hidden channels
106
- self.cv1 = Conv(c1, c_, 1, 1)
107
- self.cv2 = Conv(c_, c2, 3, 1, g=g)
108
- self.add = shortcut and c1 == c2
109
-
110
- def forward(self, x):
111
- return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
112
-
113
-
114
- class BottleneckCSP(nn.Module):
115
- # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
116
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
117
- super().__init__()
118
- c_ = int(c2 * e) # hidden channels
119
- self.cv1 = Conv(c1, c_, 1, 1)
120
- self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
121
- self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
122
- self.cv4 = Conv(2 * c_, c2, 1, 1)
123
- self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
124
- self.act = nn.SiLU()
125
- self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
126
-
127
- def forward(self, x):
128
- y1 = self.cv3(self.m(self.cv1(x)))
129
- y2 = self.cv2(x)
130
- return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
131
-
132
-
133
- class CrossConv(nn.Module):
134
- # Cross Convolution Downsample
135
- def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
136
- # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
137
- super().__init__()
138
- c_ = int(c2 * e) # hidden channels
139
- self.cv1 = Conv(c1, c_, (1, k), (1, s))
140
- self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
141
- self.add = shortcut and c1 == c2
142
-
143
- def forward(self, x):
144
- return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
145
-
146
-
147
- class C3(nn.Module):
148
- # CSP Bottleneck with 3 convolutions
149
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
150
- super().__init__()
151
- c_ = int(c2 * e) # hidden channels
152
- self.cv1 = Conv(c1, c_, 1, 1)
153
- self.cv2 = Conv(c1, c_, 1, 1)
154
- self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
155
- self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
156
-
157
- def forward(self, x):
158
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
159
-
160
-
161
- class C3x(C3):
162
- # C3 module with cross-convolutions
163
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
164
- super().__init__(c1, c2, n, shortcut, g, e)
165
- c_ = int(c2 * e)
166
- self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
167
-
168
-
169
- class C3TR(C3):
170
- # C3 module with TransformerBlock()
171
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
172
- super().__init__(c1, c2, n, shortcut, g, e)
173
- c_ = int(c2 * e)
174
- self.m = TransformerBlock(c_, c_, 4, n)
175
-
176
-
177
- class C3SPP(C3):
178
- # C3 module with SPP()
179
- def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
180
- super().__init__(c1, c2, n, shortcut, g, e)
181
- c_ = int(c2 * e)
182
- self.m = SPP(c_, c_, k)
183
-
184
-
185
- class C3Ghost(C3):
186
- # C3 module with GhostBottleneck()
187
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
188
- super().__init__(c1, c2, n, shortcut, g, e)
189
- c_ = int(c2 * e) # hidden channels
190
- self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
191
-
192
-
193
- class SPP(nn.Module):
194
- # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
195
- def __init__(self, c1, c2, k=(5, 9, 13)):
196
- super().__init__()
197
- c_ = c1 // 2 # hidden channels
198
- self.cv1 = Conv(c1, c_, 1, 1)
199
- self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
200
- self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
201
-
202
- def forward(self, x):
203
- x = self.cv1(x)
204
- with warnings.catch_warnings():
205
- warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
206
- return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
207
-
208
-
209
- class SPPF(nn.Module):
210
- # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
211
- def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
212
- super().__init__()
213
- c_ = c1 // 2 # hidden channels
214
- self.cv1 = Conv(c1, c_, 1, 1)
215
- self.cv2 = Conv(c_ * 4, c2, 1, 1)
216
- self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
217
-
218
- def forward(self, x):
219
- x = self.cv1(x)
220
- with warnings.catch_warnings():
221
- warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
222
- y1 = self.m(x)
223
- y2 = self.m(y1)
224
- return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
225
-
226
-
227
- class Focus(nn.Module):
228
- # Focus wh information into c-space
229
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
230
- super().__init__()
231
- self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
232
- # self.contract = Contract(gain=2)
233
-
234
- def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
235
- return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
236
- # return self.conv(self.contract(x))
237
-
238
-
239
- class GhostConv(nn.Module):
240
- # Ghost Convolution https://github.com/huawei-noah/ghostnet
241
- def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
242
- super().__init__()
243
- c_ = c2 // 2 # hidden channels
244
- self.cv1 = Conv(c1, c_, k, s, None, g, act)
245
- self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
246
-
247
- def forward(self, x):
248
- y = self.cv1(x)
249
- return torch.cat((y, self.cv2(y)), 1)
250
-
251
-
252
- class GhostBottleneck(nn.Module):
253
- # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
254
- def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
255
- super().__init__()
256
- c_ = c2 // 2
257
- self.conv = nn.Sequential(
258
- GhostConv(c1, c_, 1, 1), # pw
259
- DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
260
- GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
261
- self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
262
- act=False)) if s == 2 else nn.Identity()
263
-
264
- def forward(self, x):
265
- return self.conv(x) + self.shortcut(x)
266
-
267
-
268
- class Contract(nn.Module):
269
- # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
270
- def __init__(self, gain=2):
271
- super().__init__()
272
- self.gain = gain
273
-
274
- def forward(self, x):
275
- b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
276
- s = self.gain
277
- x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
278
- x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
279
- return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
280
-
281
-
282
- class Expand(nn.Module):
283
- # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
284
- def __init__(self, gain=2):
285
- super().__init__()
286
- self.gain = gain
287
-
288
- def forward(self, x):
289
- b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
290
- s = self.gain
291
- x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
292
- x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
293
- return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
294
-
295
-
296
- class Concat(nn.Module):
297
- # Concatenate a list of tensors along dimension
298
- def __init__(self, dimension=1):
299
- super().__init__()
300
- self.d = dimension
301
-
302
- def forward(self, x):
303
- return torch.cat(x, self.d)
304
-
305
-
306
- class DetectMultiBackend(nn.Module):
307
- # YOLOv5 MultiBackend class for python inference on various backends
308
- def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
309
- # Usage:
310
- # PyTorch: weights = *.pt
311
- # TorchScript: *.torchscript
312
- # ONNX Runtime: *.onnx
313
- # ONNX OpenCV DNN: *.onnx with --dnn
314
- # OpenVINO: *.xml
315
- # CoreML: *.mlmodel
316
- # TensorRT: *.engine
317
- # TensorFlow SavedModel: *_saved_model
318
- # TensorFlow GraphDef: *.pb
319
- # TensorFlow Lite: *.tflite
320
- # TensorFlow Edge TPU: *_edgetpu.tflite
321
- from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
322
-
323
- super().__init__()
324
- w = str(weights[0] if isinstance(weights, list) else weights)
325
- pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
326
- w = attempt_download(w) # download if not local
327
- fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
328
- stride, names = 32, [f'class{i}' for i in range(1000)] # assign defaults
329
- if data: # assign class names (optional)
330
- with open(data, errors='ignore') as f:
331
- names = yaml.safe_load(f)['names']
332
-
333
- if pt: # PyTorch
334
- model = attempt_load(weights if isinstance(weights, list) else w, device=device)
335
- stride = max(int(model.stride.max()), 32) # model stride
336
- names = model.module.names if hasattr(model, 'module') else model.names # get class names
337
- model.half() if fp16 else model.float()
338
- self.model = model # explicitly assign for to(), cpu(), cuda(), half()
339
- elif jit: # TorchScript
340
- LOGGER.info(f'Loading {w} for TorchScript inference...')
341
- extra_files = {'config.txt': ''} # model metadata
342
- model = torch.jit.load(w, _extra_files=extra_files)
343
- model.half() if fp16 else model.float()
344
- if extra_files['config.txt']:
345
- d = json.loads(extra_files['config.txt']) # extra_files dict
346
- stride, names = int(d['stride']), d['names']
347
- elif dnn: # ONNX OpenCV DNN
348
- LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
349
- check_requirements(('opencv-python>=4.5.4',))
350
- net = cv2.dnn.readNetFromONNX(w)
351
- elif onnx: # ONNX Runtime
352
- LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
353
- cuda = torch.cuda.is_available()
354
- check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
355
- import onnxruntime
356
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
357
- session = onnxruntime.InferenceSession(w, providers=providers)
358
- meta = session.get_modelmeta().custom_metadata_map # metadata
359
- if 'stride' in meta:
360
- stride, names = int(meta['stride']), eval(meta['names'])
361
- elif xml: # OpenVINO
362
- LOGGER.info(f'Loading {w} for OpenVINO inference...')
363
- check_requirements(('openvino',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
364
- from openvino.runtime import Core
365
- ie = Core()
366
- if not Path(w).is_file(): # if not *.xml
367
- w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
368
- network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
369
- executable_network = ie.compile_model(model=network, device_name="CPU")
370
- output_layer = next(iter(executable_network.outputs))
371
- meta = Path(w).with_suffix('.yaml')
372
- if meta.exists():
373
- stride, names = self._load_metadata(meta) # load metadata
374
- elif engine: # TensorRT
375
- LOGGER.info(f'Loading {w} for TensorRT inference...')
376
- import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
377
- check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
378
- Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
379
- logger = trt.Logger(trt.Logger.INFO)
380
- with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
381
- model = runtime.deserialize_cuda_engine(f.read())
382
- bindings = OrderedDict()
383
- fp16 = False # default updated below
384
- for index in range(model.num_bindings):
385
- name = model.get_binding_name(index)
386
- dtype = trt.nptype(model.get_binding_dtype(index))
387
- shape = tuple(model.get_binding_shape(index))
388
- data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
389
- bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
390
- if model.binding_is_input(index) and dtype == np.float16:
391
- fp16 = True
392
- binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
393
- context = model.create_execution_context()
394
- batch_size = bindings['images'].shape[0]
395
- elif coreml: # CoreML
396
- LOGGER.info(f'Loading {w} for CoreML inference...')
397
- import coremltools as ct
398
- model = ct.models.MLModel(w)
399
- else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
400
- if saved_model: # SavedModel
401
- LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
402
- import tensorflow as tf
403
- keras = False # assume TF1 saved_model
404
- model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
405
- elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
406
- LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
407
- import tensorflow as tf
408
-
409
- def wrap_frozen_graph(gd, inputs, outputs):
410
- x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
411
- ge = x.graph.as_graph_element
412
- return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
413
-
414
- gd = tf.Graph().as_graph_def() # graph_def
415
- with open(w, 'rb') as f:
416
- gd.ParseFromString(f.read())
417
- frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
418
- elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
419
- try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
420
- from tflite_runtime.interpreter import Interpreter, load_delegate
421
- except ImportError:
422
- import tensorflow as tf
423
- Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
424
- if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
425
- LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
426
- delegate = {
427
- 'Linux': 'libedgetpu.so.1',
428
- 'Darwin': 'libedgetpu.1.dylib',
429
- 'Windows': 'edgetpu.dll'}[platform.system()]
430
- interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
431
- else: # Lite
432
- LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
433
- interpreter = Interpreter(model_path=w) # load TFLite model
434
- interpreter.allocate_tensors() # allocate
435
- input_details = interpreter.get_input_details() # inputs
436
- output_details = interpreter.get_output_details() # outputs
437
- elif tfjs:
438
- raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
439
- self.__dict__.update(locals()) # assign all variables to self
440
-
441
- def forward(self, im, augment=False, visualize=False, val=False):
442
- # YOLOv5 MultiBackend inference
443
- b, ch, h, w = im.shape # batch, channel, height, width
444
- if self.pt: # PyTorch
445
- y = self.model(im, augment=augment, visualize=visualize)[0]
446
- elif self.jit: # TorchScript
447
- y = self.model(im)[0]
448
- elif self.dnn: # ONNX OpenCV DNN
449
- im = im.cpu().numpy() # torch to numpy
450
- self.net.setInput(im)
451
- y = self.net.forward()
452
- elif self.onnx: # ONNX Runtime
453
- im = im.cpu().numpy() # torch to numpy
454
- y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
455
- elif self.xml: # OpenVINO
456
- im = im.cpu().numpy() # FP32
457
- y = self.executable_network([im])[self.output_layer]
458
- elif self.engine: # TensorRT
459
- assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
460
- self.binding_addrs['images'] = int(im.data_ptr())
461
- self.context.execute_v2(list(self.binding_addrs.values()))
462
- y = self.bindings['output'].data
463
- elif self.coreml: # CoreML
464
- im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
465
- im = Image.fromarray((im[0] * 255).astype('uint8'))
466
- # im = im.resize((192, 320), Image.ANTIALIAS)
467
- y = self.model.predict({'image': im}) # coordinates are xywh normalized
468
- if 'confidence' in y:
469
- box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
470
- conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
471
- y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
472
- else:
473
- k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
474
- y = y[k] # output
475
- else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
476
- im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
477
- if self.saved_model: # SavedModel
478
- y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
479
- elif self.pb: # GraphDef
480
- y = self.frozen_func(x=self.tf.constant(im)).numpy()
481
- else: # Lite or Edge TPU
482
- input, output = self.input_details[0], self.output_details[0]
483
- int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
484
- if int8:
485
- scale, zero_point = input['quantization']
486
- im = (im / scale + zero_point).astype(np.uint8) # de-scale
487
- self.interpreter.set_tensor(input['index'], im)
488
- self.interpreter.invoke()
489
- y = self.interpreter.get_tensor(output['index'])
490
- if int8:
491
- scale, zero_point = output['quantization']
492
- y = (y.astype(np.float32) - zero_point) * scale # re-scale
493
- y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
494
-
495
- if isinstance(y, np.ndarray):
496
- y = torch.tensor(y, device=self.device)
497
- return (y, []) if val else y
498
-
499
- def warmup(self, imgsz=(1, 3, 640, 640)):
500
- # Warmup model by running inference once
501
- warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
502
- if any(warmup_types) and self.device.type != 'cpu':
503
- im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
504
- for _ in range(2 if self.jit else 1): #
505
- self.forward(im) # warmup
506
-
507
- @staticmethod
508
- def model_type(p='path/to/model.pt'):
509
- # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
510
- from export import export_formats
511
- suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
512
- check_suffix(p, suffixes) # checks
513
- p = Path(p).name # eliminate trailing separators
514
- pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
515
- xml |= xml2 # *_openvino_model or *.xml
516
- tflite &= not edgetpu # *.tflite
517
- return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
518
-
519
- @staticmethod
520
- def _load_metadata(f='path/to/meta.yaml'):
521
- # Load metadata from meta.yaml if it exists
522
- with open(f, errors='ignore') as f:
523
- d = yaml.safe_load(f)
524
- return d['stride'], d['names'] # assign stride, names
525
-
526
-
527
- class AutoShape(nn.Module):
528
- # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
529
- conf = 0.25 # NMS confidence threshold
530
- iou = 0.45 # NMS IoU threshold
531
- agnostic = False # NMS class-agnostic
532
- multi_label = False # NMS multiple labels per box
533
- classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
534
- max_det = 1000 # maximum number of detections per image
535
- amp = False # Automatic Mixed Precision (AMP) inference
536
-
537
- def __init__(self, model, verbose=True):
538
- super().__init__()
539
- if verbose:
540
- LOGGER.info('Adding AutoShape... ')
541
- copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
542
- self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
543
- self.pt = not self.dmb or model.pt # PyTorch model
544
- self.model = model.eval()
545
-
546
- def _apply(self, fn):
547
- # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
548
- self = super()._apply(fn)
549
- if self.pt:
550
- m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
551
- m.stride = fn(m.stride)
552
- m.grid = list(map(fn, m.grid))
553
- if isinstance(m.anchor_grid, list):
554
- m.anchor_grid = list(map(fn, m.anchor_grid))
555
- return self
556
-
557
- @torch.no_grad()
558
- def forward(self, imgs, size=640, augment=False, profile=False):
559
- # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
560
- # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
561
- # URI: = 'https://ultralytics.com/images/zidane.jpg'
562
- # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
563
- # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
564
- # numpy: = np.zeros((640,1280,3)) # HWC
565
- # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
566
- # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
567
-
568
- t = [time_sync()]
569
- p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type
570
- autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
571
- if isinstance(imgs, torch.Tensor): # torch
572
- with amp.autocast(autocast):
573
- return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
574
-
575
- # Pre-process
576
- n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images
577
- shape0, shape1, files = [], [], [] # image and inference shapes, filenames
578
- for i, im in enumerate(imgs):
579
- f = f'image{i}' # filename
580
- if isinstance(im, (str, Path)): # filename or uri
581
- im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
582
- im = np.asarray(exif_transpose(im))
583
- elif isinstance(im, Image.Image): # PIL Image
584
- im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
585
- files.append(Path(f).with_suffix('.jpg').name)
586
- if im.shape[0] < 5: # image in CHW
587
- im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
588
- im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
589
- s = im.shape[:2] # HWC
590
- shape0.append(s) # image shape
591
- g = (size / max(s)) # gain
592
- shape1.append([y * g for y in s])
593
- imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
594
- shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
595
- x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
596
- x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
597
- x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
598
- t.append(time_sync())
599
-
600
- with amp.autocast(autocast):
601
- # Inference
602
- y = self.model(x, augment, profile) # forward
603
- t.append(time_sync())
604
-
605
- # Post-process
606
- y = non_max_suppression(y if self.dmb else y[0],
607
- self.conf,
608
- self.iou,
609
- self.classes,
610
- self.agnostic,
611
- self.multi_label,
612
- max_det=self.max_det) # NMS
613
- for i in range(n):
614
- scale_coords(shape1, y[i][:, :4], shape0[i])
615
-
616
- t.append(time_sync())
617
- return Detections(imgs, y, files, t, self.names, x.shape)
618
-
619
-
620
- class Detections:
621
- # YOLOv5 detections class for inference results
622
- def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
623
- super().__init__()
624
- d = pred[0].device # device
625
- gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
626
- self.imgs = imgs # list of images as numpy arrays
627
- self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
628
- self.names = names # class names
629
- self.files = files # image filenames
630
- self.times = times # profiling times
631
- self.xyxy = pred # xyxy pixels
632
- self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
633
- self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
634
- self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
635
- self.n = len(self.pred) # number of images (batch size)
636
- self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
637
- self.s = shape # inference BCHW shape
638
-
639
- def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
640
- crops = []
641
- for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
642
- s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
643
- if pred.shape[0]:
644
- for c in pred[:, -1].unique():
645
- n = (pred[:, -1] == c).sum() # detections per class
646
- s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
647
- if show or save or render or crop:
648
- annotator = Annotator(im, example=str(self.names))
649
- for *box, conf, cls in reversed(pred): # xyxy, confidence, class
650
- label = f'{self.names[int(cls)]} {conf:.2f}'
651
- if crop:
652
- file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
653
- crops.append({
654
- 'box': box,
655
- 'conf': conf,
656
- 'cls': cls,
657
- 'label': label,
658
- 'im': save_one_box(box, im, file=file, save=save)})
659
- else: # all others
660
- annotator.box_label(box, label if labels else '', color=colors(cls))
661
- im = annotator.im
662
- else:
663
- s += '(no detections)'
664
-
665
- im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
666
- if pprint:
667
- print(s.rstrip(', '))
668
- if show:
669
- im.show(self.files[i]) # show
670
- if save:
671
- f = self.files[i]
672
- im.save(save_dir / f) # save
673
- if i == self.n - 1:
674
- LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
675
- if render:
676
- self.imgs[i] = np.asarray(im)
677
- if crop:
678
- if save:
679
- LOGGER.info(f'Saved results to {save_dir}\n')
680
- return crops
681
-
682
- def print(self):
683
- self.display(pprint=True) # print results
684
- print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
685
-
686
- def show(self, labels=True):
687
- self.display(show=True, labels=labels) # show results
688
-
689
- def save(self, labels=True, save_dir='runs/detect/exp'):
690
- save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
691
- self.display(save=True, labels=labels, save_dir=save_dir) # save results
692
-
693
- def crop(self, save=True, save_dir='runs/detect/exp'):
694
- save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
695
- return self.display(crop=True, save=save, save_dir=save_dir) # crop results
696
-
697
- def render(self, labels=True):
698
- self.display(render=True, labels=labels) # render results
699
- return self.imgs
700
-
701
- def pandas(self):
702
- # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
703
- new = copy(self) # return copy
704
- ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
705
- cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
706
- for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
707
- a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
708
- setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
709
- return new
710
-
711
- def tolist(self):
712
- # return a list of Detections objects, i.e. 'for result in results.tolist():'
713
- r = range(self.n) # iterable
714
- x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
715
- # for d in x:
716
- # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
717
- # setattr(d, k, getattr(d, k)[0]) # pop out of list
718
- return x
719
-
720
- def __len__(self):
721
- return self.n # override len(results)
722
-
723
- def __str__(self):
724
- self.print() # override print(results)
725
- return ''
726
-
727
-
728
- class Classify(nn.Module):
729
- # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
730
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
731
- super().__init__()
732
- self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
733
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
734
- self.flat = nn.Flatten()
735
-
736
- def forward(self, x):
737
- z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
738
- return self.flat(self.conv(z)) # flatten to x(b,c2)