ultralytics 8.1.38__py3-none-any.whl → 8.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (58) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +3 -3
  3. ultralytics/cfg/datasets/lvis.yaml +1239 -0
  4. ultralytics/data/__init__.py +18 -2
  5. ultralytics/data/augment.py +124 -3
  6. ultralytics/data/base.py +2 -2
  7. ultralytics/data/build.py +25 -3
  8. ultralytics/data/converter.py +24 -6
  9. ultralytics/data/dataset.py +142 -27
  10. ultralytics/data/loaders.py +11 -8
  11. ultralytics/data/split_dota.py +1 -1
  12. ultralytics/data/utils.py +33 -8
  13. ultralytics/engine/exporter.py +3 -3
  14. ultralytics/engine/model.py +6 -3
  15. ultralytics/engine/results.py +2 -2
  16. ultralytics/engine/trainer.py +59 -55
  17. ultralytics/engine/validator.py +2 -2
  18. ultralytics/hub/utils.py +1 -1
  19. ultralytics/models/fastsam/model.py +1 -1
  20. ultralytics/models/fastsam/prompt.py +4 -5
  21. ultralytics/models/nas/model.py +1 -1
  22. ultralytics/models/sam/model.py +1 -1
  23. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  24. ultralytics/models/yolo/__init__.py +2 -2
  25. ultralytics/models/yolo/classify/train.py +1 -1
  26. ultralytics/models/yolo/detect/train.py +1 -1
  27. ultralytics/models/yolo/detect/val.py +36 -17
  28. ultralytics/models/yolo/model.py +1 -0
  29. ultralytics/models/yolo/world/__init__.py +5 -0
  30. ultralytics/models/yolo/world/train.py +92 -0
  31. ultralytics/models/yolo/world/train_world.py +108 -0
  32. ultralytics/nn/autobackend.py +5 -5
  33. ultralytics/nn/modules/block.py +4 -2
  34. ultralytics/nn/modules/conv.py +1 -1
  35. ultralytics/nn/modules/head.py +13 -4
  36. ultralytics/nn/tasks.py +30 -14
  37. ultralytics/solutions/ai_gym.py +1 -1
  38. ultralytics/solutions/heatmap.py +85 -47
  39. ultralytics/solutions/object_counter.py +79 -64
  40. ultralytics/trackers/byte_tracker.py +1 -1
  41. ultralytics/trackers/track.py +1 -1
  42. ultralytics/trackers/utils/gmc.py +1 -1
  43. ultralytics/utils/__init__.py +4 -4
  44. ultralytics/utils/benchmarks.py +2 -2
  45. ultralytics/utils/callbacks/comet.py +1 -1
  46. ultralytics/utils/callbacks/mlflow.py +1 -1
  47. ultralytics/utils/checks.py +3 -3
  48. ultralytics/utils/downloads.py +2 -2
  49. ultralytics/utils/loss.py +1 -1
  50. ultralytics/utils/metrics.py +1 -1
  51. ultralytics/utils/plotting.py +36 -22
  52. ultralytics/utils/torch_utils.py +17 -3
  53. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/METADATA +1 -1
  54. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/RECORD +58 -54
  55. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/LICENSE +0 -0
  56. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/WHEEL +0 -0
  57. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/entry_points.txt +0 -0
  58. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,31 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  from .base import BaseDataset
4
- from .build import build_dataloader, build_yolo_dataset, load_inference_source
5
- from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
4
+ from .build import (
5
+ build_dataloader,
6
+ build_yolo_dataset,
7
+ build_grounding,
8
+ load_inference_source,
9
+ )
10
+ from .dataset import (
11
+ ClassificationDataset,
12
+ SemanticDataset,
13
+ YOLODataset,
14
+ YOLOMultiModalDataset,
15
+ GroundingDataset,
16
+ YOLOConcatDataset,
17
+ )
6
18
 
7
19
  __all__ = (
8
20
  "BaseDataset",
9
21
  "ClassificationDataset",
10
22
  "SemanticDataset",
11
23
  "YOLODataset",
24
+ "YOLOMultiModalDataset",
25
+ "YOLOConcatDataset",
26
+ "GroundingDataset",
12
27
  "build_yolo_dataset",
28
+ "build_grounding",
13
29
  "build_dataloader",
14
30
  "load_inference_source",
15
31
  )
@@ -3,6 +3,7 @@
3
3
  import math
4
4
  import random
5
5
  from copy import deepcopy
6
+ from typing import Tuple, Union
6
7
 
7
8
  import cv2
8
9
  import numpy as np
@@ -66,7 +67,7 @@ class Compose:
66
67
 
67
68
  def __init__(self, transforms):
68
69
  """Initializes the Compose object with a list of transforms."""
69
- self.transforms = transforms
70
+ self.transforms = transforms if isinstance(transforms, list) else [transforms]
70
71
 
71
72
  def __call__(self, data):
72
73
  """Applies a series of transformations to input data."""
@@ -78,6 +79,29 @@ class Compose:
78
79
  """Appends a new transform to the existing list of transforms."""
79
80
  self.transforms.append(transform)
80
81
 
82
+ def insert(self, index, transform):
83
+ """Inserts a new transform to the existing list of transforms."""
84
+ self.transforms.insert(index, transform)
85
+
86
+ def __getitem__(self, index: Union[list, int]) -> "Compose":
87
+ """Retrieve a specific transform or a set of transforms using indexing."""
88
+ assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
89
+ index = [index] if isinstance(index, int) else index
90
+ return Compose([self.transforms[i] for i in index])
91
+
92
+ def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
93
+ """Retrieve a specific transform or a set of transforms using indexing."""
94
+ assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
95
+ if isinstance(index, list):
96
+ assert isinstance(
97
+ value, list
98
+ ), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
99
+ if isinstance(index, int):
100
+ index, value = [index], [value]
101
+ for i, v in zip(index, value):
102
+ assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
103
+ self.transforms[i] = v
104
+
81
105
  def tolist(self):
82
106
  """Converts the list of transforms to a standard Python list."""
83
107
  return self.transforms
@@ -118,6 +142,8 @@ class BaseMixTransform:
118
142
  mix_labels[i] = self.pre_transform(data)
119
143
  labels["mix_labels"] = mix_labels
120
144
 
145
+ # Update cls and texts
146
+ labels = self._update_label_text(labels)
121
147
  # Mosaic or MixUp
122
148
  labels = self._mix_transform(labels)
123
149
  labels.pop("mix_labels", None)
@@ -131,6 +157,22 @@ class BaseMixTransform:
131
157
  """Gets a list of shuffled indexes for mosaic augmentation."""
132
158
  raise NotImplementedError
133
159
 
160
+ def _update_label_text(self, labels):
161
+ """Update label text."""
162
+ if "texts" not in labels:
163
+ return labels
164
+
165
+ mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
166
+ mix_texts = list({tuple(x) for x in mix_texts})
167
+ text2id = {text: i for i, text in enumerate(mix_texts)}
168
+
169
+ for label in [labels] + labels["mix_labels"]:
170
+ for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
171
+ text = label["texts"][int(l)]
172
+ label["cls"][i] = text2id[tuple(text)]
173
+ label["texts"] = mix_texts
174
+ return labels
175
+
134
176
 
135
177
  class Mosaic(BaseMixTransform):
136
178
  """
@@ -149,7 +191,7 @@ class Mosaic(BaseMixTransform):
149
191
  def __init__(self, dataset, imgsz=640, p=1.0, n=4):
150
192
  """Initializes the object with a dataset, image size, probability, and border."""
151
193
  assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
152
- assert n in (4, 9), "grid must be equal to 4 or 9."
194
+ assert n in {4, 9}, "grid must be equal to 4 or 9."
153
195
  super().__init__(dataset=dataset, p=p)
154
196
  self.dataset = dataset
155
197
  self.imgsz = imgsz
@@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
320
362
  final_labels["instances"].clip(imgsz, imgsz)
321
363
  good = final_labels["instances"].remove_zero_area_boxes()
322
364
  final_labels["cls"] = final_labels["cls"][good]
365
+ if "texts" in mosaic_labels[0]:
366
+ final_labels["texts"] = mosaic_labels[0]["texts"]
323
367
  return final_labels
324
368
 
325
369
 
@@ -641,7 +685,7 @@ class RandomFlip:
641
685
  Default is 'horizontal'.
642
686
  flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
643
687
  """
644
- assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
688
+ assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
645
689
  assert 0 <= p <= 1.0
646
690
 
647
691
  self.p = p
@@ -970,6 +1014,83 @@ class Format:
970
1014
  return masks, instances, cls
971
1015
 
972
1016
 
1017
+ class RandomLoadText:
1018
+ """
1019
+ Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
1020
+
1021
+ Attributes:
1022
+ prompt_format (str): Format for prompt. Default is '{}'.
1023
+ neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
1024
+ max_samples (int): The max number of different text samples in one image, Default is 80.
1025
+ padding (bool): Whether to pad texts to max_samples. Default is False.
1026
+ padding_value (str): The padding text. Default is "".
1027
+ """
1028
+
1029
+ def __init__(
1030
+ self,
1031
+ prompt_format: str = "{}",
1032
+ neg_samples: Tuple[int, int] = (80, 80),
1033
+ max_samples: int = 80,
1034
+ padding: bool = False,
1035
+ padding_value: str = "",
1036
+ ) -> None:
1037
+ """Initializes the RandomLoadText class with given parameters."""
1038
+ self.prompt_format = prompt_format
1039
+ self.neg_samples = neg_samples
1040
+ self.max_samples = max_samples
1041
+ self.padding = padding
1042
+ self.padding_value = padding_value
1043
+
1044
+ def __call__(self, labels: dict) -> dict:
1045
+ """Return updated classes and texts."""
1046
+ assert "texts" in labels, "No texts found in labels."
1047
+ class_texts = labels["texts"]
1048
+ num_classes = len(class_texts)
1049
+ cls = np.asarray(labels.pop("cls"), dtype=int)
1050
+ pos_labels = np.unique(cls).tolist()
1051
+
1052
+ if len(pos_labels) > self.max_samples:
1053
+ pos_labels = set(random.sample(pos_labels, k=self.max_samples))
1054
+
1055
+ neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
1056
+ neg_labels = []
1057
+ for i in range(num_classes):
1058
+ if i not in pos_labels:
1059
+ neg_labels.append(i)
1060
+ neg_labels = random.sample(neg_labels, k=neg_samples)
1061
+
1062
+ sampled_labels = pos_labels + neg_labels
1063
+ random.shuffle(sampled_labels)
1064
+
1065
+ label2ids = {label: i for i, label in enumerate(sampled_labels)}
1066
+ valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
1067
+ new_cls = []
1068
+ for i, label in enumerate(cls.squeeze(-1).tolist()):
1069
+ if label not in label2ids:
1070
+ continue
1071
+ valid_idx[i] = True
1072
+ new_cls.append([label2ids[label]])
1073
+ labels["instances"] = labels["instances"][valid_idx]
1074
+ labels["cls"] = np.array(new_cls)
1075
+
1076
+ # Randomly select one prompt when there's more than one prompts
1077
+ texts = []
1078
+ for label in sampled_labels:
1079
+ prompts = class_texts[label]
1080
+ assert len(prompts) > 0
1081
+ prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
1082
+ texts.append(prompt)
1083
+
1084
+ if self.padding:
1085
+ valid_labels = len(pos_labels) + len(neg_labels)
1086
+ num_padding = self.max_samples - valid_labels
1087
+ if num_padding > 0:
1088
+ texts += [self.padding_value] * num_padding
1089
+
1090
+ labels["texts"] = texts
1091
+ return labels
1092
+
1093
+
973
1094
  def v8_transforms(dataset, imgsz, hyp, stretch=False):
974
1095
  """Convert images to a size suitable for YOLOv8 training."""
975
1096
  pre_transform = Compose(
ultralytics/data/base.py CHANGED
@@ -15,7 +15,7 @@ import psutil
15
15
  from torch.utils.data import Dataset
16
16
 
17
17
  from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
18
- from .utils import HELP_URL, IMG_FORMATS
18
+ from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
19
19
 
20
20
 
21
21
  class BaseDataset(Dataset):
@@ -118,7 +118,7 @@ class BaseDataset(Dataset):
118
118
  raise FileNotFoundError(f"{self.prefix}{p} does not exist")
119
119
  im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
120
120
  # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
121
- assert im_files, f"{self.prefix}No images found in {img_path}"
121
+ assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
122
122
  except Exception as e:
123
123
  raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
124
124
  if self.fraction < 1:
ultralytics/data/build.py CHANGED
@@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
22
22
  from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
23
23
  from ultralytics.utils import RANK, colorstr
24
24
  from ultralytics.utils.checks import check_file
25
- from .dataset import YOLODataset
25
+ from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
26
26
  from .utils import PIN_MEMORY
27
27
 
28
28
 
@@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
82
82
  random.seed(worker_seed)
83
83
 
84
84
 
85
- def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
85
+ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
86
86
  """Build YOLO Dataset."""
87
- return YOLODataset(
87
+ dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
88
+ return dataset(
88
89
  img_path=img_path,
89
90
  imgsz=cfg.imgsz,
90
91
  batch_size=batch,
@@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
103
104
  )
104
105
 
105
106
 
107
+ def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
108
+ """Build YOLO Dataset."""
109
+ return GroundingDataset(
110
+ img_path=img_path,
111
+ json_file=json_file,
112
+ imgsz=cfg.imgsz,
113
+ batch_size=batch,
114
+ augment=mode == "train", # augmentation
115
+ hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
116
+ rect=cfg.rect or rect, # rectangular batches
117
+ cache=cfg.cache or None,
118
+ single_cls=cfg.single_cls or False,
119
+ stride=int(stride),
120
+ pad=0.0 if mode == "train" else 0.5,
121
+ prefix=colorstr(f"{mode}: "),
122
+ task=cfg.task,
123
+ classes=cfg.classes,
124
+ fraction=cfg.fraction if mode == "train" else 1.0,
125
+ )
126
+
127
+
106
128
  def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
107
129
  """Return an InfiniteDataLoader or DataLoader for training or validation set."""
108
130
  batch = min(batch, len(dataset))
@@ -219,6 +219,7 @@ def convert_coco(
219
219
  use_segments=False,
220
220
  use_keypoints=False,
221
221
  cls91to80=True,
222
+ lvis=False,
222
223
  ):
223
224
  """
224
225
  Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@@ -229,12 +230,14 @@ def convert_coco(
229
230
  use_segments (bool, optional): Whether to include segmentation masks in the output.
230
231
  use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
231
232
  cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
233
+ lvis (bool, optional): Whether to convert data in lvis dataset way.
232
234
 
233
235
  Example:
234
236
  ```python
235
237
  from ultralytics.data.converter import convert_coco
236
238
 
237
239
  convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
240
+ convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
238
241
  ```
239
242
 
240
243
  Output:
@@ -251,8 +254,14 @@ def convert_coco(
251
254
 
252
255
  # Import json
253
256
  for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
254
- fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
257
+ lname = "" if lvis else json_file.stem.replace("instances_", "")
258
+ fn = Path(save_dir) / "labels" / lname # folder name
255
259
  fn.mkdir(parents=True, exist_ok=True)
260
+ if lvis:
261
+ # NOTE: create folders for both train and val in advance,
262
+ # since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
263
+ (fn / "train2017").mkdir(parents=True, exist_ok=True)
264
+ (fn / "val2017").mkdir(parents=True, exist_ok=True)
256
265
  with open(json_file) as f:
257
266
  data = json.load(f)
258
267
 
@@ -263,16 +272,20 @@ def convert_coco(
263
272
  for ann in data["annotations"]:
264
273
  imgToAnns[ann["image_id"]].append(ann)
265
274
 
275
+ image_txt = []
266
276
  # Write labels file
267
277
  for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
268
278
  img = images[f"{img_id:d}"]
269
- h, w, f = img["height"], img["width"], img["file_name"]
279
+ h, w = img["height"], img["width"]
280
+ f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
281
+ if lvis:
282
+ image_txt.append(str(Path("./images") / f))
270
283
 
271
284
  bboxes = []
272
285
  segments = []
273
286
  keypoints = []
274
287
  for ann in anns:
275
- if ann["iscrowd"]:
288
+ if ann.get("iscrowd", False):
276
289
  continue
277
290
  # The COCO box format is [top left x, top left y, width, height]
278
291
  box = np.array(ann["bbox"], dtype=np.float64)
@@ -314,7 +327,12 @@ def convert_coco(
314
327
  ) # cls, box or segments
315
328
  file.write(("%g " * len(line)).rstrip() % line + "\n")
316
329
 
317
- LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
330
+ if lvis:
331
+ with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
332
+ for l in image_txt:
333
+ f.write(f"{l}\n")
334
+
335
+ LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
318
336
 
319
337
 
320
338
  def convert_dota_to_yolo_obb(dota_root_path: str):
@@ -463,7 +481,7 @@ def merge_multi_segment(segments):
463
481
  segments[i] = np.roll(segments[i], -idx[0], axis=0)
464
482
  segments[i] = np.concatenate([segments[i], segments[i][:1]])
465
483
  # Deal with the first segment and the last one
466
- if i in [0, len(idx_list) - 1]:
484
+ if i in {0, len(idx_list) - 1}:
467
485
  s.append(segments[i])
468
486
  else:
469
487
  idx = [0, idx[1] - idx[0]]
@@ -471,7 +489,7 @@ def merge_multi_segment(segments):
471
489
 
472
490
  else:
473
491
  for i in range(len(idx_list) - 1, -1, -1):
474
- if i not in [0, len(idx_list) - 1]:
492
+ if i not in {0, len(idx_list) - 1}:
475
493
  idx = idx_list[i]
476
494
  nidx = abs(idx[1] - idx[0])
477
495
  s.append(segments[i][nidx:])
@@ -1,20 +1,41 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  import contextlib
3
3
  from itertools import repeat
4
+ from collections import defaultdict
4
5
  from multiprocessing.pool import ThreadPool
5
6
  from pathlib import Path
6
7
 
7
8
  import cv2
9
+ import json
8
10
  import numpy as np
9
11
  import torch
10
12
  import torchvision
11
13
  from PIL import Image
12
14
 
13
- from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
15
+ from torch.utils.data import ConcatDataset
16
+ from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
14
17
  from ultralytics.utils.ops import resample_segments
15
- from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
18
+ from .augment import (
19
+ Compose,
20
+ Format,
21
+ Instances,
22
+ LetterBox,
23
+ RandomLoadText,
24
+ classify_augmentations,
25
+ classify_transforms,
26
+ v8_transforms,
27
+ )
16
28
  from .base import BaseDataset
17
- from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
29
+ from .utils import (
30
+ HELP_URL,
31
+ LOGGER,
32
+ get_hash,
33
+ img2label_paths,
34
+ verify_image,
35
+ verify_image_label,
36
+ load_dataset_cache_file,
37
+ save_dataset_cache_file,
38
+ )
18
39
 
19
40
  # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
20
41
  DATASET_CACHE_VERSION = "1.0.3"
@@ -56,7 +77,7 @@ class YOLODataset(BaseDataset):
56
77
  desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
57
78
  total = len(self.im_files)
58
79
  nkpt, ndim = self.data.get("kpt_shape", (0, 0))
59
- if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
80
+ if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
60
81
  raise ValueError(
61
82
  "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
62
83
  "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
@@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
105
126
  x["hash"] = get_hash(self.label_files + self.im_files)
106
127
  x["results"] = nf, nm, ne, nc, len(self.im_files)
107
128
  x["msgs"] = msgs # warnings
108
- save_dataset_cache_file(self.prefix, path, x)
129
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
109
130
  return x
110
131
 
111
132
  def get_labels(self):
@@ -121,7 +142,7 @@ class YOLODataset(BaseDataset):
121
142
 
122
143
  # Display cache
123
144
  nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
124
- if exists and LOCAL_RANK in (-1, 0):
145
+ if exists and LOCAL_RANK in {-1, 0}:
125
146
  d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
126
147
  TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
127
148
  if cache["msgs"]:
@@ -214,7 +235,7 @@ class YOLODataset(BaseDataset):
214
235
  value = values[i]
215
236
  if k == "img":
216
237
  value = torch.stack(value, 0)
217
- if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
238
+ if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
218
239
  value = torch.cat(value, 0)
219
240
  new_batch[k] = value
220
241
  new_batch["batch_idx"] = list(new_batch["batch_idx"])
@@ -313,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
313
334
  assert cache["version"] == DATASET_CACHE_VERSION # matches current version
314
335
  assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
315
336
  nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
316
- if LOCAL_RANK in (-1, 0):
337
+ if LOCAL_RANK in {-1, 0}:
317
338
  d = f"{desc} {nf} images, {nc} corrupt"
318
339
  TQDM(None, desc=d, total=n, initial=n)
319
340
  if cache["msgs"]:
@@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
339
360
  x["hash"] = get_hash([x[0] for x in self.samples])
340
361
  x["results"] = nf, nc, len(samples), samples
341
362
  x["msgs"] = msgs # warnings
342
- save_dataset_cache_file(self.prefix, path, x)
363
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
343
364
  return samples
344
365
 
345
366
 
346
- def load_dataset_cache_file(path):
347
- """Load an Ultralytics *.cache dictionary from path."""
348
- import gc
367
+ class YOLOMultiModalDataset(YOLODataset):
368
+ """
369
+ Dataset class for loading object detection and/or segmentation labels in YOLO format.
370
+
371
+ Args:
372
+ data (dict, optional): A dataset YAML dictionary. Defaults to None.
373
+ task (str): An explicit arg to point current task, Defaults to 'detect'.
374
+
375
+ Returns:
376
+ (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
377
+ """
378
+
379
+ def __init__(self, *args, data=None, task="detect", **kwargs):
380
+ """Initializes a dataset object for object detection tasks with optional specifications."""
381
+ super().__init__(*args, data=data, task=task, **kwargs)
382
+
383
+ def update_labels_info(self, label):
384
+ """Add texts information for multi modal model training."""
385
+ labels = super().update_labels_info(label)
386
+ # NOTE: some categories are concatenated with its synonyms by `/`.
387
+ labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
388
+ return labels
389
+
390
+ def build_transforms(self, hyp=None):
391
+ """Enhances data transformations with optional text augmentation for multi-modal training."""
392
+ transforms = super().build_transforms(hyp)
393
+ if self.augment:
394
+ # NOTE: hard-coded the args for now.
395
+ transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
396
+ return transforms
397
+
398
+
399
+ class GroundingDataset(YOLODataset):
400
+ def __init__(self, *args, task="detect", json_file, **kwargs):
401
+ """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
402
+ assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
403
+ self.json_file = json_file
404
+ super().__init__(*args, task=task, data={}, **kwargs)
405
+
406
+ def get_img_files(self, img_path):
407
+ """The image files would be read in `get_labels` function, return empty list here."""
408
+ return []
409
+
410
+ def get_labels(self):
411
+ """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
412
+ labels = []
413
+ LOGGER.info("Loading annotation file...")
414
+ with open(self.json_file, "r") as f:
415
+ annotations = json.load(f)
416
+ images = {f'{x["id"]:d}': x for x in annotations["images"]}
417
+ imgToAnns = defaultdict(list)
418
+ for ann in annotations["annotations"]:
419
+ imgToAnns[ann["image_id"]].append(ann)
420
+ for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
421
+ img = images[f"{img_id:d}"]
422
+ h, w, f = img["height"], img["width"], img["file_name"]
423
+ im_file = Path(self.img_path) / f
424
+ if not im_file.exists():
425
+ continue
426
+ self.im_files.append(str(im_file))
427
+ bboxes = []
428
+ cat2id = {}
429
+ texts = []
430
+ for ann in anns:
431
+ if ann["iscrowd"]:
432
+ continue
433
+ box = np.array(ann["bbox"], dtype=np.float32)
434
+ box[:2] += box[2:] / 2
435
+ box[[0, 2]] /= float(w)
436
+ box[[1, 3]] /= float(h)
437
+ if box[2] <= 0 or box[3] <= 0:
438
+ continue
439
+
440
+ cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
441
+ if cat_name not in cat2id:
442
+ cat2id[cat_name] = len(cat2id)
443
+ texts.append([cat_name])
444
+ cls = cat2id[cat_name] # class
445
+ box = [cls] + box.tolist()
446
+ if box not in bboxes:
447
+ bboxes.append(box)
448
+ lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
449
+ labels.append(
450
+ dict(
451
+ im_file=im_file,
452
+ shape=(h, w),
453
+ cls=lb[:, 0:1], # n, 1
454
+ bboxes=lb[:, 1:], # n, 4
455
+ normalized=True,
456
+ bbox_format="xywh",
457
+ texts=texts,
458
+ )
459
+ )
460
+ return labels
461
+
462
+ def build_transforms(self, hyp=None):
463
+ """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
464
+ transforms = super().build_transforms(hyp)
465
+ if self.augment:
466
+ # NOTE: hard-coded the args for now.
467
+ transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
468
+ return transforms
349
469
 
350
- gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
351
- cache = np.load(str(path), allow_pickle=True).item() # load dict
352
- gc.enable()
353
- return cache
354
470
 
471
+ class YOLOConcatDataset(ConcatDataset):
472
+ """
473
+ Dataset as a concatenation of multiple datasets.
355
474
 
356
- def save_dataset_cache_file(prefix, path, x):
357
- """Save an Ultralytics dataset *.cache dictionary x to path."""
358
- x["version"] = DATASET_CACHE_VERSION # add cache version
359
- if is_dir_writeable(path.parent):
360
- if path.exists():
361
- path.unlink() # remove *.cache file if exists
362
- np.save(str(path), x) # save cache for next time
363
- path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
364
- LOGGER.info(f"{prefix}New cache created: {path}")
365
- else:
366
- LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
475
+ This class is useful to assemble different existing datasets.
476
+ """
477
+
478
+ @staticmethod
479
+ def collate_fn(batch):
480
+ """Collates data samples into batches."""
481
+ return YOLODataset.collate_fn(batch)
367
482
 
368
483
 
369
484
  # TODO: support semantic segmentation
@@ -15,7 +15,7 @@ import requests
15
15
  import torch
16
16
  from PIL import Image
17
17
 
18
- from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
18
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
19
19
  from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
20
20
  from ultralytics.utils.checks import check_requirements
21
21
 
@@ -83,7 +83,7 @@ class LoadStreams:
83
83
  for i, s in enumerate(sources): # index, source
84
84
  # Start thread to read frames from video stream
85
85
  st = f"{i + 1}/{n}: {s}... "
86
- if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
86
+ if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
87
87
  # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
88
88
  s = get_best_youtube_url(s)
89
89
  s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
@@ -291,8 +291,14 @@ class LoadImagesAndVideos:
291
291
  else:
292
292
  raise FileNotFoundError(f"{p} does not exist")
293
293
 
294
- images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
295
- videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
294
+ # Define files as images or videos
295
+ images, videos = [], []
296
+ for f in files:
297
+ suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
298
+ if suffix in IMG_FORMATS:
299
+ images.append(f)
300
+ elif suffix in VID_FORMATS:
301
+ videos.append(f)
296
302
  ni, nv = len(images), len(videos)
297
303
 
298
304
  self.files = images + videos
@@ -307,10 +313,7 @@ class LoadImagesAndVideos:
307
313
  else:
308
314
  self.cap = None
309
315
  if self.nf == 0:
310
- raise FileNotFoundError(
311
- f"No images or videos found in {p}. "
312
- f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
313
- )
316
+ raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
314
317
 
315
318
  def __iter__(self):
316
319
  """Returns an iterator object for VideoStream or ImageFolder."""
@@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
71
71
  - train
72
72
  - val
73
73
  """
74
- assert split in ["train", "val"]
74
+ assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
75
75
  im_dir = Path(data_root) / "images" / split
76
76
  assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
77
77
  im_files = glob(str(Path(data_root) / "images" / split / "*"))