ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/engine/model.py
CHANGED
|
@@ -5,10 +5,11 @@ import sys
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Union
|
|
7
7
|
|
|
8
|
+
from hub_sdk.config import HUB_WEB_ROOT
|
|
9
|
+
|
|
8
10
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
|
9
|
-
from ultralytics.hub.utils import HUB_WEB_ROOT
|
|
10
11
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
|
11
|
-
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
|
|
12
|
+
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class Model(nn.Module):
|
|
@@ -52,7 +53,7 @@ class Model(nn.Module):
|
|
|
52
53
|
list(ultralytics.engine.results.Results): The prediction results.
|
|
53
54
|
"""
|
|
54
55
|
|
|
55
|
-
def __init__(self, model: Union[str, Path] =
|
|
56
|
+
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None:
|
|
56
57
|
"""
|
|
57
58
|
Initializes the YOLO model.
|
|
58
59
|
|
|
@@ -76,8 +77,8 @@ class Model(nn.Module):
|
|
|
76
77
|
|
|
77
78
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
|
78
79
|
if self.is_hub_model(model):
|
|
79
|
-
|
|
80
|
-
self.session =
|
|
80
|
+
# Fetch model from HUB
|
|
81
|
+
self.session = self._get_hub_session(model)
|
|
81
82
|
model = self.session.model_file
|
|
82
83
|
|
|
83
84
|
# Check if Triton Server model
|
|
@@ -88,29 +89,43 @@ class Model(nn.Module):
|
|
|
88
89
|
|
|
89
90
|
# Load or create new YOLO model
|
|
90
91
|
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
|
91
|
-
if Path(model).suffix in (
|
|
92
|
+
if Path(model).suffix in (".yaml", ".yml"):
|
|
92
93
|
self._new(model, task)
|
|
93
94
|
else:
|
|
94
95
|
self._load(model, task)
|
|
95
96
|
|
|
97
|
+
self.model_name = model
|
|
98
|
+
|
|
96
99
|
def __call__(self, source=None, stream=False, **kwargs):
|
|
97
100
|
"""Calls the predict() method with given arguments to perform object detection."""
|
|
98
101
|
return self.predict(source, stream, **kwargs)
|
|
99
102
|
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _get_hub_session(model: str):
|
|
105
|
+
"""Creates a session for Hub Training."""
|
|
106
|
+
from ultralytics.hub.session import HUBTrainingSession
|
|
107
|
+
|
|
108
|
+
session = HUBTrainingSession(model)
|
|
109
|
+
return session if session.client.authenticated else None
|
|
110
|
+
|
|
100
111
|
@staticmethod
|
|
101
112
|
def is_triton_model(model):
|
|
102
113
|
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
|
103
114
|
from urllib.parse import urlsplit
|
|
115
|
+
|
|
104
116
|
url = urlsplit(model)
|
|
105
|
-
return url.netloc and url.path and url.scheme in {
|
|
117
|
+
return url.netloc and url.path and url.scheme in {"http", "grpc"}
|
|
106
118
|
|
|
107
119
|
@staticmethod
|
|
108
120
|
def is_hub_model(model):
|
|
109
121
|
"""Check if the provided model is a HUB model."""
|
|
110
|
-
return any(
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
122
|
+
return any(
|
|
123
|
+
(
|
|
124
|
+
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
|
125
|
+
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
|
|
126
|
+
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"),
|
|
127
|
+
)
|
|
128
|
+
) # MODELID
|
|
114
129
|
|
|
115
130
|
def _new(self, cfg: str, task=None, model=None, verbose=True):
|
|
116
131
|
"""
|
|
@@ -125,9 +140,9 @@ class Model(nn.Module):
|
|
|
125
140
|
cfg_dict = yaml_model_load(cfg)
|
|
126
141
|
self.cfg = cfg
|
|
127
142
|
self.task = task or guess_model_task(cfg_dict)
|
|
128
|
-
self.model = (model or self._smart_load(
|
|
129
|
-
self.overrides[
|
|
130
|
-
self.overrides[
|
|
143
|
+
self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
|
144
|
+
self.overrides["model"] = self.cfg
|
|
145
|
+
self.overrides["task"] = self.task
|
|
131
146
|
|
|
132
147
|
# Below added to allow export from YAMLs
|
|
133
148
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
|
@@ -142,9 +157,9 @@ class Model(nn.Module):
|
|
|
142
157
|
task (str | None): model task
|
|
143
158
|
"""
|
|
144
159
|
suffix = Path(weights).suffix
|
|
145
|
-
if suffix ==
|
|
160
|
+
if suffix == ".pt":
|
|
146
161
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
|
147
|
-
self.task = self.model.args[
|
|
162
|
+
self.task = self.model.args["task"]
|
|
148
163
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
|
149
164
|
self.ckpt_path = self.model.pt_path
|
|
150
165
|
else:
|
|
@@ -152,12 +167,12 @@ class Model(nn.Module):
|
|
|
152
167
|
self.model, self.ckpt = weights, None
|
|
153
168
|
self.task = task or guess_model_task(weights)
|
|
154
169
|
self.ckpt_path = weights
|
|
155
|
-
self.overrides[
|
|
156
|
-
self.overrides[
|
|
170
|
+
self.overrides["model"] = weights
|
|
171
|
+
self.overrides["task"] = self.task
|
|
157
172
|
|
|
158
173
|
def _check_is_pytorch_model(self):
|
|
159
174
|
"""Raises TypeError is model is not a PyTorch model."""
|
|
160
|
-
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix ==
|
|
175
|
+
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
|
161
176
|
pt_module = isinstance(self.model, nn.Module)
|
|
162
177
|
if not (pt_module or pt_str):
|
|
163
178
|
raise TypeError(
|
|
@@ -165,19 +180,20 @@ class Model(nn.Module):
|
|
|
165
180
|
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
|
|
166
181
|
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
|
|
167
182
|
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
|
|
168
|
-
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
|
183
|
+
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
|
184
|
+
)
|
|
169
185
|
|
|
170
186
|
def reset_weights(self):
|
|
171
187
|
"""Resets the model modules parameters to randomly initialized values, losing all training information."""
|
|
172
188
|
self._check_is_pytorch_model()
|
|
173
189
|
for m in self.model.modules():
|
|
174
|
-
if hasattr(m,
|
|
190
|
+
if hasattr(m, "reset_parameters"):
|
|
175
191
|
m.reset_parameters()
|
|
176
192
|
for p in self.model.parameters():
|
|
177
193
|
p.requires_grad = True
|
|
178
194
|
return self
|
|
179
195
|
|
|
180
|
-
def load(self, weights=
|
|
196
|
+
def load(self, weights="yolov8n.pt"):
|
|
181
197
|
"""Transfers parameters with matching names and shapes from 'weights' to model."""
|
|
182
198
|
self._check_is_pytorch_model()
|
|
183
199
|
if isinstance(weights, (str, Path)):
|
|
@@ -215,8 +231,8 @@ class Model(nn.Module):
|
|
|
215
231
|
Returns:
|
|
216
232
|
(List[torch.Tensor]): A list of image embeddings.
|
|
217
233
|
"""
|
|
218
|
-
if not kwargs.get(
|
|
219
|
-
kwargs[
|
|
234
|
+
if not kwargs.get("embed"):
|
|
235
|
+
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
|
220
236
|
return self.predict(source, stream, **kwargs)
|
|
221
237
|
|
|
222
238
|
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
|
@@ -238,21 +254,22 @@ class Model(nn.Module):
|
|
|
238
254
|
source = ASSETS
|
|
239
255
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
|
240
256
|
|
|
241
|
-
is_cli = (sys.argv[0].endswith(
|
|
242
|
-
x in sys.argv for x in (
|
|
257
|
+
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
|
|
258
|
+
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
|
|
259
|
+
)
|
|
243
260
|
|
|
244
|
-
custom = {
|
|
245
|
-
args = {**self.overrides, **custom, **kwargs,
|
|
246
|
-
prompts = args.pop(
|
|
261
|
+
custom = {"conf": 0.25, "save": is_cli} # method defaults
|
|
262
|
+
args = {**self.overrides, **custom, **kwargs, "mode": "predict"} # highest priority args on the right
|
|
263
|
+
prompts = args.pop("prompts", None) # for SAM-type models
|
|
247
264
|
|
|
248
265
|
if not self.predictor:
|
|
249
|
-
self.predictor = predictor or self._smart_load(
|
|
266
|
+
self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
|
|
250
267
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
|
251
268
|
else: # only update args if predictor is already setup
|
|
252
269
|
self.predictor.args = get_cfg(self.predictor.args, args)
|
|
253
|
-
if
|
|
270
|
+
if "project" in args or "name" in args:
|
|
254
271
|
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
|
255
|
-
if prompts and hasattr(self.predictor,
|
|
272
|
+
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
|
256
273
|
self.predictor.set_prompts(prompts)
|
|
257
274
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
|
258
275
|
|
|
@@ -269,11 +286,12 @@ class Model(nn.Module):
|
|
|
269
286
|
Returns:
|
|
270
287
|
(List[ultralytics.engine.results.Results]): The tracking results.
|
|
271
288
|
"""
|
|
272
|
-
if not hasattr(self.predictor,
|
|
289
|
+
if not hasattr(self.predictor, "trackers"):
|
|
273
290
|
from ultralytics.trackers import register_tracker
|
|
291
|
+
|
|
274
292
|
register_tracker(self, persist)
|
|
275
|
-
kwargs[
|
|
276
|
-
kwargs[
|
|
293
|
+
kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
|
|
294
|
+
kwargs["mode"] = "track"
|
|
277
295
|
return self.predict(source=source, stream=stream, **kwargs)
|
|
278
296
|
|
|
279
297
|
def val(self, validator=None, **kwargs):
|
|
@@ -284,10 +302,10 @@ class Model(nn.Module):
|
|
|
284
302
|
validator (BaseValidator): Customized validator.
|
|
285
303
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
|
286
304
|
"""
|
|
287
|
-
custom = {
|
|
288
|
-
args = {**self.overrides, **custom, **kwargs,
|
|
305
|
+
custom = {"rect": True} # method defaults
|
|
306
|
+
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
|
289
307
|
|
|
290
|
-
validator = (validator or self._smart_load(
|
|
308
|
+
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
|
291
309
|
validator(model=self.model)
|
|
292
310
|
self.metrics = validator.metrics
|
|
293
311
|
return validator.metrics
|
|
@@ -302,16 +320,17 @@ class Model(nn.Module):
|
|
|
302
320
|
self._check_is_pytorch_model()
|
|
303
321
|
from ultralytics.utils.benchmarks import benchmark
|
|
304
322
|
|
|
305
|
-
custom = {
|
|
306
|
-
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs,
|
|
323
|
+
custom = {"verbose": False} # method defaults
|
|
324
|
+
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
|
307
325
|
return benchmark(
|
|
308
326
|
model=self,
|
|
309
|
-
data=kwargs.get(
|
|
310
|
-
imgsz=args[
|
|
311
|
-
half=args[
|
|
312
|
-
int8=args[
|
|
313
|
-
device=args[
|
|
314
|
-
verbose=kwargs.get(
|
|
327
|
+
data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
|
|
328
|
+
imgsz=args["imgsz"],
|
|
329
|
+
half=args["half"],
|
|
330
|
+
int8=args["int8"],
|
|
331
|
+
device=args["device"],
|
|
332
|
+
verbose=kwargs.get("verbose"),
|
|
333
|
+
)
|
|
315
334
|
|
|
316
335
|
def export(self, **kwargs):
|
|
317
336
|
"""
|
|
@@ -323,8 +342,8 @@ class Model(nn.Module):
|
|
|
323
342
|
self._check_is_pytorch_model()
|
|
324
343
|
from .exporter import Exporter
|
|
325
344
|
|
|
326
|
-
custom = {
|
|
327
|
-
args = {**self.overrides, **custom, **kwargs,
|
|
345
|
+
custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
|
|
346
|
+
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
|
|
328
347
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
|
329
348
|
|
|
330
349
|
def train(self, trainer=None, **kwargs):
|
|
@@ -336,22 +355,37 @@ class Model(nn.Module):
|
|
|
336
355
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
|
337
356
|
"""
|
|
338
357
|
self._check_is_pytorch_model()
|
|
339
|
-
if self.session: # Ultralytics HUB session
|
|
358
|
+
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
|
340
359
|
if any(kwargs):
|
|
341
|
-
LOGGER.warning(
|
|
342
|
-
kwargs = self.session.train_args
|
|
360
|
+
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
|
|
361
|
+
kwargs = self.session.train_args # overwrite kwargs
|
|
362
|
+
|
|
343
363
|
checks.check_pip_update_available()
|
|
344
364
|
|
|
345
|
-
overrides = yaml_load(checks.check_yaml(kwargs[
|
|
346
|
-
custom = {
|
|
347
|
-
args = {**overrides, **custom, **kwargs,
|
|
348
|
-
if args.get(
|
|
349
|
-
args[
|
|
365
|
+
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
|
366
|
+
custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
|
|
367
|
+
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
|
368
|
+
if args.get("resume"):
|
|
369
|
+
args["resume"] = self.ckpt_path
|
|
350
370
|
|
|
351
|
-
self.trainer = (trainer or self._smart_load(
|
|
352
|
-
if not args.get(
|
|
371
|
+
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
|
|
372
|
+
if not args.get("resume"): # manually set model only if not resuming
|
|
353
373
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
|
354
374
|
self.model = self.trainer.model
|
|
375
|
+
|
|
376
|
+
if SETTINGS["hub"] is True and not self.session:
|
|
377
|
+
# Create a model in HUB
|
|
378
|
+
try:
|
|
379
|
+
self.session = self._get_hub_session(self.model_name)
|
|
380
|
+
if self.session:
|
|
381
|
+
self.session.create_model(args)
|
|
382
|
+
# Check model was created
|
|
383
|
+
if not getattr(self.session.model, "id", None):
|
|
384
|
+
self.session = None
|
|
385
|
+
except PermissionError:
|
|
386
|
+
# Ignore permission error
|
|
387
|
+
pass
|
|
388
|
+
|
|
355
389
|
self.trainer.hub_session = self.session # attach optional HUB session
|
|
356
390
|
self.trainer.train()
|
|
357
391
|
# Update model and cfg after training
|
|
@@ -359,7 +393,7 @@ class Model(nn.Module):
|
|
|
359
393
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
|
360
394
|
self.model, _ = attempt_load_one_weight(ckpt)
|
|
361
395
|
self.overrides = self.model.args
|
|
362
|
-
self.metrics = getattr(self.trainer.validator,
|
|
396
|
+
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
|
363
397
|
return self.metrics
|
|
364
398
|
|
|
365
399
|
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
|
|
@@ -372,12 +406,13 @@ class Model(nn.Module):
|
|
|
372
406
|
self._check_is_pytorch_model()
|
|
373
407
|
if use_ray:
|
|
374
408
|
from ultralytics.utils.tuner import run_ray_tune
|
|
409
|
+
|
|
375
410
|
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
|
|
376
411
|
else:
|
|
377
412
|
from .tuner import Tuner
|
|
378
413
|
|
|
379
414
|
custom = {} # method defaults
|
|
380
|
-
args = {**self.overrides, **custom, **kwargs,
|
|
415
|
+
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
|
381
416
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
|
382
417
|
|
|
383
418
|
def _apply(self, fn):
|
|
@@ -385,13 +420,13 @@ class Model(nn.Module):
|
|
|
385
420
|
self._check_is_pytorch_model()
|
|
386
421
|
self = super()._apply(fn) # noqa
|
|
387
422
|
self.predictor = None # reset predictor as device may have changed
|
|
388
|
-
self.overrides[
|
|
423
|
+
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
|
389
424
|
return self
|
|
390
425
|
|
|
391
426
|
@property
|
|
392
427
|
def names(self):
|
|
393
428
|
"""Returns class names of the loaded model."""
|
|
394
|
-
return self.model.names if hasattr(self.model,
|
|
429
|
+
return self.model.names if hasattr(self.model, "names") else None
|
|
395
430
|
|
|
396
431
|
@property
|
|
397
432
|
def device(self):
|
|
@@ -401,7 +436,7 @@ class Model(nn.Module):
|
|
|
401
436
|
@property
|
|
402
437
|
def transforms(self):
|
|
403
438
|
"""Returns transform of the loaded model."""
|
|
404
|
-
return self.model.transforms if hasattr(self.model,
|
|
439
|
+
return self.model.transforms if hasattr(self.model, "transforms") else None
|
|
405
440
|
|
|
406
441
|
def add_callback(self, event: str, func):
|
|
407
442
|
"""Add a callback."""
|
|
@@ -419,7 +454,7 @@ class Model(nn.Module):
|
|
|
419
454
|
@staticmethod
|
|
420
455
|
def _reset_ckpt_args(args):
|
|
421
456
|
"""Reset arguments when loading a PyTorch model."""
|
|
422
|
-
include = {
|
|
457
|
+
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
|
423
458
|
return {k: v for k, v in args.items() if k in include}
|
|
424
459
|
|
|
425
460
|
# def __getattr__(self, attr):
|
|
@@ -435,7 +470,8 @@ class Model(nn.Module):
|
|
|
435
470
|
name = self.__class__.__name__
|
|
436
471
|
mode = inspect.stack()[1][3] # get the function name.
|
|
437
472
|
raise NotImplementedError(
|
|
438
|
-
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
|
|
473
|
+
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
|
|
474
|
+
) from e
|
|
439
475
|
|
|
440
476
|
@property
|
|
441
477
|
def task_map(self):
|
|
@@ -445,4 +481,4 @@ class Model(nn.Module):
|
|
|
445
481
|
Returns:
|
|
446
482
|
task_map (dict): The map of model task to mode classes.
|
|
447
483
|
"""
|
|
448
|
-
raise NotImplementedError(
|
|
484
|
+
raise NotImplementedError("Please provide task map for your model!")
|
ultralytics/engine/predictor.py
CHANGED
|
@@ -132,8 +132,11 @@ class BasePredictor:
|
|
|
132
132
|
|
|
133
133
|
def inference(self, im, *args, **kwargs):
|
|
134
134
|
"""Runs inference on a given image using the specified model and arguments."""
|
|
135
|
-
visualize =
|
|
136
|
-
|
|
135
|
+
visualize = (
|
|
136
|
+
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
|
137
|
+
if self.args.visualize and (not self.source_type.tensor)
|
|
138
|
+
else False
|
|
139
|
+
)
|
|
137
140
|
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
|
138
141
|
|
|
139
142
|
def pre_transform(self, im):
|
|
@@ -153,35 +156,38 @@ class BasePredictor:
|
|
|
153
156
|
def write_results(self, idx, results, batch):
|
|
154
157
|
"""Write inference results to a file or directory."""
|
|
155
158
|
p, im, _ = batch
|
|
156
|
-
log_string =
|
|
159
|
+
log_string = ""
|
|
157
160
|
if len(im.shape) == 3:
|
|
158
161
|
im = im[None] # expand for batch dim
|
|
159
162
|
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
|
160
|
-
log_string += f
|
|
163
|
+
log_string += f"{idx}: "
|
|
161
164
|
frame = self.dataset.count
|
|
162
165
|
else:
|
|
163
|
-
frame = getattr(self.dataset,
|
|
166
|
+
frame = getattr(self.dataset, "frame", 0)
|
|
164
167
|
self.data_path = p
|
|
165
|
-
self.txt_path = str(self.save_dir /
|
|
166
|
-
log_string +=
|
|
168
|
+
self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
|
|
169
|
+
log_string += "%gx%g " % im.shape[2:] # print string
|
|
167
170
|
result = results[idx]
|
|
168
171
|
log_string += result.verbose()
|
|
169
172
|
|
|
170
173
|
if self.args.save or self.args.show: # Add bbox to image
|
|
171
174
|
plot_args = {
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
175
|
+
"line_width": self.args.line_width,
|
|
176
|
+
"boxes": self.args.show_boxes,
|
|
177
|
+
"conf": self.args.show_conf,
|
|
178
|
+
"labels": self.args.show_labels,
|
|
179
|
+
}
|
|
176
180
|
if not self.args.retina_masks:
|
|
177
|
-
plot_args[
|
|
181
|
+
plot_args["im_gpu"] = im[idx]
|
|
178
182
|
self.plotted_img = result.plot(**plot_args)
|
|
179
183
|
# Write
|
|
180
184
|
if self.args.save_txt:
|
|
181
|
-
result.save_txt(f
|
|
185
|
+
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
|
|
182
186
|
if self.args.save_crop:
|
|
183
|
-
result.save_crop(
|
|
184
|
-
|
|
187
|
+
result.save_crop(
|
|
188
|
+
save_dir=self.save_dir / "crops",
|
|
189
|
+
file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
|
|
190
|
+
)
|
|
185
191
|
|
|
186
192
|
return log_string
|
|
187
193
|
|
|
@@ -210,17 +216,24 @@ class BasePredictor:
|
|
|
210
216
|
def setup_source(self, source):
|
|
211
217
|
"""Sets up source and inference mode."""
|
|
212
218
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
|
213
|
-
self.transforms =
|
|
214
|
-
|
|
215
|
-
self.
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
219
|
+
self.transforms = (
|
|
220
|
+
getattr(
|
|
221
|
+
self.model.model,
|
|
222
|
+
"transforms",
|
|
223
|
+
classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
|
|
224
|
+
)
|
|
225
|
+
if self.args.task == "classify"
|
|
226
|
+
else None
|
|
227
|
+
)
|
|
228
|
+
self.dataset = load_inference_source(
|
|
229
|
+
source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
|
|
230
|
+
)
|
|
220
231
|
self.source_type = self.dataset.source_type
|
|
221
|
-
if not getattr(self,
|
|
222
|
-
|
|
223
|
-
|
|
232
|
+
if not getattr(self, "stream", True) and (
|
|
233
|
+
self.dataset.mode == "stream" # streams
|
|
234
|
+
or len(self.dataset) > 1000 # images
|
|
235
|
+
or any(getattr(self.dataset, "video_flag", [False]))
|
|
236
|
+
): # videos
|
|
224
237
|
LOGGER.warning(STREAM_WARNING)
|
|
225
238
|
self.vid_path = [None] * self.dataset.bs
|
|
226
239
|
self.vid_writer = [None] * self.dataset.bs
|
|
@@ -230,7 +243,7 @@ class BasePredictor:
|
|
|
230
243
|
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
|
231
244
|
"""Streams real-time inference on camera feed and saves results to file."""
|
|
232
245
|
if self.args.verbose:
|
|
233
|
-
LOGGER.info(
|
|
246
|
+
LOGGER.info("")
|
|
234
247
|
|
|
235
248
|
# Setup model
|
|
236
249
|
if not self.model:
|
|
@@ -242,7 +255,7 @@ class BasePredictor:
|
|
|
242
255
|
|
|
243
256
|
# Check if save_dir/ label file exists
|
|
244
257
|
if self.args.save or self.args.save_txt:
|
|
245
|
-
(self.save_dir /
|
|
258
|
+
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
|
246
259
|
|
|
247
260
|
# Warmup model
|
|
248
261
|
if not self.done_warmup:
|
|
@@ -250,10 +263,10 @@ class BasePredictor:
|
|
|
250
263
|
self.done_warmup = True
|
|
251
264
|
|
|
252
265
|
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
|
253
|
-
self.run_callbacks(
|
|
266
|
+
self.run_callbacks("on_predict_start")
|
|
254
267
|
|
|
255
268
|
for batch in self.dataset:
|
|
256
|
-
self.run_callbacks(
|
|
269
|
+
self.run_callbacks("on_predict_batch_start")
|
|
257
270
|
self.batch = batch
|
|
258
271
|
path, im0s, vid_cap, s = batch
|
|
259
272
|
|
|
@@ -272,15 +285,16 @@ class BasePredictor:
|
|
|
272
285
|
with profilers[2]:
|
|
273
286
|
self.results = self.postprocess(preds, im, im0s)
|
|
274
287
|
|
|
275
|
-
self.run_callbacks(
|
|
288
|
+
self.run_callbacks("on_predict_postprocess_end")
|
|
276
289
|
# Visualize, save, write results
|
|
277
290
|
n = len(im0s)
|
|
278
291
|
for i in range(n):
|
|
279
292
|
self.seen += 1
|
|
280
293
|
self.results[i].speed = {
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
294
|
+
"preprocess": profilers[0].dt * 1e3 / n,
|
|
295
|
+
"inference": profilers[1].dt * 1e3 / n,
|
|
296
|
+
"postprocess": profilers[2].dt * 1e3 / n,
|
|
297
|
+
}
|
|
284
298
|
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
|
285
299
|
p = Path(p)
|
|
286
300
|
|
|
@@ -293,12 +307,12 @@ class BasePredictor:
|
|
|
293
307
|
if self.args.save and self.plotted_img is not None:
|
|
294
308
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
|
295
309
|
|
|
296
|
-
self.run_callbacks(
|
|
310
|
+
self.run_callbacks("on_predict_batch_end")
|
|
297
311
|
yield from self.results
|
|
298
312
|
|
|
299
313
|
# Print time (inference-only)
|
|
300
314
|
if self.args.verbose:
|
|
301
|
-
LOGGER.info(f
|
|
315
|
+
LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
|
|
302
316
|
|
|
303
317
|
# Release assets
|
|
304
318
|
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
|
@@ -306,25 +320,29 @@ class BasePredictor:
|
|
|
306
320
|
|
|
307
321
|
# Print results
|
|
308
322
|
if self.args.verbose and self.seen:
|
|
309
|
-
t = tuple(x.t / self.seen *
|
|
310
|
-
LOGGER.info(
|
|
311
|
-
|
|
323
|
+
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
|
324
|
+
LOGGER.info(
|
|
325
|
+
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
|
|
326
|
+
f"{(1, 3, *im.shape[2:])}" % t
|
|
327
|
+
)
|
|
312
328
|
if self.args.save or self.args.save_txt or self.args.save_crop:
|
|
313
|
-
nl = len(list(self.save_dir.glob(
|
|
314
|
-
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else
|
|
329
|
+
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
|
|
330
|
+
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
|
|
315
331
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
|
316
332
|
|
|
317
|
-
self.run_callbacks(
|
|
333
|
+
self.run_callbacks("on_predict_end")
|
|
318
334
|
|
|
319
335
|
def setup_model(self, model, verbose=True):
|
|
320
336
|
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
|
321
|
-
self.model = AutoBackend(
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
337
|
+
self.model = AutoBackend(
|
|
338
|
+
model or self.args.model,
|
|
339
|
+
device=select_device(self.args.device, verbose=verbose),
|
|
340
|
+
dnn=self.args.dnn,
|
|
341
|
+
data=self.args.data,
|
|
342
|
+
fp16=self.args.half,
|
|
343
|
+
fuse=True,
|
|
344
|
+
verbose=verbose,
|
|
345
|
+
)
|
|
328
346
|
|
|
329
347
|
self.device = self.model.device # update device
|
|
330
348
|
self.args.half = self.model.fp16 # update half
|
|
@@ -333,18 +351,18 @@ class BasePredictor:
|
|
|
333
351
|
def show(self, p):
|
|
334
352
|
"""Display an image in a window using OpenCV imshow()."""
|
|
335
353
|
im0 = self.plotted_img
|
|
336
|
-
if platform.system() ==
|
|
354
|
+
if platform.system() == "Linux" and p not in self.windows:
|
|
337
355
|
self.windows.append(p)
|
|
338
356
|
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
|
339
357
|
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
|
340
358
|
cv2.imshow(str(p), im0)
|
|
341
|
-
cv2.waitKey(500 if self.batch[3].startswith(
|
|
359
|
+
cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
|
|
342
360
|
|
|
343
361
|
def save_preds(self, vid_cap, idx, save_path):
|
|
344
362
|
"""Save video predictions as mp4 at specified path."""
|
|
345
363
|
im0 = self.plotted_img
|
|
346
364
|
# Save imgs
|
|
347
|
-
if self.dataset.mode ==
|
|
365
|
+
if self.dataset.mode == "image":
|
|
348
366
|
cv2.imwrite(save_path, im0)
|
|
349
367
|
else: # 'video' or 'stream'
|
|
350
368
|
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
|
|
@@ -361,15 +379,16 @@ class BasePredictor:
|
|
|
361
379
|
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
362
380
|
else: # stream
|
|
363
381
|
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
|
364
|
-
suffix, fourcc = (
|
|
365
|
-
self.vid_writer[idx] = cv2.VideoWriter(
|
|
366
|
-
|
|
382
|
+
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
|
|
383
|
+
self.vid_writer[idx] = cv2.VideoWriter(
|
|
384
|
+
str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
|
|
385
|
+
)
|
|
367
386
|
# Write video
|
|
368
387
|
self.vid_writer[idx].write(im0)
|
|
369
388
|
|
|
370
389
|
# Write frame
|
|
371
390
|
if self.args.save_frames:
|
|
372
|
-
cv2.imwrite(f
|
|
391
|
+
cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
|
|
373
392
|
self.vid_frame[idx] += 1
|
|
374
393
|
|
|
375
394
|
def run_callbacks(self, event: str):
|