ultralytics 8.1.42__py3-none-any.whl → 8.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (58) hide show
  1. ultralytics/__init__.py +3 -2
  2. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  3. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  4. ultralytics/cfg/models/v9/yolov9e-seg.yaml +2 -3
  5. ultralytics/cfg/models/v9/yolov9e.yaml +2 -3
  6. ultralytics/data/__init__.py +3 -8
  7. ultralytics/data/augment.py +14 -11
  8. ultralytics/data/base.py +1 -1
  9. ultralytics/data/build.py +1 -1
  10. ultralytics/data/converter.py +4 -3
  11. ultralytics/data/dataset.py +149 -144
  12. ultralytics/data/explorer/explorer.py +10 -11
  13. ultralytics/data/explorer/gui/dash.py +3 -3
  14. ultralytics/data/explorer/utils.py +3 -2
  15. ultralytics/data/loaders.py +3 -3
  16. ultralytics/data/utils.py +1 -1
  17. ultralytics/engine/exporter.py +3 -2
  18. ultralytics/engine/model.py +2 -1
  19. ultralytics/engine/trainer.py +2 -1
  20. ultralytics/hub/auth.py +3 -3
  21. ultralytics/hub/session.py +3 -3
  22. ultralytics/hub/utils.py +6 -6
  23. ultralytics/models/fastsam/prompt.py +4 -1
  24. ultralytics/models/rtdetr/val.py +1 -1
  25. ultralytics/models/sam/modules/tiny_encoder.py +2 -2
  26. ultralytics/models/sam/modules/transformer.py +1 -1
  27. ultralytics/models/sam/predict.py +16 -13
  28. ultralytics/models/yolo/classify/train.py +2 -1
  29. ultralytics/models/yolo/detect/val.py +1 -1
  30. ultralytics/models/yolo/model.py +1 -1
  31. ultralytics/models/yolo/obb/val.py +1 -1
  32. ultralytics/models/yolo/world/train_world.py +2 -2
  33. ultralytics/nn/modules/__init__.py +8 -8
  34. ultralytics/nn/modules/head.py +1 -1
  35. ultralytics/nn/tasks.py +7 -7
  36. ultralytics/solutions/heatmap.py +14 -27
  37. ultralytics/solutions/object_counter.py +12 -22
  38. ultralytics/trackers/byte_tracker.py +1 -1
  39. ultralytics/trackers/utils/kalman_filter.py +4 -4
  40. ultralytics/trackers/utils/matching.py +1 -1
  41. ultralytics/utils/__init__.py +56 -41
  42. ultralytics/utils/benchmarks.py +1 -2
  43. ultralytics/utils/callbacks/clearml.py +4 -3
  44. ultralytics/utils/callbacks/hub.py +1 -4
  45. ultralytics/utils/callbacks/mlflow.py +1 -1
  46. ultralytics/utils/callbacks/tensorboard.py +1 -0
  47. ultralytics/utils/callbacks/wb.py +5 -5
  48. ultralytics/utils/checks.py +17 -20
  49. ultralytics/utils/metrics.py +3 -3
  50. ultralytics/utils/ops.py +1 -1
  51. ultralytics/utils/plotting.py +67 -40
  52. ultralytics/utils/torch_utils.py +13 -6
  53. {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/METADATA +1 -1
  54. {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/RECORD +58 -58
  55. {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/LICENSE +0 -0
  56. {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/WHEEL +0 -0
  57. {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/entry_points.txt +0 -0
  58. {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py CHANGED
@@ -1,15 +1,16 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.1.42"
3
+ __version__ = "8.1.44"
4
4
 
5
5
  from ultralytics.data.explorer.explorer import Explorer
6
6
  from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
7
7
  from ultralytics.models.fastsam import FastSAM
8
8
  from ultralytics.models.nas import NAS
9
- from ultralytics.utils import ASSETS, SETTINGS as settings
9
+ from ultralytics.utils import ASSETS, SETTINGS
10
10
  from ultralytics.utils.checks import check_yolo as checks
11
11
  from ultralytics.utils.downloads import download
12
12
 
13
+ settings = SETTINGS
13
14
  __all__ = (
14
15
  "__version__",
15
16
  "ASSETS",
@@ -35,4 +35,4 @@ head:
35
35
  - [[-1, 9], 1, Concat, [1]] # cat head P5
36
36
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large)
37
37
 
38
- - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
38
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
@@ -35,4 +35,4 @@ head:
35
35
  - [[-1, 9], 1, Concat, [1]] # cat head P5
36
36
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large)
37
37
 
38
- - [[15, 18, 21], 1, Detect, [nc]] # DDetect(P3, P4, P5)
38
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
@@ -17,13 +17,13 @@ backbone:
17
17
  - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7
18
18
  - [-1, 1, ADown, [1024]] # 8-P5/32
19
19
  - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9
20
-
20
+
21
21
  - [1, 1, CBLinear, [[64]]] # 10
22
22
  - [3, 1, CBLinear, [[64, 128]]] # 11
23
23
  - [5, 1, CBLinear, [[64, 128, 256]]] # 12
24
24
  - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13
25
25
  - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14
26
-
26
+
27
27
  - [0, 1, Conv, [64, 3, 2]] # 15-P1/2
28
28
  - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16
29
29
  - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4
@@ -58,5 +58,4 @@ head:
58
58
  - [[-1, 29], 1, Concat, [1]] # cat head P5
59
59
  - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large)
60
60
 
61
- # segment
62
61
  - [[35, 38, 41], 1, Segment, [nc, 32, 256]] # Segment (P3, P4, P5)
@@ -17,13 +17,13 @@ backbone:
17
17
  - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7
18
18
  - [-1, 1, ADown, [1024]] # 8-P5/32
19
19
  - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9
20
-
20
+
21
21
  - [1, 1, CBLinear, [[64]]] # 10
22
22
  - [3, 1, CBLinear, [[64, 128]]] # 11
23
23
  - [5, 1, CBLinear, [[64, 128, 256]]] # 12
24
24
  - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13
25
25
  - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14
26
-
26
+
27
27
  - [0, 1, Conv, [64, 3, 2]] # 15-P1/2
28
28
  - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16
29
29
  - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4
@@ -58,5 +58,4 @@ head:
58
58
  - [[-1, 29], 1, Concat, [1]] # cat head P5
59
59
  - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large)
60
60
 
61
- # detect
62
61
  - [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5)
@@ -1,19 +1,14 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  from .base import BaseDataset
4
- from .build import (
5
- build_dataloader,
6
- build_yolo_dataset,
7
- build_grounding,
8
- load_inference_source,
9
- )
4
+ from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source
10
5
  from .dataset import (
11
6
  ClassificationDataset,
7
+ GroundingDataset,
12
8
  SemanticDataset,
9
+ YOLOConcatDataset,
13
10
  YOLODataset,
14
11
  YOLOMultiModalDataset,
15
- GroundingDataset,
16
- YOLOConcatDataset,
17
12
  )
18
13
 
19
14
  __all__ = (
@@ -8,7 +8,7 @@ from typing import Tuple, Union
8
8
  import cv2
9
9
  import numpy as np
10
10
  import torch
11
- import torchvision.transforms as T
11
+ from PIL import Image
12
12
 
13
13
  from ultralytics.utils import LOGGER, colorstr
14
14
  from ultralytics.utils.checks import check_version
@@ -20,7 +20,7 @@ from .utils import polygons2masks, polygons2masks_overlap
20
20
 
21
21
  DEFAULT_MEAN = (0.0, 0.0, 0.0)
22
22
  DEFAULT_STD = (1.0, 1.0, 1.0)
23
- DEFAULT_CROP_FTACTION = 1.0
23
+ DEFAULT_CROP_FRACTION = 1.0
24
24
 
25
25
 
26
26
  # TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
@@ -167,8 +167,8 @@ class BaseMixTransform:
167
167
  text2id = {text: i for i, text in enumerate(mix_texts)}
168
168
 
169
169
  for label in [labels] + labels["mix_labels"]:
170
- for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
171
- text = label["texts"][int(l)]
170
+ for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
171
+ text = label["texts"][int(cls)]
172
172
  label["cls"][i] = text2id[tuple(text)]
173
173
  label["texts"] = mix_texts
174
174
  return labels
@@ -1133,8 +1133,8 @@ def classify_transforms(
1133
1133
  size=224,
1134
1134
  mean=DEFAULT_MEAN,
1135
1135
  std=DEFAULT_STD,
1136
- interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
1137
- crop_fraction: float = DEFAULT_CROP_FTACTION,
1136
+ interpolation=Image.BILINEAR,
1137
+ crop_fraction: float = DEFAULT_CROP_FRACTION,
1138
1138
  ):
1139
1139
  """
1140
1140
  Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py.
@@ -1149,6 +1149,7 @@ def classify_transforms(
1149
1149
  Returns:
1150
1150
  (T.Compose): torchvision transforms
1151
1151
  """
1152
+ import torchvision.transforms as T # scope for faster 'import ultralytics'
1152
1153
 
1153
1154
  if isinstance(size, (tuple, list)):
1154
1155
  assert len(size) == 2
@@ -1157,12 +1158,12 @@ def classify_transforms(
1157
1158
  scale_size = math.floor(size / crop_fraction)
1158
1159
  scale_size = (scale_size, scale_size)
1159
1160
 
1160
- # aspect ratio is preserved, crops center within image, no borders are added, image is lost
1161
+ # Aspect ratio is preserved, crops center within image, no borders are added, image is lost
1161
1162
  if scale_size[0] == scale_size[1]:
1162
- # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
1163
+ # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
1163
1164
  tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
1164
1165
  else:
1165
- # resize shortest edge to matching target dim for non-square target
1166
+ # Resize the shortest edge to matching target dim for non-square target
1166
1167
  tfl = [T.Resize(scale_size)]
1167
1168
  tfl += [T.CenterCrop(size)]
1168
1169
 
@@ -1192,7 +1193,7 @@ def classify_augmentations(
1192
1193
  hsv_v=0.4, # image HSV-Value augmentation (fraction)
1193
1194
  force_color_jitter=False,
1194
1195
  erasing=0.0,
1195
- interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
1196
+ interpolation=Image.BILINEAR,
1196
1197
  ):
1197
1198
  """
1198
1199
  Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
@@ -1216,7 +1217,9 @@ def classify_augmentations(
1216
1217
  Returns:
1217
1218
  (T.Compose): torchvision transforms
1218
1219
  """
1219
- # Transforms to apply if albumentations not installed
1220
+ # Transforms to apply if Albumentations not installed
1221
+ import torchvision.transforms as T # scope for faster 'import ultralytics'
1222
+
1220
1223
  if not isinstance(size, int):
1221
1224
  raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
1222
1225
  scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ultralytics/data/base.py CHANGED
@@ -15,7 +15,7 @@ import psutil
15
15
  from torch.utils.data import Dataset
16
16
 
17
17
  from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
18
- from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
18
+ from .utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
19
19
 
20
20
 
21
21
  class BaseDataset(Dataset):
ultralytics/data/build.py CHANGED
@@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
22
22
  from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
23
23
  from ultralytics.utils import RANK, colorstr
24
24
  from ultralytics.utils.checks import check_file
25
- from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
25
+ from .dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
26
26
  from .utils import PIN_MEMORY
27
27
 
28
28
 
@@ -519,11 +519,12 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
519
519
  ├─ ..
520
520
  └─ NNN.txt
521
521
  """
522
+ from tqdm import tqdm
523
+
524
+ from ultralytics import SAM
522
525
  from ultralytics.data import YOLODataset
523
- from ultralytics.utils.ops import xywh2xyxy
524
526
  from ultralytics.utils import LOGGER
525
- from ultralytics import SAM
526
- from tqdm import tqdm
527
+ from ultralytics.utils.ops import xywh2xyxy
527
528
 
528
529
  # NOTE: add placeholder to pass class index check
529
530
  dataset = YOLODataset(im_dir, data=dict(names=list(range(1000))))
@@ -1,18 +1,17 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  import contextlib
3
- from itertools import repeat
3
+ import json
4
4
  from collections import defaultdict
5
+ from itertools import repeat
5
6
  from multiprocessing.pool import ThreadPool
6
7
  from pathlib import Path
7
8
 
8
9
  import cv2
9
- import json
10
10
  import numpy as np
11
11
  import torch
12
- import torchvision
13
12
  from PIL import Image
14
-
15
13
  from torch.utils.data import ConcatDataset
14
+
16
15
  from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
17
16
  from ultralytics.utils.ops import resample_segments
18
17
  from .augment import (
@@ -31,10 +30,10 @@ from .utils import (
31
30
  LOGGER,
32
31
  get_hash,
33
32
  img2label_paths,
34
- verify_image,
35
- verify_image_label,
36
33
  load_dataset_cache_file,
37
34
  save_dataset_cache_file,
35
+ verify_image,
36
+ verify_image_label,
38
37
  )
39
38
 
40
39
  # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
@@ -103,16 +102,16 @@ class YOLODataset(BaseDataset):
103
102
  nc += nc_f
104
103
  if im_file:
105
104
  x["labels"].append(
106
- dict(
107
- im_file=im_file,
108
- shape=shape,
109
- cls=lb[:, 0:1], # n, 1
110
- bboxes=lb[:, 1:], # n, 4
111
- segments=segments,
112
- keypoints=keypoint,
113
- normalized=True,
114
- bbox_format="xywh",
115
- )
105
+ {
106
+ "im_file": im_file,
107
+ "shape": shape,
108
+ "cls": lb[:, 0:1], # n, 1
109
+ "bboxes": lb[:, 1:], # n, 4
110
+ "segments": segments,
111
+ "keypoints": keypoint,
112
+ "normalized": True,
113
+ "bbox_format": "xywh",
114
+ }
116
115
  )
117
116
  if msg:
118
117
  msgs.append(msg)
@@ -245,125 +244,6 @@ class YOLODataset(BaseDataset):
245
244
  return new_batch
246
245
 
247
246
 
248
- # Classification dataloaders -------------------------------------------------------------------------------------------
249
- class ClassificationDataset(torchvision.datasets.ImageFolder):
250
- """
251
- Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
252
- augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
253
- learning models, with optional image transformations and caching mechanisms to speed up training.
254
-
255
- This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
256
- in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
257
- to ensure data integrity and consistency.
258
-
259
- Attributes:
260
- cache_ram (bool): Indicates if caching in RAM is enabled.
261
- cache_disk (bool): Indicates if caching on disk is enabled.
262
- samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
263
- file (if caching on disk), and optionally the loaded image array (if caching in RAM).
264
- torch_transforms (callable): PyTorch transforms to be applied to the images.
265
- """
266
-
267
- def __init__(self, root, args, augment=False, prefix=""):
268
- """
269
- Initialize YOLO object with root, image size, augmentations, and cache settings.
270
-
271
- Args:
272
- root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
273
- args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
274
- parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
275
- of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
276
- `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
277
- augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
278
- prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
279
- debugging. Default is an empty string.
280
- """
281
- super().__init__(root=root)
282
- if augment and args.fraction < 1.0: # reduce training fraction
283
- self.samples = self.samples[: round(len(self.samples) * args.fraction)]
284
- self.prefix = colorstr(f"{prefix}: ") if prefix else ""
285
- self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
286
- self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
287
- self.samples = self.verify_images() # filter out bad images
288
- self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
289
- scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
290
- self.torch_transforms = (
291
- classify_augmentations(
292
- size=args.imgsz,
293
- scale=scale,
294
- hflip=args.fliplr,
295
- vflip=args.flipud,
296
- erasing=args.erasing,
297
- auto_augment=args.auto_augment,
298
- hsv_h=args.hsv_h,
299
- hsv_s=args.hsv_s,
300
- hsv_v=args.hsv_v,
301
- )
302
- if augment
303
- else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
304
- )
305
-
306
- def __getitem__(self, i):
307
- """Returns subset of data and targets corresponding to given indices."""
308
- f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
309
- if self.cache_ram:
310
- if im is None: # Warning: two separate if statements required here, do not combine this with previous line
311
- im = self.samples[i][3] = cv2.imread(f)
312
- elif self.cache_disk:
313
- if not fn.exists(): # load npy
314
- np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
315
- im = np.load(fn)
316
- else: # read image
317
- im = cv2.imread(f) # BGR
318
- # Convert NumPy array to PIL image
319
- im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
320
- sample = self.torch_transforms(im)
321
- return {"img": sample, "cls": j}
322
-
323
- def __len__(self) -> int:
324
- """Return the total number of samples in the dataset."""
325
- return len(self.samples)
326
-
327
- def verify_images(self):
328
- """Verify all images in dataset."""
329
- desc = f"{self.prefix}Scanning {self.root}..."
330
- path = Path(self.root).with_suffix(".cache") # *.cache file path
331
-
332
- with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
333
- cache = load_dataset_cache_file(path) # attempt to load a *.cache file
334
- assert cache["version"] == DATASET_CACHE_VERSION # matches current version
335
- assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
336
- nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
337
- if LOCAL_RANK in {-1, 0}:
338
- d = f"{desc} {nf} images, {nc} corrupt"
339
- TQDM(None, desc=d, total=n, initial=n)
340
- if cache["msgs"]:
341
- LOGGER.info("\n".join(cache["msgs"])) # display warnings
342
- return samples
343
-
344
- # Run scan if *.cache retrieval failed
345
- nf, nc, msgs, samples, x = 0, 0, [], [], {}
346
- with ThreadPool(NUM_THREADS) as pool:
347
- results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
348
- pbar = TQDM(results, desc=desc, total=len(self.samples))
349
- for sample, nf_f, nc_f, msg in pbar:
350
- if nf_f:
351
- samples.append(sample)
352
- if msg:
353
- msgs.append(msg)
354
- nf += nf_f
355
- nc += nc_f
356
- pbar.desc = f"{desc} {nf} images, {nc} corrupt"
357
- pbar.close()
358
- if msgs:
359
- LOGGER.info("\n".join(msgs))
360
- x["hash"] = get_hash([x[0] for x in self.samples])
361
- x["results"] = nf, nc, len(samples), samples
362
- x["msgs"] = msgs # warnings
363
- save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
364
- return samples
365
-
366
-
367
247
  class YOLOMultiModalDataset(YOLODataset):
368
248
  """
369
249
  Dataset class for loading object detection and/or segmentation labels in YOLO format.
@@ -447,15 +327,15 @@ class GroundingDataset(YOLODataset):
447
327
  bboxes.append(box)
448
328
  lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
449
329
  labels.append(
450
- dict(
451
- im_file=im_file,
452
- shape=(h, w),
453
- cls=lb[:, 0:1], # n, 1
454
- bboxes=lb[:, 1:], # n, 4
455
- normalized=True,
456
- bbox_format="xywh",
457
- texts=texts,
458
- )
330
+ {
331
+ "im_file": im_file,
332
+ "shape": (h, w),
333
+ "cls": lb[:, 0:1], # n, 1
334
+ "bboxes": lb[:, 1:], # n, 4
335
+ "normalized": True,
336
+ "bbox_format": "xywh",
337
+ "texts": texts,
338
+ }
459
339
  )
460
340
  return labels
461
341
 
@@ -497,3 +377,128 @@ class SemanticDataset(BaseDataset):
497
377
  def __init__(self):
498
378
  """Initialize a SemanticDataset object."""
499
379
  super().__init__()
380
+
381
+
382
+ class ClassificationDataset:
383
+ """
384
+ Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
385
+ augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
386
+ learning models, with optional image transformations and caching mechanisms to speed up training.
387
+
388
+ This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
389
+ in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
390
+ to ensure data integrity and consistency.
391
+
392
+ Attributes:
393
+ cache_ram (bool): Indicates if caching in RAM is enabled.
394
+ cache_disk (bool): Indicates if caching on disk is enabled.
395
+ samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
396
+ file (if caching on disk), and optionally the loaded image array (if caching in RAM).
397
+ torch_transforms (callable): PyTorch transforms to be applied to the images.
398
+ """
399
+
400
+ def __init__(self, root, args, augment=False, prefix=""):
401
+ """
402
+ Initialize YOLO object with root, image size, augmentations, and cache settings.
403
+
404
+ Args:
405
+ root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
406
+ args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
407
+ parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
408
+ of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
409
+ `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
410
+ augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
411
+ prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
412
+ debugging. Default is an empty string.
413
+ """
414
+ import torchvision # scope for faster 'import ultralytics'
415
+
416
+ # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
417
+ self.base = torchvision.datasets.ImageFolder(root=root)
418
+ self.samples = self.base.samples
419
+ self.root = self.base.root
420
+
421
+ # Initialize attributes
422
+ if augment and args.fraction < 1.0: # reduce training fraction
423
+ self.samples = self.samples[: round(len(self.samples) * args.fraction)]
424
+ self.prefix = colorstr(f"{prefix}: ") if prefix else ""
425
+ self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
426
+ self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
427
+ self.samples = self.verify_images() # filter out bad images
428
+ self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
429
+ scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
430
+ self.torch_transforms = (
431
+ classify_augmentations(
432
+ size=args.imgsz,
433
+ scale=scale,
434
+ hflip=args.fliplr,
435
+ vflip=args.flipud,
436
+ erasing=args.erasing,
437
+ auto_augment=args.auto_augment,
438
+ hsv_h=args.hsv_h,
439
+ hsv_s=args.hsv_s,
440
+ hsv_v=args.hsv_v,
441
+ )
442
+ if augment
443
+ else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
444
+ )
445
+
446
+ def __getitem__(self, i):
447
+ """Returns subset of data and targets corresponding to given indices."""
448
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
449
+ if self.cache_ram:
450
+ if im is None: # Warning: two separate if statements required here, do not combine this with previous line
451
+ im = self.samples[i][3] = cv2.imread(f)
452
+ elif self.cache_disk:
453
+ if not fn.exists(): # load npy
454
+ np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
455
+ im = np.load(fn)
456
+ else: # read image
457
+ im = cv2.imread(f) # BGR
458
+ # Convert NumPy array to PIL image
459
+ im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
460
+ sample = self.torch_transforms(im)
461
+ return {"img": sample, "cls": j}
462
+
463
+ def __len__(self) -> int:
464
+ """Return the total number of samples in the dataset."""
465
+ return len(self.samples)
466
+
467
+ def verify_images(self):
468
+ """Verify all images in dataset."""
469
+ desc = f"{self.prefix}Scanning {self.root}..."
470
+ path = Path(self.root).with_suffix(".cache") # *.cache file path
471
+
472
+ with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
473
+ cache = load_dataset_cache_file(path) # attempt to load a *.cache file
474
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
475
+ assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
476
+ nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
477
+ if LOCAL_RANK in {-1, 0}:
478
+ d = f"{desc} {nf} images, {nc} corrupt"
479
+ TQDM(None, desc=d, total=n, initial=n)
480
+ if cache["msgs"]:
481
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
482
+ return samples
483
+
484
+ # Run scan if *.cache retrieval failed
485
+ nf, nc, msgs, samples, x = 0, 0, [], [], {}
486
+ with ThreadPool(NUM_THREADS) as pool:
487
+ results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
488
+ pbar = TQDM(results, desc=desc, total=len(self.samples))
489
+ for sample, nf_f, nc_f, msg in pbar:
490
+ if nf_f:
491
+ samples.append(sample)
492
+ if msg:
493
+ msgs.append(msg)
494
+ nf += nf_f
495
+ nc += nc_f
496
+ pbar.desc = f"{desc} {nf} images, {nc} corrupt"
497
+ pbar.close()
498
+ if msgs:
499
+ LOGGER.info("\n".join(msgs))
500
+ x["hash"] = get_hash([x[0] for x in self.samples])
501
+ x["results"] = nf, nc, len(samples), samples
502
+ x["msgs"] = msgs # warnings
503
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
504
+ return samples
@@ -9,14 +9,13 @@ import numpy as np
9
9
  import torch
10
10
  from PIL import Image
11
11
  from matplotlib import pyplot as plt
12
- from pandas import DataFrame
13
12
  from tqdm import tqdm
14
13
 
15
14
  from ultralytics.data.augment import Format
16
15
  from ultralytics.data.dataset import YOLODataset
17
16
  from ultralytics.data.utils import check_det_dataset
18
17
  from ultralytics.models.yolo.model import YOLO
19
- from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR
18
+ from ultralytics.utils import LOGGER, USER_CONFIG_DIR, IterableSimpleNamespace, checks
20
19
  from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
21
20
 
22
21
 
@@ -172,7 +171,7 @@ class Explorer:
172
171
 
173
172
  def sql_query(
174
173
  self, query: str, return_type: str = "pandas"
175
- ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
174
+ ) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
176
175
  """
177
176
  Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
178
177
 
@@ -204,7 +203,8 @@ class Explorer:
204
203
  table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
205
204
  if not query.startswith("SELECT") and not query.startswith("WHERE"):
206
205
  raise ValueError(
207
- f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
206
+ f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
207
+ f"clause. found {query}"
208
208
  )
209
209
  if query.startswith("WHERE"):
210
210
  query = f"SELECT * FROM 'table' {query}"
@@ -247,7 +247,7 @@ class Explorer:
247
247
  idx: Union[int, List[int]] = None,
248
248
  limit: int = 25,
249
249
  return_type: str = "pandas",
250
- ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
250
+ ) -> Any: # pandas.DataFrame or pyarrow.Table
251
251
  """
252
252
  Query the table for similar images. Accepts a single image or a list of images.
253
253
 
@@ -312,20 +312,20 @@ class Explorer:
312
312
  img = plot_query_result(similar, plot_labels=labels)
313
313
  return Image.fromarray(img)
314
314
 
315
- def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
315
+ def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
316
316
  """
317
317
  Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
318
318
  are max_dist or closer to the image in the embedding space at a given index.
319
319
 
320
320
  Args:
321
321
  max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
322
- top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
322
+ top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
323
323
  vector search. Defaults: None.
324
324
  force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
325
325
 
326
326
  Returns:
327
- (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
328
- include indices of similar images and their respective distances.
327
+ (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
328
+ and columns include indices of similar images and their respective distances.
329
329
 
330
330
  Example:
331
331
  ```python
@@ -447,12 +447,11 @@ class Explorer:
447
447
  """
448
448
  result = prompt_sql_query(query)
449
449
  try:
450
- df = self.sql_query(result)
450
+ return self.sql_query(result)
451
451
  except Exception as e:
452
452
  LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
453
453
  LOGGER.error(e)
454
454
  return None
455
- return df
456
455
 
457
456
  def visualize(self, result):
458
457
  """
@@ -3,8 +3,6 @@
3
3
  import time
4
4
  from threading import Thread
5
5
 
6
- import pandas as pd
7
-
8
6
  from ultralytics import Explorer
9
7
  from ultralytics.utils import ROOT, SETTINGS
10
8
  from ultralytics.utils.checks import check_requirements
@@ -148,12 +146,14 @@ def run_ai_query():
148
146
  'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
149
147
  )
150
148
  return
149
+ import pandas # scope for faster 'import ultralytics'
150
+
151
151
  st.session_state["error"] = None
152
152
  query = st.session_state.get("ai_query")
153
153
  if query.rstrip().lstrip():
154
154
  exp = st.session_state["explorer"]
155
155
  res = exp.ask_ai(query)
156
- if not isinstance(res, pd.DataFrame) or res.empty:
156
+ if not isinstance(res, pandas.DataFrame) or res.empty:
157
157
  st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
158
158
  return
159
159
  st.session_state["imgs"] = res["im_file"].to_list()