ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
|
@@ -7,16 +7,54 @@ from pathlib import Path
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.nn as nn
|
|
9
9
|
|
|
10
|
-
from ultralytics.nn.modules import (
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
10
|
+
from ultralytics.nn.modules import (
|
|
11
|
+
AIFI,
|
|
12
|
+
C1,
|
|
13
|
+
C2,
|
|
14
|
+
C3,
|
|
15
|
+
C3TR,
|
|
16
|
+
OBB,
|
|
17
|
+
SPP,
|
|
18
|
+
SPPF,
|
|
19
|
+
Bottleneck,
|
|
20
|
+
BottleneckCSP,
|
|
21
|
+
C2f,
|
|
22
|
+
C3Ghost,
|
|
23
|
+
C3x,
|
|
24
|
+
Classify,
|
|
25
|
+
Concat,
|
|
26
|
+
Conv,
|
|
27
|
+
Conv2,
|
|
28
|
+
ConvTranspose,
|
|
29
|
+
Detect,
|
|
30
|
+
DWConv,
|
|
31
|
+
DWConvTranspose2d,
|
|
32
|
+
Focus,
|
|
33
|
+
GhostBottleneck,
|
|
34
|
+
GhostConv,
|
|
35
|
+
HGBlock,
|
|
36
|
+
HGStem,
|
|
37
|
+
Pose,
|
|
38
|
+
RepC3,
|
|
39
|
+
RepConv,
|
|
40
|
+
ResNetLayer,
|
|
41
|
+
RTDETRDecoder,
|
|
42
|
+
Segment,
|
|
43
|
+
)
|
|
14
44
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
|
15
45
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
|
16
46
|
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
|
|
17
47
|
from ultralytics.utils.plotting import feature_visualization
|
|
18
|
-
from ultralytics.utils.torch_utils import (
|
|
19
|
-
|
|
48
|
+
from ultralytics.utils.torch_utils import (
|
|
49
|
+
fuse_conv_and_bn,
|
|
50
|
+
fuse_deconv_and_bn,
|
|
51
|
+
initialize_weights,
|
|
52
|
+
intersect_dicts,
|
|
53
|
+
make_divisible,
|
|
54
|
+
model_info,
|
|
55
|
+
scale_img,
|
|
56
|
+
time_sync,
|
|
57
|
+
)
|
|
20
58
|
|
|
21
59
|
try:
|
|
22
60
|
import thop
|
|
@@ -90,8 +128,10 @@ class BaseModel(nn.Module):
|
|
|
90
128
|
|
|
91
129
|
def _predict_augment(self, x):
|
|
92
130
|
"""Perform augmentations on input image x and return augmented inference."""
|
|
93
|
-
LOGGER.warning(
|
|
94
|
-
|
|
131
|
+
LOGGER.warning(
|
|
132
|
+
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
|
133
|
+
f"Reverting to single-scale inference instead."
|
|
134
|
+
)
|
|
95
135
|
return self._predict_once(x)
|
|
96
136
|
|
|
97
137
|
def _profile_one_layer(self, m, x, dt):
|
|
@@ -108,14 +148,14 @@ class BaseModel(nn.Module):
|
|
|
108
148
|
None
|
|
109
149
|
"""
|
|
110
150
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
|
111
|
-
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] /
|
|
151
|
+
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
|
|
112
152
|
t = time_sync()
|
|
113
153
|
for _ in range(10):
|
|
114
154
|
m(x.copy() if c else x)
|
|
115
155
|
dt.append((time_sync() - t) * 100)
|
|
116
156
|
if m == self.model[0]:
|
|
117
157
|
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
|
118
|
-
LOGGER.info(f
|
|
158
|
+
LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
|
|
119
159
|
if c:
|
|
120
160
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
121
161
|
|
|
@@ -129,15 +169,15 @@ class BaseModel(nn.Module):
|
|
|
129
169
|
"""
|
|
130
170
|
if not self.is_fused():
|
|
131
171
|
for m in self.model.modules():
|
|
132
|
-
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m,
|
|
172
|
+
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
|
|
133
173
|
if isinstance(m, Conv2):
|
|
134
174
|
m.fuse_convs()
|
|
135
175
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
136
|
-
delattr(m,
|
|
176
|
+
delattr(m, "bn") # remove batchnorm
|
|
137
177
|
m.forward = m.forward_fuse # update forward
|
|
138
|
-
if isinstance(m, ConvTranspose) and hasattr(m,
|
|
178
|
+
if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
|
|
139
179
|
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
|
140
|
-
delattr(m,
|
|
180
|
+
delattr(m, "bn") # remove batchnorm
|
|
141
181
|
m.forward = m.forward_fuse # update forward
|
|
142
182
|
if isinstance(m, RepConv):
|
|
143
183
|
m.fuse_convs()
|
|
@@ -156,7 +196,7 @@ class BaseModel(nn.Module):
|
|
|
156
196
|
Returns:
|
|
157
197
|
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
|
158
198
|
"""
|
|
159
|
-
bn = tuple(v for k, v in nn.__dict__.items() if
|
|
199
|
+
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
|
160
200
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
|
161
201
|
|
|
162
202
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
|
@@ -196,12 +236,12 @@ class BaseModel(nn.Module):
|
|
|
196
236
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
|
197
237
|
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
|
198
238
|
"""
|
|
199
|
-
model = weights[
|
|
239
|
+
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
|
200
240
|
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
|
201
241
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
|
202
242
|
self.load_state_dict(csd, strict=False) # load
|
|
203
243
|
if verbose:
|
|
204
|
-
LOGGER.info(f
|
|
244
|
+
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
|
205
245
|
|
|
206
246
|
def loss(self, batch, preds=None):
|
|
207
247
|
"""
|
|
@@ -211,33 +251,33 @@ class BaseModel(nn.Module):
|
|
|
211
251
|
batch (dict): Batch to compute loss on
|
|
212
252
|
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
|
213
253
|
"""
|
|
214
|
-
if not hasattr(self,
|
|
254
|
+
if not hasattr(self, "criterion"):
|
|
215
255
|
self.criterion = self.init_criterion()
|
|
216
256
|
|
|
217
|
-
preds = self.forward(batch[
|
|
257
|
+
preds = self.forward(batch["img"]) if preds is None else preds
|
|
218
258
|
return self.criterion(preds, batch)
|
|
219
259
|
|
|
220
260
|
def init_criterion(self):
|
|
221
261
|
"""Initialize the loss criterion for the BaseModel."""
|
|
222
|
-
raise NotImplementedError(
|
|
262
|
+
raise NotImplementedError("compute_loss() needs to be implemented by task heads")
|
|
223
263
|
|
|
224
264
|
|
|
225
265
|
class DetectionModel(BaseModel):
|
|
226
266
|
"""YOLOv8 detection model."""
|
|
227
267
|
|
|
228
|
-
def __init__(self, cfg=
|
|
268
|
+
def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
|
229
269
|
"""Initialize the YOLOv8 detection model with the given config and parameters."""
|
|
230
270
|
super().__init__()
|
|
231
271
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
|
232
272
|
|
|
233
273
|
# Define model
|
|
234
|
-
ch = self.yaml[
|
|
235
|
-
if nc and nc != self.yaml[
|
|
274
|
+
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
|
275
|
+
if nc and nc != self.yaml["nc"]:
|
|
236
276
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
|
237
|
-
self.yaml[
|
|
277
|
+
self.yaml["nc"] = nc # override YAML value
|
|
238
278
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
|
239
|
-
self.names = {i: f
|
|
240
|
-
self.inplace = self.yaml.get(
|
|
279
|
+
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
|
280
|
+
self.inplace = self.yaml.get("inplace", True)
|
|
241
281
|
|
|
242
282
|
# Build strides
|
|
243
283
|
m = self.model[-1] # Detect()
|
|
@@ -255,7 +295,7 @@ class DetectionModel(BaseModel):
|
|
|
255
295
|
initialize_weights(self)
|
|
256
296
|
if verbose:
|
|
257
297
|
self.info()
|
|
258
|
-
LOGGER.info(
|
|
298
|
+
LOGGER.info("")
|
|
259
299
|
|
|
260
300
|
def _predict_augment(self, x):
|
|
261
301
|
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
|
@@ -285,9 +325,9 @@ class DetectionModel(BaseModel):
|
|
|
285
325
|
def _clip_augmented(self, y):
|
|
286
326
|
"""Clip YOLO augmented inference tails."""
|
|
287
327
|
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
|
288
|
-
g = sum(4
|
|
328
|
+
g = sum(4**x for x in range(nl)) # grid points
|
|
289
329
|
e = 1 # exclude layer count
|
|
290
|
-
i = (y[0].shape[-1] // g) * sum(4
|
|
330
|
+
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
|
|
291
331
|
y[0] = y[0][..., :-i] # large
|
|
292
332
|
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
|
293
333
|
y[-1] = y[-1][..., i:] # small
|
|
@@ -301,18 +341,19 @@ class DetectionModel(BaseModel):
|
|
|
301
341
|
class OBBModel(DetectionModel):
|
|
302
342
|
""""YOLOv8 Oriented Bounding Box (OBB) model."""
|
|
303
343
|
|
|
304
|
-
def __init__(self, cfg=
|
|
344
|
+
def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
|
|
305
345
|
"""Initialize YOLOv8 OBB model with given config and parameters."""
|
|
306
346
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
307
347
|
|
|
308
348
|
def init_criterion(self):
|
|
349
|
+
"""Initialize the loss criterion for the model."""
|
|
309
350
|
return v8OBBLoss(self)
|
|
310
351
|
|
|
311
352
|
|
|
312
353
|
class SegmentationModel(DetectionModel):
|
|
313
354
|
"""YOLOv8 segmentation model."""
|
|
314
355
|
|
|
315
|
-
def __init__(self, cfg=
|
|
356
|
+
def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
|
|
316
357
|
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
|
317
358
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
318
359
|
|
|
@@ -324,13 +365,13 @@ class SegmentationModel(DetectionModel):
|
|
|
324
365
|
class PoseModel(DetectionModel):
|
|
325
366
|
"""YOLOv8 pose model."""
|
|
326
367
|
|
|
327
|
-
def __init__(self, cfg=
|
|
368
|
+
def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
|
328
369
|
"""Initialize YOLOv8 Pose model."""
|
|
329
370
|
if not isinstance(cfg, dict):
|
|
330
371
|
cfg = yaml_model_load(cfg) # load model YAML
|
|
331
|
-
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg[
|
|
372
|
+
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
|
|
332
373
|
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
|
333
|
-
cfg[
|
|
374
|
+
cfg["kpt_shape"] = data_kpt_shape
|
|
334
375
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
335
376
|
|
|
336
377
|
def init_criterion(self):
|
|
@@ -341,7 +382,7 @@ class PoseModel(DetectionModel):
|
|
|
341
382
|
class ClassificationModel(BaseModel):
|
|
342
383
|
"""YOLOv8 classification model."""
|
|
343
384
|
|
|
344
|
-
def __init__(self, cfg=
|
|
385
|
+
def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
|
|
345
386
|
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
|
346
387
|
super().__init__()
|
|
347
388
|
self._from_yaml(cfg, ch, nc, verbose)
|
|
@@ -351,21 +392,21 @@ class ClassificationModel(BaseModel):
|
|
|
351
392
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
|
352
393
|
|
|
353
394
|
# Define model
|
|
354
|
-
ch = self.yaml[
|
|
355
|
-
if nc and nc != self.yaml[
|
|
395
|
+
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
|
396
|
+
if nc and nc != self.yaml["nc"]:
|
|
356
397
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
|
357
|
-
self.yaml[
|
|
358
|
-
elif not nc and not self.yaml.get(
|
|
359
|
-
raise ValueError(
|
|
398
|
+
self.yaml["nc"] = nc # override YAML value
|
|
399
|
+
elif not nc and not self.yaml.get("nc", None):
|
|
400
|
+
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
|
360
401
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
|
361
402
|
self.stride = torch.Tensor([1]) # no stride constraints
|
|
362
|
-
self.names = {i: f
|
|
403
|
+
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
|
363
404
|
self.info()
|
|
364
405
|
|
|
365
406
|
@staticmethod
|
|
366
407
|
def reshape_outputs(model, nc):
|
|
367
408
|
"""Update a TorchVision classification model to class count 'n' if required."""
|
|
368
|
-
name, m = list((model.model if hasattr(model,
|
|
409
|
+
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
|
369
410
|
if isinstance(m, Classify): # YOLO Classify() head
|
|
370
411
|
if m.linear.out_features != nc:
|
|
371
412
|
m.linear = nn.Linear(m.linear.in_features, nc)
|
|
@@ -408,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
408
449
|
predict: Performs a forward pass through the network and returns the output.
|
|
409
450
|
"""
|
|
410
451
|
|
|
411
|
-
def __init__(self, cfg=
|
|
452
|
+
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
|
412
453
|
"""
|
|
413
454
|
Initialize the RTDETRDetectionModel.
|
|
414
455
|
|
|
@@ -437,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
437
478
|
Returns:
|
|
438
479
|
(tuple): A tuple containing the total loss and main three losses in a tensor.
|
|
439
480
|
"""
|
|
440
|
-
if not hasattr(self,
|
|
481
|
+
if not hasattr(self, "criterion"):
|
|
441
482
|
self.criterion = self.init_criterion()
|
|
442
483
|
|
|
443
|
-
img = batch[
|
|
484
|
+
img = batch["img"]
|
|
444
485
|
# NOTE: preprocess gt_bbox and gt_labels to list.
|
|
445
486
|
bs = len(img)
|
|
446
|
-
batch_idx = batch[
|
|
487
|
+
batch_idx = batch["batch_idx"]
|
|
447
488
|
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
|
448
489
|
targets = {
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
490
|
+
"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
|
|
491
|
+
"bboxes": batch["bboxes"].to(device=img.device),
|
|
492
|
+
"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
|
|
493
|
+
"gt_groups": gt_groups,
|
|
494
|
+
}
|
|
453
495
|
|
|
454
496
|
preds = self.predict(img, batch=targets) if preds is None else preds
|
|
455
497
|
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
|
456
498
|
if dn_meta is None:
|
|
457
499
|
dn_bboxes, dn_scores = None, None
|
|
458
500
|
else:
|
|
459
|
-
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta[
|
|
460
|
-
dn_scores, dec_scores = torch.split(dec_scores, dn_meta[
|
|
501
|
+
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
|
|
502
|
+
dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
|
|
461
503
|
|
|
462
504
|
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
|
463
505
|
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
|
464
506
|
|
|
465
|
-
loss = self.criterion(
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
dn_scores=dn_scores,
|
|
469
|
-
dn_meta=dn_meta)
|
|
507
|
+
loss = self.criterion(
|
|
508
|
+
(dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
|
|
509
|
+
)
|
|
470
510
|
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
|
471
|
-
return sum(loss.values()), torch.as_tensor(
|
|
472
|
-
|
|
511
|
+
return sum(loss.values()), torch.as_tensor(
|
|
512
|
+
[loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
|
|
513
|
+
)
|
|
473
514
|
|
|
474
515
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
|
475
516
|
"""
|
|
@@ -552,6 +593,7 @@ def temporary_modules(modules=None):
|
|
|
552
593
|
|
|
553
594
|
import importlib
|
|
554
595
|
import sys
|
|
596
|
+
|
|
555
597
|
try:
|
|
556
598
|
# Set modules in sys.modules under their old name
|
|
557
599
|
for old, new in modules.items():
|
|
@@ -579,30 +621,38 @@ def torch_safe_load(weight):
|
|
|
579
621
|
"""
|
|
580
622
|
from ultralytics.utils.downloads import attempt_download_asset
|
|
581
623
|
|
|
582
|
-
check_suffix(file=weight, suffix=
|
|
624
|
+
check_suffix(file=weight, suffix=".pt")
|
|
583
625
|
file = attempt_download_asset(weight) # search online if missing locally
|
|
584
626
|
try:
|
|
585
|
-
with temporary_modules(
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
627
|
+
with temporary_modules(
|
|
628
|
+
{
|
|
629
|
+
"ultralytics.yolo.utils": "ultralytics.utils",
|
|
630
|
+
"ultralytics.yolo.v8": "ultralytics.models.yolo",
|
|
631
|
+
"ultralytics.yolo.data": "ultralytics.data",
|
|
632
|
+
}
|
|
633
|
+
): # for legacy 8.0 Classify and Pose models
|
|
634
|
+
return torch.load(file, map_location="cpu"), file # load
|
|
590
635
|
|
|
591
636
|
except ModuleNotFoundError as e: # e.name is missing module name
|
|
592
|
-
if e.name ==
|
|
637
|
+
if e.name == "models":
|
|
593
638
|
raise TypeError(
|
|
594
|
-
emojis(
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
639
|
+
emojis(
|
|
640
|
+
f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
|
|
641
|
+
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
|
642
|
+
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
|
643
|
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
|
644
|
+
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
|
645
|
+
)
|
|
646
|
+
) from e
|
|
647
|
+
LOGGER.warning(
|
|
648
|
+
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
|
649
|
+
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
|
650
|
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
|
651
|
+
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
|
652
|
+
)
|
|
603
653
|
check_requirements(e.name) # install missing module
|
|
604
654
|
|
|
605
|
-
return torch.load(file, map_location=
|
|
655
|
+
return torch.load(file, map_location="cpu"), file # load
|
|
606
656
|
|
|
607
657
|
|
|
608
658
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
@@ -611,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
|
611
661
|
ensemble = Ensemble()
|
|
612
662
|
for w in weights if isinstance(weights, list) else [weights]:
|
|
613
663
|
ckpt, w = torch_safe_load(w) # load ckpt
|
|
614
|
-
args = {**DEFAULT_CFG_DICT, **ckpt[
|
|
615
|
-
model = (ckpt.get(
|
|
664
|
+
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
|
|
665
|
+
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
|
616
666
|
|
|
617
667
|
# Model compatibility updates
|
|
618
668
|
model.args = args # attach args to model
|
|
619
669
|
model.pt_path = w # attach *.pt file path to model
|
|
620
670
|
model.task = guess_model_task(model)
|
|
621
|
-
if not hasattr(model,
|
|
622
|
-
model.stride = torch.tensor([32.])
|
|
671
|
+
if not hasattr(model, "stride"):
|
|
672
|
+
model.stride = torch.tensor([32.0])
|
|
623
673
|
|
|
624
674
|
# Append
|
|
625
|
-
ensemble.append(model.fuse().eval() if fuse and hasattr(model,
|
|
675
|
+
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
|
|
626
676
|
|
|
627
677
|
# Module updates
|
|
628
678
|
for m in ensemble.modules():
|
|
629
679
|
t = type(m)
|
|
630
680
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
|
631
681
|
m.inplace = inplace
|
|
632
|
-
elif t is nn.Upsample and not hasattr(m,
|
|
682
|
+
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
|
633
683
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
|
634
684
|
|
|
635
685
|
# Return model
|
|
@@ -637,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
|
637
687
|
return ensemble[-1]
|
|
638
688
|
|
|
639
689
|
# Return ensemble
|
|
640
|
-
LOGGER.info(f
|
|
641
|
-
for k in
|
|
690
|
+
LOGGER.info(f"Ensemble created with {weights}\n")
|
|
691
|
+
for k in "names", "nc", "yaml":
|
|
642
692
|
setattr(ensemble, k, getattr(ensemble[0], k))
|
|
643
693
|
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
|
644
|
-
assert all(ensemble[0].nc == m.nc for m in ensemble), f
|
|
694
|
+
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
|
645
695
|
return ensemble
|
|
646
696
|
|
|
647
697
|
|
|
648
698
|
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|
649
699
|
"""Loads a single model weights."""
|
|
650
700
|
ckpt, weight = torch_safe_load(weight) # load ckpt
|
|
651
|
-
args = {**DEFAULT_CFG_DICT, **(ckpt.get(
|
|
652
|
-
model = (ckpt.get(
|
|
701
|
+
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
|
|
702
|
+
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
|
653
703
|
|
|
654
704
|
# Model compatibility updates
|
|
655
705
|
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
|
656
706
|
model.pt_path = weight # attach *.pt file path to model
|
|
657
707
|
model.task = guess_model_task(model)
|
|
658
|
-
if not hasattr(model,
|
|
659
|
-
model.stride = torch.tensor([32.])
|
|
708
|
+
if not hasattr(model, "stride"):
|
|
709
|
+
model.stride = torch.tensor([32.0])
|
|
660
710
|
|
|
661
|
-
model = model.fuse().eval() if fuse and hasattr(model,
|
|
711
|
+
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
|
|
662
712
|
|
|
663
713
|
# Module updates
|
|
664
714
|
for m in model.modules():
|
|
665
715
|
t = type(m)
|
|
666
716
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
|
667
717
|
m.inplace = inplace
|
|
668
|
-
elif t is nn.Upsample and not hasattr(m,
|
|
718
|
+
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
|
669
719
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
|
670
720
|
|
|
671
721
|
# Return model and ckpt
|
|
@@ -677,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
677
727
|
import ast
|
|
678
728
|
|
|
679
729
|
# Args
|
|
680
|
-
max_channels = float(
|
|
681
|
-
nc, act, scales = (d.get(x) for x in (
|
|
682
|
-
depth, width, kpt_shape = (d.get(x, 1.0) for x in (
|
|
730
|
+
max_channels = float("inf")
|
|
731
|
+
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
|
732
|
+
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
|
683
733
|
if scales:
|
|
684
|
-
scale = d.get(
|
|
734
|
+
scale = d.get("scale")
|
|
685
735
|
if not scale:
|
|
686
736
|
scale = tuple(scales.keys())[0]
|
|
687
737
|
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
|
|
@@ -696,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
696
746
|
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
|
697
747
|
ch = [ch]
|
|
698
748
|
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
|
699
|
-
for i, (f, n, m, args) in enumerate(d[
|
|
700
|
-
m = getattr(torch.nn, m[3:]) if
|
|
749
|
+
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
|
750
|
+
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
|
|
701
751
|
for j, a in enumerate(args):
|
|
702
752
|
if isinstance(a, str):
|
|
703
753
|
with contextlib.suppress(ValueError):
|
|
704
754
|
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
|
705
755
|
|
|
706
756
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
|
707
|
-
if m in (
|
|
708
|
-
|
|
757
|
+
if m in (
|
|
758
|
+
Classify,
|
|
759
|
+
Conv,
|
|
760
|
+
ConvTranspose,
|
|
761
|
+
GhostConv,
|
|
762
|
+
Bottleneck,
|
|
763
|
+
GhostBottleneck,
|
|
764
|
+
SPP,
|
|
765
|
+
SPPF,
|
|
766
|
+
DWConv,
|
|
767
|
+
Focus,
|
|
768
|
+
BottleneckCSP,
|
|
769
|
+
C1,
|
|
770
|
+
C2,
|
|
771
|
+
C2f,
|
|
772
|
+
C3,
|
|
773
|
+
C3TR,
|
|
774
|
+
C3Ghost,
|
|
775
|
+
nn.ConvTranspose2d,
|
|
776
|
+
DWConvTranspose2d,
|
|
777
|
+
C3x,
|
|
778
|
+
RepC3,
|
|
779
|
+
):
|
|
709
780
|
c1, c2 = ch[f], args[0]
|
|
710
781
|
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
|
711
782
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
|
@@ -738,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
738
809
|
c2 = ch[f]
|
|
739
810
|
|
|
740
811
|
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
|
741
|
-
t = str(m)[8:-2].replace(
|
|
812
|
+
t = str(m)[8:-2].replace("__main__.", "") # module type
|
|
742
813
|
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
|
743
814
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
|
744
815
|
if verbose:
|
|
745
|
-
LOGGER.info(f
|
|
816
|
+
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
|
|
746
817
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
747
818
|
layers.append(m_)
|
|
748
819
|
if i == 0:
|
|
@@ -756,16 +827,16 @@ def yaml_model_load(path):
|
|
|
756
827
|
import re
|
|
757
828
|
|
|
758
829
|
path = Path(path)
|
|
759
|
-
if path.stem in (f
|
|
760
|
-
new_stem = re.sub(r
|
|
761
|
-
LOGGER.warning(f
|
|
830
|
+
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
|
831
|
+
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
|
832
|
+
LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
|
|
762
833
|
path = path.with_name(new_stem + path.suffix)
|
|
763
834
|
|
|
764
|
-
unified_path = re.sub(r
|
|
835
|
+
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
|
765
836
|
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
|
766
837
|
d = yaml_load(yaml_file) # model dict
|
|
767
|
-
d[
|
|
768
|
-
d[
|
|
838
|
+
d["scale"] = guess_model_scale(path)
|
|
839
|
+
d["yaml_file"] = str(path)
|
|
769
840
|
return d
|
|
770
841
|
|
|
771
842
|
|
|
@@ -783,8 +854,9 @@ def guess_model_scale(model_path):
|
|
|
783
854
|
"""
|
|
784
855
|
with contextlib.suppress(AttributeError):
|
|
785
856
|
import re
|
|
786
|
-
|
|
787
|
-
|
|
857
|
+
|
|
858
|
+
return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
|
859
|
+
return ""
|
|
788
860
|
|
|
789
861
|
|
|
790
862
|
def guess_model_task(model):
|
|
@@ -803,17 +875,17 @@ def guess_model_task(model):
|
|
|
803
875
|
|
|
804
876
|
def cfg2task(cfg):
|
|
805
877
|
"""Guess from YAML dictionary."""
|
|
806
|
-
m = cfg[
|
|
807
|
-
if m in (
|
|
808
|
-
return
|
|
809
|
-
if m ==
|
|
810
|
-
return
|
|
811
|
-
if m ==
|
|
812
|
-
return
|
|
813
|
-
if m ==
|
|
814
|
-
return
|
|
815
|
-
if m ==
|
|
816
|
-
return
|
|
878
|
+
m = cfg["head"][-1][-2].lower() # output module name
|
|
879
|
+
if m in ("classify", "classifier", "cls", "fc"):
|
|
880
|
+
return "classify"
|
|
881
|
+
if m == "detect":
|
|
882
|
+
return "detect"
|
|
883
|
+
if m == "segment":
|
|
884
|
+
return "segment"
|
|
885
|
+
if m == "pose":
|
|
886
|
+
return "pose"
|
|
887
|
+
if m == "obb":
|
|
888
|
+
return "obb"
|
|
817
889
|
|
|
818
890
|
# Guess from model cfg
|
|
819
891
|
if isinstance(model, dict):
|
|
@@ -822,40 +894,42 @@ def guess_model_task(model):
|
|
|
822
894
|
|
|
823
895
|
# Guess from PyTorch model
|
|
824
896
|
if isinstance(model, nn.Module): # PyTorch model
|
|
825
|
-
for x in
|
|
897
|
+
for x in "model.args", "model.model.args", "model.model.model.args":
|
|
826
898
|
with contextlib.suppress(Exception):
|
|
827
|
-
return eval(x)[
|
|
828
|
-
for x in
|
|
899
|
+
return eval(x)["task"]
|
|
900
|
+
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
|
829
901
|
with contextlib.suppress(Exception):
|
|
830
902
|
return cfg2task(eval(x))
|
|
831
903
|
|
|
832
904
|
for m in model.modules():
|
|
833
905
|
if isinstance(m, Detect):
|
|
834
|
-
return
|
|
906
|
+
return "detect"
|
|
835
907
|
elif isinstance(m, Segment):
|
|
836
|
-
return
|
|
908
|
+
return "segment"
|
|
837
909
|
elif isinstance(m, Classify):
|
|
838
|
-
return
|
|
910
|
+
return "classify"
|
|
839
911
|
elif isinstance(m, Pose):
|
|
840
|
-
return
|
|
912
|
+
return "pose"
|
|
841
913
|
elif isinstance(m, OBB):
|
|
842
|
-
return
|
|
914
|
+
return "obb"
|
|
843
915
|
|
|
844
916
|
# Guess from model filename
|
|
845
917
|
if isinstance(model, (str, Path)):
|
|
846
918
|
model = Path(model)
|
|
847
|
-
if
|
|
848
|
-
return
|
|
849
|
-
elif
|
|
850
|
-
return
|
|
851
|
-
elif
|
|
852
|
-
return
|
|
853
|
-
elif
|
|
854
|
-
return
|
|
855
|
-
elif
|
|
856
|
-
return
|
|
919
|
+
if "-seg" in model.stem or "segment" in model.parts:
|
|
920
|
+
return "segment"
|
|
921
|
+
elif "-cls" in model.stem or "classify" in model.parts:
|
|
922
|
+
return "classify"
|
|
923
|
+
elif "-pose" in model.stem or "pose" in model.parts:
|
|
924
|
+
return "pose"
|
|
925
|
+
elif "-obb" in model.stem or "obb" in model.parts:
|
|
926
|
+
return "obb"
|
|
927
|
+
elif "detect" in model.parts:
|
|
928
|
+
return "detect"
|
|
857
929
|
|
|
858
930
|
# Unable to determine task from model
|
|
859
|
-
LOGGER.warning(
|
|
860
|
-
|
|
861
|
-
|
|
931
|
+
LOGGER.warning(
|
|
932
|
+
"WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
|
933
|
+
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
|
|
934
|
+
)
|
|
935
|
+
return "detect" # assume detect
|