megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
@@ -36,14 +36,23 @@ taxonomy_urls = {
36
36
  }
37
37
 
38
38
  files_to_unzip = {
39
- 'GBIF': ['backbone/Taxon.tsv', 'backbone/VernacularName.tsv'],
39
+ # GBIF used to put everything in a "backbone" folder within the zipfile, but as of
40
+ # 12.2023, this is no longer the case.
41
+ # 'GBIF': ['backbone/Taxon.tsv', 'backbone/VernacularName.tsv'],
42
+ 'GBIF': ['Taxon.tsv', 'VernacularName.tsv'],
40
43
  'iNaturalist': ['taxa.csv']
41
44
  }
42
45
 
43
46
  # As of 2020.05.12:
44
47
  #
45
48
  # GBIF: ~777MB zipped, ~1.6GB taxonomy
46
- # iNat: ~2.2GB zipped, ~51MB taxonomy
49
+ # iNat: ~2.2GB zipped, ~51MB taxonomy (most of the zipfile is observations)
50
+
51
+ # As of 2023.12.29:
52
+ #
53
+ # GBIF: ~948MB zipped, ~2.2GB taxonomy
54
+ # iNat: ~6.7GB zipped, ~62MB taxonomy (most of the zipfile is observations)
55
+
47
56
 
48
57
  os.makedirs(taxonomy_download_dir, exist_ok=True)
49
58
  for taxonomy_name in taxonomy_urls:
@@ -99,15 +108,16 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
99
108
  gbif_taxon_id_to_scientific,\
100
109
  gbif_scientific_to_taxon_id
101
110
 
111
+
102
112
  ## Load serialized taxonomy info if we've already saved it
103
113
 
104
114
  if (not force_init) and (inat_taxonomy is not None):
105
115
  print('Skipping taxonomy re-init')
106
116
  return
107
117
 
108
- if os.path.isfile(serialized_structures_file):
118
+ if (not force_init) and (os.path.isfile(serialized_structures_file)):
109
119
 
110
- print(f'Reading taxonomy data from {serialized_structures_file}')
120
+ print(f'De-serializing taxonomy data from {serialized_structures_file}')
111
121
 
112
122
  with open(serialized_structures_file, 'rb') as f:
113
123
  structures_to_serialize = pickle.load(f)
@@ -125,6 +135,7 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
125
135
  gbif_vernacular_to_taxon_id,\
126
136
  gbif_taxon_id_to_scientific,\
127
137
  gbif_scientific_to_taxon_id = structures_to_serialize
138
+
128
139
  return
129
140
 
130
141
 
@@ -135,6 +146,9 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
135
146
  for taxonomy_name, zip_url in taxonomy_urls.items():
136
147
 
137
148
  need_to_download = False
149
+
150
+ if force_init:
151
+ need_to_download = True
138
152
 
139
153
  # Don't download the zipfile if we've already unzipped what we need
140
154
  for fn in files_to_unzip[taxonomy_name]:
@@ -150,11 +164,11 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
150
164
  zipfile_path = os.path.join(
151
165
  taxonomy_download_dir, zip_url.split('/')[-1])
152
166
 
153
- # Bypasses download if the file exists already
167
+ # Bypasses download if the file exists already (unless force_init is set)
154
168
  url_utils.download_url(
155
169
  zip_url, os.path.join(zipfile_path),
156
170
  progress_updater=url_utils.DownloadProgressBar(),
157
- verbose=True)
171
+ verbose=True,force_download=force_init)
158
172
 
159
173
  # Unzip the files we need
160
174
  files_we_need = files_to_unzip[taxonomy_name]
@@ -166,7 +180,7 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
166
180
  target_file = os.path.join(
167
181
  taxonomy_download_dir, taxonomy_name, os.path.basename(fn))
168
182
 
169
- if os.path.isfile(target_file):
183
+ if (not force_init) and (os.path.isfile(target_file)):
170
184
  print(f'Bypassing unzip of {target_file}, file exists')
171
185
  else:
172
186
  os.makedirs(os.path.basename(target_file),exist_ok=True)
@@ -185,13 +199,16 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
185
199
  # name file
186
200
 
187
201
  # Load iNat taxonomy
188
- inat_taxonomy = pd.read_csv(os.path.join(taxonomy_download_dir, 'iNaturalist', 'taxa.csv'))
202
+ inat_taxonomy_file = os.path.join(taxonomy_download_dir, 'iNaturalist', 'taxa.csv')
203
+ print('Loading iNat taxonomy from {}'.format(inat_taxonomy_file))
204
+ inat_taxonomy = pd.read_csv(inat_taxonomy_file)
189
205
  inat_taxonomy['scientificName'] = inat_taxonomy['scientificName'].fillna('').str.strip()
190
206
  inat_taxonomy['vernacularName'] = inat_taxonomy['vernacularName'].fillna('').str.strip()
191
207
 
192
208
  # Load GBIF taxonomy
193
- gbif_taxonomy = pd.read_csv(os.path.join(
194
- taxonomy_download_dir, 'GBIF', 'Taxon.tsv'), sep='\t')
209
+ gbif_taxonomy_file = os.path.join(taxonomy_download_dir, 'GBIF', 'Taxon.tsv')
210
+ print('Loading GBIF taxonomy from {}'.format(gbif_taxonomy_file))
211
+ gbif_taxonomy = pd.read_csv(gbif_taxonomy_file, sep='\t')
195
212
  gbif_taxonomy['scientificName'] = gbif_taxonomy['scientificName'].fillna('').str.strip()
196
213
  gbif_taxonomy['canonicalName'] = gbif_taxonomy['canonicalName'].fillna('').str.strip()
197
214
 
@@ -249,7 +266,8 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
249
266
 
250
267
  # Build iNat dictionaries
251
268
 
252
- # row = inat_taxonomy.iloc[0]
269
+ print('Building lookup dictionaries for iNat taxonomy')
270
+
253
271
  for i_row, row in tqdm(inat_taxonomy.iterrows(), total=len(inat_taxonomy)):
254
272
 
255
273
  taxon_id = row['taxonID']
@@ -267,6 +285,8 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
267
285
 
268
286
  # Build GBIF dictionaries
269
287
 
288
+ print('Building lookup dictionaries for GBIF taxonomy')
289
+
270
290
  for i_row, row in tqdm(gbif_taxonomy.iterrows(), total=len(gbif_taxonomy)):
271
291
 
272
292
  taxon_id = row['taxonID']
@@ -320,13 +340,13 @@ def initialize_taxonomy_lookup(force_init=False) -> None:
320
340
  gbif_scientific_to_taxon_id
321
341
  ]
322
342
 
323
- print('Serializing...', end='')
343
+ print('Serializing to {}...'.format(serialized_structures_file), end='')
324
344
  if not os.path.isfile(serialized_structures_file):
325
345
  with open(serialized_structures_file, 'wb') as p:
326
346
  pickle.dump(structures_to_serialize, p)
327
347
  print(' done')
328
348
 
329
- # ...def initialize_taxonomy_lookup()
349
+ # ...def initialize_taxonomy_lookup(...)
330
350
 
331
351
 
332
352
  def get_scientific_name_from_row(r):
@@ -45,7 +45,7 @@ def check_taxonomy_csv(csv_path: str) -> None:
45
45
  num_taxon_level_errors = 0
46
46
  num_scientific_name_errors = 0
47
47
 
48
- for i, row in taxonomy_df.iterrows():
48
+ for i_row, row in taxonomy_df.iterrows():
49
49
 
50
50
  ds = row['dataset_name']
51
51
  ds_label = row['query']
@@ -81,14 +81,14 @@ def check_taxonomy_csv(csv_path: str) -> None:
81
81
  node.add_id(id_source, int(taxon_id)) # np.int64 -> int
82
82
  if j == 0:
83
83
  if level != taxon_level:
84
- print(f'row: {i}, {ds}, {ds_label}')
84
+ print(f'row: {i_row}, {ds}, {ds_label}')
85
85
  print(f'- taxonomy_level column: {level}, '
86
86
  f'level from taxonomy_string: {taxon_level}')
87
87
  print()
88
88
  num_taxon_level_errors += 1
89
89
 
90
90
  if scientific_name != taxon_name:
91
- print(f'row: {i}, {ds}, {ds_label}')
91
+ print(f'row: {i_row}, {ds}, {ds_label}')
92
92
  print(f'- scientific_name column: {scientific_name}, '
93
93
  f'name from taxonomy_string: {taxon_name}')
94
94
  print()
@@ -97,7 +97,7 @@ def check_taxonomy_csv(csv_path: str) -> None:
97
97
  taxon_child = node
98
98
 
99
99
  # ...for each row in the taxonomy file
100
-
100
+
101
101
  assert nx.is_directed_acyclic_graph(graph)
102
102
 
103
103
  for node in graph.nodes:
@@ -123,6 +123,8 @@ def check_taxonomy_csv(csv_path: str) -> None:
123
123
  except AssertionError as e:
124
124
  print(f'At least one node has unresolved ambiguous parents: {e}')
125
125
 
126
+ print('Processed {} rows from {}'.format(len(taxonomy_df),csv_path))
127
+
126
128
  print('num taxon level errors:', num_taxon_level_errors)
127
129
  print('num scientific name errors:', num_scientific_name_errors)
128
130
 
@@ -154,4 +156,4 @@ if False:
154
156
  import os
155
157
  csv_path = os.path.expanduser('~/lila/lila-taxonomy-mapping_release.csv')
156
158
  check_taxonomy_csv(csv_path)
157
-
159
+
@@ -1,252 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- Run inference on images, videos, directories, streams, etc.
4
-
5
- Usage - sources:
6
- $ python path/to/detect.py --weights yolov5s.pt --source 0 # webcam
7
- img.jpg # image
8
- vid.mp4 # video
9
- path/ # directory
10
- path/*.jpg # glob
11
- 'https://youtu.be/Zgi9g1ksQHc' # YouTube
12
- 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
13
-
14
- Usage - formats:
15
- $ python path/to/detect.py --weights yolov5s.pt # PyTorch
16
- yolov5s.torchscript # TorchScript
17
- yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
18
- yolov5s.xml # OpenVINO
19
- yolov5s.engine # TensorRT
20
- yolov5s.mlmodel # CoreML (macOS-only)
21
- yolov5s_saved_model # TensorFlow SavedModel
22
- yolov5s.pb # TensorFlow GraphDef
23
- yolov5s.tflite # TensorFlow Lite
24
- yolov5s_edgetpu.tflite # TensorFlow Edge TPU
25
- """
26
-
27
- import argparse
28
- import os
29
- import sys
30
- from pathlib import Path
31
-
32
- import torch
33
- import torch.backends.cudnn as cudnn
34
-
35
- FILE = Path(__file__).resolve()
36
- ROOT = FILE.parents[0] # YOLOv5 root directory
37
- if str(ROOT) not in sys.path:
38
- sys.path.append(str(ROOT)) # add ROOT to PATH
39
- ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
40
-
41
- from models.common import DetectMultiBackend
42
- from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
43
- from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
44
- increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
45
- from utils.plots import Annotator, colors, save_one_box
46
- from utils.torch_utils import select_device, time_sync
47
-
48
-
49
- @torch.no_grad()
50
- def run(
51
- weights=ROOT / 'yolov5s.pt', # model.pt path(s)
52
- source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
53
- data=ROOT / 'data/coco128.yaml', # dataset.yaml path
54
- imgsz=(640, 640), # inference size (height, width)
55
- conf_thres=0.25, # confidence threshold
56
- iou_thres=0.45, # NMS IOU threshold
57
- max_det=1000, # maximum detections per image
58
- device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
59
- view_img=False, # show results
60
- save_txt=False, # save results to *.txt
61
- save_conf=False, # save confidences in --save-txt labels
62
- save_crop=False, # save cropped prediction boxes
63
- nosave=False, # do not save images/videos
64
- classes=None, # filter by class: --class 0, or --class 0 2 3
65
- agnostic_nms=False, # class-agnostic NMS
66
- augment=False, # augmented inference
67
- visualize=False, # visualize features
68
- update=False, # update all models
69
- project=ROOT / 'runs/detect', # save results to project/name
70
- name='exp', # save results to project/name
71
- exist_ok=False, # existing project/name ok, do not increment
72
- line_thickness=3, # bounding box thickness (pixels)
73
- hide_labels=False, # hide labels
74
- hide_conf=False, # hide confidences
75
- half=False, # use FP16 half-precision inference
76
- dnn=False, # use OpenCV DNN for ONNX inference
77
- ):
78
- source = str(source)
79
- save_img = not nosave and not source.endswith('.txt') # save inference images
80
- is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
81
- is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
82
- webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
83
- if is_url and is_file:
84
- source = check_file(source) # download
85
-
86
- # Directories
87
- save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
88
- (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
89
-
90
- # Load model
91
- device = select_device(device)
92
- model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
93
- stride, names, pt = model.stride, model.names, model.pt
94
- imgsz = check_img_size(imgsz, s=stride) # check image size
95
-
96
- # Dataloader
97
- if webcam:
98
- view_img = check_imshow()
99
- cudnn.benchmark = True # set True to speed up constant image size inference
100
- dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
101
- bs = len(dataset) # batch_size
102
- else:
103
- dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
104
- bs = 1 # batch_size
105
- vid_path, vid_writer = [None] * bs, [None] * bs
106
-
107
- # Run inference
108
- model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
109
- dt, seen = [0.0, 0.0, 0.0], 0
110
- for path, im, im0s, vid_cap, s in dataset:
111
- t1 = time_sync()
112
- im = torch.from_numpy(im).to(device)
113
- im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
114
- im /= 255 # 0 - 255 to 0.0 - 1.0
115
- if len(im.shape) == 3:
116
- im = im[None] # expand for batch dim
117
- t2 = time_sync()
118
- dt[0] += t2 - t1
119
-
120
- # Inference
121
- visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
122
- pred = model(im, augment=augment, visualize=visualize)
123
- t3 = time_sync()
124
- dt[1] += t3 - t2
125
-
126
- # NMS
127
- pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
128
- dt[2] += time_sync() - t3
129
-
130
- # Second-stage classifier (optional)
131
- # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
132
-
133
- # Process predictions
134
- for i, det in enumerate(pred): # per image
135
- seen += 1
136
- if webcam: # batch_size >= 1
137
- p, im0, frame = path[i], im0s[i].copy(), dataset.count
138
- s += f'{i}: '
139
- else:
140
- p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
141
-
142
- p = Path(p) # to Path
143
- save_path = str(save_dir / p.name) # im.jpg
144
- txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
145
- s += '%gx%g ' % im.shape[2:] # print string
146
- gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
147
- imc = im0.copy() if save_crop else im0 # for save_crop
148
- annotator = Annotator(im0, line_width=line_thickness, example=str(names))
149
- if len(det):
150
- # Rescale boxes from img_size to im0 size
151
- det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
152
-
153
- # Print results
154
- for c in det[:, -1].unique():
155
- n = (det[:, -1] == c).sum() # detections per class
156
- s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
157
-
158
- # Write results
159
- for *xyxy, conf, cls in reversed(det):
160
- if save_txt: # Write to file
161
- xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
162
- line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
163
- with open(f'{txt_path}.txt', 'a') as f:
164
- f.write(('%g ' * len(line)).rstrip() % line + '\n')
165
-
166
- if save_img or save_crop or view_img: # Add bbox to image
167
- c = int(cls) # integer class
168
- label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
169
- annotator.box_label(xyxy, label, color=colors(c, True))
170
- if save_crop:
171
- save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
172
-
173
- # Stream results
174
- im0 = annotator.result()
175
- if view_img:
176
- cv2.imshow(str(p), im0)
177
- cv2.waitKey(1) # 1 millisecond
178
-
179
- # Save results (image with detections)
180
- if save_img:
181
- if dataset.mode == 'image':
182
- cv2.imwrite(save_path, im0)
183
- else: # 'video' or 'stream'
184
- if vid_path[i] != save_path: # new video
185
- vid_path[i] = save_path
186
- if isinstance(vid_writer[i], cv2.VideoWriter):
187
- vid_writer[i].release() # release previous video writer
188
- if vid_cap: # video
189
- fps = vid_cap.get(cv2.CAP_PROP_FPS)
190
- w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
191
- h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
192
- else: # stream
193
- fps, w, h = 30, im0.shape[1], im0.shape[0]
194
- save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
195
- vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
196
- vid_writer[i].write(im0)
197
-
198
- # Print time (inference-only)
199
- LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
200
-
201
- # Print results
202
- t = tuple(x / seen * 1E3 for x in dt) # speeds per image
203
- LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
204
- if save_txt or save_img:
205
- s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
206
- LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
207
- if update:
208
- strip_optimizer(weights) # update model (to fix SourceChangeWarning)
209
-
210
-
211
- def parse_opt():
212
- parser = argparse.ArgumentParser()
213
- parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
214
- parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
215
- parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
216
- parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
217
- parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
218
- parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
219
- parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
220
- parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
221
- parser.add_argument('--view-img', action='store_true', help='show results')
222
- parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
223
- parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
224
- parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
225
- parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
226
- parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
227
- parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
228
- parser.add_argument('--augment', action='store_true', help='augmented inference')
229
- parser.add_argument('--visualize', action='store_true', help='visualize features')
230
- parser.add_argument('--update', action='store_true', help='update all models')
231
- parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
232
- parser.add_argument('--name', default='exp', help='save results to project/name')
233
- parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
234
- parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
235
- parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
236
- parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
237
- parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
238
- parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
239
- opt = parser.parse_args()
240
- opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
241
- print_args(vars(opt))
242
- return opt
243
-
244
-
245
- def main(opt):
246
- check_requirements(exclude=('tensorboard', 'thop'))
247
- run(**vars(opt))
248
-
249
-
250
- if __name__ == "__main__":
251
- opt = parse_opt()
252
- main(opt)