ultralytics 8.3.98__py3-none-any.whl → 8.3.100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tests/test_python.py +56 -0
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  4. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  5. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  6. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  7. ultralytics/data/augment.py +101 -5
  8. ultralytics/data/dataset.py +165 -12
  9. ultralytics/engine/exporter.py +5 -4
  10. ultralytics/engine/trainer.py +16 -7
  11. ultralytics/models/__init__.py +2 -2
  12. ultralytics/models/yolo/__init__.py +3 -3
  13. ultralytics/models/yolo/detect/val.py +6 -1
  14. ultralytics/models/yolo/model.py +183 -3
  15. ultralytics/models/yolo/segment/val.py +43 -16
  16. ultralytics/models/yolo/yoloe/__init__.py +21 -0
  17. ultralytics/models/yolo/yoloe/predict.py +170 -0
  18. ultralytics/models/yolo/yoloe/train.py +355 -0
  19. ultralytics/models/yolo/yoloe/train_seg.py +141 -0
  20. ultralytics/models/yolo/yoloe/val.py +187 -0
  21. ultralytics/nn/autobackend.py +17 -7
  22. ultralytics/nn/modules/__init__.py +18 -1
  23. ultralytics/nn/modules/block.py +17 -1
  24. ultralytics/nn/modules/head.py +359 -22
  25. ultralytics/nn/tasks.py +276 -10
  26. ultralytics/nn/text_model.py +193 -0
  27. ultralytics/utils/benchmarks.py +1 -0
  28. ultralytics/utils/callbacks/comet.py +3 -6
  29. ultralytics/utils/downloads.py +6 -2
  30. ultralytics/utils/loss.py +67 -6
  31. ultralytics/utils/plotting.py +1 -1
  32. ultralytics/utils/tal.py +1 -1
  33. {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/METADATA +10 -10
  34. {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/RECORD +38 -28
  35. {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/WHEEL +0 -0
  36. {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/entry_points.txt +0 -0
  37. {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/licenses/LICENSE +0 -0
  38. {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/top_level.txt +0 -0
tests/test_python.py CHANGED
@@ -608,6 +608,62 @@ def test_yolo_world():
608
608
  )
609
609
 
610
610
 
611
+ @pytest.mark.skipif(checks.IS_PYTHON_3_12 or not TORCH_1_9, reason="YOLOE with CLIP is not supported in Python 3.12")
612
+ def test_yoloe():
613
+ """Test YOLOE models with MobileClip support."""
614
+ # Predict
615
+ # text-prompts
616
+ model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt")
617
+ names = ["person", "bus"]
618
+ model.set_classes(names, model.get_text_pe(names))
619
+ model(SOURCE, conf=0.01)
620
+
621
+ import numpy as np
622
+
623
+ from ultralytics import YOLOE
624
+ from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
625
+
626
+ # visual-prompts
627
+ visuals = dict(
628
+ bboxes=np.array(
629
+ [[221.52, 405.8, 344.98, 857.54], [120, 425, 160, 445]],
630
+ ),
631
+ cls=np.array([0, 1]),
632
+ )
633
+ model.predict(
634
+ SOURCE,
635
+ visual_prompts=visuals,
636
+ predictor=YOLOEVPSegPredictor,
637
+ )
638
+
639
+ # Val
640
+ model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt")
641
+ # text prompts
642
+ model.val(data="coco128-seg.yaml", imgsz=32)
643
+ # visual prompts
644
+ model.val(data="coco128-seg.yaml", load_vp=True, imgsz=32)
645
+
646
+ # Train, fine-tune
647
+ from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer
648
+
649
+ model = YOLOE("yoloe-11s-seg.pt")
650
+ model.train(
651
+ data="coco128-seg.yaml",
652
+ epochs=1,
653
+ close_mosaic=1,
654
+ trainer=YOLOEPESegTrainer,
655
+ imgsz=32,
656
+ )
657
+
658
+ # prompt-free
659
+ # predict
660
+ model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt")
661
+ model.predict(SOURCE)
662
+ # val
663
+ model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
664
+ model.val(data="coco128-seg.yaml", imgsz=32)
665
+
666
+
611
667
  def test_yolov10():
612
668
  """Test YOLOv10 model training, validation, and prediction functionality."""
613
669
  model = YOLO("yolov10n.yaml")
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.98"
3
+ __version__ = "8.3.100"
4
4
 
5
5
  import os
6
6
 
@@ -8,7 +8,7 @@ import os
8
8
  if not os.environ.get("OMP_NUM_THREADS"):
9
9
  os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training
10
10
 
11
- from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
11
+ from ultralytics.models import NAS, RTDETR, SAM, YOLO, YOLOE, FastSAM, YOLOWorld
12
12
  from ultralytics.utils import ASSETS, SETTINGS
13
13
  from ultralytics.utils.checks import check_yolo as checks
14
14
  from ultralytics.utils.downloads import download
@@ -19,6 +19,7 @@ __all__ = (
19
19
  "ASSETS",
20
20
  "YOLO",
21
21
  "YOLOWorld",
22
+ "YOLOE",
22
23
  "NAS",
23
24
  "SAM",
24
25
  "FastSAM",
@@ -0,0 +1,48 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # YOLO11-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
4
+
5
+ # Parameters
6
+ nc: 80 # number of classes
7
+ scales: # model compound scaling constants, i.e. 'model=yolo11n-seg.yaml' will call yolo11-seg.yaml with scale 'n'
8
+ # [depth, width, max_channels]
9
+ n: [0.50, 0.25, 1024] # summary: 355 layers, 2876848 parameters, 2876832 gradients, 10.5 GFLOPs
10
+ s: [0.50, 0.50, 1024] # summary: 355 layers, 10113248 parameters, 10113232 gradients, 35.8 GFLOPs
11
+ m: [0.50, 1.00, 512] # summary: 445 layers, 22420896 parameters, 22420880 gradients, 123.9 GFLOPs
12
+ l: [1.00, 1.00, 512] # summary: 667 layers, 27678368 parameters, 27678352 gradients, 143.0 GFLOPs
13
+ x: [1.00, 1.50, 512] # summary: 667 layers, 62142656 parameters, 62142640 gradients, 320.2 GFLOPs
14
+
15
+ # YOLO11n backbone
16
+ backbone:
17
+ # [from, repeats, module, args]
18
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20
+ - [-1, 2, C3k2, [256, False, 0.25]]
21
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22
+ - [-1, 2, C3k2, [512, False, 0.25]]
23
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24
+ - [-1, 2, C3k2, [512, True]]
25
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26
+ - [-1, 2, C3k2, [1024, True]]
27
+ - [-1, 1, SPPF, [1024, 5]] # 9
28
+ - [-1, 2, C2PSA, [1024]] # 10
29
+
30
+ # YOLO11n head
31
+ head:
32
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
33
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
34
+ - [-1, 2, C3k2, [512, False]] # 13
35
+
36
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
37
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
38
+ - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
42
+ - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
43
+
44
+ - [-1, 1, Conv, [512, 3, 2]]
45
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
46
+ - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
47
+
48
+ - [[16, 19, 22], 1, YOLOESegment, [nc, 32, 256, 512, True]] # Detect(P3, P4, P5)
@@ -0,0 +1,48 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
4
+
5
+ # Parameters
6
+ nc: 80 # number of classes
7
+ scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
8
+ # [depth, width, max_channels]
9
+ n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
10
+ s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
11
+ m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
12
+ l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
13
+ x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
14
+
15
+ # YOLO11n backbone
16
+ backbone:
17
+ # [from, repeats, module, args]
18
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20
+ - [-1, 2, C3k2, [256, False, 0.25]]
21
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22
+ - [-1, 2, C3k2, [512, False, 0.25]]
23
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24
+ - [-1, 2, C3k2, [512, True]]
25
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26
+ - [-1, 2, C3k2, [1024, True]]
27
+ - [-1, 1, SPPF, [1024, 5]] # 9
28
+ - [-1, 2, C2PSA, [1024]] # 10
29
+
30
+ # YOLO11n head
31
+ head:
32
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
33
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
34
+ - [-1, 2, C3k2, [512, False]] # 13
35
+
36
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
37
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
38
+ - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
42
+ - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
43
+
44
+ - [-1, 1, Conv, [512, 3, 2]]
45
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
46
+ - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
47
+
48
+ - [[16, 19, 22], 1, YOLOEDetect, [nc, 512, True]] # Detect(P3, P4, P5)
@@ -0,0 +1,45 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
6
+ # [depth, width, max_channels]
7
+ n: [0.33, 0.25, 1024] # YOLOv8n-world summary: 161 layers, 4204111 parameters, 4204095 gradients, 39.6 GFLOPs
8
+ s: [0.33, 0.50, 1024] # YOLOv8s-world summary: 161 layers, 13383496 parameters, 13383480 gradients, 71.5 GFLOPs
9
+ m: [0.67, 0.75, 768] # YOLOv8m-world summary: 201 layers, 29065310 parameters, 29065294 gradients, 131.4 GFLOPs
10
+ l: [1.00, 1.00, 512] # YOLOv8l-world summary: 241 layers, 47553970 parameters, 47553954 gradients, 225.6 GFLOPs
11
+ x: [1.00, 1.25, 512] # YOLOv8x-world summary: 241 layers, 73690217 parameters, 73690201 gradients, 330.8 GFLOPs
12
+
13
+ # YOLOv8.0n backbone
14
+ backbone:
15
+ # [from, repeats, module, args]
16
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
17
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
18
+ - [-1, 3, C2f, [128, True]]
19
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
20
+ - [-1, 6, C2f, [256, True]]
21
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
22
+ - [-1, 6, C2f, [512, True]]
23
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
24
+ - [-1, 3, C2f, [1024, True]]
25
+ - [-1, 1, SPPF, [1024, 5]] # 9
26
+
27
+ # YOLOv8.0n head
28
+ head:
29
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
30
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
31
+ - [-1, 3, C2f, [512]] # 12
32
+
33
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
34
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
35
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
36
+
37
+ - [15, 1, Conv, [256, 3, 2]]
38
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
39
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
40
+
41
+ - [-1, 1, Conv, [512, 3, 2]]
42
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
43
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
44
+
45
+ - [[15, 18, 21], 1, YOLOESegment, [nc, 32, 256, 512, True]] # Segment(P3, P4, P5)
@@ -0,0 +1,45 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
6
+ # [depth, width, max_channels]
7
+ n: [0.33, 0.25, 1024] # YOLOv8n-worldv2 summary: 148 layers, 3695183 parameters, 3695167 gradients, 19.5 GFLOPS
8
+ s: [0.33, 0.50, 1024] # YOLOv8s-worldv2 summary: 148 layers, 12759880 parameters, 12759864 gradients, 51.0 GFLOPS
9
+ m: [0.67, 0.75, 768] # YOLOv8m-worldv2 summary: 188 layers, 28376158 parameters, 28376142 gradients, 110.5 GFLOPS
10
+ l: [1.00, 1.00, 512] # YOLOv8l-worldv2 summary: 228 layers, 46832050 parameters, 46832034 gradients, 204.5 GFLOPS
11
+ x: [1.00, 1.25, 512] # YOLOv8x-worldv2 summary: 228 layers, 72886377 parameters, 72886361 gradients, 309.3 GFLOPS
12
+
13
+ # YOLOv8.0n backbone
14
+ backbone:
15
+ # [from, repeats, module, args]
16
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
17
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
18
+ - [-1, 3, C2f, [128, True]]
19
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
20
+ - [-1, 6, C2f, [256, True]]
21
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
22
+ - [-1, 6, C2f, [512, True]]
23
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
24
+ - [-1, 3, C2f, [1024, True]]
25
+ - [-1, 1, SPPF, [1024, 5]] # 9
26
+
27
+ # YOLOv8.0n head
28
+ head:
29
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
30
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
31
+ - [-1, 3, C2f, [512]] # 12
32
+
33
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
34
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
35
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
36
+
37
+ - [15, 1, Conv, [256, 3, 2]]
38
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
39
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
40
+
41
+ - [-1, 1, Conv, [512, 3, 2]]
42
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
43
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
44
+
45
+ - [[15, 18, 21], 1, YOLOEDetect, [nc, 512, True]] # Detect(P3, P4, P5)
@@ -3,19 +3,20 @@
3
3
  import math
4
4
  import random
5
5
  from copy import deepcopy
6
- from typing import Tuple, Union
6
+ from typing import List, Tuple, Union
7
7
 
8
8
  import cv2
9
9
  import numpy as np
10
10
  import torch
11
11
  from PIL import Image
12
+ from torch.nn import functional as F
12
13
 
13
14
  from ultralytics.data.utils import polygons2masks, polygons2masks_overlap
14
15
  from ultralytics.utils import LOGGER, colorstr
15
16
  from ultralytics.utils.checks import check_version
16
17
  from ultralytics.utils.instance import Instances
17
18
  from ultralytics.utils.metrics import bbox_ioa
18
- from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
19
+ from ultralytics.utils.ops import segment2box, xywh2xyxy, xyxyxyxy2xywhr
19
20
  from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
20
21
 
21
22
  DEFAULT_MEAN = (0.0, 0.0, 0.0)
@@ -2140,6 +2141,99 @@ class Format:
2140
2141
  return masks, instances, cls
2141
2142
 
2142
2143
 
2144
+ class LoadVisualPrompt:
2145
+ """Creates visual prompts from bounding boxes or masks for model input."""
2146
+
2147
+ def __init__(self, scale_factor=1 / 8):
2148
+ """
2149
+ Initialize the LoadVisualPrompt with a scale factor.
2150
+
2151
+ Args:
2152
+ scale_factor (float): Factor to scale the input image dimensions.
2153
+ """
2154
+ self.scale_factor = scale_factor
2155
+
2156
+ def make_mask(self, boxes, h, w):
2157
+ """
2158
+ Create binary masks from bounding boxes.
2159
+
2160
+ Args:
2161
+ boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).
2162
+ h (int): Height of the mask.
2163
+ w (int): Width of the mask.
2164
+
2165
+ Returns:
2166
+ (torch.Tensor): Binary masks with shape (N, h, w).
2167
+ """
2168
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
2169
+ r = torch.arange(w)[None, None, :] # rows shape(1,1,w)
2170
+ c = torch.arange(h)[None, :, None] # cols shape(1,h,1)
2171
+
2172
+ return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)
2173
+
2174
+ def __call__(self, labels):
2175
+ """
2176
+ Process labels to create visual prompts.
2177
+
2178
+ Args:
2179
+ labels (dict): Dictionary containing image data and annotations.
2180
+
2181
+ Returns:
2182
+ (dict): Updated labels with visual prompts added.
2183
+ """
2184
+ imgsz = labels["img"].shape[1:]
2185
+ bboxes, masks = None, None
2186
+ if "bboxes" in labels:
2187
+ bboxes = labels["bboxes"]
2188
+ bboxes = xywh2xyxy(bboxes) * torch.tensor(imgsz)[[1, 0, 1, 0]] # denormalize boxes
2189
+
2190
+ cls = labels["cls"].squeeze(-1).to(torch.int)
2191
+ visuals = self.get_visuals(cls, imgsz, bboxes=bboxes, masks=masks)
2192
+ labels["visuals"] = visuals
2193
+ return labels
2194
+
2195
+ def get_visuals(self, category, shape, bboxes=None, masks=None):
2196
+ """
2197
+ Generate visual masks based on bounding boxes or masks.
2198
+
2199
+ Args:
2200
+ category (int | np.ndarray | torch.Tensor): The category labels for the objects.
2201
+ shape (tuple): The shape of the image (height, width).
2202
+ bboxes (np.ndarray | torch.Tensor, optional): Bounding boxes for the objects, xyxy format. Defaults to None.
2203
+ masks (np.ndarray | torch.Tensor, optional): Masks for the objects. Defaults to None.
2204
+
2205
+ Returns:
2206
+ (torch.Tensor): A tensor containing the visual masks for each category.
2207
+
2208
+ Raises:
2209
+ ValueError: If neither bboxes nor masks are provided.
2210
+ """
2211
+ masksz = (int(shape[0] * self.scale_factor), int(shape[1] * self.scale_factor))
2212
+ if bboxes is not None:
2213
+ if isinstance(bboxes, np.ndarray):
2214
+ bboxes = torch.from_numpy(bboxes)
2215
+ bboxes *= self.scale_factor
2216
+ masks = self.make_mask(bboxes, *masksz).float()
2217
+ elif masks is not None:
2218
+ if isinstance(masks, np.ndarray):
2219
+ masks = torch.from_numpy(masks) # (N, H, W)
2220
+ masks = F.interpolate(masks.unsqueeze(1), masksz, mode="nearest").squeeze(1).float()
2221
+ else:
2222
+ raise ValueError("LoadVisualPrompt must have bboxes or masks in the label")
2223
+ if not isinstance(category, torch.Tensor):
2224
+ category = torch.tensor(category, dtype=torch.int)
2225
+ cls_unique, inverse_indices = torch.unique(category, sorted=True, return_inverse=True)
2226
+ # NOTE: `cls` indices from RandomLoadText should be continuous.
2227
+ # if len(cls_unique):
2228
+ # assert len(cls_unique) == cls_unique[-1] + 1, (
2229
+ # f"Expected a continuous range of class indices, but got {cls_unique}"
2230
+ # )
2231
+ visuals = torch.zeros(len(cls_unique), *masksz)
2232
+ for idx, mask in zip(inverse_indices, masks):
2233
+ visuals[idx] = torch.logical_or(visuals[idx], mask)
2234
+ return visuals
2235
+
2236
+
2143
2237
  class RandomLoadText:
2144
2238
  """
2145
2239
  Randomly samples positive and negative texts and updates class indices accordingly.
@@ -2172,7 +2266,7 @@ class RandomLoadText:
2172
2266
  neg_samples: Tuple[int, int] = (80, 80),
2173
2267
  max_samples: int = 80,
2174
2268
  padding: bool = False,
2175
- padding_value: str = "",
2269
+ padding_value: List[str] = [""],
2176
2270
  ) -> None:
2177
2271
  """
2178
2272
  Initializes the RandomLoadText class for randomly sampling positive and negative texts.
@@ -2246,7 +2340,8 @@ class RandomLoadText:
2246
2340
  neg_labels = random.sample(neg_labels, k=neg_samples)
2247
2341
 
2248
2342
  sampled_labels = pos_labels + neg_labels
2249
- random.shuffle(sampled_labels)
2343
+ # Randomness
2344
+ # random.shuffle(sampled_labels)
2250
2345
 
2251
2346
  label2ids = {label: i for i, label in enumerate(sampled_labels)}
2252
2347
  valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
@@ -2271,8 +2366,9 @@ class RandomLoadText:
2271
2366
  valid_labels = len(pos_labels) + len(neg_labels)
2272
2367
  num_padding = self.max_samples - valid_labels
2273
2368
  if num_padding > 0:
2274
- texts += [self.padding_value] * num_padding
2369
+ texts += random.choices(self.padding_value, k=num_padding)
2275
2370
 
2371
+ assert len(texts) == self.max_samples
2276
2372
  labels["texts"] = texts
2277
2373
  return labels
2278
2374