ultralytics 8.1.42__py3-none-any.whl → 8.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/data/augment.py +12 -9
- ultralytics/data/dataset.py +147 -142
- ultralytics/data/explorer/explorer.py +4 -6
- ultralytics/data/explorer/gui/dash.py +3 -3
- ultralytics/data/explorer/utils.py +3 -2
- ultralytics/engine/exporter.py +3 -2
- ultralytics/engine/trainer.py +1 -1
- ultralytics/models/fastsam/prompt.py +4 -1
- ultralytics/models/sam/predict.py +4 -1
- ultralytics/models/yolo/classify/train.py +2 -1
- ultralytics/solutions/heatmap.py +14 -27
- ultralytics/solutions/object_counter.py +12 -23
- ultralytics/utils/__init__.py +4 -1
- ultralytics/utils/benchmarks.py +1 -2
- ultralytics/utils/callbacks/clearml.py +4 -3
- ultralytics/utils/callbacks/wb.py +5 -5
- ultralytics/utils/checks.py +6 -9
- ultralytics/utils/metrics.py +3 -3
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/plotting.py +67 -40
- ultralytics/utils/torch_utils.py +13 -6
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/METADATA +1 -1
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/RECORD +28 -28
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py
CHANGED
ultralytics/data/augment.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Tuple, Union
|
|
|
8
8
|
import cv2
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
|
-
|
|
11
|
+
from PIL import Image
|
|
12
12
|
|
|
13
13
|
from ultralytics.utils import LOGGER, colorstr
|
|
14
14
|
from ultralytics.utils.checks import check_version
|
|
@@ -167,8 +167,8 @@ class BaseMixTransform:
|
|
|
167
167
|
text2id = {text: i for i, text in enumerate(mix_texts)}
|
|
168
168
|
|
|
169
169
|
for label in [labels] + labels["mix_labels"]:
|
|
170
|
-
for i,
|
|
171
|
-
text = label["texts"][int(
|
|
170
|
+
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
|
|
171
|
+
text = label["texts"][int(cls)]
|
|
172
172
|
label["cls"][i] = text2id[tuple(text)]
|
|
173
173
|
label["texts"] = mix_texts
|
|
174
174
|
return labels
|
|
@@ -1133,7 +1133,7 @@ def classify_transforms(
|
|
|
1133
1133
|
size=224,
|
|
1134
1134
|
mean=DEFAULT_MEAN,
|
|
1135
1135
|
std=DEFAULT_STD,
|
|
1136
|
-
interpolation
|
|
1136
|
+
interpolation=Image.BILINEAR,
|
|
1137
1137
|
crop_fraction: float = DEFAULT_CROP_FTACTION,
|
|
1138
1138
|
):
|
|
1139
1139
|
"""
|
|
@@ -1149,6 +1149,7 @@ def classify_transforms(
|
|
|
1149
1149
|
Returns:
|
|
1150
1150
|
(T.Compose): torchvision transforms
|
|
1151
1151
|
"""
|
|
1152
|
+
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
|
1152
1153
|
|
|
1153
1154
|
if isinstance(size, (tuple, list)):
|
|
1154
1155
|
assert len(size) == 2
|
|
@@ -1157,12 +1158,12 @@ def classify_transforms(
|
|
|
1157
1158
|
scale_size = math.floor(size / crop_fraction)
|
|
1158
1159
|
scale_size = (scale_size, scale_size)
|
|
1159
1160
|
|
|
1160
|
-
#
|
|
1161
|
+
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
|
1161
1162
|
if scale_size[0] == scale_size[1]:
|
|
1162
|
-
#
|
|
1163
|
+
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
|
|
1163
1164
|
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
|
|
1164
1165
|
else:
|
|
1165
|
-
#
|
|
1166
|
+
# Resize the shortest edge to matching target dim for non-square target
|
|
1166
1167
|
tfl = [T.Resize(scale_size)]
|
|
1167
1168
|
tfl += [T.CenterCrop(size)]
|
|
1168
1169
|
|
|
@@ -1192,7 +1193,7 @@ def classify_augmentations(
|
|
|
1192
1193
|
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
|
1193
1194
|
force_color_jitter=False,
|
|
1194
1195
|
erasing=0.0,
|
|
1195
|
-
interpolation
|
|
1196
|
+
interpolation=Image.BILINEAR,
|
|
1196
1197
|
):
|
|
1197
1198
|
"""
|
|
1198
1199
|
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
|
|
@@ -1216,7 +1217,9 @@ def classify_augmentations(
|
|
|
1216
1217
|
Returns:
|
|
1217
1218
|
(T.Compose): torchvision transforms
|
|
1218
1219
|
"""
|
|
1219
|
-
# Transforms to apply if
|
|
1220
|
+
# Transforms to apply if Albumentations not installed
|
|
1221
|
+
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
|
1222
|
+
|
|
1220
1223
|
if not isinstance(size, int):
|
|
1221
1224
|
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
|
1222
1225
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
ultralytics/data/dataset.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
import contextlib
|
|
3
|
-
|
|
3
|
+
import json
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from itertools import repeat
|
|
5
6
|
from multiprocessing.pool import ThreadPool
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
|
|
8
9
|
import cv2
|
|
9
|
-
import json
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
|
-
import torchvision
|
|
13
12
|
from PIL import Image
|
|
14
|
-
|
|
15
13
|
from torch.utils.data import ConcatDataset
|
|
14
|
+
|
|
16
15
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
|
17
16
|
from ultralytics.utils.ops import resample_segments
|
|
18
17
|
from .augment import (
|
|
@@ -103,16 +102,16 @@ class YOLODataset(BaseDataset):
|
|
|
103
102
|
nc += nc_f
|
|
104
103
|
if im_file:
|
|
105
104
|
x["labels"].append(
|
|
106
|
-
|
|
107
|
-
im_file
|
|
108
|
-
shape
|
|
109
|
-
cls
|
|
110
|
-
bboxes
|
|
111
|
-
segments
|
|
112
|
-
keypoints
|
|
113
|
-
normalized
|
|
114
|
-
bbox_format
|
|
115
|
-
|
|
105
|
+
{
|
|
106
|
+
"im_file": im_file,
|
|
107
|
+
"shape": shape,
|
|
108
|
+
"cls": lb[:, 0:1], # n, 1
|
|
109
|
+
"bboxes": lb[:, 1:], # n, 4
|
|
110
|
+
"segments": segments,
|
|
111
|
+
"keypoints": keypoint,
|
|
112
|
+
"normalized": True,
|
|
113
|
+
"bbox_format": "xywh",
|
|
114
|
+
}
|
|
116
115
|
)
|
|
117
116
|
if msg:
|
|
118
117
|
msgs.append(msg)
|
|
@@ -245,125 +244,6 @@ class YOLODataset(BaseDataset):
|
|
|
245
244
|
return new_batch
|
|
246
245
|
|
|
247
246
|
|
|
248
|
-
# Classification dataloaders -------------------------------------------------------------------------------------------
|
|
249
|
-
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
250
|
-
"""
|
|
251
|
-
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
|
252
|
-
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
|
253
|
-
learning models, with optional image transformations and caching mechanisms to speed up training.
|
|
254
|
-
|
|
255
|
-
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
|
256
|
-
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
|
257
|
-
to ensure data integrity and consistency.
|
|
258
|
-
|
|
259
|
-
Attributes:
|
|
260
|
-
cache_ram (bool): Indicates if caching in RAM is enabled.
|
|
261
|
-
cache_disk (bool): Indicates if caching on disk is enabled.
|
|
262
|
-
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
|
263
|
-
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
|
264
|
-
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
|
265
|
-
"""
|
|
266
|
-
|
|
267
|
-
def __init__(self, root, args, augment=False, prefix=""):
|
|
268
|
-
"""
|
|
269
|
-
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
|
273
|
-
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
|
274
|
-
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
|
275
|
-
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
|
276
|
-
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
|
277
|
-
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
|
278
|
-
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
|
279
|
-
debugging. Default is an empty string.
|
|
280
|
-
"""
|
|
281
|
-
super().__init__(root=root)
|
|
282
|
-
if augment and args.fraction < 1.0: # reduce training fraction
|
|
283
|
-
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
|
284
|
-
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
|
285
|
-
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
|
286
|
-
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
|
287
|
-
self.samples = self.verify_images() # filter out bad images
|
|
288
|
-
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
289
|
-
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
290
|
-
self.torch_transforms = (
|
|
291
|
-
classify_augmentations(
|
|
292
|
-
size=args.imgsz,
|
|
293
|
-
scale=scale,
|
|
294
|
-
hflip=args.fliplr,
|
|
295
|
-
vflip=args.flipud,
|
|
296
|
-
erasing=args.erasing,
|
|
297
|
-
auto_augment=args.auto_augment,
|
|
298
|
-
hsv_h=args.hsv_h,
|
|
299
|
-
hsv_s=args.hsv_s,
|
|
300
|
-
hsv_v=args.hsv_v,
|
|
301
|
-
)
|
|
302
|
-
if augment
|
|
303
|
-
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
def __getitem__(self, i):
|
|
307
|
-
"""Returns subset of data and targets corresponding to given indices."""
|
|
308
|
-
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
|
309
|
-
if self.cache_ram:
|
|
310
|
-
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
|
311
|
-
im = self.samples[i][3] = cv2.imread(f)
|
|
312
|
-
elif self.cache_disk:
|
|
313
|
-
if not fn.exists(): # load npy
|
|
314
|
-
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
|
315
|
-
im = np.load(fn)
|
|
316
|
-
else: # read image
|
|
317
|
-
im = cv2.imread(f) # BGR
|
|
318
|
-
# Convert NumPy array to PIL image
|
|
319
|
-
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
|
320
|
-
sample = self.torch_transforms(im)
|
|
321
|
-
return {"img": sample, "cls": j}
|
|
322
|
-
|
|
323
|
-
def __len__(self) -> int:
|
|
324
|
-
"""Return the total number of samples in the dataset."""
|
|
325
|
-
return len(self.samples)
|
|
326
|
-
|
|
327
|
-
def verify_images(self):
|
|
328
|
-
"""Verify all images in dataset."""
|
|
329
|
-
desc = f"{self.prefix}Scanning {self.root}..."
|
|
330
|
-
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
|
331
|
-
|
|
332
|
-
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
333
|
-
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
334
|
-
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
335
|
-
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
336
|
-
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
|
337
|
-
if LOCAL_RANK in {-1, 0}:
|
|
338
|
-
d = f"{desc} {nf} images, {nc} corrupt"
|
|
339
|
-
TQDM(None, desc=d, total=n, initial=n)
|
|
340
|
-
if cache["msgs"]:
|
|
341
|
-
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
342
|
-
return samples
|
|
343
|
-
|
|
344
|
-
# Run scan if *.cache retrieval failed
|
|
345
|
-
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
|
346
|
-
with ThreadPool(NUM_THREADS) as pool:
|
|
347
|
-
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
|
348
|
-
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
|
349
|
-
for sample, nf_f, nc_f, msg in pbar:
|
|
350
|
-
if nf_f:
|
|
351
|
-
samples.append(sample)
|
|
352
|
-
if msg:
|
|
353
|
-
msgs.append(msg)
|
|
354
|
-
nf += nf_f
|
|
355
|
-
nc += nc_f
|
|
356
|
-
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
|
357
|
-
pbar.close()
|
|
358
|
-
if msgs:
|
|
359
|
-
LOGGER.info("\n".join(msgs))
|
|
360
|
-
x["hash"] = get_hash([x[0] for x in self.samples])
|
|
361
|
-
x["results"] = nf, nc, len(samples), samples
|
|
362
|
-
x["msgs"] = msgs # warnings
|
|
363
|
-
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
364
|
-
return samples
|
|
365
|
-
|
|
366
|
-
|
|
367
247
|
class YOLOMultiModalDataset(YOLODataset):
|
|
368
248
|
"""
|
|
369
249
|
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
|
@@ -447,15 +327,15 @@ class GroundingDataset(YOLODataset):
|
|
|
447
327
|
bboxes.append(box)
|
|
448
328
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
|
449
329
|
labels.append(
|
|
450
|
-
|
|
451
|
-
im_file
|
|
452
|
-
shape
|
|
453
|
-
cls
|
|
454
|
-
bboxes
|
|
455
|
-
normalized
|
|
456
|
-
bbox_format
|
|
457
|
-
texts
|
|
458
|
-
|
|
330
|
+
{
|
|
331
|
+
"im_file": im_file,
|
|
332
|
+
"shape": (h, w),
|
|
333
|
+
"cls": lb[:, 0:1], # n, 1
|
|
334
|
+
"bboxes": lb[:, 1:], # n, 4
|
|
335
|
+
"normalized": True,
|
|
336
|
+
"bbox_format": "xywh",
|
|
337
|
+
"texts": texts,
|
|
338
|
+
}
|
|
459
339
|
)
|
|
460
340
|
return labels
|
|
461
341
|
|
|
@@ -497,3 +377,128 @@ class SemanticDataset(BaseDataset):
|
|
|
497
377
|
def __init__(self):
|
|
498
378
|
"""Initialize a SemanticDataset object."""
|
|
499
379
|
super().__init__()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class ClassificationDataset:
|
|
383
|
+
"""
|
|
384
|
+
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
|
385
|
+
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
|
386
|
+
learning models, with optional image transformations and caching mechanisms to speed up training.
|
|
387
|
+
|
|
388
|
+
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
|
389
|
+
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
|
390
|
+
to ensure data integrity and consistency.
|
|
391
|
+
|
|
392
|
+
Attributes:
|
|
393
|
+
cache_ram (bool): Indicates if caching in RAM is enabled.
|
|
394
|
+
cache_disk (bool): Indicates if caching on disk is enabled.
|
|
395
|
+
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
|
396
|
+
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
|
397
|
+
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
def __init__(self, root, args, augment=False, prefix=""):
|
|
401
|
+
"""
|
|
402
|
+
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
|
406
|
+
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
|
407
|
+
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
|
408
|
+
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
|
409
|
+
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
|
410
|
+
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
|
411
|
+
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
|
412
|
+
debugging. Default is an empty string.
|
|
413
|
+
"""
|
|
414
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
415
|
+
|
|
416
|
+
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
|
417
|
+
self.base = torchvision.datasets.ImageFolder(root=root)
|
|
418
|
+
self.samples = self.base.samples
|
|
419
|
+
self.root = self.base.root
|
|
420
|
+
|
|
421
|
+
# Initialize attributes
|
|
422
|
+
if augment and args.fraction < 1.0: # reduce training fraction
|
|
423
|
+
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
|
424
|
+
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
|
425
|
+
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
|
426
|
+
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
|
427
|
+
self.samples = self.verify_images() # filter out bad images
|
|
428
|
+
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
429
|
+
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
430
|
+
self.torch_transforms = (
|
|
431
|
+
classify_augmentations(
|
|
432
|
+
size=args.imgsz,
|
|
433
|
+
scale=scale,
|
|
434
|
+
hflip=args.fliplr,
|
|
435
|
+
vflip=args.flipud,
|
|
436
|
+
erasing=args.erasing,
|
|
437
|
+
auto_augment=args.auto_augment,
|
|
438
|
+
hsv_h=args.hsv_h,
|
|
439
|
+
hsv_s=args.hsv_s,
|
|
440
|
+
hsv_v=args.hsv_v,
|
|
441
|
+
)
|
|
442
|
+
if augment
|
|
443
|
+
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
def __getitem__(self, i):
|
|
447
|
+
"""Returns subset of data and targets corresponding to given indices."""
|
|
448
|
+
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
|
449
|
+
if self.cache_ram:
|
|
450
|
+
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
|
451
|
+
im = self.samples[i][3] = cv2.imread(f)
|
|
452
|
+
elif self.cache_disk:
|
|
453
|
+
if not fn.exists(): # load npy
|
|
454
|
+
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
|
455
|
+
im = np.load(fn)
|
|
456
|
+
else: # read image
|
|
457
|
+
im = cv2.imread(f) # BGR
|
|
458
|
+
# Convert NumPy array to PIL image
|
|
459
|
+
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
|
460
|
+
sample = self.torch_transforms(im)
|
|
461
|
+
return {"img": sample, "cls": j}
|
|
462
|
+
|
|
463
|
+
def __len__(self) -> int:
|
|
464
|
+
"""Return the total number of samples in the dataset."""
|
|
465
|
+
return len(self.samples)
|
|
466
|
+
|
|
467
|
+
def verify_images(self):
|
|
468
|
+
"""Verify all images in dataset."""
|
|
469
|
+
desc = f"{self.prefix}Scanning {self.root}..."
|
|
470
|
+
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
|
471
|
+
|
|
472
|
+
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
473
|
+
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
474
|
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
475
|
+
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
476
|
+
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
|
477
|
+
if LOCAL_RANK in {-1, 0}:
|
|
478
|
+
d = f"{desc} {nf} images, {nc} corrupt"
|
|
479
|
+
TQDM(None, desc=d, total=n, initial=n)
|
|
480
|
+
if cache["msgs"]:
|
|
481
|
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
482
|
+
return samples
|
|
483
|
+
|
|
484
|
+
# Run scan if *.cache retrieval failed
|
|
485
|
+
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
|
486
|
+
with ThreadPool(NUM_THREADS) as pool:
|
|
487
|
+
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
|
488
|
+
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
|
489
|
+
for sample, nf_f, nc_f, msg in pbar:
|
|
490
|
+
if nf_f:
|
|
491
|
+
samples.append(sample)
|
|
492
|
+
if msg:
|
|
493
|
+
msgs.append(msg)
|
|
494
|
+
nf += nf_f
|
|
495
|
+
nc += nc_f
|
|
496
|
+
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
|
497
|
+
pbar.close()
|
|
498
|
+
if msgs:
|
|
499
|
+
LOGGER.info("\n".join(msgs))
|
|
500
|
+
x["hash"] = get_hash([x[0] for x in self.samples])
|
|
501
|
+
x["results"] = nf, nc, len(samples), samples
|
|
502
|
+
x["msgs"] = msgs # warnings
|
|
503
|
+
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
504
|
+
return samples
|
|
@@ -9,7 +9,6 @@ import numpy as np
|
|
|
9
9
|
import torch
|
|
10
10
|
from PIL import Image
|
|
11
11
|
from matplotlib import pyplot as plt
|
|
12
|
-
from pandas import DataFrame
|
|
13
12
|
from tqdm import tqdm
|
|
14
13
|
|
|
15
14
|
from ultralytics.data.augment import Format
|
|
@@ -172,7 +171,7 @@ class Explorer:
|
|
|
172
171
|
|
|
173
172
|
def sql_query(
|
|
174
173
|
self, query: str, return_type: str = "pandas"
|
|
175
|
-
) -> Union[
|
|
174
|
+
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
|
|
176
175
|
"""
|
|
177
176
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
|
178
177
|
|
|
@@ -247,7 +246,7 @@ class Explorer:
|
|
|
247
246
|
idx: Union[int, List[int]] = None,
|
|
248
247
|
limit: int = 25,
|
|
249
248
|
return_type: str = "pandas",
|
|
250
|
-
) ->
|
|
249
|
+
) -> Any: # pandas.DataFrame or pyarrow.Table
|
|
251
250
|
"""
|
|
252
251
|
Query the table for similar images. Accepts a single image or a list of images.
|
|
253
252
|
|
|
@@ -312,7 +311,7 @@ class Explorer:
|
|
|
312
311
|
img = plot_query_result(similar, plot_labels=labels)
|
|
313
312
|
return Image.fromarray(img)
|
|
314
313
|
|
|
315
|
-
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame
|
|
314
|
+
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
|
|
316
315
|
"""
|
|
317
316
|
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
|
318
317
|
are max_dist or closer to the image in the embedding space at a given index.
|
|
@@ -447,12 +446,11 @@ class Explorer:
|
|
|
447
446
|
"""
|
|
448
447
|
result = prompt_sql_query(query)
|
|
449
448
|
try:
|
|
450
|
-
|
|
449
|
+
return self.sql_query(result)
|
|
451
450
|
except Exception as e:
|
|
452
451
|
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
|
453
452
|
LOGGER.error(e)
|
|
454
453
|
return None
|
|
455
|
-
return df
|
|
456
454
|
|
|
457
455
|
def visualize(self, result):
|
|
458
456
|
"""
|
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
import time
|
|
4
4
|
from threading import Thread
|
|
5
5
|
|
|
6
|
-
import pandas as pd
|
|
7
|
-
|
|
8
6
|
from ultralytics import Explorer
|
|
9
7
|
from ultralytics.utils import ROOT, SETTINGS
|
|
10
8
|
from ultralytics.utils.checks import check_requirements
|
|
@@ -148,12 +146,14 @@ def run_ai_query():
|
|
|
148
146
|
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
|
149
147
|
)
|
|
150
148
|
return
|
|
149
|
+
import pandas # scope for faster 'import ultralytics'
|
|
150
|
+
|
|
151
151
|
st.session_state["error"] = None
|
|
152
152
|
query = st.session_state.get("ai_query")
|
|
153
153
|
if query.rstrip().lstrip():
|
|
154
154
|
exp = st.session_state["explorer"]
|
|
155
155
|
res = exp.ask_ai(query)
|
|
156
|
-
if not isinstance(res,
|
|
156
|
+
if not isinstance(res, pandas.DataFrame) or res.empty:
|
|
157
157
|
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
|
158
158
|
return
|
|
159
159
|
st.session_state["imgs"] = res["im_file"].to_list()
|
|
@@ -5,7 +5,6 @@ from typing import List
|
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
7
7
|
import numpy as np
|
|
8
|
-
import pandas as pd
|
|
9
8
|
|
|
10
9
|
from ultralytics.data.augment import LetterBox
|
|
11
10
|
from ultralytics.utils import LOGGER as logger
|
|
@@ -64,8 +63,10 @@ def plot_query_result(similar_set, plot_labels=True):
|
|
|
64
63
|
similar_set (list): Pyarrow or pandas object containing the similar data points
|
|
65
64
|
plot_labels (bool): Whether to plot labels or not
|
|
66
65
|
"""
|
|
66
|
+
import pandas # scope for faster 'import ultralytics'
|
|
67
|
+
|
|
67
68
|
similar_set = (
|
|
68
|
-
similar_set.to_dict(orient="list") if isinstance(similar_set,
|
|
69
|
+
similar_set.to_dict(orient="list") if isinstance(similar_set, pandas.DataFrame) else similar_set.to_pydict()
|
|
69
70
|
)
|
|
70
71
|
empty_masks = [[[]]]
|
|
71
72
|
empty_boxes = [[]]
|
ultralytics/engine/exporter.py
CHANGED
|
@@ -75,6 +75,7 @@ from ultralytics.utils import (
|
|
|
75
75
|
LINUX,
|
|
76
76
|
LOGGER,
|
|
77
77
|
MACOS,
|
|
78
|
+
PYTHON_VERSION,
|
|
78
79
|
ROOT,
|
|
79
80
|
WINDOWS,
|
|
80
81
|
__version__,
|
|
@@ -83,7 +84,7 @@ from ultralytics.utils import (
|
|
|
83
84
|
get_default_args,
|
|
84
85
|
yaml_save,
|
|
85
86
|
)
|
|
86
|
-
from ultralytics.utils.checks import
|
|
87
|
+
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
|
|
87
88
|
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
|
88
89
|
from ultralytics.utils.files import file_size, spaces_in_path
|
|
89
90
|
from ultralytics.utils.ops import Profile
|
|
@@ -92,7 +93,7 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d
|
|
|
92
93
|
|
|
93
94
|
def export_formats():
|
|
94
95
|
"""YOLOv8 export formats."""
|
|
95
|
-
import pandas
|
|
96
|
+
import pandas # scope for faster 'import ultralytics'
|
|
96
97
|
|
|
97
98
|
x = [
|
|
98
99
|
["PyTorch", "-", ".pt", True, True],
|
ultralytics/engine/trainer.py
CHANGED
|
@@ -464,7 +464,7 @@ class BaseTrainer:
|
|
|
464
464
|
def save_model(self):
|
|
465
465
|
"""Save model training checkpoints with additional metadata."""
|
|
466
466
|
import io
|
|
467
|
-
import pandas as pd # scope for faster
|
|
467
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
|
468
468
|
|
|
469
469
|
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
|
470
470
|
buffer = io.BytesIO()
|
|
@@ -4,7 +4,6 @@ import os
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
7
|
-
import matplotlib.pyplot as plt
|
|
8
7
|
import numpy as np
|
|
9
8
|
import torch
|
|
10
9
|
from PIL import Image
|
|
@@ -118,6 +117,8 @@ class FastSAMPrompt:
|
|
|
118
117
|
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
|
119
118
|
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
|
120
119
|
"""
|
|
120
|
+
import matplotlib.pyplot as plt
|
|
121
|
+
|
|
121
122
|
pbar = TQDM(annotations, total=len(annotations))
|
|
122
123
|
for ann in pbar:
|
|
123
124
|
result_name = os.path.basename(ann.path)
|
|
@@ -202,6 +203,8 @@ class FastSAMPrompt:
|
|
|
202
203
|
target_height (int, optional): Target height for resizing. Defaults to 960.
|
|
203
204
|
target_width (int, optional): Target width for resizing. Defaults to 960.
|
|
204
205
|
"""
|
|
206
|
+
import matplotlib.pyplot as plt
|
|
207
|
+
|
|
205
208
|
n, h, w = annotation.shape # batch, height, width
|
|
206
209
|
|
|
207
210
|
areas = np.sum(annotation, axis=(1, 2))
|
|
@@ -11,7 +11,6 @@ segmentation tasks.
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn.functional as F
|
|
14
|
-
import torchvision
|
|
15
14
|
|
|
16
15
|
from ultralytics.data.augment import LetterBox
|
|
17
16
|
from ultralytics.engine.predictor import BasePredictor
|
|
@@ -246,6 +245,8 @@ class Predictor(BasePredictor):
|
|
|
246
245
|
Returns:
|
|
247
246
|
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
|
248
247
|
"""
|
|
248
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
249
|
+
|
|
249
250
|
self.segment_all = True
|
|
250
251
|
ih, iw = im.shape[2:]
|
|
251
252
|
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
|
@@ -449,6 +450,8 @@ class Predictor(BasePredictor):
|
|
|
449
450
|
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
|
450
451
|
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
|
451
452
|
"""
|
|
453
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
454
|
+
|
|
452
455
|
if len(masks) == 0:
|
|
453
456
|
return masks
|
|
454
457
|
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
import torchvision
|
|
5
4
|
|
|
6
5
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
7
6
|
from ultralytics.engine.trainer import BaseTrainer
|
|
@@ -59,6 +58,8 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
59
58
|
|
|
60
59
|
def setup_model(self):
|
|
61
60
|
"""Load, create or download model for any task."""
|
|
61
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
62
|
+
|
|
62
63
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
|
63
64
|
return
|
|
64
65
|
|