ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_exports.py +2 -2
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +15 -19
- ultralytics/engine/exporter.py +24 -23
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +42 -24
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +65 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
- ultralytics-8.3.91.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py
CHANGED
@@ -119,10 +119,10 @@ class BaseModel(torch.nn.Module):
|
|
119
119
|
|
120
120
|
Args:
|
121
121
|
x (torch.Tensor): The input tensor to the model.
|
122
|
-
profile (bool):
|
123
|
-
visualize (bool): Save the feature maps of the model if True
|
124
|
-
augment (bool): Augment image during prediction
|
125
|
-
embed (
|
122
|
+
profile (bool): Print the computation time of each layer if True.
|
123
|
+
visualize (bool): Save the feature maps of the model if True.
|
124
|
+
augment (bool): Augment image during prediction.
|
125
|
+
embed (List, optional): A list of feature vectors/embeddings to return.
|
126
126
|
|
127
127
|
Returns:
|
128
128
|
(torch.Tensor): The last output of the model.
|
@@ -137,9 +137,9 @@ class BaseModel(torch.nn.Module):
|
|
137
137
|
|
138
138
|
Args:
|
139
139
|
x (torch.Tensor): The input tensor to the model.
|
140
|
-
profile (bool):
|
141
|
-
visualize (bool): Save the feature maps of the model if True
|
142
|
-
embed (
|
140
|
+
profile (bool): Print the computation time of each layer if True.
|
141
|
+
visualize (bool): Save the feature maps of the model if True.
|
142
|
+
embed (List, optional): A list of feature vectors/embeddings to return.
|
143
143
|
|
144
144
|
Returns:
|
145
145
|
(torch.Tensor): The last output of the model.
|
@@ -170,13 +170,12 @@ class BaseModel(torch.nn.Module):
|
|
170
170
|
|
171
171
|
def _profile_one_layer(self, m, x, dt):
|
172
172
|
"""
|
173
|
-
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
174
|
-
the provided list.
|
173
|
+
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
175
174
|
|
176
175
|
Args:
|
177
176
|
m (torch.nn.Module): The layer to be profiled.
|
178
177
|
x (torch.Tensor): The input data to the layer.
|
179
|
-
dt (
|
178
|
+
dt (List): A list to store the computation time of the layer.
|
180
179
|
"""
|
181
180
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
182
181
|
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
@@ -192,8 +191,8 @@ class BaseModel(torch.nn.Module):
|
|
192
191
|
|
193
192
|
def fuse(self, verbose=True):
|
194
193
|
"""
|
195
|
-
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer
|
196
|
-
|
194
|
+
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
195
|
+
efficiency.
|
197
196
|
|
198
197
|
Returns:
|
199
198
|
(torch.nn.Module): The fused model is returned.
|
@@ -225,7 +224,7 @@ class BaseModel(torch.nn.Module):
|
|
225
224
|
Check if the model has less than a certain threshold of BatchNorm layers.
|
226
225
|
|
227
226
|
Args:
|
228
|
-
thresh (int, optional): The threshold number of BatchNorm layers.
|
227
|
+
thresh (int, optional): The threshold number of BatchNorm layers.
|
229
228
|
|
230
229
|
Returns:
|
231
230
|
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
@@ -235,21 +234,21 @@ class BaseModel(torch.nn.Module):
|
|
235
234
|
|
236
235
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
237
236
|
"""
|
238
|
-
|
237
|
+
Print model information.
|
239
238
|
|
240
239
|
Args:
|
241
|
-
detailed (bool):
|
242
|
-
verbose (bool):
|
243
|
-
imgsz (int):
|
240
|
+
detailed (bool): If True, prints out detailed information about the model.
|
241
|
+
verbose (bool): If True, prints out the model information.
|
242
|
+
imgsz (int): The size of the image that the model will be trained on.
|
244
243
|
"""
|
245
244
|
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
246
245
|
|
247
246
|
def _apply(self, fn):
|
248
247
|
"""
|
249
|
-
|
248
|
+
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
250
249
|
|
251
250
|
Args:
|
252
|
-
fn (function):
|
251
|
+
fn (function): The function to apply to the model.
|
253
252
|
|
254
253
|
Returns:
|
255
254
|
(BaseModel): An updated BaseModel object.
|
@@ -264,11 +263,11 @@ class BaseModel(torch.nn.Module):
|
|
264
263
|
|
265
264
|
def load(self, weights, verbose=True):
|
266
265
|
"""
|
267
|
-
Load
|
266
|
+
Load weights into the model.
|
268
267
|
|
269
268
|
Args:
|
270
269
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
271
|
-
verbose (bool, optional): Whether to log the transfer progress.
|
270
|
+
verbose (bool, optional): Whether to log the transfer progress.
|
272
271
|
"""
|
273
272
|
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
274
273
|
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
@@ -282,8 +281,8 @@ class BaseModel(torch.nn.Module):
|
|
282
281
|
Compute loss.
|
283
282
|
|
284
283
|
Args:
|
285
|
-
batch (dict): Batch to compute loss on
|
286
|
-
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
284
|
+
batch (dict): Batch to compute loss on.
|
285
|
+
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
|
287
286
|
"""
|
288
287
|
if getattr(self, "criterion", None) is None:
|
289
288
|
self.criterion = self.init_criterion()
|
@@ -300,7 +299,15 @@ class DetectionModel(BaseModel):
|
|
300
299
|
"""YOLO detection model."""
|
301
300
|
|
302
301
|
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
303
|
-
"""
|
302
|
+
"""
|
303
|
+
Initialize the YOLO detection model with the given config and parameters.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
307
|
+
ch (int): Number of input channels.
|
308
|
+
nc (int, optional): Number of classes.
|
309
|
+
verbose (bool): Whether to display model information.
|
310
|
+
"""
|
304
311
|
super().__init__()
|
305
312
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
306
313
|
if self.yaml["backbone"][0][2] == "Silence":
|
@@ -327,7 +334,7 @@ class DetectionModel(BaseModel):
|
|
327
334
|
m.inplace = self.inplace
|
328
335
|
|
329
336
|
def _forward(x):
|
330
|
-
"""
|
337
|
+
"""Perform a forward pass through the model, handling different Detect subclass types accordingly."""
|
331
338
|
if self.end2end:
|
332
339
|
return self.forward(x)["one2many"]
|
333
340
|
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
@@ -345,7 +352,15 @@ class DetectionModel(BaseModel):
|
|
345
352
|
LOGGER.info("")
|
346
353
|
|
347
354
|
def _predict_augment(self, x):
|
348
|
-
"""
|
355
|
+
"""
|
356
|
+
Perform augmentations on input image x and return augmented inference and train outputs.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
x (torch.Tensor): Input image tensor.
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
(torch.Tensor): Augmented inference output.
|
363
|
+
"""
|
349
364
|
if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
|
350
365
|
LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
|
351
366
|
return self._predict_once(x)
|
@@ -363,7 +378,19 @@ class DetectionModel(BaseModel):
|
|
363
378
|
|
364
379
|
@staticmethod
|
365
380
|
def _descale_pred(p, flips, scale, img_size, dim=1):
|
366
|
-
"""
|
381
|
+
"""
|
382
|
+
De-scale predictions following augmented inference (inverse operation).
|
383
|
+
|
384
|
+
Args:
|
385
|
+
p (torch.Tensor): Predictions tensor.
|
386
|
+
flips (int): Flip type (0=none, 2=ud, 3=lr).
|
387
|
+
scale (float): Scale factor.
|
388
|
+
img_size (tuple): Original image size (height, width).
|
389
|
+
dim (int): Dimension to split at.
|
390
|
+
|
391
|
+
Returns:
|
392
|
+
(torch.Tensor): De-scaled predictions.
|
393
|
+
"""
|
367
394
|
p[:, :4] /= scale # de-scale
|
368
395
|
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
|
369
396
|
if flips == 2:
|
@@ -373,7 +400,15 @@ class DetectionModel(BaseModel):
|
|
373
400
|
return torch.cat((x, y, wh, cls), dim)
|
374
401
|
|
375
402
|
def _clip_augmented(self, y):
|
376
|
-
"""
|
403
|
+
"""
|
404
|
+
Clip YOLO augmented inference tails.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
y (List[torch.Tensor]): List of detection tensors.
|
408
|
+
|
409
|
+
Returns:
|
410
|
+
(List[torch.Tensor]): Clipped detection tensors.
|
411
|
+
"""
|
377
412
|
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
378
413
|
g = sum(4**x for x in range(nl)) # grid points
|
379
414
|
e = 1 # exclude layer count
|
@@ -392,7 +427,15 @@ class OBBModel(DetectionModel):
|
|
392
427
|
"""YOLO Oriented Bounding Box (OBB) model."""
|
393
428
|
|
394
429
|
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
|
395
|
-
"""
|
430
|
+
"""
|
431
|
+
Initialize YOLO OBB model with given config and parameters.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
435
|
+
ch (int): Number of input channels.
|
436
|
+
nc (int, optional): Number of classes.
|
437
|
+
verbose (bool): Whether to display model information.
|
438
|
+
"""
|
396
439
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
397
440
|
|
398
441
|
def init_criterion(self):
|
@@ -404,7 +447,15 @@ class SegmentationModel(DetectionModel):
|
|
404
447
|
"""YOLO segmentation model."""
|
405
448
|
|
406
449
|
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
|
407
|
-
"""
|
450
|
+
"""
|
451
|
+
Initialize YOLOv8 segmentation model with given config and parameters.
|
452
|
+
|
453
|
+
Args:
|
454
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
455
|
+
ch (int): Number of input channels.
|
456
|
+
nc (int, optional): Number of classes.
|
457
|
+
verbose (bool): Whether to display model information.
|
458
|
+
"""
|
408
459
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
409
460
|
|
410
461
|
def init_criterion(self):
|
@@ -416,7 +467,16 @@ class PoseModel(DetectionModel):
|
|
416
467
|
"""YOLO pose model."""
|
417
468
|
|
418
469
|
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
419
|
-
"""
|
470
|
+
"""
|
471
|
+
Initialize YOLOv8 Pose model.
|
472
|
+
|
473
|
+
Args:
|
474
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
475
|
+
ch (int): Number of input channels.
|
476
|
+
nc (int, optional): Number of classes.
|
477
|
+
data_kpt_shape (tuple): Shape of keypoints data.
|
478
|
+
verbose (bool): Whether to display model information.
|
479
|
+
"""
|
420
480
|
if not isinstance(cfg, dict):
|
421
481
|
cfg = yaml_model_load(cfg) # load model YAML
|
422
482
|
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
|
@@ -433,12 +493,28 @@ class ClassificationModel(BaseModel):
|
|
433
493
|
"""YOLO classification model."""
|
434
494
|
|
435
495
|
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
|
436
|
-
"""
|
496
|
+
"""
|
497
|
+
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
498
|
+
|
499
|
+
Args:
|
500
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
501
|
+
ch (int): Number of input channels.
|
502
|
+
nc (int, optional): Number of classes.
|
503
|
+
verbose (bool): Whether to display model information.
|
504
|
+
"""
|
437
505
|
super().__init__()
|
438
506
|
self._from_yaml(cfg, ch, nc, verbose)
|
439
507
|
|
440
508
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
441
|
-
"""
|
509
|
+
"""
|
510
|
+
Set YOLOv8 model configurations and define the model architecture.
|
511
|
+
|
512
|
+
Args:
|
513
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
514
|
+
ch (int): Number of input channels.
|
515
|
+
nc (int, optional): Number of classes.
|
516
|
+
verbose (bool): Whether to display model information.
|
517
|
+
"""
|
442
518
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
443
519
|
|
444
520
|
# Define model
|
@@ -455,7 +531,13 @@ class ClassificationModel(BaseModel):
|
|
455
531
|
|
456
532
|
@staticmethod
|
457
533
|
def reshape_outputs(model, nc):
|
458
|
-
"""
|
534
|
+
"""
|
535
|
+
Update a TorchVision classification model to class count 'n' if required.
|
536
|
+
|
537
|
+
Args:
|
538
|
+
model (torch.nn.Module): Model to update.
|
539
|
+
nc (int): New number of classes.
|
540
|
+
"""
|
459
541
|
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
460
542
|
if isinstance(m, Classify): # YOLO Classify() head
|
461
543
|
if m.linear.out_features != nc:
|
@@ -500,10 +582,10 @@ class RTDETRDetectionModel(DetectionModel):
|
|
500
582
|
Initialize the RTDETRDetectionModel.
|
501
583
|
|
502
584
|
Args:
|
503
|
-
cfg (str): Configuration file name or path.
|
585
|
+
cfg (str | dict): Configuration file name or path.
|
504
586
|
ch (int): Number of input channels.
|
505
|
-
nc (int, optional): Number of classes.
|
506
|
-
verbose (bool
|
587
|
+
nc (int, optional): Number of classes.
|
588
|
+
verbose (bool): Print additional information during initialization.
|
507
589
|
"""
|
508
590
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
509
591
|
|
@@ -519,7 +601,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|
519
601
|
|
520
602
|
Args:
|
521
603
|
batch (dict): Dictionary containing image and label data.
|
522
|
-
preds (torch.Tensor, optional): Precomputed model predictions.
|
604
|
+
preds (torch.Tensor, optional): Precomputed model predictions.
|
523
605
|
|
524
606
|
Returns:
|
525
607
|
(tuple): A tuple containing the total loss and main three losses in a tensor.
|
@@ -564,11 +646,11 @@ class RTDETRDetectionModel(DetectionModel):
|
|
564
646
|
|
565
647
|
Args:
|
566
648
|
x (torch.Tensor): The input tensor.
|
567
|
-
profile (bool
|
568
|
-
visualize (bool
|
569
|
-
batch (dict, optional): Ground truth data for evaluation.
|
570
|
-
augment (bool
|
571
|
-
embed (
|
649
|
+
profile (bool): If True, profile the computation time for each layer.
|
650
|
+
visualize (bool): If True, save feature maps for visualization.
|
651
|
+
batch (dict, optional): Ground truth data for evaluation.
|
652
|
+
augment (bool): If True, perform data augmentation during inference.
|
653
|
+
embed (List, optional): A list of feature vectors/embeddings to return.
|
572
654
|
|
573
655
|
Returns:
|
574
656
|
(torch.Tensor): Model's output tensor.
|
@@ -596,13 +678,28 @@ class WorldModel(DetectionModel):
|
|
596
678
|
"""YOLOv8 World Model."""
|
597
679
|
|
598
680
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
599
|
-
"""
|
681
|
+
"""
|
682
|
+
Initialize YOLOv8 world model with given config and parameters.
|
683
|
+
|
684
|
+
Args:
|
685
|
+
cfg (str | dict): Model configuration file path or dictionary.
|
686
|
+
ch (int): Number of input channels.
|
687
|
+
nc (int, optional): Number of classes.
|
688
|
+
verbose (bool): Whether to display model information.
|
689
|
+
"""
|
600
690
|
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
|
601
691
|
self.clip_model = None # CLIP model placeholder
|
602
692
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
603
693
|
|
604
694
|
def set_classes(self, text, batch=80, cache_clip_model=True):
|
605
|
-
"""
|
695
|
+
"""
|
696
|
+
Set classes in advance so that model could do offline-inference without clip model.
|
697
|
+
|
698
|
+
Args:
|
699
|
+
text (List[str]): List of class names.
|
700
|
+
batch (int): Batch size for processing text tokens.
|
701
|
+
cache_clip_model (bool): Whether to cache the CLIP model.
|
702
|
+
"""
|
606
703
|
try:
|
607
704
|
import clip
|
608
705
|
except ImportError:
|
@@ -628,11 +725,11 @@ class WorldModel(DetectionModel):
|
|
628
725
|
|
629
726
|
Args:
|
630
727
|
x (torch.Tensor): The input tensor.
|
631
|
-
profile (bool
|
632
|
-
visualize (bool
|
633
|
-
txt_feats (torch.Tensor): The text features, use it if it's given.
|
634
|
-
augment (bool
|
635
|
-
embed (
|
728
|
+
profile (bool): If True, profile the computation time for each layer.
|
729
|
+
visualize (bool): If True, save feature maps for visualization.
|
730
|
+
txt_feats (torch.Tensor, optional): The text features, use it if it's given.
|
731
|
+
augment (bool): If True, perform data augmentation during inference.
|
732
|
+
embed (List, optional): A list of feature vectors/embeddings to return.
|
636
733
|
|
637
734
|
Returns:
|
638
735
|
(torch.Tensor): Model's output tensor.
|
@@ -671,7 +768,7 @@ class WorldModel(DetectionModel):
|
|
671
768
|
|
672
769
|
Args:
|
673
770
|
batch (dict): Batch to compute loss on.
|
674
|
-
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
771
|
+
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
|
675
772
|
"""
|
676
773
|
if not hasattr(self, "criterion"):
|
677
774
|
self.criterion = self.init_criterion()
|
@@ -689,7 +786,18 @@ class Ensemble(torch.nn.ModuleList):
|
|
689
786
|
super().__init__()
|
690
787
|
|
691
788
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
692
|
-
"""
|
789
|
+
"""
|
790
|
+
Generate the YOLO network's final layer.
|
791
|
+
|
792
|
+
Args:
|
793
|
+
x (torch.Tensor): Input tensor.
|
794
|
+
augment (bool): Whether to augment the input.
|
795
|
+
profile (bool): Whether to profile the model.
|
796
|
+
visualize (bool): Whether to visualize the features.
|
797
|
+
|
798
|
+
Returns:
|
799
|
+
(tuple): Tuple containing the concatenated predictions and None.
|
800
|
+
"""
|
693
801
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
694
802
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
695
803
|
# y = torch.stack(y).mean(0) # mean ensemble
|
@@ -765,7 +873,16 @@ class SafeUnpickler(pickle.Unpickler):
|
|
765
873
|
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
766
874
|
|
767
875
|
def find_class(self, module, name):
|
768
|
-
"""
|
876
|
+
"""
|
877
|
+
Attempt to find a class, returning SafeClass if not among safe modules.
|
878
|
+
|
879
|
+
Args:
|
880
|
+
module (str): Module name.
|
881
|
+
name (str): Class name.
|
882
|
+
|
883
|
+
Returns:
|
884
|
+
(type): Found class or SafeClass.
|
885
|
+
"""
|
769
886
|
safe_modules = (
|
770
887
|
"torch",
|
771
888
|
"collections",
|
@@ -791,13 +908,13 @@ def torch_safe_load(weight, safe_only=False):
|
|
791
908
|
weight (str): The file path of the PyTorch model.
|
792
909
|
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
|
793
910
|
|
911
|
+
Returns:
|
912
|
+
ckpt (dict): The loaded model checkpoint.
|
913
|
+
file (str): The loaded filename.
|
914
|
+
|
794
915
|
Examples:
|
795
916
|
>>> from ultralytics.nn.tasks import torch_safe_load
|
796
917
|
>>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
|
797
|
-
|
798
|
-
Returns:
|
799
|
-
ckpt (dict): The loaded model checkpoint.
|
800
|
-
file (str): The loaded filename
|
801
918
|
"""
|
802
919
|
from ultralytics.utils.downloads import attempt_download_asset
|
803
920
|
|
@@ -858,7 +975,18 @@ def torch_safe_load(weight, safe_only=False):
|
|
858
975
|
|
859
976
|
|
860
977
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
861
|
-
"""
|
978
|
+
"""
|
979
|
+
Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
|
980
|
+
|
981
|
+
Args:
|
982
|
+
weights (str | List[str]): Model weights path(s).
|
983
|
+
device (torch.device, optional): Device to load model to.
|
984
|
+
inplace (bool): Whether to do inplace operations.
|
985
|
+
fuse (bool): Whether to fuse model.
|
986
|
+
|
987
|
+
Returns:
|
988
|
+
(torch.nn.Module): Loaded model.
|
989
|
+
"""
|
862
990
|
ensemble = Ensemble()
|
863
991
|
for w in weights if isinstance(weights, list) else [weights]:
|
864
992
|
ckpt, w = torch_safe_load(w) # load ckpt
|
@@ -896,7 +1024,18 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
896
1024
|
|
897
1025
|
|
898
1026
|
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
899
|
-
"""
|
1027
|
+
"""
|
1028
|
+
Load a single model weights.
|
1029
|
+
|
1030
|
+
Args:
|
1031
|
+
weight (str): Model weight path.
|
1032
|
+
device (torch.device, optional): Device to load model to.
|
1033
|
+
inplace (bool): Whether to do inplace operations.
|
1034
|
+
fuse (bool): Whether to fuse model.
|
1035
|
+
|
1036
|
+
Returns:
|
1037
|
+
(tuple): Tuple containing the model and checkpoint.
|
1038
|
+
"""
|
900
1039
|
ckpt, weight = torch_safe_load(weight) # load ckpt
|
901
1040
|
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
|
902
1041
|
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
@@ -922,7 +1061,17 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|
922
1061
|
|
923
1062
|
|
924
1063
|
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
925
|
-
"""
|
1064
|
+
"""
|
1065
|
+
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
1066
|
+
|
1067
|
+
Args:
|
1068
|
+
d (dict): Model dictionary.
|
1069
|
+
ch (int): Input channels.
|
1070
|
+
verbose (bool): Whether to print model details.
|
1071
|
+
|
1072
|
+
Returns:
|
1073
|
+
(tuple): Tuple containing the PyTorch model and sorted list of output layers.
|
1074
|
+
"""
|
926
1075
|
import ast
|
927
1076
|
|
928
1077
|
# Args
|
@@ -1086,7 +1235,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
1086
1235
|
|
1087
1236
|
|
1088
1237
|
def yaml_model_load(path):
|
1089
|
-
"""
|
1238
|
+
"""
|
1239
|
+
Load a YOLOv8 model from a YAML file.
|
1240
|
+
|
1241
|
+
Args:
|
1242
|
+
path (str | Path): Path to the YAML file.
|
1243
|
+
|
1244
|
+
Returns:
|
1245
|
+
(dict): Model dictionary.
|
1246
|
+
"""
|
1090
1247
|
path = Path(path)
|
1091
1248
|
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
1092
1249
|
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
@@ -1103,15 +1260,13 @@ def yaml_model_load(path):
|
|
1103
1260
|
|
1104
1261
|
def guess_model_scale(model_path):
|
1105
1262
|
"""
|
1106
|
-
|
1107
|
-
uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by
|
1108
|
-
n, s, m, l, or x. The function returns the size character of the model scale as a string.
|
1263
|
+
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
1109
1264
|
|
1110
1265
|
Args:
|
1111
1266
|
model_path (str | Path): The path to the YOLO model's YAML file.
|
1112
1267
|
|
1113
1268
|
Returns:
|
1114
|
-
(str): The size character of the model's scale
|
1269
|
+
(str): The size character of the model's scale (n, s, m, l, or x).
|
1115
1270
|
"""
|
1116
1271
|
try:
|
1117
1272
|
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # returns n, s, m, l, or x
|
@@ -1127,10 +1282,7 @@ def guess_model_task(model):
|
|
1127
1282
|
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
1128
1283
|
|
1129
1284
|
Returns:
|
1130
|
-
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
|
1131
|
-
|
1132
|
-
Raises:
|
1133
|
-
SyntaxError: If the task of the model could not be determined.
|
1285
|
+
(str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
|
1134
1286
|
"""
|
1135
1287
|
|
1136
1288
|
def cfg2task(cfg):
|
ultralytics/solutions/ai_gym.py
CHANGED
@@ -33,7 +33,7 @@ class AIGym(BaseSolution):
|
|
33
33
|
|
34
34
|
def __init__(self, **kwargs):
|
35
35
|
"""
|
36
|
-
|
36
|
+
Initialize AIGym for workout monitoring using pose estimation and predefined angles.
|
37
37
|
|
38
38
|
Args:
|
39
39
|
**kwargs (Any): Keyword arguments passed to the parent class constructor.
|
@@ -53,7 +53,7 @@ class AIGym(BaseSolution):
|
|
53
53
|
|
54
54
|
def process(self, im0):
|
55
55
|
"""
|
56
|
-
|
56
|
+
Monitor workouts using Ultralytics YOLO Pose Model.
|
57
57
|
|
58
58
|
This function processes an input image to track and analyze human poses for workout monitoring. It uses
|
59
59
|
the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
|
@@ -37,8 +37,8 @@ class Analytics(BaseSolution):
|
|
37
37
|
color_mapping (Dict[str, str]): Dictionary mapping class labels to colors for consistent visualization.
|
38
38
|
|
39
39
|
Methods:
|
40
|
-
process:
|
41
|
-
update_graph:
|
40
|
+
process: Process image data and update the chart.
|
41
|
+
update_graph: Update the chart with new data points.
|
42
42
|
|
43
43
|
Examples:
|
44
44
|
>>> analytics = Analytics(analytics_type="line")
|
@@ -87,7 +87,7 @@ class Analytics(BaseSolution):
|
|
87
87
|
|
88
88
|
def process(self, im0, frame_number):
|
89
89
|
"""
|
90
|
-
|
90
|
+
Process image data and run object tracking to update analytics charts.
|
91
91
|
|
92
92
|
Args:
|
93
93
|
im0 (np.ndarray): Input image for processing.
|
@@ -127,7 +127,7 @@ class Analytics(BaseSolution):
|
|
127
127
|
|
128
128
|
def update_graph(self, frame_number, count_dict=None, plot="line"):
|
129
129
|
"""
|
130
|
-
|
130
|
+
Update the graph with new data for single or multiple classes.
|
131
131
|
|
132
132
|
Args:
|
133
133
|
frame_number (int): The current frame number.
|
ultralytics/solutions/heatmap.py
CHANGED
@@ -21,8 +21,8 @@ class Heatmap(ObjectCounter):
|
|
21
21
|
annotator (SolutionAnnotator): Object for drawing annotations on the image.
|
22
22
|
|
23
23
|
Methods:
|
24
|
-
heatmap_effect:
|
25
|
-
process:
|
24
|
+
heatmap_effect: Calculate and update the heatmap effect for a given bounding box.
|
25
|
+
process: Generate and apply the heatmap effect to each frame.
|
26
26
|
|
27
27
|
Examples:
|
28
28
|
>>> from ultralytics.solutions import Heatmap
|
@@ -33,7 +33,7 @@ class Heatmap(ObjectCounter):
|
|
33
33
|
|
34
34
|
def __init__(self, **kwargs):
|
35
35
|
"""
|
36
|
-
|
36
|
+
Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
|
37
37
|
|
38
38
|
Args:
|
39
39
|
**kwargs (Any): Keyword arguments passed to the parent ObjectCounter class.
|
@@ -50,7 +50,7 @@ class Heatmap(ObjectCounter):
|
|
50
50
|
|
51
51
|
def heatmap_effect(self, box):
|
52
52
|
"""
|
53
|
-
Efficiently
|
53
|
+
Efficiently calculate heatmap area and effect location for applying colormap.
|
54
54
|
|
55
55
|
Args:
|
56
56
|
box (List[float]): Bounding box coordinates [x0, y0, x1, y1].
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
4
4
|
from ultralytics.utils.plotting import colors
|
@@ -13,9 +13,15 @@ class InstanceSegmentation(BaseSolution):
|
|
13
13
|
|
14
14
|
Attributes:
|
15
15
|
model (str): The segmentation model to use for inference.
|
16
|
+
line_width (int): Width of the bounding box and text lines.
|
17
|
+
names (Dict[int, str]): Dictionary mapping class indices to class names.
|
18
|
+
clss (List[int]): List of detected class indices.
|
19
|
+
track_ids (List[int]): List of track IDs for detected instances.
|
20
|
+
masks (List[numpy.ndarray]): List of segmentation masks for detected instances.
|
16
21
|
|
17
22
|
Methods:
|
18
|
-
process:
|
23
|
+
process: Process the input image to perform instance segmentation and annotate results.
|
24
|
+
extract_tracks: Extract tracks including bounding boxes, classes, and masks from model predictions.
|
19
25
|
|
20
26
|
Examples:
|
21
27
|
>>> segmenter = InstanceSegmentation()
|
@@ -26,7 +32,7 @@ class InstanceSegmentation(BaseSolution):
|
|
26
32
|
|
27
33
|
def __init__(self, **kwargs):
|
28
34
|
"""
|
29
|
-
|
35
|
+
Initialize the InstanceSegmentation class for detecting and annotating segmented instances.
|
30
36
|
|
31
37
|
Args:
|
32
38
|
**kwargs (Any): Keyword arguments passed to the BaseSolution parent class.
|
@@ -37,7 +43,7 @@ class InstanceSegmentation(BaseSolution):
|
|
37
43
|
|
38
44
|
def process(self, im0):
|
39
45
|
"""
|
40
|
-
|
46
|
+
Perform instance segmentation on the input image and annotate the results.
|
41
47
|
|
42
48
|
Args:
|
43
49
|
im0 (numpy.ndarray): The input image for segmentation.
|