ultralytics 8.1.38__py3-none-any.whl → 8.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

@@ -1,15 +1,31 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  from .base import BaseDataset
4
- from .build import build_dataloader, build_yolo_dataset, load_inference_source
5
- from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
4
+ from .build import (
5
+ build_dataloader,
6
+ build_yolo_dataset,
7
+ build_grounding,
8
+ load_inference_source,
9
+ )
10
+ from .dataset import (
11
+ ClassificationDataset,
12
+ SemanticDataset,
13
+ YOLODataset,
14
+ YOLOMultiModalDataset,
15
+ GroundingDataset,
16
+ YOLOConcatDataset,
17
+ )
6
18
 
7
19
  __all__ = (
8
20
  "BaseDataset",
9
21
  "ClassificationDataset",
10
22
  "SemanticDataset",
11
23
  "YOLODataset",
24
+ "YOLOMultiModalDataset",
25
+ "YOLOConcatDataset",
26
+ "GroundingDataset",
12
27
  "build_yolo_dataset",
28
+ "build_grounding",
13
29
  "build_dataloader",
14
30
  "load_inference_source",
15
31
  )
@@ -3,6 +3,7 @@
3
3
  import math
4
4
  import random
5
5
  from copy import deepcopy
6
+ from typing import Tuple, Union
6
7
 
7
8
  import cv2
8
9
  import numpy as np
@@ -66,7 +67,7 @@ class Compose:
66
67
 
67
68
  def __init__(self, transforms):
68
69
  """Initializes the Compose object with a list of transforms."""
69
- self.transforms = transforms
70
+ self.transforms = transforms if isinstance(transforms, list) else [transforms]
70
71
 
71
72
  def __call__(self, data):
72
73
  """Applies a series of transformations to input data."""
@@ -78,6 +79,29 @@ class Compose:
78
79
  """Appends a new transform to the existing list of transforms."""
79
80
  self.transforms.append(transform)
80
81
 
82
+ def insert(self, index, transform):
83
+ """Inserts a new transform to the existing list of transforms."""
84
+ self.transforms.insert(index, transform)
85
+
86
+ def __getitem__(self, index: Union[list, int]) -> "Compose":
87
+ """Retrieve a specific transform or a set of transforms using indexing."""
88
+ assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
89
+ index = [index] if isinstance(index, int) else index
90
+ return Compose([self.transforms[i] for i in index])
91
+
92
+ def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
93
+ """Retrieve a specific transform or a set of transforms using indexing."""
94
+ assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
95
+ if isinstance(index, list):
96
+ assert isinstance(
97
+ value, list
98
+ ), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
99
+ if isinstance(index, int):
100
+ index, value = [index], [value]
101
+ for i, v in zip(index, value):
102
+ assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
103
+ self.transforms[i] = v
104
+
81
105
  def tolist(self):
82
106
  """Converts the list of transforms to a standard Python list."""
83
107
  return self.transforms
@@ -118,6 +142,8 @@ class BaseMixTransform:
118
142
  mix_labels[i] = self.pre_transform(data)
119
143
  labels["mix_labels"] = mix_labels
120
144
 
145
+ # Update cls and texts
146
+ labels = self._update_label_text(labels)
121
147
  # Mosaic or MixUp
122
148
  labels = self._mix_transform(labels)
123
149
  labels.pop("mix_labels", None)
@@ -131,6 +157,22 @@ class BaseMixTransform:
131
157
  """Gets a list of shuffled indexes for mosaic augmentation."""
132
158
  raise NotImplementedError
133
159
 
160
+ def _update_label_text(self, labels):
161
+ """Update label text."""
162
+ if "texts" not in labels:
163
+ return labels
164
+
165
+ mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
166
+ mix_texts = list({tuple(x) for x in mix_texts})
167
+ text2id = {text: i for i, text in enumerate(mix_texts)}
168
+
169
+ for label in [labels] + labels["mix_labels"]:
170
+ for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
171
+ text = label["texts"][int(l)]
172
+ label["cls"][i] = text2id[tuple(text)]
173
+ label["texts"] = mix_texts
174
+ return labels
175
+
134
176
 
135
177
  class Mosaic(BaseMixTransform):
136
178
  """
@@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
320
362
  final_labels["instances"].clip(imgsz, imgsz)
321
363
  good = final_labels["instances"].remove_zero_area_boxes()
322
364
  final_labels["cls"] = final_labels["cls"][good]
365
+ if "texts" in mosaic_labels[0]:
366
+ final_labels["texts"] = mosaic_labels[0]["texts"]
323
367
  return final_labels
324
368
 
325
369
 
@@ -970,6 +1014,83 @@ class Format:
970
1014
  return masks, instances, cls
971
1015
 
972
1016
 
1017
+ class RandomLoadText:
1018
+ """
1019
+ Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
1020
+
1021
+ Attributes:
1022
+ prompt_format (str): Format for prompt. Default is '{}'.
1023
+ neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
1024
+ max_samples (int): The max number of different text samples in one image, Default is 80.
1025
+ padding (bool): Whether to pad texts to max_samples. Default is False.
1026
+ padding_value (str): The padding text. Default is "".
1027
+ """
1028
+
1029
+ def __init__(
1030
+ self,
1031
+ prompt_format: str = "{}",
1032
+ neg_samples: Tuple[int, int] = (80, 80),
1033
+ max_samples: int = 80,
1034
+ padding: bool = False,
1035
+ padding_value: str = "",
1036
+ ) -> None:
1037
+ """Initializes the RandomLoadText class with given parameters."""
1038
+ self.prompt_format = prompt_format
1039
+ self.neg_samples = neg_samples
1040
+ self.max_samples = max_samples
1041
+ self.padding = padding
1042
+ self.padding_value = padding_value
1043
+
1044
+ def __call__(self, labels: dict) -> dict:
1045
+ """Return updated classes and texts."""
1046
+ assert "texts" in labels, "No texts found in labels."
1047
+ class_texts = labels["texts"]
1048
+ num_classes = len(class_texts)
1049
+ cls = np.asarray(labels.pop("cls"), dtype=int)
1050
+ pos_labels = np.unique(cls).tolist()
1051
+
1052
+ if len(pos_labels) > self.max_samples:
1053
+ pos_labels = set(random.sample(pos_labels, k=self.max_samples))
1054
+
1055
+ neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
1056
+ neg_labels = []
1057
+ for i in range(num_classes):
1058
+ if i not in pos_labels:
1059
+ neg_labels.append(i)
1060
+ neg_labels = random.sample(neg_labels, k=neg_samples)
1061
+
1062
+ sampled_labels = pos_labels + neg_labels
1063
+ random.shuffle(sampled_labels)
1064
+
1065
+ label2ids = {label: i for i, label in enumerate(sampled_labels)}
1066
+ valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
1067
+ new_cls = []
1068
+ for i, label in enumerate(cls.squeeze(-1).tolist()):
1069
+ if label not in label2ids:
1070
+ continue
1071
+ valid_idx[i] = True
1072
+ new_cls.append([label2ids[label]])
1073
+ labels["instances"] = labels["instances"][valid_idx]
1074
+ labels["cls"] = np.array(new_cls)
1075
+
1076
+ # Randomly select one prompt when there's more than one prompts
1077
+ texts = []
1078
+ for label in sampled_labels:
1079
+ prompts = class_texts[label]
1080
+ assert len(prompts) > 0
1081
+ prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
1082
+ texts.append(prompt)
1083
+
1084
+ if self.padding:
1085
+ valid_labels = len(pos_labels) + len(neg_labels)
1086
+ num_padding = self.max_samples - valid_labels
1087
+ if num_padding > 0:
1088
+ texts += [self.padding_value] * num_padding
1089
+
1090
+ labels["texts"] = texts
1091
+ return labels
1092
+
1093
+
973
1094
  def v8_transforms(dataset, imgsz, hyp, stretch=False):
974
1095
  """Convert images to a size suitable for YOLOv8 training."""
975
1096
  pre_transform = Compose(
ultralytics/data/build.py CHANGED
@@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
22
22
  from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
23
23
  from ultralytics.utils import RANK, colorstr
24
24
  from ultralytics.utils.checks import check_file
25
- from .dataset import YOLODataset
25
+ from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
26
26
  from .utils import PIN_MEMORY
27
27
 
28
28
 
@@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
82
82
  random.seed(worker_seed)
83
83
 
84
84
 
85
- def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
85
+ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
86
86
  """Build YOLO Dataset."""
87
- return YOLODataset(
87
+ dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
88
+ return dataset(
88
89
  img_path=img_path,
89
90
  imgsz=cfg.imgsz,
90
91
  batch_size=batch,
@@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
103
104
  )
104
105
 
105
106
 
107
+ def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
108
+ """Build YOLO Dataset."""
109
+ return GroundingDataset(
110
+ img_path=img_path,
111
+ json_file=json_file,
112
+ imgsz=cfg.imgsz,
113
+ batch_size=batch,
114
+ augment=mode == "train", # augmentation
115
+ hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
116
+ rect=cfg.rect or rect, # rectangular batches
117
+ cache=cfg.cache or None,
118
+ single_cls=cfg.single_cls or False,
119
+ stride=int(stride),
120
+ pad=0.0 if mode == "train" else 0.5,
121
+ prefix=colorstr(f"{mode}: "),
122
+ task=cfg.task,
123
+ classes=cfg.classes,
124
+ fraction=cfg.fraction if mode == "train" else 1.0,
125
+ )
126
+
127
+
106
128
  def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
107
129
  """Return an InfiniteDataLoader or DataLoader for training or validation set."""
108
130
  batch = min(batch, len(dataset))
@@ -219,6 +219,7 @@ def convert_coco(
219
219
  use_segments=False,
220
220
  use_keypoints=False,
221
221
  cls91to80=True,
222
+ lvis=False,
222
223
  ):
223
224
  """
224
225
  Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@@ -229,12 +230,14 @@ def convert_coco(
229
230
  use_segments (bool, optional): Whether to include segmentation masks in the output.
230
231
  use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
231
232
  cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
233
+ lvis (bool, optional): Whether to convert data in lvis dataset way.
232
234
 
233
235
  Example:
234
236
  ```python
235
237
  from ultralytics.data.converter import convert_coco
236
238
 
237
239
  convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
240
+ convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
238
241
  ```
239
242
 
240
243
  Output:
@@ -251,8 +254,14 @@ def convert_coco(
251
254
 
252
255
  # Import json
253
256
  for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
254
- fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
257
+ lname = "" if lvis else json_file.stem.replace("instances_", "")
258
+ fn = Path(save_dir) / "labels" / lname # folder name
255
259
  fn.mkdir(parents=True, exist_ok=True)
260
+ if lvis:
261
+ # NOTE: create folders for both train and val in advance,
262
+ # since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
263
+ (fn / "train2017").mkdir(parents=True, exist_ok=True)
264
+ (fn / "val2017").mkdir(parents=True, exist_ok=True)
256
265
  with open(json_file) as f:
257
266
  data = json.load(f)
258
267
 
@@ -263,16 +272,20 @@ def convert_coco(
263
272
  for ann in data["annotations"]:
264
273
  imgToAnns[ann["image_id"]].append(ann)
265
274
 
275
+ image_txt = []
266
276
  # Write labels file
267
277
  for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
268
278
  img = images[f"{img_id:d}"]
269
- h, w, f = img["height"], img["width"], img["file_name"]
279
+ h, w = img["height"], img["width"]
280
+ f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
281
+ if lvis:
282
+ image_txt.append(str(Path("./images") / f))
270
283
 
271
284
  bboxes = []
272
285
  segments = []
273
286
  keypoints = []
274
287
  for ann in anns:
275
- if ann["iscrowd"]:
288
+ if ann.get("iscrowd", False):
276
289
  continue
277
290
  # The COCO box format is [top left x, top left y, width, height]
278
291
  box = np.array(ann["bbox"], dtype=np.float64)
@@ -314,7 +327,12 @@ def convert_coco(
314
327
  ) # cls, box or segments
315
328
  file.write(("%g " * len(line)).rstrip() % line + "\n")
316
329
 
317
- LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
330
+ if lvis:
331
+ with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
332
+ for l in image_txt:
333
+ f.write(f"{l}\n")
334
+
335
+ LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
318
336
 
319
337
 
320
338
  def convert_dota_to_yolo_obb(dota_root_path: str):
@@ -1,20 +1,41 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  import contextlib
3
3
  from itertools import repeat
4
+ from collections import defaultdict
4
5
  from multiprocessing.pool import ThreadPool
5
6
  from pathlib import Path
6
7
 
7
8
  import cv2
9
+ import json
8
10
  import numpy as np
9
11
  import torch
10
12
  import torchvision
11
13
  from PIL import Image
12
14
 
13
- from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
15
+ from torch.utils.data import ConcatDataset
16
+ from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
14
17
  from ultralytics.utils.ops import resample_segments
15
- from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
18
+ from .augment import (
19
+ Compose,
20
+ Format,
21
+ Instances,
22
+ LetterBox,
23
+ RandomLoadText,
24
+ classify_augmentations,
25
+ classify_transforms,
26
+ v8_transforms,
27
+ )
16
28
  from .base import BaseDataset
17
- from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
29
+ from .utils import (
30
+ HELP_URL,
31
+ LOGGER,
32
+ get_hash,
33
+ img2label_paths,
34
+ verify_image,
35
+ verify_image_label,
36
+ load_dataset_cache_file,
37
+ save_dataset_cache_file,
38
+ )
18
39
 
19
40
  # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
20
41
  DATASET_CACHE_VERSION = "1.0.3"
@@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
105
126
  x["hash"] = get_hash(self.label_files + self.im_files)
106
127
  x["results"] = nf, nm, ne, nc, len(self.im_files)
107
128
  x["msgs"] = msgs # warnings
108
- save_dataset_cache_file(self.prefix, path, x)
129
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
109
130
  return x
110
131
 
111
132
  def get_labels(self):
@@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
339
360
  x["hash"] = get_hash([x[0] for x in self.samples])
340
361
  x["results"] = nf, nc, len(samples), samples
341
362
  x["msgs"] = msgs # warnings
342
- save_dataset_cache_file(self.prefix, path, x)
363
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
343
364
  return samples
344
365
 
345
366
 
346
- def load_dataset_cache_file(path):
347
- """Load an Ultralytics *.cache dictionary from path."""
348
- import gc
367
+ class YOLOMultiModalDataset(YOLODataset):
368
+ """
369
+ Dataset class for loading object detection and/or segmentation labels in YOLO format.
370
+
371
+ Args:
372
+ data (dict, optional): A dataset YAML dictionary. Defaults to None.
373
+ task (str): An explicit arg to point current task, Defaults to 'detect'.
374
+
375
+ Returns:
376
+ (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
377
+ """
378
+
379
+ def __init__(self, *args, data=None, task="detect", **kwargs):
380
+ """Initializes a dataset object for object detection tasks with optional specifications."""
381
+ super().__init__(*args, data=data, task=task, **kwargs)
382
+
383
+ def update_labels_info(self, label):
384
+ """Add texts information for multi modal model training."""
385
+ labels = super().update_labels_info(label)
386
+ # NOTE: some categories are concatenated with its synonyms by `/`.
387
+ labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
388
+ return labels
389
+
390
+ def build_transforms(self, hyp=None):
391
+ """Enhances data transformations with optional text augmentation for multi-modal training."""
392
+ transforms = super().build_transforms(hyp)
393
+ if self.augment:
394
+ # NOTE: hard-coded the args for now.
395
+ transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
396
+ return transforms
397
+
398
+
399
+ class GroundingDataset(YOLODataset):
400
+ def __init__(self, *args, task="detect", json_file, **kwargs):
401
+ """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
402
+ assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
403
+ self.json_file = json_file
404
+ super().__init__(*args, task=task, data={}, **kwargs)
405
+
406
+ def get_img_files(self, img_path):
407
+ """The image files would be read in `get_labels` function, return empty list here."""
408
+ return []
409
+
410
+ def get_labels(self):
411
+ """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
412
+ labels = []
413
+ LOGGER.info("Loading annotation file...")
414
+ with open(self.json_file, "r") as f:
415
+ annotations = json.load(f)
416
+ images = {f'{x["id"]:d}': x for x in annotations["images"]}
417
+ imgToAnns = defaultdict(list)
418
+ for ann in annotations["annotations"]:
419
+ imgToAnns[ann["image_id"]].append(ann)
420
+ for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
421
+ img = images[f"{img_id:d}"]
422
+ h, w, f = img["height"], img["width"], img["file_name"]
423
+ im_file = Path(self.img_path) / f
424
+ if not im_file.exists():
425
+ continue
426
+ self.im_files.append(str(im_file))
427
+ bboxes = []
428
+ cat2id = {}
429
+ texts = []
430
+ for ann in anns:
431
+ if ann["iscrowd"]:
432
+ continue
433
+ box = np.array(ann["bbox"], dtype=np.float32)
434
+ box[:2] += box[2:] / 2
435
+ box[[0, 2]] /= float(w)
436
+ box[[1, 3]] /= float(h)
437
+ if box[2] <= 0 or box[3] <= 0:
438
+ continue
439
+
440
+ cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
441
+ if cat_name not in cat2id:
442
+ cat2id[cat_name] = len(cat2id)
443
+ texts.append([cat_name])
444
+ cls = cat2id[cat_name] # class
445
+ box = [cls] + box.tolist()
446
+ if box not in bboxes:
447
+ bboxes.append(box)
448
+ lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
449
+ labels.append(
450
+ dict(
451
+ im_file=im_file,
452
+ shape=(h, w),
453
+ cls=lb[:, 0:1], # n, 1
454
+ bboxes=lb[:, 1:], # n, 4
455
+ normalized=True,
456
+ bbox_format="xywh",
457
+ texts=texts,
458
+ )
459
+ )
460
+ return labels
461
+
462
+ def build_transforms(self, hyp=None):
463
+ """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
464
+ transforms = super().build_transforms(hyp)
465
+ if self.augment:
466
+ # NOTE: hard-coded the args for now.
467
+ transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
468
+ return transforms
349
469
 
350
- gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
351
- cache = np.load(str(path), allow_pickle=True).item() # load dict
352
- gc.enable()
353
- return cache
354
470
 
471
+ class YOLOConcatDataset(ConcatDataset):
472
+ """
473
+ Dataset as a concatenation of multiple datasets.
355
474
 
356
- def save_dataset_cache_file(prefix, path, x):
357
- """Save an Ultralytics dataset *.cache dictionary x to path."""
358
- x["version"] = DATASET_CACHE_VERSION # add cache version
359
- if is_dir_writeable(path.parent):
360
- if path.exists():
361
- path.unlink() # remove *.cache file if exists
362
- np.save(str(path), x) # save cache for next time
363
- path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
364
- LOGGER.info(f"{prefix}New cache created: {path}")
365
- else:
366
- LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
475
+ This class is useful to assemble different existing datasets.
476
+ """
477
+
478
+ @staticmethod
479
+ def collate_fn(batch):
480
+ """Collates data samples into batches."""
481
+ return YOLODataset.collate_fn(batch)
367
482
 
368
483
 
369
484
  # TODO: support semantic segmentation
ultralytics/data/utils.py CHANGED
@@ -29,6 +29,7 @@ from ultralytics.utils import (
29
29
  emojis,
30
30
  yaml_load,
31
31
  yaml_save,
32
+ is_dir_writeable,
32
33
  )
33
34
  from ultralytics.utils.checks import check_file, check_font, is_ascii
34
35
  from ultralytics.utils.downloads import download, safe_download, unzip_file
@@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True):
303
304
 
304
305
  # Set paths
305
306
  data["path"] = path # download scripts
306
- for k in "train", "val", "test":
307
+ for k in "train", "val", "test", "minival":
307
308
  if data.get(k): # prepend path
308
309
  if isinstance(data[k], str):
309
310
  x = (path / data[k]).resolve()
@@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
649
650
  if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
650
651
  with open(path.parent / txt[i], "a") as f:
651
652
  f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
653
+
654
+
655
+ def load_dataset_cache_file(path):
656
+ """Load an Ultralytics *.cache dictionary from path."""
657
+ import gc
658
+
659
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
660
+ cache = np.load(str(path), allow_pickle=True).item() # load dict
661
+ gc.enable()
662
+ return cache
663
+
664
+
665
+ def save_dataset_cache_file(prefix, path, x, version):
666
+ """Save an Ultralytics dataset *.cache dictionary x to path."""
667
+ x["version"] = version # add cache version
668
+ if is_dir_writeable(path.parent):
669
+ if path.exists():
670
+ path.unlink() # remove *.cache file if exists
671
+ np.save(str(path), x) # save cache for next time
672
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
673
+ LOGGER.info(f"{prefix}New cache created: {path}")
674
+ else:
675
+ LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
@@ -733,7 +733,10 @@ class Model(nn.Module):
733
733
  """
734
734
  from ultralytics.nn.autobackend import check_class_names
735
735
 
736
- return check_class_names(self.model.names) if hasattr(self.model, "names") else None
736
+ if hasattr(self.model, "names"):
737
+ return check_class_names(self.model.names)
738
+ elif self.predictor:
739
+ return self.predictor.model.names
737
740
 
738
741
  @property
739
742
  def device(self) -> torch.device: