dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from copy import copy
|
|
6
|
+
|
|
7
|
+
from ultralytics.models.yolo.detect import DetectionTrainer
|
|
8
|
+
from ultralytics.nn.tasks import RTDETRDetectionModel
|
|
9
|
+
from ultralytics.utils import RANK, colorstr
|
|
10
|
+
|
|
11
|
+
from .val import RTDETRDataset, RTDETRValidator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RTDETRTrainer(DetectionTrainer):
|
|
15
|
+
"""Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
|
16
|
+
|
|
17
|
+
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
|
|
18
|
+
RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
|
|
19
|
+
inference speed.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
loss_names (tuple): Names of the loss components used for training.
|
|
23
|
+
data (dict): Dataset configuration containing class count and other parameters.
|
|
24
|
+
args (dict): Training arguments and hyperparameters.
|
|
25
|
+
save_dir (Path): Directory to save training results.
|
|
26
|
+
test_loader (DataLoader): DataLoader for validation/testing data.
|
|
27
|
+
|
|
28
|
+
Methods:
|
|
29
|
+
get_model: Initialize and return an RT-DETR model for object detection tasks.
|
|
30
|
+
build_dataset: Build and return an RT-DETR dataset for training or validation.
|
|
31
|
+
get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
|
|
32
|
+
|
|
33
|
+
Examples:
|
|
34
|
+
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
|
|
35
|
+
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
|
36
|
+
>>> trainer = RTDETRTrainer(overrides=args)
|
|
37
|
+
>>> trainer.train()
|
|
38
|
+
|
|
39
|
+
Notes:
|
|
40
|
+
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
|
41
|
+
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
|
|
45
|
+
"""Initialize and return an RT-DETR model for object detection tasks.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
cfg (dict, optional): Model configuration.
|
|
49
|
+
weights (str, optional): Path to pre-trained model weights.
|
|
50
|
+
verbose (bool): Verbose logging if True.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
(RTDETRDetectionModel): Initialized model.
|
|
54
|
+
"""
|
|
55
|
+
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
|
56
|
+
if weights:
|
|
57
|
+
model.load(weights)
|
|
58
|
+
return model
|
|
59
|
+
|
|
60
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
|
|
61
|
+
"""Build and return an RT-DETR dataset for training or validation.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
img_path (str): Path to the folder containing images.
|
|
65
|
+
mode (str): Dataset mode, either 'train' or 'val'.
|
|
66
|
+
batch (int, optional): Batch size for rectangle training.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
(RTDETRDataset): Dataset object for the specific mode.
|
|
70
|
+
"""
|
|
71
|
+
return RTDETRDataset(
|
|
72
|
+
img_path=img_path,
|
|
73
|
+
imgsz=self.args.imgsz,
|
|
74
|
+
batch_size=batch,
|
|
75
|
+
augment=mode == "train",
|
|
76
|
+
hyp=self.args,
|
|
77
|
+
rect=False,
|
|
78
|
+
cache=self.args.cache or None,
|
|
79
|
+
single_cls=self.args.single_cls or False,
|
|
80
|
+
prefix=colorstr(f"{mode}: "),
|
|
81
|
+
classes=self.args.classes,
|
|
82
|
+
data=self.data,
|
|
83
|
+
fraction=self.args.fraction if mode == "train" else 1.0,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def get_validator(self):
|
|
87
|
+
"""Return a DetectionValidator suitable for RT-DETR model validation."""
|
|
88
|
+
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
|
89
|
+
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ultralytics.data import YOLODataset
|
|
11
|
+
from ultralytics.data.augment import Compose, Format, v8_transforms
|
|
12
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
|
13
|
+
from ultralytics.utils import colorstr, ops
|
|
14
|
+
|
|
15
|
+
__all__ = ("RTDETRValidator",) # tuple or list
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RTDETRDataset(YOLODataset):
|
|
19
|
+
"""Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
|
20
|
+
|
|
21
|
+
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
|
22
|
+
real-time detection and tracking tasks.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
augment (bool): Whether to apply data augmentation.
|
|
26
|
+
rect (bool): Whether to use rectangular training.
|
|
27
|
+
use_segments (bool): Whether to use segmentation masks.
|
|
28
|
+
use_keypoints (bool): Whether to use keypoint annotations.
|
|
29
|
+
imgsz (int): Target image size for training.
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
load_image: Load one image from dataset index.
|
|
33
|
+
build_transforms: Build transformation pipeline for the dataset.
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
Initialize an RT-DETR dataset
|
|
37
|
+
>>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
|
|
38
|
+
>>> image, hw0, hw = dataset.load_image(0)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, *args, data=None, **kwargs):
|
|
42
|
+
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
|
43
|
+
|
|
44
|
+
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
|
|
45
|
+
model, building upon the base YOLODataset functionality.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
*args (Any): Variable length argument list passed to the parent YOLODataset class.
|
|
49
|
+
data (dict | None): Dictionary containing dataset information. If None, default values will be used.
|
|
50
|
+
**kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(*args, data=data, **kwargs)
|
|
53
|
+
|
|
54
|
+
def load_image(self, i, rect_mode=False):
|
|
55
|
+
"""Load one image from dataset index 'i'.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
i (int): Index of the image to load.
|
|
59
|
+
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
im (np.ndarray): Loaded image as a NumPy array.
|
|
63
|
+
hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
|
|
64
|
+
hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
|
|
65
|
+
|
|
66
|
+
Examples:
|
|
67
|
+
Load an image from the dataset
|
|
68
|
+
>>> dataset = RTDETRDataset(img_path="path/to/images")
|
|
69
|
+
>>> image, hw0, hw = dataset.load_image(0)
|
|
70
|
+
"""
|
|
71
|
+
return super().load_image(i=i, rect_mode=rect_mode)
|
|
72
|
+
|
|
73
|
+
def build_transforms(self, hyp=None):
|
|
74
|
+
"""Build transformation pipeline for the dataset.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
hyp (dict, optional): Hyperparameters for transformations.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
(Compose): Composition of transformation functions.
|
|
81
|
+
"""
|
|
82
|
+
if self.augment:
|
|
83
|
+
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
|
84
|
+
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
|
85
|
+
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
|
|
86
|
+
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
|
87
|
+
else:
|
|
88
|
+
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
|
|
89
|
+
transforms = Compose([])
|
|
90
|
+
transforms.append(
|
|
91
|
+
Format(
|
|
92
|
+
bbox_format="xywh",
|
|
93
|
+
normalize=True,
|
|
94
|
+
return_mask=self.use_segments,
|
|
95
|
+
return_keypoint=self.use_keypoints,
|
|
96
|
+
batch_idx=True,
|
|
97
|
+
mask_ratio=hyp.mask_ratio,
|
|
98
|
+
mask_overlap=hyp.overlap_mask,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
return transforms
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class RTDETRValidator(DetectionValidator):
|
|
105
|
+
"""RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
|
106
|
+
the RT-DETR (Real-Time DETR) object detection model.
|
|
107
|
+
|
|
108
|
+
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
|
109
|
+
post-processing, and updates evaluation metrics accordingly.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
args (Namespace): Configuration arguments for validation.
|
|
113
|
+
data (dict): Dataset configuration dictionary.
|
|
114
|
+
|
|
115
|
+
Methods:
|
|
116
|
+
build_dataset: Build an RTDETR Dataset for validation.
|
|
117
|
+
postprocess: Apply Non-maximum suppression to prediction outputs.
|
|
118
|
+
|
|
119
|
+
Examples:
|
|
120
|
+
Initialize and run RT-DETR validation
|
|
121
|
+
>>> from ultralytics.models.rtdetr import RTDETRValidator
|
|
122
|
+
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
|
123
|
+
>>> validator = RTDETRValidator(args=args)
|
|
124
|
+
>>> validator()
|
|
125
|
+
|
|
126
|
+
Notes:
|
|
127
|
+
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def build_dataset(self, img_path, mode="val", batch=None):
|
|
131
|
+
"""Build an RTDETR Dataset.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
img_path (str): Path to the folder containing images.
|
|
135
|
+
mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
|
|
136
|
+
each mode.
|
|
137
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
(RTDETRDataset): Dataset configured for RT-DETR validation.
|
|
141
|
+
"""
|
|
142
|
+
return RTDETRDataset(
|
|
143
|
+
img_path=img_path,
|
|
144
|
+
imgsz=self.args.imgsz,
|
|
145
|
+
batch_size=batch,
|
|
146
|
+
augment=False, # no augmentation
|
|
147
|
+
hyp=self.args,
|
|
148
|
+
rect=False, # no rect
|
|
149
|
+
cache=self.args.cache or None,
|
|
150
|
+
prefix=colorstr(f"{mode}: "),
|
|
151
|
+
data=self.data,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
155
|
+
"""Scales predictions to the original image size."""
|
|
156
|
+
return predn
|
|
157
|
+
|
|
158
|
+
def postprocess(
|
|
159
|
+
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
|
160
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
161
|
+
"""Apply Non-maximum suppression to prediction outputs.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
|
|
165
|
+
(batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
|
|
166
|
+
class scores.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
(list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
|
|
170
|
+
- 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
|
|
171
|
+
- 'conf': Tensor of shape (N,) with confidence scores
|
|
172
|
+
- 'cls': Tensor of shape (N,) with class indices
|
|
173
|
+
"""
|
|
174
|
+
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
|
175
|
+
preds = [preds, None]
|
|
176
|
+
|
|
177
|
+
bs, _, nd = preds[0].shape
|
|
178
|
+
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
|
179
|
+
bboxes *= self.args.imgsz
|
|
180
|
+
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
|
181
|
+
for i, bbox in enumerate(bboxes): # (300, 4)
|
|
182
|
+
bbox = ops.xywh2xyxy(bbox)
|
|
183
|
+
score, cls = scores[i].max(-1) # (300, )
|
|
184
|
+
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
|
|
185
|
+
# Sort by confidence to correctly get internal metrics
|
|
186
|
+
pred = pred[score.argsort(descending=True)]
|
|
187
|
+
outputs[i] = pred[score > self.args.conf]
|
|
188
|
+
|
|
189
|
+
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
|
190
|
+
|
|
191
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
192
|
+
"""Serialize YOLO predictions to COCO json format.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
196
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
197
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
198
|
+
"""
|
|
199
|
+
path = Path(pbatch["im_file"])
|
|
200
|
+
stem = path.stem
|
|
201
|
+
image_id = int(stem) if stem.isnumeric() else stem
|
|
202
|
+
box = predn["bboxes"].clone()
|
|
203
|
+
box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
|
204
|
+
box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
|
205
|
+
box = ops.xyxy2xywh(box) # xywh
|
|
206
|
+
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
|
207
|
+
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
|
208
|
+
self.jdict.append(
|
|
209
|
+
{
|
|
210
|
+
"image_id": image_id,
|
|
211
|
+
"file_name": path.name,
|
|
212
|
+
"category_id": self.class_map[int(c)],
|
|
213
|
+
"bbox": [round(x, 3) for x in b],
|
|
214
|
+
"score": round(s, 5),
|
|
215
|
+
}
|
|
216
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from .model import SAM
|
|
4
|
+
from .predict import (
|
|
5
|
+
Predictor,
|
|
6
|
+
SAM2DynamicInteractivePredictor,
|
|
7
|
+
SAM2Predictor,
|
|
8
|
+
SAM2VideoPredictor,
|
|
9
|
+
SAM3Predictor,
|
|
10
|
+
SAM3SemanticPredictor,
|
|
11
|
+
SAM3VideoPredictor,
|
|
12
|
+
SAM3VideoSemanticPredictor,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = (
|
|
16
|
+
"SAM",
|
|
17
|
+
"Predictor",
|
|
18
|
+
"SAM2DynamicInteractivePredictor",
|
|
19
|
+
"SAM2Predictor",
|
|
20
|
+
"SAM2VideoPredictor",
|
|
21
|
+
"SAM3Predictor",
|
|
22
|
+
"SAM3SemanticPredictor",
|
|
23
|
+
"SAM3VideoPredictor",
|
|
24
|
+
"SAM3VideoSemanticPredictor",
|
|
25
|
+
) # tuple or list of exportable items
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from collections.abc import Generator
|
|
7
|
+
from itertools import product
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def is_box_near_crop_edge(
|
|
15
|
+
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
"""Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
boxes (torch.Tensor): Bounding boxes in XYXY format.
|
|
21
|
+
crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
|
|
22
|
+
orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
|
|
23
|
+
atol (float, optional): Absolute tolerance for edge proximity detection.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
(torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
|
|
27
|
+
|
|
28
|
+
Examples:
|
|
29
|
+
>>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
|
|
30
|
+
>>> crop_box = [0, 0, 200, 200]
|
|
31
|
+
>>> orig_box = [0, 0, 300, 300]
|
|
32
|
+
>>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
|
|
33
|
+
"""
|
|
34
|
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
|
35
|
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
|
36
|
+
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
|
37
|
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
|
38
|
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
|
39
|
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
|
40
|
+
return torch.any(near_crop_edge, dim=1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
|
44
|
+
"""Yield batches of data from input arguments with specified batch size for efficient processing.
|
|
45
|
+
|
|
46
|
+
This function takes a batch size and any number of iterables, then yields batches of elements from those
|
|
47
|
+
iterables. All input iterables must have the same length.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
batch_size (int): Size of each batch to yield.
|
|
51
|
+
*args (Any): Variable length input iterables to batch. All iterables must have the same length.
|
|
52
|
+
|
|
53
|
+
Yields:
|
|
54
|
+
(list[Any]): A list of batched elements from each input iterable.
|
|
55
|
+
|
|
56
|
+
Examples:
|
|
57
|
+
>>> data = [1, 2, 3, 4, 5]
|
|
58
|
+
>>> labels = ["a", "b", "c", "d", "e"]
|
|
59
|
+
>>> for batch in batch_iterator(2, data, labels):
|
|
60
|
+
... print(batch)
|
|
61
|
+
[[1, 2], ['a', 'b']]
|
|
62
|
+
[[3, 4], ['c', 'd']]
|
|
63
|
+
[[5], ['e']]
|
|
64
|
+
"""
|
|
65
|
+
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
|
66
|
+
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
|
67
|
+
for b in range(n_batches):
|
|
68
|
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
|
72
|
+
"""Compute the stability score for a batch of masks.
|
|
73
|
+
|
|
74
|
+
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
|
|
75
|
+
low values.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
masks (torch.Tensor): Batch of predicted mask logits.
|
|
79
|
+
mask_threshold (float): Threshold value for creating binary masks.
|
|
80
|
+
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
(torch.Tensor): Stability scores for each mask in the batch.
|
|
84
|
+
|
|
85
|
+
Examples:
|
|
86
|
+
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
|
87
|
+
>>> mask_threshold = 0.5
|
|
88
|
+
>>> threshold_offset = 0.1
|
|
89
|
+
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
|
90
|
+
|
|
91
|
+
Notes:
|
|
92
|
+
- One mask is always contained inside the other.
|
|
93
|
+
- Memory is saved by preventing unnecessary cast to torch.int64.
|
|
94
|
+
"""
|
|
95
|
+
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
96
|
+
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
97
|
+
return intersections / unions
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
101
|
+
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
|
|
102
|
+
offset = 1 / (2 * n_per_side)
|
|
103
|
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
|
104
|
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
|
105
|
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
|
106
|
+
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
|
|
110
|
+
"""Generate point grids for multiple crop layers with varying scales and densities."""
|
|
111
|
+
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def generate_crop_boxes(
|
|
115
|
+
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
|
|
116
|
+
) -> tuple[list[list[int]], list[int]]:
|
|
117
|
+
"""Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
im_size (tuple[int, ...]): Height and width of the input image.
|
|
121
|
+
n_layers (int): Number of layers to generate crop boxes for.
|
|
122
|
+
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
|
126
|
+
layer_idxs (list[int]): List of layer indices corresponding to each crop box.
|
|
127
|
+
|
|
128
|
+
Examples:
|
|
129
|
+
>>> im_size = (800, 1200) # Height, width
|
|
130
|
+
>>> n_layers = 3
|
|
131
|
+
>>> overlap_ratio = 0.25
|
|
132
|
+
>>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
|
|
133
|
+
"""
|
|
134
|
+
crop_boxes, layer_idxs = [], []
|
|
135
|
+
im_h, im_w = im_size
|
|
136
|
+
short_side = min(im_h, im_w)
|
|
137
|
+
|
|
138
|
+
# Original image
|
|
139
|
+
crop_boxes.append([0, 0, im_w, im_h])
|
|
140
|
+
layer_idxs.append(0)
|
|
141
|
+
|
|
142
|
+
def crop_len(orig_len, n_crops, overlap):
|
|
143
|
+
"""Calculate the length of each crop given the original length, number of crops, and overlap."""
|
|
144
|
+
return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
|
|
145
|
+
|
|
146
|
+
for i_layer in range(n_layers):
|
|
147
|
+
n_crops_per_side = 2 ** (i_layer + 1)
|
|
148
|
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
|
149
|
+
|
|
150
|
+
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
|
151
|
+
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
|
152
|
+
|
|
153
|
+
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
|
154
|
+
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
|
155
|
+
|
|
156
|
+
# Crops in XYWH format
|
|
157
|
+
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
|
158
|
+
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
|
159
|
+
crop_boxes.append(box)
|
|
160
|
+
layer_idxs.append(i_layer + 1)
|
|
161
|
+
|
|
162
|
+
return crop_boxes, layer_idxs
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
|
166
|
+
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
|
167
|
+
x0, y0, _, _ = crop_box
|
|
168
|
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
|
169
|
+
# Check if boxes has a channel dimension
|
|
170
|
+
if len(boxes.shape) == 3:
|
|
171
|
+
offset = offset.unsqueeze(1)
|
|
172
|
+
return boxes + offset
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
|
176
|
+
"""Uncrop points by adding the crop box offset to their coordinates."""
|
|
177
|
+
x0, y0, _, _ = crop_box
|
|
178
|
+
offset = torch.tensor([[x0, y0]], device=points.device)
|
|
179
|
+
# Check if points has a channel dimension
|
|
180
|
+
if len(points.shape) == 3:
|
|
181
|
+
offset = offset.unsqueeze(1)
|
|
182
|
+
return points + offset
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
|
186
|
+
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
|
187
|
+
x0, y0, x1, y1 = crop_box
|
|
188
|
+
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
|
189
|
+
return masks
|
|
190
|
+
# Coordinate transform masks
|
|
191
|
+
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
|
192
|
+
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
|
193
|
+
return torch.nn.functional.pad(masks, pad, value=0)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
|
|
197
|
+
"""Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
mask (np.ndarray): Binary mask to process.
|
|
201
|
+
area_thresh (float): Area threshold below which regions will be removed.
|
|
202
|
+
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
|
|
203
|
+
regions.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
processed_mask (np.ndarray): Processed binary mask with small regions removed.
|
|
207
|
+
modified (bool): Whether any regions were modified.
|
|
208
|
+
|
|
209
|
+
Examples:
|
|
210
|
+
>>> mask = np.zeros((100, 100), dtype=np.bool_)
|
|
211
|
+
>>> mask[40:60, 40:60] = True # Create a square
|
|
212
|
+
>>> mask[45:55, 45:55] = False # Create a hole
|
|
213
|
+
>>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
|
|
214
|
+
"""
|
|
215
|
+
import cv2 # type: ignore
|
|
216
|
+
|
|
217
|
+
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
|
|
218
|
+
correct_holes = mode == "holes"
|
|
219
|
+
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
220
|
+
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
221
|
+
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
222
|
+
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
223
|
+
if not small_regions:
|
|
224
|
+
return mask, False
|
|
225
|
+
fill_labels = [0, *small_regions]
|
|
226
|
+
if not correct_holes:
|
|
227
|
+
# If every region is below threshold, keep largest
|
|
228
|
+
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
|
229
|
+
mask = np.isin(regions, fill_labels)
|
|
230
|
+
return mask, True
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
|
234
|
+
"""Calculate bounding boxes in XYXY format around binary masks.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
(torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
|
|
241
|
+
|
|
242
|
+
Notes:
|
|
243
|
+
- Handles empty masks by returning zero boxes.
|
|
244
|
+
- Preserves input tensor dimensions in the output.
|
|
245
|
+
"""
|
|
246
|
+
# torch.max below raises an error on empty inputs, just skip in this case
|
|
247
|
+
if torch.numel(masks) == 0:
|
|
248
|
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
|
249
|
+
|
|
250
|
+
# Normalize shape to CxHxW
|
|
251
|
+
shape = masks.shape
|
|
252
|
+
h, w = shape[-2:]
|
|
253
|
+
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
|
|
254
|
+
# Get top and bottom edges
|
|
255
|
+
in_height, _ = torch.max(masks, dim=-1)
|
|
256
|
+
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
|
257
|
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
|
258
|
+
in_height_coords = in_height_coords + h * (~in_height)
|
|
259
|
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
|
260
|
+
|
|
261
|
+
# Get left and right edges
|
|
262
|
+
in_width, _ = torch.max(masks, dim=-2)
|
|
263
|
+
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
|
264
|
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
|
265
|
+
in_width_coords = in_width_coords + w * (~in_width)
|
|
266
|
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
|
267
|
+
|
|
268
|
+
# If the mask is empty the right edge will be to the left of the left edge.
|
|
269
|
+
# Replace these boxes with [0, 0, 0, 0]
|
|
270
|
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
|
271
|
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
|
272
|
+
out = out * (~empty_filter).unsqueeze(-1)
|
|
273
|
+
|
|
274
|
+
# Return to original shape
|
|
275
|
+
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]
|