megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- api/batch_processing/data_preparation/manage_local_batch.py +302 -263
- api/batch_processing/data_preparation/manage_video_batch.py +81 -2
- api/batch_processing/postprocessing/add_max_conf.py +1 -0
- api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
- api/batch_processing/postprocessing/compare_batch_results.py +110 -60
- api/batch_processing/postprocessing/load_api_results.py +56 -70
- api/batch_processing/postprocessing/md_to_coco.py +1 -1
- api/batch_processing/postprocessing/md_to_labelme.py +2 -1
- api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
- api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
- api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
- api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
- api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
- api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
- api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
- api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
- classification/prepare_classification_script.py +191 -191
- data_management/coco_to_yolo.py +68 -45
- data_management/databases/integrity_check_json_db.py +7 -5
- data_management/generate_crops_from_cct.py +3 -3
- data_management/get_image_sizes.py +8 -6
- data_management/importers/add_timestamps_to_icct.py +79 -0
- data_management/importers/animl_results_to_md_results.py +160 -0
- data_management/importers/auckland_doc_test_to_json.py +4 -4
- data_management/importers/auckland_doc_to_json.py +1 -1
- data_management/importers/awc_to_json.py +5 -5
- data_management/importers/bellevue_to_json.py +5 -5
- data_management/importers/carrizo_shrubfree_2018.py +5 -5
- data_management/importers/carrizo_trail_cam_2017.py +5 -5
- data_management/importers/cct_field_adjustments.py +2 -3
- data_management/importers/channel_islands_to_cct.py +4 -4
- data_management/importers/ena24_to_json.py +5 -5
- data_management/importers/helena_to_cct.py +10 -10
- data_management/importers/idaho-camera-traps.py +12 -12
- data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
- data_management/importers/jb_csv_to_json.py +4 -4
- data_management/importers/missouri_to_json.py +1 -1
- data_management/importers/noaa_seals_2019.py +1 -1
- data_management/importers/pc_to_json.py +5 -5
- data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
- data_management/importers/prepare_zsl_imerit.py +5 -5
- data_management/importers/rspb_to_json.py +4 -4
- data_management/importers/save_the_elephants_survey_A.py +5 -5
- data_management/importers/save_the_elephants_survey_B.py +6 -6
- data_management/importers/snapshot_safari_importer.py +9 -9
- data_management/importers/snapshot_serengeti_lila.py +9 -9
- data_management/importers/timelapse_csv_set_to_json.py +5 -7
- data_management/importers/ubc_to_json.py +4 -4
- data_management/importers/umn_to_json.py +4 -4
- data_management/importers/wellington_to_json.py +1 -1
- data_management/importers/wi_to_json.py +2 -2
- data_management/importers/zamba_results_to_md_results.py +181 -0
- data_management/labelme_to_coco.py +35 -7
- data_management/labelme_to_yolo.py +229 -0
- data_management/lila/add_locations_to_island_camera_traps.py +1 -1
- data_management/lila/add_locations_to_nacti.py +147 -0
- data_management/lila/create_lila_blank_set.py +474 -0
- data_management/lila/create_lila_test_set.py +2 -1
- data_management/lila/create_links_to_md_results_files.py +106 -0
- data_management/lila/download_lila_subset.py +46 -21
- data_management/lila/generate_lila_per_image_labels.py +23 -14
- data_management/lila/get_lila_annotation_counts.py +17 -11
- data_management/lila/lila_common.py +14 -11
- data_management/lila/test_lila_metadata_urls.py +116 -0
- data_management/ocr_tools.py +829 -0
- data_management/resize_coco_dataset.py +13 -11
- data_management/yolo_output_to_md_output.py +84 -12
- data_management/yolo_to_coco.py +38 -20
- detection/process_video.py +36 -14
- detection/pytorch_detector.py +23 -8
- detection/run_detector.py +76 -19
- detection/run_detector_batch.py +178 -63
- detection/run_inference_with_yolov5_val.py +326 -57
- detection/run_tiled_inference.py +153 -43
- detection/video_utils.py +34 -8
- md_utils/ct_utils.py +172 -1
- md_utils/md_tests.py +372 -51
- md_utils/path_utils.py +167 -39
- md_utils/process_utils.py +26 -7
- md_utils/split_locations_into_train_val.py +215 -0
- md_utils/string_utils.py +10 -0
- md_utils/url_utils.py +0 -2
- md_utils/write_html_image_list.py +9 -26
- md_visualization/plot_utils.py +12 -8
- md_visualization/visualization_utils.py +106 -7
- md_visualization/visualize_db.py +16 -8
- md_visualization/visualize_detector_output.py +208 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
- taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
- taxonomy_mapping/map_new_lila_datasets.py +43 -39
- taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
- taxonomy_mapping/preview_lila_taxonomy.py +27 -27
- taxonomy_mapping/species_lookup.py +33 -13
- taxonomy_mapping/taxonomy_csv_checker.py +7 -5
- api/synchronous/api_core/yolov5/detect.py +0 -252
- api/synchronous/api_core/yolov5/export.py +0 -607
- api/synchronous/api_core/yolov5/hubconf.py +0 -146
- api/synchronous/api_core/yolov5/models/__init__.py +0 -0
- api/synchronous/api_core/yolov5/models/common.py +0 -738
- api/synchronous/api_core/yolov5/models/experimental.py +0 -104
- api/synchronous/api_core/yolov5/models/tf.py +0 -574
- api/synchronous/api_core/yolov5/models/yolo.py +0 -338
- api/synchronous/api_core/yolov5/train.py +0 -670
- api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
- api/synchronous/api_core/yolov5/utils/activations.py +0 -103
- api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
- api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
- api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
- api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
- api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
- api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
- api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
- api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
- api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
- api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
- api/synchronous/api_core/yolov5/utils/general.py +0 -1018
- api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
- api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
- api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
- api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
- api/synchronous/api_core/yolov5/utils/loss.py +0 -234
- api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
- api/synchronous/api_core/yolov5/utils/plots.py +0 -489
- api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
- api/synchronous/api_core/yolov5/val.py +0 -394
- md_utils/matlab_porting_tools.py +0 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
|
@@ -1,607 +0,0 @@
|
|
|
1
|
-
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
|
|
4
|
-
|
|
5
|
-
Format | `export.py --include` | Model
|
|
6
|
-
--- | --- | ---
|
|
7
|
-
PyTorch | - | yolov5s.pt
|
|
8
|
-
TorchScript | `torchscript` | yolov5s.torchscript
|
|
9
|
-
ONNX | `onnx` | yolov5s.onnx
|
|
10
|
-
OpenVINO | `openvino` | yolov5s_openvino_model/
|
|
11
|
-
TensorRT | `engine` | yolov5s.engine
|
|
12
|
-
CoreML | `coreml` | yolov5s.mlmodel
|
|
13
|
-
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
|
|
14
|
-
TensorFlow GraphDef | `pb` | yolov5s.pb
|
|
15
|
-
TensorFlow Lite | `tflite` | yolov5s.tflite
|
|
16
|
-
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
|
|
17
|
-
TensorFlow.js | `tfjs` | yolov5s_web_model/
|
|
18
|
-
|
|
19
|
-
Requirements:
|
|
20
|
-
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
|
21
|
-
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
|
22
|
-
|
|
23
|
-
Usage:
|
|
24
|
-
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
|
|
25
|
-
|
|
26
|
-
Inference:
|
|
27
|
-
$ python path/to/detect.py --weights yolov5s.pt # PyTorch
|
|
28
|
-
yolov5s.torchscript # TorchScript
|
|
29
|
-
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
|
30
|
-
yolov5s.xml # OpenVINO
|
|
31
|
-
yolov5s.engine # TensorRT
|
|
32
|
-
yolov5s.mlmodel # CoreML (macOS-only)
|
|
33
|
-
yolov5s_saved_model # TensorFlow SavedModel
|
|
34
|
-
yolov5s.pb # TensorFlow GraphDef
|
|
35
|
-
yolov5s.tflite # TensorFlow Lite
|
|
36
|
-
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
|
37
|
-
|
|
38
|
-
TensorFlow.js:
|
|
39
|
-
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
|
40
|
-
$ npm install
|
|
41
|
-
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
|
|
42
|
-
$ npm start
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
import argparse
|
|
46
|
-
import json
|
|
47
|
-
import os
|
|
48
|
-
import platform
|
|
49
|
-
import subprocess
|
|
50
|
-
import sys
|
|
51
|
-
import time
|
|
52
|
-
import warnings
|
|
53
|
-
from pathlib import Path
|
|
54
|
-
|
|
55
|
-
import pandas as pd
|
|
56
|
-
import torch
|
|
57
|
-
import yaml
|
|
58
|
-
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
59
|
-
|
|
60
|
-
FILE = Path(__file__).resolve()
|
|
61
|
-
ROOT = FILE.parents[0] # YOLOv5 root directory
|
|
62
|
-
if str(ROOT) not in sys.path:
|
|
63
|
-
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
64
|
-
if platform.system() != 'Windows':
|
|
65
|
-
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
66
|
-
|
|
67
|
-
from models.experimental import attempt_load
|
|
68
|
-
from models.yolo import Detect
|
|
69
|
-
from utils.dataloaders import LoadImages
|
|
70
|
-
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
|
|
71
|
-
file_size, print_args, url2file)
|
|
72
|
-
from utils.torch_utils import select_device
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def export_formats():
|
|
76
|
-
# YOLOv5 export formats
|
|
77
|
-
x = [
|
|
78
|
-
['PyTorch', '-', '.pt', True],
|
|
79
|
-
['TorchScript', 'torchscript', '.torchscript', True],
|
|
80
|
-
['ONNX', 'onnx', '.onnx', True],
|
|
81
|
-
['OpenVINO', 'openvino', '_openvino_model', False],
|
|
82
|
-
['TensorRT', 'engine', '.engine', True],
|
|
83
|
-
['CoreML', 'coreml', '.mlmodel', False],
|
|
84
|
-
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
|
|
85
|
-
['TensorFlow GraphDef', 'pb', '.pb', True],
|
|
86
|
-
['TensorFlow Lite', 'tflite', '.tflite', False],
|
|
87
|
-
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
|
|
88
|
-
['TensorFlow.js', 'tfjs', '_web_model', False],]
|
|
89
|
-
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
|
93
|
-
# YOLOv5 TorchScript model export
|
|
94
|
-
try:
|
|
95
|
-
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
|
96
|
-
f = file.with_suffix('.torchscript')
|
|
97
|
-
|
|
98
|
-
ts = torch.jit.trace(model, im, strict=False)
|
|
99
|
-
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
|
|
100
|
-
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
|
101
|
-
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
|
102
|
-
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
|
103
|
-
else:
|
|
104
|
-
ts.save(str(f), _extra_files=extra_files)
|
|
105
|
-
|
|
106
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
107
|
-
return f
|
|
108
|
-
except Exception as e:
|
|
109
|
-
LOGGER.info(f'{prefix} export failure: {e}')
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
|
|
113
|
-
# YOLOv5 ONNX export
|
|
114
|
-
try:
|
|
115
|
-
check_requirements(('onnx',))
|
|
116
|
-
import onnx
|
|
117
|
-
|
|
118
|
-
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
|
119
|
-
f = file.with_suffix('.onnx')
|
|
120
|
-
|
|
121
|
-
torch.onnx.export(
|
|
122
|
-
model,
|
|
123
|
-
im,
|
|
124
|
-
f,
|
|
125
|
-
verbose=False,
|
|
126
|
-
opset_version=opset,
|
|
127
|
-
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
|
|
128
|
-
do_constant_folding=not train,
|
|
129
|
-
input_names=['images'],
|
|
130
|
-
output_names=['output'],
|
|
131
|
-
dynamic_axes={
|
|
132
|
-
'images': {
|
|
133
|
-
0: 'batch',
|
|
134
|
-
2: 'height',
|
|
135
|
-
3: 'width'}, # shape(1,3,640,640)
|
|
136
|
-
'output': {
|
|
137
|
-
0: 'batch',
|
|
138
|
-
1: 'anchors'} # shape(1,25200,85)
|
|
139
|
-
} if dynamic else None)
|
|
140
|
-
|
|
141
|
-
# Checks
|
|
142
|
-
model_onnx = onnx.load(f) # load onnx model
|
|
143
|
-
onnx.checker.check_model(model_onnx) # check onnx model
|
|
144
|
-
|
|
145
|
-
# Metadata
|
|
146
|
-
d = {'stride': int(max(model.stride)), 'names': model.names}
|
|
147
|
-
for k, v in d.items():
|
|
148
|
-
meta = model_onnx.metadata_props.add()
|
|
149
|
-
meta.key, meta.value = k, str(v)
|
|
150
|
-
onnx.save(model_onnx, f)
|
|
151
|
-
|
|
152
|
-
# Simplify
|
|
153
|
-
if simplify:
|
|
154
|
-
try:
|
|
155
|
-
check_requirements(('onnx-simplifier',))
|
|
156
|
-
import onnxsim
|
|
157
|
-
|
|
158
|
-
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
|
|
159
|
-
model_onnx, check = onnxsim.simplify(model_onnx,
|
|
160
|
-
dynamic_input_shape=dynamic,
|
|
161
|
-
input_shapes={'images': list(im.shape)} if dynamic else None)
|
|
162
|
-
assert check, 'assert check failed'
|
|
163
|
-
onnx.save(model_onnx, f)
|
|
164
|
-
except Exception as e:
|
|
165
|
-
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
|
166
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
167
|
-
return f
|
|
168
|
-
except Exception as e:
|
|
169
|
-
LOGGER.info(f'{prefix} export failure: {e}')
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
|
|
173
|
-
# YOLOv5 OpenVINO export
|
|
174
|
-
try:
|
|
175
|
-
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
|
176
|
-
import openvino.inference_engine as ie
|
|
177
|
-
|
|
178
|
-
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
|
179
|
-
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
|
180
|
-
|
|
181
|
-
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
|
|
182
|
-
subprocess.check_output(cmd.split()) # export
|
|
183
|
-
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
|
|
184
|
-
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
|
|
185
|
-
|
|
186
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
187
|
-
return f
|
|
188
|
-
except Exception as e:
|
|
189
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
|
|
193
|
-
# YOLOv5 CoreML export
|
|
194
|
-
try:
|
|
195
|
-
check_requirements(('coremltools',))
|
|
196
|
-
import coremltools as ct
|
|
197
|
-
|
|
198
|
-
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
|
199
|
-
f = file.with_suffix('.mlmodel')
|
|
200
|
-
|
|
201
|
-
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
|
|
202
|
-
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
|
|
203
|
-
bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
|
|
204
|
-
if bits < 32:
|
|
205
|
-
if platform.system() == 'Darwin': # quantization only supported on macOS
|
|
206
|
-
with warnings.catch_warnings():
|
|
207
|
-
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
|
|
208
|
-
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
|
209
|
-
else:
|
|
210
|
-
print(f'{prefix} quantization only supported on macOS, skipping...')
|
|
211
|
-
ct_model.save(f)
|
|
212
|
-
|
|
213
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
214
|
-
return ct_model, f
|
|
215
|
-
except Exception as e:
|
|
216
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
217
|
-
return None, None
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
|
221
|
-
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
|
222
|
-
try:
|
|
223
|
-
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
|
224
|
-
try:
|
|
225
|
-
import tensorrt as trt
|
|
226
|
-
except Exception:
|
|
227
|
-
if platform.system() == 'Linux':
|
|
228
|
-
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
|
|
229
|
-
import tensorrt as trt
|
|
230
|
-
|
|
231
|
-
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
|
232
|
-
grid = model.model[-1].anchor_grid
|
|
233
|
-
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
|
234
|
-
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|
|
235
|
-
model.model[-1].anchor_grid = grid
|
|
236
|
-
else: # TensorRT >= 8
|
|
237
|
-
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
|
238
|
-
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
|
|
239
|
-
onnx = file.with_suffix('.onnx')
|
|
240
|
-
|
|
241
|
-
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
|
242
|
-
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
|
243
|
-
f = file.with_suffix('.engine') # TensorRT engine file
|
|
244
|
-
logger = trt.Logger(trt.Logger.INFO)
|
|
245
|
-
if verbose:
|
|
246
|
-
logger.min_severity = trt.Logger.Severity.VERBOSE
|
|
247
|
-
|
|
248
|
-
builder = trt.Builder(logger)
|
|
249
|
-
config = builder.create_builder_config()
|
|
250
|
-
config.max_workspace_size = workspace * 1 << 30
|
|
251
|
-
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
|
252
|
-
|
|
253
|
-
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
|
254
|
-
network = builder.create_network(flag)
|
|
255
|
-
parser = trt.OnnxParser(network, logger)
|
|
256
|
-
if not parser.parse_from_file(str(onnx)):
|
|
257
|
-
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
|
258
|
-
|
|
259
|
-
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
|
260
|
-
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
|
261
|
-
LOGGER.info(f'{prefix} Network Description:')
|
|
262
|
-
for inp in inputs:
|
|
263
|
-
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
|
264
|
-
for out in outputs:
|
|
265
|
-
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
|
266
|
-
|
|
267
|
-
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
|
268
|
-
if builder.platform_has_fast_fp16 and half:
|
|
269
|
-
config.set_flag(trt.BuilderFlag.FP16)
|
|
270
|
-
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
|
271
|
-
t.write(engine.serialize())
|
|
272
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
273
|
-
return f
|
|
274
|
-
except Exception as e:
|
|
275
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
def export_saved_model(model,
|
|
279
|
-
im,
|
|
280
|
-
file,
|
|
281
|
-
dynamic,
|
|
282
|
-
tf_nms=False,
|
|
283
|
-
agnostic_nms=False,
|
|
284
|
-
topk_per_class=100,
|
|
285
|
-
topk_all=100,
|
|
286
|
-
iou_thres=0.45,
|
|
287
|
-
conf_thres=0.25,
|
|
288
|
-
keras=False,
|
|
289
|
-
prefix=colorstr('TensorFlow SavedModel:')):
|
|
290
|
-
# YOLOv5 TensorFlow SavedModel export
|
|
291
|
-
try:
|
|
292
|
-
import tensorflow as tf
|
|
293
|
-
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
|
294
|
-
|
|
295
|
-
from models.tf import TFDetect, TFModel
|
|
296
|
-
|
|
297
|
-
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
|
298
|
-
f = str(file).replace('.pt', '_saved_model')
|
|
299
|
-
batch_size, ch, *imgsz = list(im.shape) # BCHW
|
|
300
|
-
|
|
301
|
-
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
|
302
|
-
im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
|
|
303
|
-
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
|
304
|
-
inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
|
|
305
|
-
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
|
306
|
-
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
|
307
|
-
keras_model.trainable = False
|
|
308
|
-
keras_model.summary()
|
|
309
|
-
if keras:
|
|
310
|
-
keras_model.save(f, save_format='tf')
|
|
311
|
-
else:
|
|
312
|
-
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
|
|
313
|
-
m = tf.function(lambda x: keras_model(x)) # full model
|
|
314
|
-
m = m.get_concrete_function(spec)
|
|
315
|
-
frozen_func = convert_variables_to_constants_v2(m)
|
|
316
|
-
tfm = tf.Module()
|
|
317
|
-
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
|
|
318
|
-
tfm.__call__(im)
|
|
319
|
-
tf.saved_model.save(tfm,
|
|
320
|
-
f,
|
|
321
|
-
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)
|
|
322
|
-
if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
|
|
323
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
324
|
-
return keras_model, f
|
|
325
|
-
except Exception as e:
|
|
326
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
327
|
-
return None, None
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
|
|
331
|
-
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
|
332
|
-
try:
|
|
333
|
-
import tensorflow as tf
|
|
334
|
-
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
|
335
|
-
|
|
336
|
-
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
|
337
|
-
f = file.with_suffix('.pb')
|
|
338
|
-
|
|
339
|
-
m = tf.function(lambda x: keras_model(x)) # full model
|
|
340
|
-
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
|
341
|
-
frozen_func = convert_variables_to_constants_v2(m)
|
|
342
|
-
frozen_func.graph.as_graph_def()
|
|
343
|
-
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
|
344
|
-
|
|
345
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
346
|
-
return f
|
|
347
|
-
except Exception as e:
|
|
348
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
|
352
|
-
# YOLOv5 TensorFlow Lite export
|
|
353
|
-
try:
|
|
354
|
-
import tensorflow as tf
|
|
355
|
-
|
|
356
|
-
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
|
357
|
-
batch_size, ch, *imgsz = list(im.shape) # BCHW
|
|
358
|
-
f = str(file).replace('.pt', '-fp16.tflite')
|
|
359
|
-
|
|
360
|
-
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
|
361
|
-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
|
362
|
-
converter.target_spec.supported_types = [tf.float16]
|
|
363
|
-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
364
|
-
if int8:
|
|
365
|
-
from models.tf import representative_dataset_gen
|
|
366
|
-
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
|
|
367
|
-
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
|
|
368
|
-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
|
369
|
-
converter.target_spec.supported_types = []
|
|
370
|
-
converter.inference_input_type = tf.uint8 # or tf.int8
|
|
371
|
-
converter.inference_output_type = tf.uint8 # or tf.int8
|
|
372
|
-
converter.experimental_new_quantizer = True
|
|
373
|
-
f = str(file).replace('.pt', '-int8.tflite')
|
|
374
|
-
if nms or agnostic_nms:
|
|
375
|
-
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
|
376
|
-
|
|
377
|
-
tflite_model = converter.convert()
|
|
378
|
-
open(f, "wb").write(tflite_model)
|
|
379
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
380
|
-
return f
|
|
381
|
-
except Exception as e:
|
|
382
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
|
|
386
|
-
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
|
|
387
|
-
try:
|
|
388
|
-
cmd = 'edgetpu_compiler --version'
|
|
389
|
-
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
|
390
|
-
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
|
|
391
|
-
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
|
|
392
|
-
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
|
393
|
-
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
|
394
|
-
for c in (
|
|
395
|
-
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
|
|
396
|
-
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
|
397
|
-
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
|
|
398
|
-
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
|
|
399
|
-
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
|
400
|
-
|
|
401
|
-
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
|
402
|
-
f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
|
|
403
|
-
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
|
|
404
|
-
|
|
405
|
-
cmd = f"edgetpu_compiler -s -o {file.parent} {f_tfl}"
|
|
406
|
-
subprocess.run(cmd.split(), check=True)
|
|
407
|
-
|
|
408
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
409
|
-
return f
|
|
410
|
-
except Exception as e:
|
|
411
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
|
|
415
|
-
# YOLOv5 TensorFlow.js export
|
|
416
|
-
try:
|
|
417
|
-
check_requirements(('tensorflowjs',))
|
|
418
|
-
import re
|
|
419
|
-
|
|
420
|
-
import tensorflowjs as tfjs
|
|
421
|
-
|
|
422
|
-
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
|
423
|
-
f = str(file).replace('.pt', '_web_model') # js dir
|
|
424
|
-
f_pb = file.with_suffix('.pb') # *.pb path
|
|
425
|
-
f_json = f'{f}/model.json' # *.json path
|
|
426
|
-
|
|
427
|
-
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
|
|
428
|
-
f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
|
|
429
|
-
subprocess.run(cmd.split())
|
|
430
|
-
|
|
431
|
-
with open(f_json) as j:
|
|
432
|
-
json = j.read()
|
|
433
|
-
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
|
434
|
-
subst = re.sub(
|
|
435
|
-
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
|
436
|
-
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
|
437
|
-
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
|
438
|
-
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
|
|
439
|
-
r'"Identity_1": {"name": "Identity_1"}, '
|
|
440
|
-
r'"Identity_2": {"name": "Identity_2"}, '
|
|
441
|
-
r'"Identity_3": {"name": "Identity_3"}}}', json)
|
|
442
|
-
j.write(subst)
|
|
443
|
-
|
|
444
|
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
|
445
|
-
return f
|
|
446
|
-
except Exception as e:
|
|
447
|
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
@torch.no_grad()
|
|
451
|
-
def run(
|
|
452
|
-
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
|
453
|
-
weights=ROOT / 'yolov5s.pt', # weights path
|
|
454
|
-
imgsz=(640, 640), # image (height, width)
|
|
455
|
-
batch_size=1, # batch size
|
|
456
|
-
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
|
457
|
-
include=('torchscript', 'onnx'), # include formats
|
|
458
|
-
half=False, # FP16 half-precision export
|
|
459
|
-
inplace=False, # set YOLOv5 Detect() inplace=True
|
|
460
|
-
train=False, # model.train() mode
|
|
461
|
-
keras=False, # use Keras
|
|
462
|
-
optimize=False, # TorchScript: optimize for mobile
|
|
463
|
-
int8=False, # CoreML/TF INT8 quantization
|
|
464
|
-
dynamic=False, # ONNX/TF: dynamic axes
|
|
465
|
-
simplify=False, # ONNX: simplify model
|
|
466
|
-
opset=12, # ONNX: opset version
|
|
467
|
-
verbose=False, # TensorRT: verbose log
|
|
468
|
-
workspace=4, # TensorRT: workspace size (GB)
|
|
469
|
-
nms=False, # TF: add NMS to model
|
|
470
|
-
agnostic_nms=False, # TF: add agnostic NMS to model
|
|
471
|
-
topk_per_class=100, # TF.js NMS: topk per class to keep
|
|
472
|
-
topk_all=100, # TF.js NMS: topk for all classes to keep
|
|
473
|
-
iou_thres=0.45, # TF.js NMS: IoU threshold
|
|
474
|
-
conf_thres=0.25, # TF.js NMS: confidence threshold
|
|
475
|
-
):
|
|
476
|
-
t = time.time()
|
|
477
|
-
include = [x.lower() for x in include] # to lowercase
|
|
478
|
-
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
|
|
479
|
-
flags = [x in include for x in fmts]
|
|
480
|
-
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
|
|
481
|
-
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
|
|
482
|
-
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
|
|
483
|
-
|
|
484
|
-
# Load PyTorch model
|
|
485
|
-
device = select_device(device)
|
|
486
|
-
if half:
|
|
487
|
-
assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0'
|
|
488
|
-
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
|
489
|
-
model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
|
|
490
|
-
nc, names = model.nc, model.names # number of classes, class names
|
|
491
|
-
|
|
492
|
-
# Checks
|
|
493
|
-
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
|
|
494
|
-
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'
|
|
495
|
-
|
|
496
|
-
# Input
|
|
497
|
-
gs = int(max(model.stride)) # grid size (max stride)
|
|
498
|
-
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
|
|
499
|
-
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
|
500
|
-
|
|
501
|
-
# Update model
|
|
502
|
-
if half and not coreml and not xml:
|
|
503
|
-
im, model = im.half(), model.half() # to FP16
|
|
504
|
-
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
|
|
505
|
-
for k, m in model.named_modules():
|
|
506
|
-
if isinstance(m, Detect):
|
|
507
|
-
m.inplace = inplace
|
|
508
|
-
m.onnx_dynamic = dynamic
|
|
509
|
-
m.export = True
|
|
510
|
-
|
|
511
|
-
for _ in range(2):
|
|
512
|
-
y = model(im) # dry runs
|
|
513
|
-
shape = tuple(y[0].shape) # model output shape
|
|
514
|
-
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
|
|
515
|
-
|
|
516
|
-
# Exports
|
|
517
|
-
f = [''] * 10 # exported filenames
|
|
518
|
-
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
|
519
|
-
if jit:
|
|
520
|
-
f[0] = export_torchscript(model, im, file, optimize)
|
|
521
|
-
if engine: # TensorRT required before ONNX
|
|
522
|
-
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
|
|
523
|
-
if onnx or xml: # OpenVINO requires ONNX
|
|
524
|
-
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
|
525
|
-
if xml: # OpenVINO
|
|
526
|
-
f[3] = export_openvino(model, file, half)
|
|
527
|
-
if coreml:
|
|
528
|
-
_, f[4] = export_coreml(model, im, file, int8, half)
|
|
529
|
-
|
|
530
|
-
# TensorFlow Exports
|
|
531
|
-
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
|
532
|
-
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
|
533
|
-
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
|
534
|
-
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
|
535
|
-
model, f[5] = export_saved_model(model.cpu(),
|
|
536
|
-
im,
|
|
537
|
-
file,
|
|
538
|
-
dynamic,
|
|
539
|
-
tf_nms=nms or agnostic_nms or tfjs,
|
|
540
|
-
agnostic_nms=agnostic_nms or tfjs,
|
|
541
|
-
topk_per_class=topk_per_class,
|
|
542
|
-
topk_all=topk_all,
|
|
543
|
-
iou_thres=iou_thres,
|
|
544
|
-
conf_thres=conf_thres,
|
|
545
|
-
keras=keras)
|
|
546
|
-
if pb or tfjs: # pb prerequisite to tfjs
|
|
547
|
-
f[6] = export_pb(model, file)
|
|
548
|
-
if tflite or edgetpu:
|
|
549
|
-
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
|
550
|
-
if edgetpu:
|
|
551
|
-
f[8] = export_edgetpu(file)
|
|
552
|
-
if tfjs:
|
|
553
|
-
f[9] = export_tfjs(file)
|
|
554
|
-
|
|
555
|
-
# Finish
|
|
556
|
-
f = [str(x) for x in f if x] # filter out '' and None
|
|
557
|
-
if any(f):
|
|
558
|
-
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
|
|
559
|
-
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
|
560
|
-
f"\nDetect: python detect.py --weights {f[-1]}"
|
|
561
|
-
f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
|
|
562
|
-
f"\nValidate: python val.py --weights {f[-1]}"
|
|
563
|
-
f"\nVisualize: https://netron.app")
|
|
564
|
-
return f # return list of exported files/dirs
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
def parse_opt():
|
|
568
|
-
parser = argparse.ArgumentParser()
|
|
569
|
-
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
|
570
|
-
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')
|
|
571
|
-
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
|
|
572
|
-
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
|
573
|
-
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
574
|
-
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
|
575
|
-
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
|
|
576
|
-
parser.add_argument('--train', action='store_true', help='model.train() mode')
|
|
577
|
-
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
|
|
578
|
-
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
|
|
579
|
-
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
|
|
580
|
-
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
|
|
581
|
-
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
|
582
|
-
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
|
|
583
|
-
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
|
|
584
|
-
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
|
|
585
|
-
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
|
|
586
|
-
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
|
|
587
|
-
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
|
|
588
|
-
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
|
|
589
|
-
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
|
|
590
|
-
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
|
591
|
-
parser.add_argument('--include',
|
|
592
|
-
nargs='+',
|
|
593
|
-
default=['torchscript', 'onnx'],
|
|
594
|
-
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
|
|
595
|
-
opt = parser.parse_args()
|
|
596
|
-
print_args(vars(opt))
|
|
597
|
-
return opt
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
def main(opt):
|
|
601
|
-
for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
|
|
602
|
-
run(**vars(opt))
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
if __name__ == "__main__":
|
|
606
|
-
opt = parse_opt()
|
|
607
|
-
main(opt)
|