ultralytics 8.3.98__py3-none-any.whl → 8.3.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +56 -0
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/data/augment.py +101 -5
- ultralytics/data/dataset.py +165 -12
- ultralytics/engine/exporter.py +4 -3
- ultralytics/engine/trainer.py +16 -7
- ultralytics/models/__init__.py +2 -2
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/detect/val.py +6 -1
- ultralytics/models/yolo/model.py +182 -3
- ultralytics/models/yolo/segment/val.py +43 -16
- ultralytics/models/yolo/yoloe/__init__.py +21 -0
- ultralytics/models/yolo/yoloe/predict.py +170 -0
- ultralytics/models/yolo/yoloe/train.py +355 -0
- ultralytics/models/yolo/yoloe/train_seg.py +141 -0
- ultralytics/models/yolo/yoloe/val.py +187 -0
- ultralytics/nn/autobackend.py +3 -2
- ultralytics/nn/modules/__init__.py +18 -1
- ultralytics/nn/modules/block.py +17 -1
- ultralytics/nn/modules/head.py +359 -22
- ultralytics/nn/tasks.py +276 -10
- ultralytics/nn/text_model.py +193 -0
- ultralytics/utils/callbacks/comet.py +3 -6
- ultralytics/utils/downloads.py +6 -2
- ultralytics/utils/loss.py +67 -6
- ultralytics/utils/plotting.py +1 -1
- ultralytics/utils/tal.py +1 -1
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/METADATA +10 -10
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/RECORD +37 -27
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
tests/test_python.py
CHANGED
@@ -608,6 +608,62 @@ def test_yolo_world():
|
|
608
608
|
)
|
609
609
|
|
610
610
|
|
611
|
+
@pytest.mark.skipif(checks.IS_PYTHON_3_12 or not TORCH_1_9, reason="YOLOE with CLIP is not supported in Python 3.12")
|
612
|
+
def test_yoloe():
|
613
|
+
"""Test YOLOE models with MobileClip support."""
|
614
|
+
# Predict
|
615
|
+
# text-prompts
|
616
|
+
model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt")
|
617
|
+
names = ["person", "bus"]
|
618
|
+
model.set_classes(names, model.get_text_pe(names))
|
619
|
+
model(SOURCE, conf=0.01)
|
620
|
+
|
621
|
+
import numpy as np
|
622
|
+
|
623
|
+
from ultralytics import YOLOE
|
624
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
625
|
+
|
626
|
+
# visual-prompts
|
627
|
+
visuals = dict(
|
628
|
+
bboxes=np.array(
|
629
|
+
[[221.52, 405.8, 344.98, 857.54], [120, 425, 160, 445]],
|
630
|
+
),
|
631
|
+
cls=np.array([0, 1]),
|
632
|
+
)
|
633
|
+
model.predict(
|
634
|
+
SOURCE,
|
635
|
+
visual_prompts=visuals,
|
636
|
+
predictor=YOLOEVPSegPredictor,
|
637
|
+
)
|
638
|
+
|
639
|
+
# Val
|
640
|
+
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt")
|
641
|
+
# text prompts
|
642
|
+
model.val(data="coco128-seg.yaml", imgsz=32)
|
643
|
+
# visual prompts
|
644
|
+
model.val(data="coco128-seg.yaml", load_vp=True, imgsz=32)
|
645
|
+
|
646
|
+
# Train, fine-tune
|
647
|
+
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer
|
648
|
+
|
649
|
+
model = YOLOE("yoloe-11s-seg.pt")
|
650
|
+
model.train(
|
651
|
+
data="coco128-seg.yaml",
|
652
|
+
epochs=1,
|
653
|
+
close_mosaic=1,
|
654
|
+
trainer=YOLOEPESegTrainer,
|
655
|
+
imgsz=32,
|
656
|
+
)
|
657
|
+
|
658
|
+
# prompt-free
|
659
|
+
# predict
|
660
|
+
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt")
|
661
|
+
model.predict(SOURCE)
|
662
|
+
# val
|
663
|
+
model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
|
664
|
+
model.val(data="coco128-seg.yaml", imgsz=32)
|
665
|
+
|
666
|
+
|
611
667
|
def test_yolov10():
|
612
668
|
"""Test YOLOv10 model training, validation, and prediction functionality."""
|
613
669
|
model = YOLO("yolov10n.yaml")
|
ultralytics/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
__version__ = "8.3.
|
3
|
+
__version__ = "8.3.99"
|
4
4
|
|
5
5
|
import os
|
6
6
|
|
@@ -8,7 +8,7 @@ import os
|
|
8
8
|
if not os.environ.get("OMP_NUM_THREADS"):
|
9
9
|
os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training
|
10
10
|
|
11
|
-
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
|
11
|
+
from ultralytics.models import NAS, RTDETR, SAM, YOLO, YOLOE, FastSAM, YOLOWorld
|
12
12
|
from ultralytics.utils import ASSETS, SETTINGS
|
13
13
|
from ultralytics.utils.checks import check_yolo as checks
|
14
14
|
from ultralytics.utils.downloads import download
|
@@ -19,6 +19,7 @@ __all__ = (
|
|
19
19
|
"ASSETS",
|
20
20
|
"YOLO",
|
21
21
|
"YOLOWorld",
|
22
|
+
"YOLOE",
|
22
23
|
"NAS",
|
23
24
|
"SAM",
|
24
25
|
"FastSAM",
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# YOLO11-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
|
4
|
+
|
5
|
+
# Parameters
|
6
|
+
nc: 80 # number of classes
|
7
|
+
scales: # model compound scaling constants, i.e. 'model=yolo11n-seg.yaml' will call yolo11-seg.yaml with scale 'n'
|
8
|
+
# [depth, width, max_channels]
|
9
|
+
n: [0.50, 0.25, 1024] # summary: 355 layers, 2876848 parameters, 2876832 gradients, 10.5 GFLOPs
|
10
|
+
s: [0.50, 0.50, 1024] # summary: 355 layers, 10113248 parameters, 10113232 gradients, 35.8 GFLOPs
|
11
|
+
m: [0.50, 1.00, 512] # summary: 445 layers, 22420896 parameters, 22420880 gradients, 123.9 GFLOPs
|
12
|
+
l: [1.00, 1.00, 512] # summary: 667 layers, 27678368 parameters, 27678352 gradients, 143.0 GFLOPs
|
13
|
+
x: [1.00, 1.50, 512] # summary: 667 layers, 62142656 parameters, 62142640 gradients, 320.2 GFLOPs
|
14
|
+
|
15
|
+
# YOLO11n backbone
|
16
|
+
backbone:
|
17
|
+
# [from, repeats, module, args]
|
18
|
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19
|
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20
|
+
- [-1, 2, C3k2, [256, False, 0.25]]
|
21
|
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22
|
+
- [-1, 2, C3k2, [512, False, 0.25]]
|
23
|
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24
|
+
- [-1, 2, C3k2, [512, True]]
|
25
|
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
26
|
+
- [-1, 2, C3k2, [1024, True]]
|
27
|
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
28
|
+
- [-1, 2, C2PSA, [1024]] # 10
|
29
|
+
|
30
|
+
# YOLO11n head
|
31
|
+
head:
|
32
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
33
|
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
34
|
+
- [-1, 2, C3k2, [512, False]] # 13
|
35
|
+
|
36
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
37
|
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
38
|
+
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
|
39
|
+
|
40
|
+
- [-1, 1, Conv, [256, 3, 2]]
|
41
|
+
- [[-1, 13], 1, Concat, [1]] # cat head P4
|
42
|
+
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
|
43
|
+
|
44
|
+
- [-1, 1, Conv, [512, 3, 2]]
|
45
|
+
- [[-1, 10], 1, Concat, [1]] # cat head P5
|
46
|
+
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
|
47
|
+
|
48
|
+
- [[16, 19, 22], 1, YOLOESegment, [nc, 32, 256, 512, True]] # Detect(P3, P4, P5)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
4
|
+
|
5
|
+
# Parameters
|
6
|
+
nc: 80 # number of classes
|
7
|
+
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
|
8
|
+
# [depth, width, max_channels]
|
9
|
+
n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
|
10
|
+
s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
|
11
|
+
m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
|
12
|
+
l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
|
13
|
+
x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
|
14
|
+
|
15
|
+
# YOLO11n backbone
|
16
|
+
backbone:
|
17
|
+
# [from, repeats, module, args]
|
18
|
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19
|
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20
|
+
- [-1, 2, C3k2, [256, False, 0.25]]
|
21
|
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22
|
+
- [-1, 2, C3k2, [512, False, 0.25]]
|
23
|
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24
|
+
- [-1, 2, C3k2, [512, True]]
|
25
|
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
26
|
+
- [-1, 2, C3k2, [1024, True]]
|
27
|
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
28
|
+
- [-1, 2, C2PSA, [1024]] # 10
|
29
|
+
|
30
|
+
# YOLO11n head
|
31
|
+
head:
|
32
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
33
|
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
34
|
+
- [-1, 2, C3k2, [512, False]] # 13
|
35
|
+
|
36
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
37
|
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
38
|
+
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
|
39
|
+
|
40
|
+
- [-1, 1, Conv, [256, 3, 2]]
|
41
|
+
- [[-1, 13], 1, Concat, [1]] # cat head P4
|
42
|
+
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
|
43
|
+
|
44
|
+
- [-1, 1, Conv, [512, 3, 2]]
|
45
|
+
- [[-1, 10], 1, Concat, [1]] # cat head P5
|
46
|
+
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
|
47
|
+
|
48
|
+
- [[16, 19, 22], 1, YOLOEDetect, [nc, 512, True]] # Detect(P3, P4, P5)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# Parameters
|
4
|
+
nc: 80 # number of classes
|
5
|
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
6
|
+
# [depth, width, max_channels]
|
7
|
+
n: [0.33, 0.25, 1024] # YOLOv8n-world summary: 161 layers, 4204111 parameters, 4204095 gradients, 39.6 GFLOPs
|
8
|
+
s: [0.33, 0.50, 1024] # YOLOv8s-world summary: 161 layers, 13383496 parameters, 13383480 gradients, 71.5 GFLOPs
|
9
|
+
m: [0.67, 0.75, 768] # YOLOv8m-world summary: 201 layers, 29065310 parameters, 29065294 gradients, 131.4 GFLOPs
|
10
|
+
l: [1.00, 1.00, 512] # YOLOv8l-world summary: 241 layers, 47553970 parameters, 47553954 gradients, 225.6 GFLOPs
|
11
|
+
x: [1.00, 1.25, 512] # YOLOv8x-world summary: 241 layers, 73690217 parameters, 73690201 gradients, 330.8 GFLOPs
|
12
|
+
|
13
|
+
# YOLOv8.0n backbone
|
14
|
+
backbone:
|
15
|
+
# [from, repeats, module, args]
|
16
|
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
17
|
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
18
|
+
- [-1, 3, C2f, [128, True]]
|
19
|
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
20
|
+
- [-1, 6, C2f, [256, True]]
|
21
|
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
22
|
+
- [-1, 6, C2f, [512, True]]
|
23
|
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
24
|
+
- [-1, 3, C2f, [1024, True]]
|
25
|
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
26
|
+
|
27
|
+
# YOLOv8.0n head
|
28
|
+
head:
|
29
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
30
|
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
31
|
+
- [-1, 3, C2f, [512]] # 12
|
32
|
+
|
33
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
34
|
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
35
|
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
36
|
+
|
37
|
+
- [15, 1, Conv, [256, 3, 2]]
|
38
|
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
39
|
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
40
|
+
|
41
|
+
- [-1, 1, Conv, [512, 3, 2]]
|
42
|
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
43
|
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
44
|
+
|
45
|
+
- [[15, 18, 21], 1, YOLOESegment, [nc, 32, 256, 512, True]] # Segment(P3, P4, P5)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# Parameters
|
4
|
+
nc: 80 # number of classes
|
5
|
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
6
|
+
# [depth, width, max_channels]
|
7
|
+
n: [0.33, 0.25, 1024] # YOLOv8n-worldv2 summary: 148 layers, 3695183 parameters, 3695167 gradients, 19.5 GFLOPS
|
8
|
+
s: [0.33, 0.50, 1024] # YOLOv8s-worldv2 summary: 148 layers, 12759880 parameters, 12759864 gradients, 51.0 GFLOPS
|
9
|
+
m: [0.67, 0.75, 768] # YOLOv8m-worldv2 summary: 188 layers, 28376158 parameters, 28376142 gradients, 110.5 GFLOPS
|
10
|
+
l: [1.00, 1.00, 512] # YOLOv8l-worldv2 summary: 228 layers, 46832050 parameters, 46832034 gradients, 204.5 GFLOPS
|
11
|
+
x: [1.00, 1.25, 512] # YOLOv8x-worldv2 summary: 228 layers, 72886377 parameters, 72886361 gradients, 309.3 GFLOPS
|
12
|
+
|
13
|
+
# YOLOv8.0n backbone
|
14
|
+
backbone:
|
15
|
+
# [from, repeats, module, args]
|
16
|
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
17
|
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
18
|
+
- [-1, 3, C2f, [128, True]]
|
19
|
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
20
|
+
- [-1, 6, C2f, [256, True]]
|
21
|
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
22
|
+
- [-1, 6, C2f, [512, True]]
|
23
|
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
24
|
+
- [-1, 3, C2f, [1024, True]]
|
25
|
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
26
|
+
|
27
|
+
# YOLOv8.0n head
|
28
|
+
head:
|
29
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
30
|
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
31
|
+
- [-1, 3, C2f, [512]] # 12
|
32
|
+
|
33
|
+
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
34
|
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
35
|
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
36
|
+
|
37
|
+
- [15, 1, Conv, [256, 3, 2]]
|
38
|
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
39
|
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
40
|
+
|
41
|
+
- [-1, 1, Conv, [512, 3, 2]]
|
42
|
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
43
|
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
44
|
+
|
45
|
+
- [[15, 18, 21], 1, YOLOEDetect, [nc, 512, True]] # Detect(P3, P4, P5)
|
ultralytics/data/augment.py
CHANGED
@@ -3,19 +3,20 @@
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import Tuple, Union
|
6
|
+
from typing import List, Tuple, Union
|
7
7
|
|
8
8
|
import cv2
|
9
9
|
import numpy as np
|
10
10
|
import torch
|
11
11
|
from PIL import Image
|
12
|
+
from torch.nn import functional as F
|
12
13
|
|
13
14
|
from ultralytics.data.utils import polygons2masks, polygons2masks_overlap
|
14
15
|
from ultralytics.utils import LOGGER, colorstr
|
15
16
|
from ultralytics.utils.checks import check_version
|
16
17
|
from ultralytics.utils.instance import Instances
|
17
18
|
from ultralytics.utils.metrics import bbox_ioa
|
18
|
-
from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
|
19
|
+
from ultralytics.utils.ops import segment2box, xywh2xyxy, xyxyxyxy2xywhr
|
19
20
|
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
|
20
21
|
|
21
22
|
DEFAULT_MEAN = (0.0, 0.0, 0.0)
|
@@ -2140,6 +2141,99 @@ class Format:
|
|
2140
2141
|
return masks, instances, cls
|
2141
2142
|
|
2142
2143
|
|
2144
|
+
class LoadVisualPrompt:
|
2145
|
+
"""Creates visual prompts from bounding boxes or masks for model input."""
|
2146
|
+
|
2147
|
+
def __init__(self, scale_factor=1 / 8):
|
2148
|
+
"""
|
2149
|
+
Initialize the LoadVisualPrompt with a scale factor.
|
2150
|
+
|
2151
|
+
Args:
|
2152
|
+
scale_factor (float): Factor to scale the input image dimensions.
|
2153
|
+
"""
|
2154
|
+
self.scale_factor = scale_factor
|
2155
|
+
|
2156
|
+
def make_mask(self, boxes, h, w):
|
2157
|
+
"""
|
2158
|
+
Create binary masks from bounding boxes.
|
2159
|
+
|
2160
|
+
Args:
|
2161
|
+
boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).
|
2162
|
+
h (int): Height of the mask.
|
2163
|
+
w (int): Width of the mask.
|
2164
|
+
|
2165
|
+
Returns:
|
2166
|
+
(torch.Tensor): Binary masks with shape (N, h, w).
|
2167
|
+
"""
|
2168
|
+
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
2169
|
+
r = torch.arange(w)[None, None, :] # rows shape(1,1,w)
|
2170
|
+
c = torch.arange(h)[None, :, None] # cols shape(1,h,1)
|
2171
|
+
|
2172
|
+
return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)
|
2173
|
+
|
2174
|
+
def __call__(self, labels):
|
2175
|
+
"""
|
2176
|
+
Process labels to create visual prompts.
|
2177
|
+
|
2178
|
+
Args:
|
2179
|
+
labels (dict): Dictionary containing image data and annotations.
|
2180
|
+
|
2181
|
+
Returns:
|
2182
|
+
(dict): Updated labels with visual prompts added.
|
2183
|
+
"""
|
2184
|
+
imgsz = labels["img"].shape[1:]
|
2185
|
+
bboxes, masks = None, None
|
2186
|
+
if "bboxes" in labels:
|
2187
|
+
bboxes = labels["bboxes"]
|
2188
|
+
bboxes = xywh2xyxy(bboxes) * torch.tensor(imgsz)[[1, 0, 1, 0]] # denormalize boxes
|
2189
|
+
|
2190
|
+
cls = labels["cls"].squeeze(-1).to(torch.int)
|
2191
|
+
visuals = self.get_visuals(cls, imgsz, bboxes=bboxes, masks=masks)
|
2192
|
+
labels["visuals"] = visuals
|
2193
|
+
return labels
|
2194
|
+
|
2195
|
+
def get_visuals(self, category, shape, bboxes=None, masks=None):
|
2196
|
+
"""
|
2197
|
+
Generate visual masks based on bounding boxes or masks.
|
2198
|
+
|
2199
|
+
Args:
|
2200
|
+
category (int | np.ndarray | torch.Tensor): The category labels for the objects.
|
2201
|
+
shape (tuple): The shape of the image (height, width).
|
2202
|
+
bboxes (np.ndarray | torch.Tensor, optional): Bounding boxes for the objects, xyxy format. Defaults to None.
|
2203
|
+
masks (np.ndarray | torch.Tensor, optional): Masks for the objects. Defaults to None.
|
2204
|
+
|
2205
|
+
Returns:
|
2206
|
+
(torch.Tensor): A tensor containing the visual masks for each category.
|
2207
|
+
|
2208
|
+
Raises:
|
2209
|
+
ValueError: If neither bboxes nor masks are provided.
|
2210
|
+
"""
|
2211
|
+
masksz = (int(shape[0] * self.scale_factor), int(shape[1] * self.scale_factor))
|
2212
|
+
if bboxes is not None:
|
2213
|
+
if isinstance(bboxes, np.ndarray):
|
2214
|
+
bboxes = torch.from_numpy(bboxes)
|
2215
|
+
bboxes *= self.scale_factor
|
2216
|
+
masks = self.make_mask(bboxes, *masksz).float()
|
2217
|
+
elif masks is not None:
|
2218
|
+
if isinstance(masks, np.ndarray):
|
2219
|
+
masks = torch.from_numpy(masks) # (N, H, W)
|
2220
|
+
masks = F.interpolate(masks.unsqueeze(1), masksz, mode="nearest").squeeze(1).float()
|
2221
|
+
else:
|
2222
|
+
raise ValueError("LoadVisualPrompt must have bboxes or masks in the label")
|
2223
|
+
if not isinstance(category, torch.Tensor):
|
2224
|
+
category = torch.tensor(category, dtype=torch.int)
|
2225
|
+
cls_unique, inverse_indices = torch.unique(category, sorted=True, return_inverse=True)
|
2226
|
+
# NOTE: `cls` indices from RandomLoadText should be continuous.
|
2227
|
+
# if len(cls_unique):
|
2228
|
+
# assert len(cls_unique) == cls_unique[-1] + 1, (
|
2229
|
+
# f"Expected a continuous range of class indices, but got {cls_unique}"
|
2230
|
+
# )
|
2231
|
+
visuals = torch.zeros(len(cls_unique), *masksz)
|
2232
|
+
for idx, mask in zip(inverse_indices, masks):
|
2233
|
+
visuals[idx] = torch.logical_or(visuals[idx], mask)
|
2234
|
+
return visuals
|
2235
|
+
|
2236
|
+
|
2143
2237
|
class RandomLoadText:
|
2144
2238
|
"""
|
2145
2239
|
Randomly samples positive and negative texts and updates class indices accordingly.
|
@@ -2172,7 +2266,7 @@ class RandomLoadText:
|
|
2172
2266
|
neg_samples: Tuple[int, int] = (80, 80),
|
2173
2267
|
max_samples: int = 80,
|
2174
2268
|
padding: bool = False,
|
2175
|
-
padding_value: str = "",
|
2269
|
+
padding_value: List[str] = [""],
|
2176
2270
|
) -> None:
|
2177
2271
|
"""
|
2178
2272
|
Initializes the RandomLoadText class for randomly sampling positive and negative texts.
|
@@ -2246,7 +2340,8 @@ class RandomLoadText:
|
|
2246
2340
|
neg_labels = random.sample(neg_labels, k=neg_samples)
|
2247
2341
|
|
2248
2342
|
sampled_labels = pos_labels + neg_labels
|
2249
|
-
|
2343
|
+
# Randomness
|
2344
|
+
# random.shuffle(sampled_labels)
|
2250
2345
|
|
2251
2346
|
label2ids = {label: i for i, label in enumerate(sampled_labels)}
|
2252
2347
|
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
|
@@ -2271,8 +2366,9 @@ class RandomLoadText:
|
|
2271
2366
|
valid_labels = len(pos_labels) + len(neg_labels)
|
2272
2367
|
num_padding = self.max_samples - valid_labels
|
2273
2368
|
if num_padding > 0:
|
2274
|
-
texts +=
|
2369
|
+
texts += random.choices(self.padding_value, k=num_padding)
|
2275
2370
|
|
2371
|
+
assert len(texts) == self.max_samples
|
2276
2372
|
labels["texts"] = texts
|
2277
2373
|
return labels
|
2278
2374
|
|