ultralytics 8.1.42__py3-none-any.whl → 8.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +2 -3
- ultralytics/cfg/models/v9/yolov9e.yaml +2 -3
- ultralytics/data/__init__.py +3 -8
- ultralytics/data/augment.py +14 -11
- ultralytics/data/base.py +1 -1
- ultralytics/data/build.py +1 -1
- ultralytics/data/converter.py +4 -3
- ultralytics/data/dataset.py +149 -144
- ultralytics/data/explorer/explorer.py +10 -11
- ultralytics/data/explorer/gui/dash.py +3 -3
- ultralytics/data/explorer/utils.py +3 -2
- ultralytics/data/loaders.py +3 -3
- ultralytics/data/utils.py +1 -1
- ultralytics/engine/exporter.py +3 -2
- ultralytics/engine/model.py +2 -1
- ultralytics/engine/trainer.py +2 -1
- ultralytics/hub/auth.py +3 -3
- ultralytics/hub/session.py +3 -3
- ultralytics/hub/utils.py +6 -6
- ultralytics/models/fastsam/prompt.py +4 -1
- ultralytics/models/rtdetr/val.py +1 -1
- ultralytics/models/sam/modules/tiny_encoder.py +2 -2
- ultralytics/models/sam/modules/transformer.py +1 -1
- ultralytics/models/sam/predict.py +16 -13
- ultralytics/models/yolo/classify/train.py +2 -1
- ultralytics/models/yolo/detect/val.py +1 -1
- ultralytics/models/yolo/model.py +1 -1
- ultralytics/models/yolo/obb/val.py +1 -1
- ultralytics/models/yolo/world/train_world.py +2 -2
- ultralytics/nn/modules/__init__.py +8 -8
- ultralytics/nn/modules/head.py +1 -1
- ultralytics/nn/tasks.py +7 -7
- ultralytics/solutions/heatmap.py +14 -27
- ultralytics/solutions/object_counter.py +12 -22
- ultralytics/trackers/byte_tracker.py +1 -1
- ultralytics/trackers/utils/kalman_filter.py +4 -4
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +56 -41
- ultralytics/utils/benchmarks.py +1 -2
- ultralytics/utils/callbacks/clearml.py +4 -3
- ultralytics/utils/callbacks/hub.py +1 -4
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +1 -0
- ultralytics/utils/callbacks/wb.py +5 -5
- ultralytics/utils/checks.py +17 -20
- ultralytics/utils/metrics.py +3 -3
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/plotting.py +67 -40
- ultralytics/utils/torch_utils.py +13 -6
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/METADATA +1 -1
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/RECORD +58 -58
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
|
-
__version__ = "8.1.
|
|
3
|
+
__version__ = "8.1.44"
|
|
4
4
|
|
|
5
5
|
from ultralytics.data.explorer.explorer import Explorer
|
|
6
6
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
|
7
7
|
from ultralytics.models.fastsam import FastSAM
|
|
8
8
|
from ultralytics.models.nas import NAS
|
|
9
|
-
from ultralytics.utils import ASSETS, SETTINGS
|
|
9
|
+
from ultralytics.utils import ASSETS, SETTINGS
|
|
10
10
|
from ultralytics.utils.checks import check_yolo as checks
|
|
11
11
|
from ultralytics.utils.downloads import download
|
|
12
12
|
|
|
13
|
+
settings = SETTINGS
|
|
13
14
|
__all__ = (
|
|
14
15
|
"__version__",
|
|
15
16
|
"ASSETS",
|
|
@@ -17,13 +17,13 @@ backbone:
|
|
|
17
17
|
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7
|
|
18
18
|
- [-1, 1, ADown, [1024]] # 8-P5/32
|
|
19
19
|
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
- [1, 1, CBLinear, [[64]]] # 10
|
|
22
22
|
- [3, 1, CBLinear, [[64, 128]]] # 11
|
|
23
23
|
- [5, 1, CBLinear, [[64, 128, 256]]] # 12
|
|
24
24
|
- [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13
|
|
25
25
|
- [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
- [0, 1, Conv, [64, 3, 2]] # 15-P1/2
|
|
28
28
|
- [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16
|
|
29
29
|
- [-1, 1, Conv, [128, 3, 2]] # 17-P2/4
|
|
@@ -58,5 +58,4 @@ head:
|
|
|
58
58
|
- [[-1, 29], 1, Concat, [1]] # cat head P5
|
|
59
59
|
- [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large)
|
|
60
60
|
|
|
61
|
-
# segment
|
|
62
61
|
- [[35, 38, 41], 1, Segment, [nc, 32, 256]] # Segment (P3, P4, P5)
|
|
@@ -17,13 +17,13 @@ backbone:
|
|
|
17
17
|
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7
|
|
18
18
|
- [-1, 1, ADown, [1024]] # 8-P5/32
|
|
19
19
|
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
- [1, 1, CBLinear, [[64]]] # 10
|
|
22
22
|
- [3, 1, CBLinear, [[64, 128]]] # 11
|
|
23
23
|
- [5, 1, CBLinear, [[64, 128, 256]]] # 12
|
|
24
24
|
- [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13
|
|
25
25
|
- [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
- [0, 1, Conv, [64, 3, 2]] # 15-P1/2
|
|
28
28
|
- [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16
|
|
29
29
|
- [-1, 1, Conv, [128, 3, 2]] # 17-P2/4
|
|
@@ -58,5 +58,4 @@ head:
|
|
|
58
58
|
- [[-1, 29], 1, Concat, [1]] # cat head P5
|
|
59
59
|
- [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large)
|
|
60
60
|
|
|
61
|
-
# detect
|
|
62
61
|
- [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/data/__init__.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
from .base import BaseDataset
|
|
4
|
-
from .build import
|
|
5
|
-
build_dataloader,
|
|
6
|
-
build_yolo_dataset,
|
|
7
|
-
build_grounding,
|
|
8
|
-
load_inference_source,
|
|
9
|
-
)
|
|
4
|
+
from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source
|
|
10
5
|
from .dataset import (
|
|
11
6
|
ClassificationDataset,
|
|
7
|
+
GroundingDataset,
|
|
12
8
|
SemanticDataset,
|
|
9
|
+
YOLOConcatDataset,
|
|
13
10
|
YOLODataset,
|
|
14
11
|
YOLOMultiModalDataset,
|
|
15
|
-
GroundingDataset,
|
|
16
|
-
YOLOConcatDataset,
|
|
17
12
|
)
|
|
18
13
|
|
|
19
14
|
__all__ = (
|
ultralytics/data/augment.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Tuple, Union
|
|
|
8
8
|
import cv2
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
|
-
|
|
11
|
+
from PIL import Image
|
|
12
12
|
|
|
13
13
|
from ultralytics.utils import LOGGER, colorstr
|
|
14
14
|
from ultralytics.utils.checks import check_version
|
|
@@ -20,7 +20,7 @@ from .utils import polygons2masks, polygons2masks_overlap
|
|
|
20
20
|
|
|
21
21
|
DEFAULT_MEAN = (0.0, 0.0, 0.0)
|
|
22
22
|
DEFAULT_STD = (1.0, 1.0, 1.0)
|
|
23
|
-
|
|
23
|
+
DEFAULT_CROP_FRACTION = 1.0
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
|
|
@@ -167,8 +167,8 @@ class BaseMixTransform:
|
|
|
167
167
|
text2id = {text: i for i, text in enumerate(mix_texts)}
|
|
168
168
|
|
|
169
169
|
for label in [labels] + labels["mix_labels"]:
|
|
170
|
-
for i,
|
|
171
|
-
text = label["texts"][int(
|
|
170
|
+
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
|
|
171
|
+
text = label["texts"][int(cls)]
|
|
172
172
|
label["cls"][i] = text2id[tuple(text)]
|
|
173
173
|
label["texts"] = mix_texts
|
|
174
174
|
return labels
|
|
@@ -1133,8 +1133,8 @@ def classify_transforms(
|
|
|
1133
1133
|
size=224,
|
|
1134
1134
|
mean=DEFAULT_MEAN,
|
|
1135
1135
|
std=DEFAULT_STD,
|
|
1136
|
-
interpolation
|
|
1137
|
-
crop_fraction: float =
|
|
1136
|
+
interpolation=Image.BILINEAR,
|
|
1137
|
+
crop_fraction: float = DEFAULT_CROP_FRACTION,
|
|
1138
1138
|
):
|
|
1139
1139
|
"""
|
|
1140
1140
|
Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py.
|
|
@@ -1149,6 +1149,7 @@ def classify_transforms(
|
|
|
1149
1149
|
Returns:
|
|
1150
1150
|
(T.Compose): torchvision transforms
|
|
1151
1151
|
"""
|
|
1152
|
+
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
|
1152
1153
|
|
|
1153
1154
|
if isinstance(size, (tuple, list)):
|
|
1154
1155
|
assert len(size) == 2
|
|
@@ -1157,12 +1158,12 @@ def classify_transforms(
|
|
|
1157
1158
|
scale_size = math.floor(size / crop_fraction)
|
|
1158
1159
|
scale_size = (scale_size, scale_size)
|
|
1159
1160
|
|
|
1160
|
-
#
|
|
1161
|
+
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
|
1161
1162
|
if scale_size[0] == scale_size[1]:
|
|
1162
|
-
#
|
|
1163
|
+
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
|
|
1163
1164
|
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
|
|
1164
1165
|
else:
|
|
1165
|
-
#
|
|
1166
|
+
# Resize the shortest edge to matching target dim for non-square target
|
|
1166
1167
|
tfl = [T.Resize(scale_size)]
|
|
1167
1168
|
tfl += [T.CenterCrop(size)]
|
|
1168
1169
|
|
|
@@ -1192,7 +1193,7 @@ def classify_augmentations(
|
|
|
1192
1193
|
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
|
1193
1194
|
force_color_jitter=False,
|
|
1194
1195
|
erasing=0.0,
|
|
1195
|
-
interpolation
|
|
1196
|
+
interpolation=Image.BILINEAR,
|
|
1196
1197
|
):
|
|
1197
1198
|
"""
|
|
1198
1199
|
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
|
|
@@ -1216,7 +1217,9 @@ def classify_augmentations(
|
|
|
1216
1217
|
Returns:
|
|
1217
1218
|
(T.Compose): torchvision transforms
|
|
1218
1219
|
"""
|
|
1219
|
-
# Transforms to apply if
|
|
1220
|
+
# Transforms to apply if Albumentations not installed
|
|
1221
|
+
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
|
1222
|
+
|
|
1220
1223
|
if not isinstance(size, int):
|
|
1221
1224
|
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
|
1222
1225
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
ultralytics/data/base.py
CHANGED
|
@@ -15,7 +15,7 @@ import psutil
|
|
|
15
15
|
from torch.utils.data import Dataset
|
|
16
16
|
|
|
17
17
|
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
|
18
|
-
from .utils import
|
|
18
|
+
from .utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class BaseDataset(Dataset):
|
ultralytics/data/build.py
CHANGED
|
@@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
|
|
|
22
22
|
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
|
23
23
|
from ultralytics.utils import RANK, colorstr
|
|
24
24
|
from ultralytics.utils.checks import check_file
|
|
25
|
-
from .dataset import YOLODataset, YOLOMultiModalDataset
|
|
25
|
+
from .dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
|
|
26
26
|
from .utils import PIN_MEMORY
|
|
27
27
|
|
|
28
28
|
|
ultralytics/data/converter.py
CHANGED
|
@@ -519,11 +519,12 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
|
|
|
519
519
|
├─ ..
|
|
520
520
|
└─ NNN.txt
|
|
521
521
|
"""
|
|
522
|
+
from tqdm import tqdm
|
|
523
|
+
|
|
524
|
+
from ultralytics import SAM
|
|
522
525
|
from ultralytics.data import YOLODataset
|
|
523
|
-
from ultralytics.utils.ops import xywh2xyxy
|
|
524
526
|
from ultralytics.utils import LOGGER
|
|
525
|
-
from ultralytics import
|
|
526
|
-
from tqdm import tqdm
|
|
527
|
+
from ultralytics.utils.ops import xywh2xyxy
|
|
527
528
|
|
|
528
529
|
# NOTE: add placeholder to pass class index check
|
|
529
530
|
dataset = YOLODataset(im_dir, data=dict(names=list(range(1000))))
|
ultralytics/data/dataset.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
import contextlib
|
|
3
|
-
|
|
3
|
+
import json
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from itertools import repeat
|
|
5
6
|
from multiprocessing.pool import ThreadPool
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
|
|
8
9
|
import cv2
|
|
9
|
-
import json
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
|
-
import torchvision
|
|
13
12
|
from PIL import Image
|
|
14
|
-
|
|
15
13
|
from torch.utils.data import ConcatDataset
|
|
14
|
+
|
|
16
15
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
|
17
16
|
from ultralytics.utils.ops import resample_segments
|
|
18
17
|
from .augment import (
|
|
@@ -31,10 +30,10 @@ from .utils import (
|
|
|
31
30
|
LOGGER,
|
|
32
31
|
get_hash,
|
|
33
32
|
img2label_paths,
|
|
34
|
-
verify_image,
|
|
35
|
-
verify_image_label,
|
|
36
33
|
load_dataset_cache_file,
|
|
37
34
|
save_dataset_cache_file,
|
|
35
|
+
verify_image,
|
|
36
|
+
verify_image_label,
|
|
38
37
|
)
|
|
39
38
|
|
|
40
39
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
|
@@ -103,16 +102,16 @@ class YOLODataset(BaseDataset):
|
|
|
103
102
|
nc += nc_f
|
|
104
103
|
if im_file:
|
|
105
104
|
x["labels"].append(
|
|
106
|
-
|
|
107
|
-
im_file
|
|
108
|
-
shape
|
|
109
|
-
cls
|
|
110
|
-
bboxes
|
|
111
|
-
segments
|
|
112
|
-
keypoints
|
|
113
|
-
normalized
|
|
114
|
-
bbox_format
|
|
115
|
-
|
|
105
|
+
{
|
|
106
|
+
"im_file": im_file,
|
|
107
|
+
"shape": shape,
|
|
108
|
+
"cls": lb[:, 0:1], # n, 1
|
|
109
|
+
"bboxes": lb[:, 1:], # n, 4
|
|
110
|
+
"segments": segments,
|
|
111
|
+
"keypoints": keypoint,
|
|
112
|
+
"normalized": True,
|
|
113
|
+
"bbox_format": "xywh",
|
|
114
|
+
}
|
|
116
115
|
)
|
|
117
116
|
if msg:
|
|
118
117
|
msgs.append(msg)
|
|
@@ -245,125 +244,6 @@ class YOLODataset(BaseDataset):
|
|
|
245
244
|
return new_batch
|
|
246
245
|
|
|
247
246
|
|
|
248
|
-
# Classification dataloaders -------------------------------------------------------------------------------------------
|
|
249
|
-
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
250
|
-
"""
|
|
251
|
-
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
|
252
|
-
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
|
253
|
-
learning models, with optional image transformations and caching mechanisms to speed up training.
|
|
254
|
-
|
|
255
|
-
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
|
256
|
-
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
|
257
|
-
to ensure data integrity and consistency.
|
|
258
|
-
|
|
259
|
-
Attributes:
|
|
260
|
-
cache_ram (bool): Indicates if caching in RAM is enabled.
|
|
261
|
-
cache_disk (bool): Indicates if caching on disk is enabled.
|
|
262
|
-
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
|
263
|
-
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
|
264
|
-
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
|
265
|
-
"""
|
|
266
|
-
|
|
267
|
-
def __init__(self, root, args, augment=False, prefix=""):
|
|
268
|
-
"""
|
|
269
|
-
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
|
273
|
-
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
|
274
|
-
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
|
275
|
-
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
|
276
|
-
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
|
277
|
-
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
|
278
|
-
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
|
279
|
-
debugging. Default is an empty string.
|
|
280
|
-
"""
|
|
281
|
-
super().__init__(root=root)
|
|
282
|
-
if augment and args.fraction < 1.0: # reduce training fraction
|
|
283
|
-
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
|
284
|
-
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
|
285
|
-
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
|
286
|
-
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
|
287
|
-
self.samples = self.verify_images() # filter out bad images
|
|
288
|
-
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
289
|
-
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
290
|
-
self.torch_transforms = (
|
|
291
|
-
classify_augmentations(
|
|
292
|
-
size=args.imgsz,
|
|
293
|
-
scale=scale,
|
|
294
|
-
hflip=args.fliplr,
|
|
295
|
-
vflip=args.flipud,
|
|
296
|
-
erasing=args.erasing,
|
|
297
|
-
auto_augment=args.auto_augment,
|
|
298
|
-
hsv_h=args.hsv_h,
|
|
299
|
-
hsv_s=args.hsv_s,
|
|
300
|
-
hsv_v=args.hsv_v,
|
|
301
|
-
)
|
|
302
|
-
if augment
|
|
303
|
-
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
def __getitem__(self, i):
|
|
307
|
-
"""Returns subset of data and targets corresponding to given indices."""
|
|
308
|
-
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
|
309
|
-
if self.cache_ram:
|
|
310
|
-
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
|
311
|
-
im = self.samples[i][3] = cv2.imread(f)
|
|
312
|
-
elif self.cache_disk:
|
|
313
|
-
if not fn.exists(): # load npy
|
|
314
|
-
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
|
315
|
-
im = np.load(fn)
|
|
316
|
-
else: # read image
|
|
317
|
-
im = cv2.imread(f) # BGR
|
|
318
|
-
# Convert NumPy array to PIL image
|
|
319
|
-
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
|
320
|
-
sample = self.torch_transforms(im)
|
|
321
|
-
return {"img": sample, "cls": j}
|
|
322
|
-
|
|
323
|
-
def __len__(self) -> int:
|
|
324
|
-
"""Return the total number of samples in the dataset."""
|
|
325
|
-
return len(self.samples)
|
|
326
|
-
|
|
327
|
-
def verify_images(self):
|
|
328
|
-
"""Verify all images in dataset."""
|
|
329
|
-
desc = f"{self.prefix}Scanning {self.root}..."
|
|
330
|
-
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
|
331
|
-
|
|
332
|
-
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
333
|
-
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
334
|
-
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
335
|
-
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
336
|
-
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
|
337
|
-
if LOCAL_RANK in {-1, 0}:
|
|
338
|
-
d = f"{desc} {nf} images, {nc} corrupt"
|
|
339
|
-
TQDM(None, desc=d, total=n, initial=n)
|
|
340
|
-
if cache["msgs"]:
|
|
341
|
-
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
342
|
-
return samples
|
|
343
|
-
|
|
344
|
-
# Run scan if *.cache retrieval failed
|
|
345
|
-
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
|
346
|
-
with ThreadPool(NUM_THREADS) as pool:
|
|
347
|
-
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
|
348
|
-
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
|
349
|
-
for sample, nf_f, nc_f, msg in pbar:
|
|
350
|
-
if nf_f:
|
|
351
|
-
samples.append(sample)
|
|
352
|
-
if msg:
|
|
353
|
-
msgs.append(msg)
|
|
354
|
-
nf += nf_f
|
|
355
|
-
nc += nc_f
|
|
356
|
-
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
|
357
|
-
pbar.close()
|
|
358
|
-
if msgs:
|
|
359
|
-
LOGGER.info("\n".join(msgs))
|
|
360
|
-
x["hash"] = get_hash([x[0] for x in self.samples])
|
|
361
|
-
x["results"] = nf, nc, len(samples), samples
|
|
362
|
-
x["msgs"] = msgs # warnings
|
|
363
|
-
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
364
|
-
return samples
|
|
365
|
-
|
|
366
|
-
|
|
367
247
|
class YOLOMultiModalDataset(YOLODataset):
|
|
368
248
|
"""
|
|
369
249
|
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
|
@@ -447,15 +327,15 @@ class GroundingDataset(YOLODataset):
|
|
|
447
327
|
bboxes.append(box)
|
|
448
328
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
|
449
329
|
labels.append(
|
|
450
|
-
|
|
451
|
-
im_file
|
|
452
|
-
shape
|
|
453
|
-
cls
|
|
454
|
-
bboxes
|
|
455
|
-
normalized
|
|
456
|
-
bbox_format
|
|
457
|
-
texts
|
|
458
|
-
|
|
330
|
+
{
|
|
331
|
+
"im_file": im_file,
|
|
332
|
+
"shape": (h, w),
|
|
333
|
+
"cls": lb[:, 0:1], # n, 1
|
|
334
|
+
"bboxes": lb[:, 1:], # n, 4
|
|
335
|
+
"normalized": True,
|
|
336
|
+
"bbox_format": "xywh",
|
|
337
|
+
"texts": texts,
|
|
338
|
+
}
|
|
459
339
|
)
|
|
460
340
|
return labels
|
|
461
341
|
|
|
@@ -497,3 +377,128 @@ class SemanticDataset(BaseDataset):
|
|
|
497
377
|
def __init__(self):
|
|
498
378
|
"""Initialize a SemanticDataset object."""
|
|
499
379
|
super().__init__()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class ClassificationDataset:
|
|
383
|
+
"""
|
|
384
|
+
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
|
385
|
+
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
|
386
|
+
learning models, with optional image transformations and caching mechanisms to speed up training.
|
|
387
|
+
|
|
388
|
+
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
|
389
|
+
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
|
390
|
+
to ensure data integrity and consistency.
|
|
391
|
+
|
|
392
|
+
Attributes:
|
|
393
|
+
cache_ram (bool): Indicates if caching in RAM is enabled.
|
|
394
|
+
cache_disk (bool): Indicates if caching on disk is enabled.
|
|
395
|
+
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
|
396
|
+
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
|
397
|
+
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
def __init__(self, root, args, augment=False, prefix=""):
|
|
401
|
+
"""
|
|
402
|
+
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
|
406
|
+
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
|
407
|
+
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
|
408
|
+
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
|
409
|
+
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
|
410
|
+
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
|
411
|
+
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
|
412
|
+
debugging. Default is an empty string.
|
|
413
|
+
"""
|
|
414
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
415
|
+
|
|
416
|
+
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
|
417
|
+
self.base = torchvision.datasets.ImageFolder(root=root)
|
|
418
|
+
self.samples = self.base.samples
|
|
419
|
+
self.root = self.base.root
|
|
420
|
+
|
|
421
|
+
# Initialize attributes
|
|
422
|
+
if augment and args.fraction < 1.0: # reduce training fraction
|
|
423
|
+
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
|
424
|
+
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
|
425
|
+
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
|
426
|
+
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
|
427
|
+
self.samples = self.verify_images() # filter out bad images
|
|
428
|
+
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
429
|
+
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
430
|
+
self.torch_transforms = (
|
|
431
|
+
classify_augmentations(
|
|
432
|
+
size=args.imgsz,
|
|
433
|
+
scale=scale,
|
|
434
|
+
hflip=args.fliplr,
|
|
435
|
+
vflip=args.flipud,
|
|
436
|
+
erasing=args.erasing,
|
|
437
|
+
auto_augment=args.auto_augment,
|
|
438
|
+
hsv_h=args.hsv_h,
|
|
439
|
+
hsv_s=args.hsv_s,
|
|
440
|
+
hsv_v=args.hsv_v,
|
|
441
|
+
)
|
|
442
|
+
if augment
|
|
443
|
+
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
def __getitem__(self, i):
|
|
447
|
+
"""Returns subset of data and targets corresponding to given indices."""
|
|
448
|
+
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
|
449
|
+
if self.cache_ram:
|
|
450
|
+
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
|
451
|
+
im = self.samples[i][3] = cv2.imread(f)
|
|
452
|
+
elif self.cache_disk:
|
|
453
|
+
if not fn.exists(): # load npy
|
|
454
|
+
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
|
455
|
+
im = np.load(fn)
|
|
456
|
+
else: # read image
|
|
457
|
+
im = cv2.imread(f) # BGR
|
|
458
|
+
# Convert NumPy array to PIL image
|
|
459
|
+
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
|
460
|
+
sample = self.torch_transforms(im)
|
|
461
|
+
return {"img": sample, "cls": j}
|
|
462
|
+
|
|
463
|
+
def __len__(self) -> int:
|
|
464
|
+
"""Return the total number of samples in the dataset."""
|
|
465
|
+
return len(self.samples)
|
|
466
|
+
|
|
467
|
+
def verify_images(self):
|
|
468
|
+
"""Verify all images in dataset."""
|
|
469
|
+
desc = f"{self.prefix}Scanning {self.root}..."
|
|
470
|
+
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
|
471
|
+
|
|
472
|
+
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
473
|
+
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
474
|
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
475
|
+
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
476
|
+
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
|
477
|
+
if LOCAL_RANK in {-1, 0}:
|
|
478
|
+
d = f"{desc} {nf} images, {nc} corrupt"
|
|
479
|
+
TQDM(None, desc=d, total=n, initial=n)
|
|
480
|
+
if cache["msgs"]:
|
|
481
|
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
482
|
+
return samples
|
|
483
|
+
|
|
484
|
+
# Run scan if *.cache retrieval failed
|
|
485
|
+
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
|
486
|
+
with ThreadPool(NUM_THREADS) as pool:
|
|
487
|
+
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
|
488
|
+
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
|
489
|
+
for sample, nf_f, nc_f, msg in pbar:
|
|
490
|
+
if nf_f:
|
|
491
|
+
samples.append(sample)
|
|
492
|
+
if msg:
|
|
493
|
+
msgs.append(msg)
|
|
494
|
+
nf += nf_f
|
|
495
|
+
nc += nc_f
|
|
496
|
+
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
|
497
|
+
pbar.close()
|
|
498
|
+
if msgs:
|
|
499
|
+
LOGGER.info("\n".join(msgs))
|
|
500
|
+
x["hash"] = get_hash([x[0] for x in self.samples])
|
|
501
|
+
x["results"] = nf, nc, len(samples), samples
|
|
502
|
+
x["msgs"] = msgs # warnings
|
|
503
|
+
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
504
|
+
return samples
|
|
@@ -9,14 +9,13 @@ import numpy as np
|
|
|
9
9
|
import torch
|
|
10
10
|
from PIL import Image
|
|
11
11
|
from matplotlib import pyplot as plt
|
|
12
|
-
from pandas import DataFrame
|
|
13
12
|
from tqdm import tqdm
|
|
14
13
|
|
|
15
14
|
from ultralytics.data.augment import Format
|
|
16
15
|
from ultralytics.data.dataset import YOLODataset
|
|
17
16
|
from ultralytics.data.utils import check_det_dataset
|
|
18
17
|
from ultralytics.models.yolo.model import YOLO
|
|
19
|
-
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
|
|
18
|
+
from ultralytics.utils import LOGGER, USER_CONFIG_DIR, IterableSimpleNamespace, checks
|
|
20
19
|
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
|
21
20
|
|
|
22
21
|
|
|
@@ -172,7 +171,7 @@ class Explorer:
|
|
|
172
171
|
|
|
173
172
|
def sql_query(
|
|
174
173
|
self, query: str, return_type: str = "pandas"
|
|
175
|
-
) -> Union[
|
|
174
|
+
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
|
|
176
175
|
"""
|
|
177
176
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
|
178
177
|
|
|
@@ -204,7 +203,8 @@ class Explorer:
|
|
|
204
203
|
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
|
205
204
|
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
|
206
205
|
raise ValueError(
|
|
207
|
-
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE
|
|
206
|
+
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
|
|
207
|
+
f"clause. found {query}"
|
|
208
208
|
)
|
|
209
209
|
if query.startswith("WHERE"):
|
|
210
210
|
query = f"SELECT * FROM 'table' {query}"
|
|
@@ -247,7 +247,7 @@ class Explorer:
|
|
|
247
247
|
idx: Union[int, List[int]] = None,
|
|
248
248
|
limit: int = 25,
|
|
249
249
|
return_type: str = "pandas",
|
|
250
|
-
) ->
|
|
250
|
+
) -> Any: # pandas.DataFrame or pyarrow.Table
|
|
251
251
|
"""
|
|
252
252
|
Query the table for similar images. Accepts a single image or a list of images.
|
|
253
253
|
|
|
@@ -312,20 +312,20 @@ class Explorer:
|
|
|
312
312
|
img = plot_query_result(similar, plot_labels=labels)
|
|
313
313
|
return Image.fromarray(img)
|
|
314
314
|
|
|
315
|
-
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame
|
|
315
|
+
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
|
|
316
316
|
"""
|
|
317
317
|
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
|
318
318
|
are max_dist or closer to the image in the embedding space at a given index.
|
|
319
319
|
|
|
320
320
|
Args:
|
|
321
321
|
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
|
322
|
-
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit
|
|
322
|
+
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
|
|
323
323
|
vector search. Defaults: None.
|
|
324
324
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
325
325
|
|
|
326
326
|
Returns:
|
|
327
|
-
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
|
|
328
|
-
|
|
327
|
+
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
|
|
328
|
+
and columns include indices of similar images and their respective distances.
|
|
329
329
|
|
|
330
330
|
Example:
|
|
331
331
|
```python
|
|
@@ -447,12 +447,11 @@ class Explorer:
|
|
|
447
447
|
"""
|
|
448
448
|
result = prompt_sql_query(query)
|
|
449
449
|
try:
|
|
450
|
-
|
|
450
|
+
return self.sql_query(result)
|
|
451
451
|
except Exception as e:
|
|
452
452
|
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
|
453
453
|
LOGGER.error(e)
|
|
454
454
|
return None
|
|
455
|
-
return df
|
|
456
455
|
|
|
457
456
|
def visualize(self, result):
|
|
458
457
|
"""
|
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
import time
|
|
4
4
|
from threading import Thread
|
|
5
5
|
|
|
6
|
-
import pandas as pd
|
|
7
|
-
|
|
8
6
|
from ultralytics import Explorer
|
|
9
7
|
from ultralytics.utils import ROOT, SETTINGS
|
|
10
8
|
from ultralytics.utils.checks import check_requirements
|
|
@@ -148,12 +146,14 @@ def run_ai_query():
|
|
|
148
146
|
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
|
149
147
|
)
|
|
150
148
|
return
|
|
149
|
+
import pandas # scope for faster 'import ultralytics'
|
|
150
|
+
|
|
151
151
|
st.session_state["error"] = None
|
|
152
152
|
query = st.session_state.get("ai_query")
|
|
153
153
|
if query.rstrip().lstrip():
|
|
154
154
|
exp = st.session_state["explorer"]
|
|
155
155
|
res = exp.ask_ai(query)
|
|
156
|
-
if not isinstance(res,
|
|
156
|
+
if not isinstance(res, pandas.DataFrame) or res.empty:
|
|
157
157
|
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
|
158
158
|
return
|
|
159
159
|
st.session_state["imgs"] = res["im_file"].to_list()
|