dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,84 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.data.augment import LetterBox
|
6
|
+
from ultralytics.engine.predictor import BasePredictor
|
7
|
+
from ultralytics.engine.results import Results
|
8
|
+
from ultralytics.utils import ops
|
9
|
+
|
10
|
+
|
11
|
+
class RTDETRPredictor(BasePredictor):
|
12
|
+
"""
|
13
|
+
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
14
|
+
|
15
|
+
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
|
16
|
+
It supports key features like efficient hybrid encoding and IoU-aware query selection.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
imgsz (int): Image size for inference (must be square and scale-filled).
|
20
|
+
args (dict): Argument overrides for the predictor.
|
21
|
+
model (torch.nn.Module): The loaded RT-DETR model.
|
22
|
+
batch (list): Current batch of processed inputs.
|
23
|
+
|
24
|
+
Examples:
|
25
|
+
>>> from ultralytics.utils import ASSETS
|
26
|
+
>>> from ultralytics.models.rtdetr import RTDETRPredictor
|
27
|
+
>>> args = dict(model="rtdetr-l.pt", source=ASSETS)
|
28
|
+
>>> predictor = RTDETRPredictor(overrides=args)
|
29
|
+
>>> predictor.predict_cli()
|
30
|
+
"""
|
31
|
+
|
32
|
+
def postprocess(self, preds, img, orig_imgs):
|
33
|
+
"""
|
34
|
+
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
35
|
+
|
36
|
+
The method filters detections based on confidence and class if specified in `self.args`. It converts
|
37
|
+
model predictions to Results objects containing properly scaled bounding boxes.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
preds (List | Tuple): List of [predictions, extra] from the model, where predictions contain
|
41
|
+
bounding boxes and scores.
|
42
|
+
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
|
43
|
+
orig_imgs (List | torch.Tensor): Original, unprocessed images.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
(List[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
47
|
+
and class labels.
|
48
|
+
"""
|
49
|
+
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
50
|
+
preds = [preds, None]
|
51
|
+
|
52
|
+
nd = preds[0].shape[-1]
|
53
|
+
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
54
|
+
|
55
|
+
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
56
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
57
|
+
|
58
|
+
results = []
|
59
|
+
for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
|
60
|
+
bbox = ops.xywh2xyxy(bbox)
|
61
|
+
max_score, cls = score.max(-1, keepdim=True) # (300, 1)
|
62
|
+
idx = max_score.squeeze(-1) > self.args.conf # (300, )
|
63
|
+
if self.args.classes is not None:
|
64
|
+
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
65
|
+
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
|
66
|
+
oh, ow = orig_img.shape[:2]
|
67
|
+
pred[..., [0, 2]] *= ow # scale x coordinates to original width
|
68
|
+
pred[..., [1, 3]] *= oh # scale y coordinates to original height
|
69
|
+
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
70
|
+
return results
|
71
|
+
|
72
|
+
def pre_transform(self, im):
|
73
|
+
"""
|
74
|
+
Pre-transforms the input images before feeding them into the model for inference. The input images are
|
75
|
+
letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scale_filled.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
(list): List of pre-transformed images ready for model inference.
|
82
|
+
"""
|
83
|
+
letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)
|
84
|
+
return [letterbox(image=x) for x in im]
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from copy import copy
|
4
|
+
|
5
|
+
from ultralytics.models.yolo.detect import DetectionTrainer
|
6
|
+
from ultralytics.nn.tasks import RTDETRDetectionModel
|
7
|
+
from ultralytics.utils import RANK, colorstr
|
8
|
+
|
9
|
+
from .val import RTDETRDataset, RTDETRValidator
|
10
|
+
|
11
|
+
|
12
|
+
class RTDETRTrainer(DetectionTrainer):
|
13
|
+
"""
|
14
|
+
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
15
|
+
|
16
|
+
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
|
17
|
+
The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
|
18
|
+
speed.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
loss_names (Tuple[str]): Names of the loss components used for training.
|
22
|
+
data (dict): Dataset configuration containing class count and other parameters.
|
23
|
+
args (dict): Training arguments and hyperparameters.
|
24
|
+
save_dir (Path): Directory to save training results.
|
25
|
+
test_loader (DataLoader): DataLoader for validation/testing data.
|
26
|
+
|
27
|
+
Notes:
|
28
|
+
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
29
|
+
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
30
|
+
|
31
|
+
Examples:
|
32
|
+
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
|
33
|
+
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
34
|
+
>>> trainer = RTDETRTrainer(overrides=args)
|
35
|
+
>>> trainer.train()
|
36
|
+
"""
|
37
|
+
|
38
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
39
|
+
"""
|
40
|
+
Initialize and return an RT-DETR model for object detection tasks.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
cfg (dict, optional): Model configuration.
|
44
|
+
weights (str, optional): Path to pre-trained model weights.
|
45
|
+
verbose (bool): Verbose logging if True.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
(RTDETRDetectionModel): Initialized model.
|
49
|
+
"""
|
50
|
+
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
51
|
+
if weights:
|
52
|
+
model.load(weights)
|
53
|
+
return model
|
54
|
+
|
55
|
+
def build_dataset(self, img_path, mode="val", batch=None):
|
56
|
+
"""
|
57
|
+
Build and return an RT-DETR dataset for training or validation.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
img_path (str): Path to the folder containing images.
|
61
|
+
mode (str): Dataset mode, either 'train' or 'val'.
|
62
|
+
batch (int, optional): Batch size for rectangle training.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
(RTDETRDataset): Dataset object for the specific mode.
|
66
|
+
"""
|
67
|
+
return RTDETRDataset(
|
68
|
+
img_path=img_path,
|
69
|
+
imgsz=self.args.imgsz,
|
70
|
+
batch_size=batch,
|
71
|
+
augment=mode == "train",
|
72
|
+
hyp=self.args,
|
73
|
+
rect=False,
|
74
|
+
cache=self.args.cache or None,
|
75
|
+
single_cls=self.args.single_cls or False,
|
76
|
+
prefix=colorstr(f"{mode}: "),
|
77
|
+
classes=self.args.classes,
|
78
|
+
data=self.data,
|
79
|
+
fraction=self.args.fraction if mode == "train" else 1.0,
|
80
|
+
)
|
81
|
+
|
82
|
+
def get_validator(self):
|
83
|
+
"""Returns a DetectionValidator suitable for RT-DETR model validation."""
|
84
|
+
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
85
|
+
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.data import YOLODataset
|
6
|
+
from ultralytics.data.augment import Compose, Format, v8_transforms
|
7
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
8
|
+
from ultralytics.utils import colorstr, ops
|
9
|
+
|
10
|
+
__all__ = ("RTDETRValidator",) # tuple or list
|
11
|
+
|
12
|
+
|
13
|
+
class RTDETRDataset(YOLODataset):
|
14
|
+
"""
|
15
|
+
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
16
|
+
|
17
|
+
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
18
|
+
real-time detection and tracking tasks.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, *args, data=None, **kwargs):
|
22
|
+
"""
|
23
|
+
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
24
|
+
|
25
|
+
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
|
26
|
+
model, building upon the base YOLODataset functionality.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
*args (Any): Variable length argument list passed to the parent YOLODataset class.
|
30
|
+
data (Dict | None): Dictionary containing dataset information. If None, default values will be used.
|
31
|
+
**kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
|
32
|
+
"""
|
33
|
+
super().__init__(*args, data=data, **kwargs)
|
34
|
+
|
35
|
+
def load_image(self, i, rect_mode=False):
|
36
|
+
"""
|
37
|
+
Load one image from dataset index 'i'.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
i (int): Index of the image to load.
|
41
|
+
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
im (numpy.ndarray): The loaded image.
|
45
|
+
resized_hw (tuple): Height and width of the resized image with shape (2,).
|
46
|
+
|
47
|
+
Examples:
|
48
|
+
>>> dataset = RTDETRDataset(...)
|
49
|
+
>>> image, hw = dataset.load_image(0)
|
50
|
+
"""
|
51
|
+
return super().load_image(i=i, rect_mode=rect_mode)
|
52
|
+
|
53
|
+
def build_transforms(self, hyp=None):
|
54
|
+
"""
|
55
|
+
Build transformation pipeline for the dataset.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
hyp (dict, optional): Hyperparameters for transformations.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
(Compose): Composition of transformation functions.
|
62
|
+
"""
|
63
|
+
if self.augment:
|
64
|
+
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
65
|
+
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
66
|
+
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
|
67
|
+
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
68
|
+
else:
|
69
|
+
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
|
70
|
+
transforms = Compose([])
|
71
|
+
transforms.append(
|
72
|
+
Format(
|
73
|
+
bbox_format="xywh",
|
74
|
+
normalize=True,
|
75
|
+
return_mask=self.use_segments,
|
76
|
+
return_keypoint=self.use_keypoints,
|
77
|
+
batch_idx=True,
|
78
|
+
mask_ratio=hyp.mask_ratio,
|
79
|
+
mask_overlap=hyp.overlap_mask,
|
80
|
+
)
|
81
|
+
)
|
82
|
+
return transforms
|
83
|
+
|
84
|
+
|
85
|
+
class RTDETRValidator(DetectionValidator):
|
86
|
+
"""
|
87
|
+
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
88
|
+
the RT-DETR (Real-Time DETR) object detection model.
|
89
|
+
|
90
|
+
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
91
|
+
post-processing, and updates evaluation metrics accordingly.
|
92
|
+
|
93
|
+
Examples:
|
94
|
+
>>> from ultralytics.models.rtdetr import RTDETRValidator
|
95
|
+
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
96
|
+
>>> validator = RTDETRValidator(args=args)
|
97
|
+
>>> validator()
|
98
|
+
|
99
|
+
Note:
|
100
|
+
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
101
|
+
"""
|
102
|
+
|
103
|
+
def build_dataset(self, img_path, mode="val", batch=None):
|
104
|
+
"""
|
105
|
+
Build an RTDETR Dataset.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
img_path (str): Path to the folder containing images.
|
109
|
+
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
110
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
(RTDETRDataset): Dataset configured for RT-DETR validation.
|
114
|
+
"""
|
115
|
+
return RTDETRDataset(
|
116
|
+
img_path=img_path,
|
117
|
+
imgsz=self.args.imgsz,
|
118
|
+
batch_size=batch,
|
119
|
+
augment=False, # no augmentation
|
120
|
+
hyp=self.args,
|
121
|
+
rect=False, # no rect
|
122
|
+
cache=self.args.cache or None,
|
123
|
+
prefix=colorstr(f"{mode}: "),
|
124
|
+
data=self.data,
|
125
|
+
)
|
126
|
+
|
127
|
+
def postprocess(self, preds):
|
128
|
+
"""
|
129
|
+
Apply Non-maximum suppression to prediction outputs.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
preds (List | Tuple | torch.Tensor): Raw predictions from the model.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
(List[torch.Tensor]): List of processed predictions for each image in batch.
|
136
|
+
"""
|
137
|
+
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
138
|
+
preds = [preds, None]
|
139
|
+
|
140
|
+
bs, _, nd = preds[0].shape
|
141
|
+
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
142
|
+
bboxes *= self.args.imgsz
|
143
|
+
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
144
|
+
for i, bbox in enumerate(bboxes): # (300, 4)
|
145
|
+
bbox = ops.xywh2xyxy(bbox)
|
146
|
+
score, cls = scores[i].max(-1) # (300, )
|
147
|
+
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
|
148
|
+
# Sort by confidence to correctly get internal metrics
|
149
|
+
pred = pred[score.argsort(descending=True)]
|
150
|
+
outputs[i] = pred[score > self.args.conf]
|
151
|
+
|
152
|
+
return outputs
|
153
|
+
|
154
|
+
def _prepare_batch(self, si, batch):
|
155
|
+
"""
|
156
|
+
Prepares a batch for validation by applying necessary transformations.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
si (int): Batch index.
|
160
|
+
batch (dict): Batch data containing images and annotations.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
(dict): Prepared batch with transformed annotations.
|
164
|
+
"""
|
165
|
+
idx = batch["batch_idx"] == si
|
166
|
+
cls = batch["cls"][idx].squeeze(-1)
|
167
|
+
bbox = batch["bboxes"][idx]
|
168
|
+
ori_shape = batch["ori_shape"][si]
|
169
|
+
imgsz = batch["img"].shape[2:]
|
170
|
+
ratio_pad = batch["ratio_pad"][si]
|
171
|
+
if len(cls):
|
172
|
+
bbox = ops.xywh2xyxy(bbox) # target boxes
|
173
|
+
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
174
|
+
bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
|
175
|
+
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
176
|
+
|
177
|
+
def _prepare_pred(self, pred, pbatch):
|
178
|
+
"""
|
179
|
+
Prepares predictions by scaling bounding boxes to original image dimensions.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
pred (torch.Tensor): Raw predictions.
|
183
|
+
pbatch (dict): Prepared batch information.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
(torch.Tensor): Predictions scaled to original image dimensions.
|
187
|
+
"""
|
188
|
+
predn = pred.clone()
|
189
|
+
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
190
|
+
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
191
|
+
return predn.float()
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import math
|
4
|
+
from itertools import product
|
5
|
+
from typing import Any, Generator, List, Tuple
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
|
10
|
+
|
11
|
+
def is_box_near_crop_edge(
|
12
|
+
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
13
|
+
) -> torch.Tensor:
|
14
|
+
"""Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
|
15
|
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
16
|
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
17
|
+
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
18
|
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
19
|
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
20
|
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
21
|
+
return torch.any(near_crop_edge, dim=1)
|
22
|
+
|
23
|
+
|
24
|
+
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
25
|
+
"""
|
26
|
+
Yield batches of data from input arguments with specified batch size for efficient processing.
|
27
|
+
|
28
|
+
This function takes a batch size and any number of iterables, then yields batches of elements from those
|
29
|
+
iterables. All input iterables must have the same length.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
batch_size (int): Size of each batch to yield.
|
33
|
+
*args (Any): Variable length input iterables to batch. All iterables must have the same length.
|
34
|
+
|
35
|
+
Yields:
|
36
|
+
(List[Any]): A list of batched elements from each input iterable.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
>>> data = [1, 2, 3, 4, 5]
|
40
|
+
>>> labels = ["a", "b", "c", "d", "e"]
|
41
|
+
>>> for batch in batch_iterator(2, data, labels):
|
42
|
+
... print(batch)
|
43
|
+
[[1, 2], ['a', 'b']]
|
44
|
+
[[3, 4], ['c', 'd']]
|
45
|
+
[[5], ['e']]
|
46
|
+
"""
|
47
|
+
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
48
|
+
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
49
|
+
for b in range(n_batches):
|
50
|
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
51
|
+
|
52
|
+
|
53
|
+
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
54
|
+
"""
|
55
|
+
Computes the stability score for a batch of masks.
|
56
|
+
|
57
|
+
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
58
|
+
high and low values.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
masks (torch.Tensor): Batch of predicted mask logits.
|
62
|
+
mask_threshold (float): Threshold value for creating binary masks.
|
63
|
+
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
(torch.Tensor): Stability scores for each mask in the batch.
|
67
|
+
|
68
|
+
Notes:
|
69
|
+
- One mask is always contained inside the other.
|
70
|
+
- Memory is saved by preventing unnecessary cast to torch.int64.
|
71
|
+
|
72
|
+
Examples:
|
73
|
+
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
74
|
+
>>> mask_threshold = 0.5
|
75
|
+
>>> threshold_offset = 0.1
|
76
|
+
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
77
|
+
"""
|
78
|
+
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
79
|
+
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
80
|
+
return intersections / unions
|
81
|
+
|
82
|
+
|
83
|
+
def build_point_grid(n_per_side: int) -> np.ndarray:
|
84
|
+
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
|
85
|
+
offset = 1 / (2 * n_per_side)
|
86
|
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
87
|
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
88
|
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
89
|
+
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
90
|
+
|
91
|
+
|
92
|
+
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
93
|
+
"""Generates point grids for multiple crop layers with varying scales and densities."""
|
94
|
+
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
95
|
+
|
96
|
+
|
97
|
+
def generate_crop_boxes(
|
98
|
+
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
99
|
+
) -> Tuple[List[List[int]], List[int]]:
|
100
|
+
"""
|
101
|
+
Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
im_size (Tuple[int, ...]): Height and width of the input image.
|
105
|
+
n_layers (int): Number of layers to generate crop boxes for.
|
106
|
+
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
(List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
110
|
+
(List[int]): List of layer indices corresponding to each crop box.
|
111
|
+
|
112
|
+
Examples:
|
113
|
+
>>> im_size = (800, 1200) # Height, width
|
114
|
+
>>> n_layers = 3
|
115
|
+
>>> overlap_ratio = 0.25
|
116
|
+
>>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
|
117
|
+
"""
|
118
|
+
crop_boxes, layer_idxs = [], []
|
119
|
+
im_h, im_w = im_size
|
120
|
+
short_side = min(im_h, im_w)
|
121
|
+
|
122
|
+
# Original image
|
123
|
+
crop_boxes.append([0, 0, im_w, im_h])
|
124
|
+
layer_idxs.append(0)
|
125
|
+
|
126
|
+
def crop_len(orig_len, n_crops, overlap):
|
127
|
+
"""Calculates the length of each crop given the original length, number of crops, and overlap."""
|
128
|
+
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
129
|
+
|
130
|
+
for i_layer in range(n_layers):
|
131
|
+
n_crops_per_side = 2 ** (i_layer + 1)
|
132
|
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
133
|
+
|
134
|
+
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
135
|
+
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
136
|
+
|
137
|
+
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
138
|
+
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
139
|
+
|
140
|
+
# Crops in XYWH format
|
141
|
+
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
142
|
+
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
143
|
+
crop_boxes.append(box)
|
144
|
+
layer_idxs.append(i_layer + 1)
|
145
|
+
|
146
|
+
return crop_boxes, layer_idxs
|
147
|
+
|
148
|
+
|
149
|
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
150
|
+
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
151
|
+
x0, y0, _, _ = crop_box
|
152
|
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
153
|
+
# Check if boxes has a channel dimension
|
154
|
+
if len(boxes.shape) == 3:
|
155
|
+
offset = offset.unsqueeze(1)
|
156
|
+
return boxes + offset
|
157
|
+
|
158
|
+
|
159
|
+
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
160
|
+
"""Uncrop points by adding the crop box offset to their coordinates."""
|
161
|
+
x0, y0, _, _ = crop_box
|
162
|
+
offset = torch.tensor([[x0, y0]], device=points.device)
|
163
|
+
# Check if points has a channel dimension
|
164
|
+
if len(points.shape) == 3:
|
165
|
+
offset = offset.unsqueeze(1)
|
166
|
+
return points + offset
|
167
|
+
|
168
|
+
|
169
|
+
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
170
|
+
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
171
|
+
x0, y0, x1, y1 = crop_box
|
172
|
+
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
173
|
+
return masks
|
174
|
+
# Coordinate transform masks
|
175
|
+
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
176
|
+
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
177
|
+
return torch.nn.functional.pad(masks, pad, value=0)
|
178
|
+
|
179
|
+
|
180
|
+
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
181
|
+
"""
|
182
|
+
Removes small disconnected regions or holes in a mask based on area threshold and mode.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
mask (np.ndarray): Binary mask to process.
|
186
|
+
area_thresh (float): Area threshold below which regions will be removed.
|
187
|
+
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
(np.ndarray): Processed binary mask with small regions removed.
|
191
|
+
(bool): Whether any regions were modified.
|
192
|
+
|
193
|
+
Examples:
|
194
|
+
>>> mask = np.zeros((100, 100), dtype=np.bool_)
|
195
|
+
>>> mask[40:60, 40:60] = True # Create a square
|
196
|
+
>>> mask[45:55, 45:55] = False # Create a hole
|
197
|
+
>>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
|
198
|
+
"""
|
199
|
+
import cv2 # type: ignore
|
200
|
+
|
201
|
+
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
|
202
|
+
correct_holes = mode == "holes"
|
203
|
+
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
204
|
+
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
205
|
+
sizes = stats[:, -1][1:] # Row 0 is background label
|
206
|
+
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
207
|
+
if not small_regions:
|
208
|
+
return mask, False
|
209
|
+
fill_labels = [0] + small_regions
|
210
|
+
if not correct_holes:
|
211
|
+
# If every region is below threshold, keep largest
|
212
|
+
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
213
|
+
mask = np.isin(regions, fill_labels)
|
214
|
+
return mask, True
|
215
|
+
|
216
|
+
|
217
|
+
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
218
|
+
"""
|
219
|
+
Calculates bounding boxes in XYXY format around binary masks.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
(torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
|
226
|
+
|
227
|
+
Notes:
|
228
|
+
- Handles empty masks by returning zero boxes.
|
229
|
+
- Preserves input tensor dimensions in the output.
|
230
|
+
"""
|
231
|
+
# torch.max below raises an error on empty inputs, just skip in this case
|
232
|
+
if torch.numel(masks) == 0:
|
233
|
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
234
|
+
|
235
|
+
# Normalize shape to CxHxW
|
236
|
+
shape = masks.shape
|
237
|
+
h, w = shape[-2:]
|
238
|
+
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
|
239
|
+
# Get top and bottom edges
|
240
|
+
in_height, _ = torch.max(masks, dim=-1)
|
241
|
+
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
242
|
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
243
|
+
in_height_coords = in_height_coords + h * (~in_height)
|
244
|
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
245
|
+
|
246
|
+
# Get left and right edges
|
247
|
+
in_width, _ = torch.max(masks, dim=-2)
|
248
|
+
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
249
|
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
250
|
+
in_width_coords = in_width_coords + w * (~in_width)
|
251
|
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
252
|
+
|
253
|
+
# If the mask is empty the right edge will be to the left of the left edge.
|
254
|
+
# Replace these boxes with [0, 0, 0, 0]
|
255
|
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
256
|
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
257
|
+
out = out * (~empty_filter).unsqueeze(-1)
|
258
|
+
|
259
|
+
# Return to original shape
|
260
|
+
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]
|