ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
|
-
__version__ =
|
|
3
|
+
__version__ = "8.0.239"
|
|
4
4
|
|
|
5
5
|
from ultralytics.data.explorer.explorer import Explorer
|
|
6
6
|
from ultralytics.models import RTDETR, SAM, YOLO
|
|
@@ -10,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings
|
|
|
10
10
|
from ultralytics.utils.checks import check_yolo as checks
|
|
11
11
|
from ultralytics.utils.downloads import download
|
|
12
12
|
|
|
13
|
-
__all__ =
|
|
13
|
+
__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer"
|
ultralytics/cfg/__init__.py
CHANGED
|
@@ -8,34 +8,53 @@ from pathlib import Path
|
|
|
8
8
|
from types import SimpleNamespace
|
|
9
9
|
from typing import Dict, List, Union
|
|
10
10
|
|
|
11
|
-
from ultralytics.utils import (
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
from ultralytics.utils import (
|
|
12
|
+
ASSETS,
|
|
13
|
+
DEFAULT_CFG,
|
|
14
|
+
DEFAULT_CFG_DICT,
|
|
15
|
+
DEFAULT_CFG_PATH,
|
|
16
|
+
LOGGER,
|
|
17
|
+
RANK,
|
|
18
|
+
ROOT,
|
|
19
|
+
RUNS_DIR,
|
|
20
|
+
SETTINGS,
|
|
21
|
+
SETTINGS_YAML,
|
|
22
|
+
TESTS_RUNNING,
|
|
23
|
+
IterableSimpleNamespace,
|
|
24
|
+
__version__,
|
|
25
|
+
checks,
|
|
26
|
+
colorstr,
|
|
27
|
+
deprecation_warn,
|
|
28
|
+
yaml_load,
|
|
29
|
+
yaml_print,
|
|
30
|
+
)
|
|
14
31
|
|
|
15
32
|
# Define valid tasks and modes
|
|
16
|
-
MODES =
|
|
17
|
-
TASKS =
|
|
33
|
+
MODES = "train", "val", "predict", "export", "track", "benchmark"
|
|
34
|
+
TASKS = "detect", "segment", "classify", "pose", "obb"
|
|
18
35
|
TASK2DATA = {
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
36
|
+
"detect": "coco8.yaml",
|
|
37
|
+
"segment": "coco8-seg.yaml",
|
|
38
|
+
"classify": "imagenet10",
|
|
39
|
+
"pose": "coco8-pose.yaml",
|
|
40
|
+
"obb": "dota8-obb.yaml",
|
|
41
|
+
}
|
|
24
42
|
TASK2MODEL = {
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
43
|
+
"detect": "yolov8n.pt",
|
|
44
|
+
"segment": "yolov8n-seg.pt",
|
|
45
|
+
"classify": "yolov8n-cls.pt",
|
|
46
|
+
"pose": "yolov8n-pose.pt",
|
|
47
|
+
"obb": "yolov8n-obb.pt",
|
|
48
|
+
}
|
|
30
49
|
TASK2METRIC = {
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
50
|
+
"detect": "metrics/mAP50-95(B)",
|
|
51
|
+
"segment": "metrics/mAP50-95(M)",
|
|
52
|
+
"classify": "metrics/accuracy_top1",
|
|
53
|
+
"pose": "metrics/mAP50-95(P)",
|
|
54
|
+
"obb": "metrics/mAP50-95(OBB)",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
CLI_HELP_MSG = f"""
|
|
39
58
|
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
|
40
59
|
|
|
41
60
|
yolo TASK MODE ARGS
|
|
@@ -74,16 +93,83 @@ CLI_HELP_MSG = \
|
|
|
74
93
|
"""
|
|
75
94
|
|
|
76
95
|
# Define keys for arg type checks
|
|
77
|
-
CFG_FLOAT_KEYS =
|
|
78
|
-
CFG_FRACTION_KEYS = (
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
96
|
+
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
|
|
97
|
+
CFG_FRACTION_KEYS = (
|
|
98
|
+
"dropout",
|
|
99
|
+
"iou",
|
|
100
|
+
"lr0",
|
|
101
|
+
"lrf",
|
|
102
|
+
"momentum",
|
|
103
|
+
"weight_decay",
|
|
104
|
+
"warmup_momentum",
|
|
105
|
+
"warmup_bias_lr",
|
|
106
|
+
"label_smoothing",
|
|
107
|
+
"hsv_h",
|
|
108
|
+
"hsv_s",
|
|
109
|
+
"hsv_v",
|
|
110
|
+
"translate",
|
|
111
|
+
"scale",
|
|
112
|
+
"perspective",
|
|
113
|
+
"flipud",
|
|
114
|
+
"fliplr",
|
|
115
|
+
"mosaic",
|
|
116
|
+
"mixup",
|
|
117
|
+
"copy_paste",
|
|
118
|
+
"conf",
|
|
119
|
+
"iou",
|
|
120
|
+
"fraction",
|
|
121
|
+
) # fraction floats 0.0 - 1.0
|
|
122
|
+
CFG_INT_KEYS = (
|
|
123
|
+
"epochs",
|
|
124
|
+
"patience",
|
|
125
|
+
"batch",
|
|
126
|
+
"workers",
|
|
127
|
+
"seed",
|
|
128
|
+
"close_mosaic",
|
|
129
|
+
"mask_ratio",
|
|
130
|
+
"max_det",
|
|
131
|
+
"vid_stride",
|
|
132
|
+
"line_width",
|
|
133
|
+
"workspace",
|
|
134
|
+
"nbs",
|
|
135
|
+
"save_period",
|
|
136
|
+
)
|
|
137
|
+
CFG_BOOL_KEYS = (
|
|
138
|
+
"save",
|
|
139
|
+
"exist_ok",
|
|
140
|
+
"verbose",
|
|
141
|
+
"deterministic",
|
|
142
|
+
"single_cls",
|
|
143
|
+
"rect",
|
|
144
|
+
"cos_lr",
|
|
145
|
+
"overlap_mask",
|
|
146
|
+
"val",
|
|
147
|
+
"save_json",
|
|
148
|
+
"save_hybrid",
|
|
149
|
+
"half",
|
|
150
|
+
"dnn",
|
|
151
|
+
"plots",
|
|
152
|
+
"show",
|
|
153
|
+
"save_txt",
|
|
154
|
+
"save_conf",
|
|
155
|
+
"save_crop",
|
|
156
|
+
"save_frames",
|
|
157
|
+
"show_labels",
|
|
158
|
+
"show_conf",
|
|
159
|
+
"visualize",
|
|
160
|
+
"augment",
|
|
161
|
+
"agnostic_nms",
|
|
162
|
+
"retina_masks",
|
|
163
|
+
"show_boxes",
|
|
164
|
+
"keras",
|
|
165
|
+
"optimize",
|
|
166
|
+
"int8",
|
|
167
|
+
"dynamic",
|
|
168
|
+
"simplify",
|
|
169
|
+
"nms",
|
|
170
|
+
"profile",
|
|
171
|
+
"multi_scale",
|
|
172
|
+
)
|
|
87
173
|
|
|
88
174
|
|
|
89
175
|
def cfg2dict(cfg):
|
|
@@ -119,38 +205,44 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|
|
119
205
|
# Merge overrides
|
|
120
206
|
if overrides:
|
|
121
207
|
overrides = cfg2dict(overrides)
|
|
122
|
-
if
|
|
123
|
-
overrides.pop(
|
|
208
|
+
if "save_dir" not in cfg:
|
|
209
|
+
overrides.pop("save_dir", None) # special override keys to ignore
|
|
124
210
|
check_dict_alignment(cfg, overrides)
|
|
125
211
|
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
|
126
212
|
|
|
127
213
|
# Special handling for numeric project/name
|
|
128
|
-
for k in
|
|
214
|
+
for k in "project", "name":
|
|
129
215
|
if k in cfg and isinstance(cfg[k], (int, float)):
|
|
130
216
|
cfg[k] = str(cfg[k])
|
|
131
|
-
if cfg.get(
|
|
132
|
-
cfg[
|
|
217
|
+
if cfg.get("name") == "model": # assign model to 'name' arg
|
|
218
|
+
cfg["name"] = cfg.get("model", "").split(".")[0]
|
|
133
219
|
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
|
134
220
|
|
|
135
221
|
# Type and Value checks
|
|
136
222
|
for k, v in cfg.items():
|
|
137
223
|
if v is not None: # None values may be from optional args
|
|
138
224
|
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
|
139
|
-
raise TypeError(
|
|
140
|
-
|
|
225
|
+
raise TypeError(
|
|
226
|
+
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
|
227
|
+
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
|
228
|
+
)
|
|
141
229
|
elif k in CFG_FRACTION_KEYS:
|
|
142
230
|
if not isinstance(v, (int, float)):
|
|
143
|
-
raise TypeError(
|
|
144
|
-
|
|
231
|
+
raise TypeError(
|
|
232
|
+
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
|
233
|
+
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
|
234
|
+
)
|
|
145
235
|
if not (0.0 <= v <= 1.0):
|
|
146
|
-
raise ValueError(f"'{k}={v}' is an invalid value. "
|
|
147
|
-
f"Valid '{k}' values are between 0.0 and 1.0.")
|
|
236
|
+
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
|
|
148
237
|
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
|
149
|
-
raise TypeError(
|
|
150
|
-
|
|
238
|
+
raise TypeError(
|
|
239
|
+
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
|
240
|
+
)
|
|
151
241
|
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
|
152
|
-
raise TypeError(
|
|
153
|
-
|
|
242
|
+
raise TypeError(
|
|
243
|
+
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
|
244
|
+
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
|
|
245
|
+
)
|
|
154
246
|
|
|
155
247
|
# Return instance
|
|
156
248
|
return IterableSimpleNamespace(**cfg)
|
|
@@ -159,13 +251,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|
|
159
251
|
def get_save_dir(args, name=None):
|
|
160
252
|
"""Return save_dir as created from train/val/predict arguments."""
|
|
161
253
|
|
|
162
|
-
if getattr(args,
|
|
254
|
+
if getattr(args, "save_dir", None):
|
|
163
255
|
save_dir = args.save_dir
|
|
164
256
|
else:
|
|
165
257
|
from ultralytics.utils.files import increment_path
|
|
166
258
|
|
|
167
|
-
project = args.project or (ROOT.parent /
|
|
168
|
-
name = name or args.name or f
|
|
259
|
+
project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
|
|
260
|
+
name = name or args.name or f"{args.mode}"
|
|
169
261
|
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
|
170
262
|
|
|
171
263
|
return Path(save_dir)
|
|
@@ -175,18 +267,18 @@ def _handle_deprecation(custom):
|
|
|
175
267
|
"""Hardcoded function to handle deprecated config keys."""
|
|
176
268
|
|
|
177
269
|
for key in custom.copy().keys():
|
|
178
|
-
if key ==
|
|
179
|
-
deprecation_warn(key,
|
|
180
|
-
custom[
|
|
181
|
-
if key ==
|
|
182
|
-
deprecation_warn(key,
|
|
183
|
-
custom[
|
|
184
|
-
if key ==
|
|
185
|
-
deprecation_warn(key,
|
|
186
|
-
custom[
|
|
187
|
-
if key ==
|
|
188
|
-
deprecation_warn(key,
|
|
189
|
-
custom[
|
|
270
|
+
if key == "boxes":
|
|
271
|
+
deprecation_warn(key, "show_boxes")
|
|
272
|
+
custom["show_boxes"] = custom.pop("boxes")
|
|
273
|
+
if key == "hide_labels":
|
|
274
|
+
deprecation_warn(key, "show_labels")
|
|
275
|
+
custom["show_labels"] = custom.pop("hide_labels") == "False"
|
|
276
|
+
if key == "hide_conf":
|
|
277
|
+
deprecation_warn(key, "show_conf")
|
|
278
|
+
custom["show_conf"] = custom.pop("hide_conf") == "False"
|
|
279
|
+
if key == "line_thickness":
|
|
280
|
+
deprecation_warn(key, "line_width")
|
|
281
|
+
custom["line_width"] = custom.pop("line_thickness")
|
|
190
282
|
|
|
191
283
|
return custom
|
|
192
284
|
|
|
@@ -207,11 +299,11 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
|
|
207
299
|
if mismatched:
|
|
208
300
|
from difflib import get_close_matches
|
|
209
301
|
|
|
210
|
-
string =
|
|
302
|
+
string = ""
|
|
211
303
|
for x in mismatched:
|
|
212
304
|
matches = get_close_matches(x, base_keys) # key list
|
|
213
|
-
matches = [f
|
|
214
|
-
match_str = f
|
|
305
|
+
matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches]
|
|
306
|
+
match_str = f"Similar arguments are i.e. {matches}." if matches else ""
|
|
215
307
|
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
|
216
308
|
raise SyntaxError(string + CLI_HELP_MSG) from e
|
|
217
309
|
|
|
@@ -229,13 +321,13 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|
|
229
321
|
"""
|
|
230
322
|
new_args = []
|
|
231
323
|
for i, arg in enumerate(args):
|
|
232
|
-
if arg ==
|
|
233
|
-
new_args[-1] += f
|
|
324
|
+
if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
|
325
|
+
new_args[-1] += f"={args[i + 1]}"
|
|
234
326
|
del args[i + 1]
|
|
235
|
-
elif arg.endswith(
|
|
236
|
-
new_args.append(f
|
|
327
|
+
elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val']
|
|
328
|
+
new_args.append(f"{arg}{args[i + 1]}")
|
|
237
329
|
del args[i + 1]
|
|
238
|
-
elif arg.startswith(
|
|
330
|
+
elif arg.startswith("=") and i > 0: # merge ['arg', '=val']
|
|
239
331
|
new_args[-1] += arg
|
|
240
332
|
else:
|
|
241
333
|
new_args.append(arg)
|
|
@@ -259,11 +351,11 @@ def handle_yolo_hub(args: List[str]) -> None:
|
|
|
259
351
|
"""
|
|
260
352
|
from ultralytics import hub
|
|
261
353
|
|
|
262
|
-
if args[0] ==
|
|
263
|
-
key = args[1] if len(args) > 1 else
|
|
354
|
+
if args[0] == "login":
|
|
355
|
+
key = args[1] if len(args) > 1 else ""
|
|
264
356
|
# Log in to Ultralytics HUB using the provided API key
|
|
265
357
|
hub.login(key)
|
|
266
|
-
elif args[0] ==
|
|
358
|
+
elif args[0] == "logout":
|
|
267
359
|
# Log out from Ultralytics HUB
|
|
268
360
|
hub.logout()
|
|
269
361
|
|
|
@@ -283,19 +375,19 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|
|
283
375
|
python my_script.py yolo settings reset
|
|
284
376
|
```
|
|
285
377
|
"""
|
|
286
|
-
url =
|
|
378
|
+
url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL
|
|
287
379
|
try:
|
|
288
380
|
if any(args):
|
|
289
|
-
if args[0] ==
|
|
381
|
+
if args[0] == "reset":
|
|
290
382
|
SETTINGS_YAML.unlink() # delete the settings file
|
|
291
383
|
SETTINGS.reset() # create new settings
|
|
292
|
-
LOGGER.info(
|
|
384
|
+
LOGGER.info("Settings reset successfully") # inform the user that settings have been reset
|
|
293
385
|
else: # save a new setting
|
|
294
386
|
new = dict(parse_key_value_pair(a) for a in args)
|
|
295
387
|
check_dict_alignment(SETTINGS, new)
|
|
296
388
|
SETTINGS.update(new)
|
|
297
389
|
|
|
298
|
-
LOGGER.info(f
|
|
390
|
+
LOGGER.info(f"💡 Learn about settings at {url}")
|
|
299
391
|
yaml_print(SETTINGS_YAML) # print the current settings
|
|
300
392
|
except Exception as e:
|
|
301
393
|
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
|
|
@@ -303,13 +395,13 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|
|
303
395
|
|
|
304
396
|
def handle_explorer():
|
|
305
397
|
"""Open the Ultralytics Explorer GUI."""
|
|
306
|
-
checks.check_requirements(
|
|
307
|
-
subprocess.run([
|
|
398
|
+
checks.check_requirements("streamlit")
|
|
399
|
+
subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])
|
|
308
400
|
|
|
309
401
|
|
|
310
402
|
def parse_key_value_pair(pair):
|
|
311
403
|
"""Parse one 'key=value' pair and return key and value."""
|
|
312
|
-
k, v = pair.split(
|
|
404
|
+
k, v = pair.split("=", 1) # split on first '=' sign
|
|
313
405
|
k, v = k.strip(), v.strip() # remove spaces
|
|
314
406
|
assert v, f"missing '{k}' value"
|
|
315
407
|
return k, smart_value(v)
|
|
@@ -318,11 +410,11 @@ def parse_key_value_pair(pair):
|
|
|
318
410
|
def smart_value(v):
|
|
319
411
|
"""Convert a string to an underlying type such as int, float, bool, etc."""
|
|
320
412
|
v_lower = v.lower()
|
|
321
|
-
if v_lower ==
|
|
413
|
+
if v_lower == "none":
|
|
322
414
|
return None
|
|
323
|
-
elif v_lower ==
|
|
415
|
+
elif v_lower == "true":
|
|
324
416
|
return True
|
|
325
|
-
elif v_lower ==
|
|
417
|
+
elif v_lower == "false":
|
|
326
418
|
return False
|
|
327
419
|
else:
|
|
328
420
|
with contextlib.suppress(Exception):
|
|
@@ -330,7 +422,7 @@ def smart_value(v):
|
|
|
330
422
|
return v
|
|
331
423
|
|
|
332
424
|
|
|
333
|
-
def entrypoint(debug=
|
|
425
|
+
def entrypoint(debug=""):
|
|
334
426
|
"""
|
|
335
427
|
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
|
336
428
|
to the package.
|
|
@@ -345,139 +437,150 @@ def entrypoint(debug=''):
|
|
|
345
437
|
It uses the package's default cfg and initializes it using the passed overrides.
|
|
346
438
|
Then it calls the CLI function with the composed cfg
|
|
347
439
|
"""
|
|
348
|
-
args = (debug.split(
|
|
440
|
+
args = (debug.split(" ") if debug else sys.argv)[1:]
|
|
349
441
|
if not args: # no arguments passed
|
|
350
442
|
LOGGER.info(CLI_HELP_MSG)
|
|
351
443
|
return
|
|
352
444
|
|
|
353
445
|
special = {
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
446
|
+
"help": lambda: LOGGER.info(CLI_HELP_MSG),
|
|
447
|
+
"checks": checks.collect_system_info,
|
|
448
|
+
"version": lambda: LOGGER.info(__version__),
|
|
449
|
+
"settings": lambda: handle_yolo_settings(args[1:]),
|
|
450
|
+
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
|
|
451
|
+
"hub": lambda: handle_yolo_hub(args[1:]),
|
|
452
|
+
"login": lambda: handle_yolo_hub(args),
|
|
453
|
+
"copy-cfg": copy_default_cfg,
|
|
454
|
+
"explorer": lambda: handle_explorer(),
|
|
455
|
+
}
|
|
363
456
|
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
|
364
457
|
|
|
365
458
|
# Define common misuses of special commands, i.e. -h, -help, --help
|
|
366
459
|
special.update({k[0]: v for k, v in special.items()}) # singular
|
|
367
|
-
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith(
|
|
368
|
-
special = {**special, **{f
|
|
460
|
+
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular
|
|
461
|
+
special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}}
|
|
369
462
|
|
|
370
463
|
overrides = {} # basic overrides, i.e. imgsz=320
|
|
371
464
|
for a in merge_equals_args(args): # merge spaces around '=' sign
|
|
372
|
-
if a.startswith(
|
|
465
|
+
if a.startswith("--"):
|
|
373
466
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
|
374
467
|
a = a[2:]
|
|
375
|
-
if a.endswith(
|
|
468
|
+
if a.endswith(","):
|
|
376
469
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
|
377
470
|
a = a[:-1]
|
|
378
|
-
if
|
|
471
|
+
if "=" in a:
|
|
379
472
|
try:
|
|
380
473
|
k, v = parse_key_value_pair(a)
|
|
381
|
-
if k ==
|
|
382
|
-
LOGGER.info(f
|
|
383
|
-
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k !=
|
|
474
|
+
if k == "cfg" and v is not None: # custom.yaml passed
|
|
475
|
+
LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
|
|
476
|
+
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"}
|
|
384
477
|
else:
|
|
385
478
|
overrides[k] = v
|
|
386
479
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
|
387
|
-
check_dict_alignment(full_args_dict, {a:
|
|
480
|
+
check_dict_alignment(full_args_dict, {a: ""}, e)
|
|
388
481
|
|
|
389
482
|
elif a in TASKS:
|
|
390
|
-
overrides[
|
|
483
|
+
overrides["task"] = a
|
|
391
484
|
elif a in MODES:
|
|
392
|
-
overrides[
|
|
485
|
+
overrides["mode"] = a
|
|
393
486
|
elif a.lower() in special:
|
|
394
487
|
special[a.lower()]()
|
|
395
488
|
return
|
|
396
489
|
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
|
397
490
|
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
|
398
491
|
elif a in DEFAULT_CFG_DICT:
|
|
399
|
-
raise SyntaxError(
|
|
400
|
-
|
|
492
|
+
raise SyntaxError(
|
|
493
|
+
f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
|
494
|
+
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}"
|
|
495
|
+
)
|
|
401
496
|
else:
|
|
402
|
-
check_dict_alignment(full_args_dict, {a:
|
|
497
|
+
check_dict_alignment(full_args_dict, {a: ""})
|
|
403
498
|
|
|
404
499
|
# Check keys
|
|
405
500
|
check_dict_alignment(full_args_dict, overrides)
|
|
406
501
|
|
|
407
502
|
# Mode
|
|
408
|
-
mode = overrides.get(
|
|
503
|
+
mode = overrides.get("mode")
|
|
409
504
|
if mode is None:
|
|
410
|
-
mode = DEFAULT_CFG.mode or
|
|
505
|
+
mode = DEFAULT_CFG.mode or "predict"
|
|
411
506
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
|
412
507
|
elif mode not in MODES:
|
|
413
508
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
|
414
509
|
|
|
415
510
|
# Task
|
|
416
|
-
task = overrides.pop(
|
|
511
|
+
task = overrides.pop("task", None)
|
|
417
512
|
if task:
|
|
418
513
|
if task not in TASKS:
|
|
419
514
|
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
|
420
|
-
if
|
|
421
|
-
overrides[
|
|
515
|
+
if "model" not in overrides:
|
|
516
|
+
overrides["model"] = TASK2MODEL[task]
|
|
422
517
|
|
|
423
518
|
# Model
|
|
424
|
-
model = overrides.pop(
|
|
519
|
+
model = overrides.pop("model", DEFAULT_CFG.model)
|
|
425
520
|
if model is None:
|
|
426
|
-
model =
|
|
521
|
+
model = "yolov8n.pt"
|
|
427
522
|
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
|
428
|
-
overrides[
|
|
523
|
+
overrides["model"] = model
|
|
429
524
|
stem = Path(model).stem.lower()
|
|
430
|
-
if
|
|
525
|
+
if "rtdetr" in stem: # guess architecture
|
|
431
526
|
from ultralytics import RTDETR
|
|
527
|
+
|
|
432
528
|
model = RTDETR(model) # no task argument
|
|
433
|
-
elif
|
|
529
|
+
elif "fastsam" in stem:
|
|
434
530
|
from ultralytics import FastSAM
|
|
531
|
+
|
|
435
532
|
model = FastSAM(model)
|
|
436
|
-
elif
|
|
533
|
+
elif "sam" in stem:
|
|
437
534
|
from ultralytics import SAM
|
|
535
|
+
|
|
438
536
|
model = SAM(model)
|
|
439
537
|
else:
|
|
440
538
|
from ultralytics import YOLO
|
|
539
|
+
|
|
441
540
|
model = YOLO(model, task=task)
|
|
442
|
-
if isinstance(overrides.get(
|
|
443
|
-
model.load(overrides[
|
|
541
|
+
if isinstance(overrides.get("pretrained"), str):
|
|
542
|
+
model.load(overrides["pretrained"])
|
|
444
543
|
|
|
445
544
|
# Task Update
|
|
446
545
|
if task != model.task:
|
|
447
546
|
if task:
|
|
448
|
-
LOGGER.warning(
|
|
449
|
-
|
|
547
|
+
LOGGER.warning(
|
|
548
|
+
f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
|
549
|
+
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model."
|
|
550
|
+
)
|
|
450
551
|
task = model.task
|
|
451
552
|
|
|
452
553
|
# Mode
|
|
453
|
-
if mode in (
|
|
454
|
-
overrides[
|
|
554
|
+
if mode in ("predict", "track") and "source" not in overrides:
|
|
555
|
+
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
|
455
556
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
|
456
|
-
elif mode in (
|
|
457
|
-
if
|
|
458
|
-
overrides[
|
|
557
|
+
elif mode in ("train", "val"):
|
|
558
|
+
if "data" not in overrides and "resume" not in overrides:
|
|
559
|
+
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
|
459
560
|
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
|
460
|
-
elif mode ==
|
|
461
|
-
if
|
|
462
|
-
overrides[
|
|
561
|
+
elif mode == "export":
|
|
562
|
+
if "format" not in overrides:
|
|
563
|
+
overrides["format"] = DEFAULT_CFG.format or "torchscript"
|
|
463
564
|
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
|
464
565
|
|
|
465
566
|
# Run command in python
|
|
466
567
|
getattr(model, mode)(**overrides) # default args from model
|
|
467
568
|
|
|
468
569
|
# Show help
|
|
469
|
-
LOGGER.info(f
|
|
570
|
+
LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}")
|
|
470
571
|
|
|
471
572
|
|
|
472
573
|
# Special modes --------------------------------------------------------------------------------------------------------
|
|
473
574
|
def copy_default_cfg():
|
|
474
575
|
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
|
475
|
-
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(
|
|
576
|
+
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")
|
|
476
577
|
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
|
477
|
-
LOGGER.info(
|
|
478
|
-
|
|
578
|
+
LOGGER.info(
|
|
579
|
+
f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
|
580
|
+
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8"
|
|
581
|
+
)
|
|
479
582
|
|
|
480
583
|
|
|
481
|
-
if __name__ ==
|
|
584
|
+
if __name__ == "__main__":
|
|
482
585
|
# Example: entrypoint(debug='yolo predict model=yolov8n.pt')
|
|
483
|
-
entrypoint(debug=
|
|
586
|
+
entrypoint(debug="")
|
ultralytics/data/__init__.py
CHANGED
|
@@ -4,5 +4,12 @@ from .base import BaseDataset
|
|
|
4
4
|
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
|
5
5
|
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
|
6
6
|
|
|
7
|
-
__all__ = (
|
|
8
|
-
|
|
7
|
+
__all__ = (
|
|
8
|
+
"BaseDataset",
|
|
9
|
+
"ClassificationDataset",
|
|
10
|
+
"SemanticDataset",
|
|
11
|
+
"YOLODataset",
|
|
12
|
+
"build_yolo_dataset",
|
|
13
|
+
"build_dataloader",
|
|
14
|
+
"load_inference_source",
|
|
15
|
+
)
|