ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,472 +0,0 @@
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2
|
-
|
3
|
-
from io import BytesIO
|
4
|
-
from pathlib import Path
|
5
|
-
from typing import Any, List, Tuple, Union
|
6
|
-
|
7
|
-
import cv2
|
8
|
-
import numpy as np
|
9
|
-
import torch
|
10
|
-
from PIL import Image
|
11
|
-
from matplotlib import pyplot as plt
|
12
|
-
from pandas import DataFrame
|
13
|
-
from tqdm import tqdm
|
14
|
-
|
15
|
-
from ultralytics.data.augment import Format
|
16
|
-
from ultralytics.data.dataset import YOLODataset
|
17
|
-
from ultralytics.data.utils import check_det_dataset
|
18
|
-
from ultralytics.models.yolo.model import YOLO
|
19
|
-
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR
|
20
|
-
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
21
|
-
|
22
|
-
|
23
|
-
class ExplorerDataset(YOLODataset):
|
24
|
-
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
25
|
-
super().__init__(*args, data=data, **kwargs)
|
26
|
-
|
27
|
-
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
28
|
-
"""Loads 1 image from dataset index 'i' without any resize ops."""
|
29
|
-
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
30
|
-
if im is None: # not cached in RAM
|
31
|
-
if fn.exists(): # load npy
|
32
|
-
im = np.load(fn)
|
33
|
-
else: # read image
|
34
|
-
im = cv2.imread(f) # BGR
|
35
|
-
if im is None:
|
36
|
-
raise FileNotFoundError(f"Image Not Found {f}")
|
37
|
-
h0, w0 = im.shape[:2] # orig hw
|
38
|
-
return im, (h0, w0), im.shape[:2]
|
39
|
-
|
40
|
-
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
41
|
-
|
42
|
-
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
43
|
-
"""Creates transforms for dataset images without resizing."""
|
44
|
-
return Format(
|
45
|
-
bbox_format="xyxy",
|
46
|
-
normalize=False,
|
47
|
-
return_mask=self.use_segments,
|
48
|
-
return_keypoint=self.use_keypoints,
|
49
|
-
batch_idx=True,
|
50
|
-
mask_ratio=hyp.mask_ratio,
|
51
|
-
mask_overlap=hyp.overlap_mask,
|
52
|
-
)
|
53
|
-
|
54
|
-
|
55
|
-
class Explorer:
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
data: Union[str, Path] = "coco128.yaml",
|
59
|
-
model: str = "yolov8n.pt",
|
60
|
-
uri: str = USER_CONFIG_DIR / "explorer",
|
61
|
-
) -> None:
|
62
|
-
# Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
|
63
|
-
checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
|
64
|
-
import lancedb
|
65
|
-
|
66
|
-
self.connection = lancedb.connect(uri)
|
67
|
-
self.table_name = Path(data).name.lower() + "_" + model.lower()
|
68
|
-
self.sim_idx_base_name = (
|
69
|
-
f"{self.table_name}_sim_idx".lower()
|
70
|
-
) # Use this name and append thres and top_k to reuse the table
|
71
|
-
self.model = YOLO(model)
|
72
|
-
self.data = data # None
|
73
|
-
self.choice_set = None
|
74
|
-
|
75
|
-
self.table = None
|
76
|
-
self.progress = 0
|
77
|
-
|
78
|
-
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
79
|
-
"""
|
80
|
-
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
81
|
-
already exists. Pass force=True to overwrite the existing table.
|
82
|
-
|
83
|
-
Args:
|
84
|
-
force (bool): Whether to overwrite the existing table or not. Defaults to False.
|
85
|
-
split (str): Split of the dataset to use. Defaults to 'train'.
|
86
|
-
|
87
|
-
Example:
|
88
|
-
```python
|
89
|
-
exp = Explorer()
|
90
|
-
exp.create_embeddings_table()
|
91
|
-
```
|
92
|
-
"""
|
93
|
-
if self.table is not None and not force:
|
94
|
-
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
95
|
-
return
|
96
|
-
if self.table_name in self.connection.table_names() and not force:
|
97
|
-
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
98
|
-
self.table = self.connection.open_table(self.table_name)
|
99
|
-
self.progress = 1
|
100
|
-
return
|
101
|
-
if self.data is None:
|
102
|
-
raise ValueError("Data must be provided to create embeddings table")
|
103
|
-
|
104
|
-
data_info = check_det_dataset(self.data)
|
105
|
-
if split not in data_info:
|
106
|
-
raise ValueError(
|
107
|
-
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
108
|
-
)
|
109
|
-
|
110
|
-
choice_set = data_info[split]
|
111
|
-
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
|
112
|
-
self.choice_set = choice_set
|
113
|
-
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
|
114
|
-
|
115
|
-
# Create the table schema
|
116
|
-
batch = dataset[0]
|
117
|
-
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
118
|
-
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
119
|
-
table.add(
|
120
|
-
self._yield_batches(
|
121
|
-
dataset,
|
122
|
-
data_info,
|
123
|
-
self.model,
|
124
|
-
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
125
|
-
)
|
126
|
-
)
|
127
|
-
|
128
|
-
self.table = table
|
129
|
-
|
130
|
-
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
131
|
-
"""Generates batches of data for embedding, excluding specified keys."""
|
132
|
-
for i in tqdm(range(len(dataset))):
|
133
|
-
self.progress = float(i + 1) / len(dataset)
|
134
|
-
batch = dataset[i]
|
135
|
-
for k in exclude_keys:
|
136
|
-
batch.pop(k, None)
|
137
|
-
batch = sanitize_batch(batch, data_info)
|
138
|
-
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
139
|
-
yield [batch]
|
140
|
-
|
141
|
-
def query(
|
142
|
-
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
143
|
-
) -> Any: # pyarrow.Table
|
144
|
-
"""
|
145
|
-
Query the table for similar images. Accepts a single image or a list of images.
|
146
|
-
|
147
|
-
Args:
|
148
|
-
imgs (str or list): Path to the image or a list of paths to the images.
|
149
|
-
limit (int): Number of results to return.
|
150
|
-
|
151
|
-
Returns:
|
152
|
-
(pyarrow.Table): An arrow table containing the results. Supports converting to:
|
153
|
-
- pandas dataframe: `result.to_pandas()`
|
154
|
-
- dict of lists: `result.to_pydict()`
|
155
|
-
|
156
|
-
Example:
|
157
|
-
```python
|
158
|
-
exp = Explorer()
|
159
|
-
exp.create_embeddings_table()
|
160
|
-
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
|
161
|
-
```
|
162
|
-
"""
|
163
|
-
if self.table is None:
|
164
|
-
raise ValueError("Table is not created. Please create the table first.")
|
165
|
-
if isinstance(imgs, str):
|
166
|
-
imgs = [imgs]
|
167
|
-
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
168
|
-
embeds = self.model.embed(imgs)
|
169
|
-
# Get avg if multiple images are passed (len > 1)
|
170
|
-
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
171
|
-
return self.table.search(embeds).limit(limit).to_arrow()
|
172
|
-
|
173
|
-
def sql_query(
|
174
|
-
self, query: str, return_type: str = "pandas"
|
175
|
-
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
176
|
-
"""
|
177
|
-
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
178
|
-
|
179
|
-
Args:
|
180
|
-
query (str): SQL query to run.
|
181
|
-
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
182
|
-
|
183
|
-
Returns:
|
184
|
-
(pyarrow.Table): An arrow table containing the results.
|
185
|
-
|
186
|
-
Example:
|
187
|
-
```python
|
188
|
-
exp = Explorer()
|
189
|
-
exp.create_embeddings_table()
|
190
|
-
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
191
|
-
result = exp.sql_query(query)
|
192
|
-
```
|
193
|
-
"""
|
194
|
-
assert return_type in {
|
195
|
-
"pandas",
|
196
|
-
"arrow",
|
197
|
-
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
198
|
-
import duckdb
|
199
|
-
|
200
|
-
if self.table is None:
|
201
|
-
raise ValueError("Table is not created. Please create the table first.")
|
202
|
-
|
203
|
-
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
204
|
-
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
205
|
-
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
206
|
-
raise ValueError(
|
207
|
-
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
208
|
-
)
|
209
|
-
if query.startswith("WHERE"):
|
210
|
-
query = f"SELECT * FROM 'table' {query}"
|
211
|
-
LOGGER.info(f"Running query: {query}")
|
212
|
-
|
213
|
-
rs = duckdb.sql(query)
|
214
|
-
if return_type == "arrow":
|
215
|
-
return rs.arrow()
|
216
|
-
elif return_type == "pandas":
|
217
|
-
return rs.df()
|
218
|
-
|
219
|
-
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
220
|
-
"""
|
221
|
-
Plot the results of a SQL-Like query on the table.
|
222
|
-
Args:
|
223
|
-
query (str): SQL query to run.
|
224
|
-
labels (bool): Whether to plot the labels or not.
|
225
|
-
|
226
|
-
Returns:
|
227
|
-
(PIL.Image): Image containing the plot.
|
228
|
-
|
229
|
-
Example:
|
230
|
-
```python
|
231
|
-
exp = Explorer()
|
232
|
-
exp.create_embeddings_table()
|
233
|
-
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
234
|
-
result = exp.plot_sql_query(query)
|
235
|
-
```
|
236
|
-
"""
|
237
|
-
result = self.sql_query(query, return_type="arrow")
|
238
|
-
if len(result) == 0:
|
239
|
-
LOGGER.info("No results found.")
|
240
|
-
return None
|
241
|
-
img = plot_query_result(result, plot_labels=labels)
|
242
|
-
return Image.fromarray(img)
|
243
|
-
|
244
|
-
def get_similar(
|
245
|
-
self,
|
246
|
-
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
247
|
-
idx: Union[int, List[int]] = None,
|
248
|
-
limit: int = 25,
|
249
|
-
return_type: str = "pandas",
|
250
|
-
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
251
|
-
"""
|
252
|
-
Query the table for similar images. Accepts a single image or a list of images.
|
253
|
-
|
254
|
-
Args:
|
255
|
-
img (str or list): Path to the image or a list of paths to the images.
|
256
|
-
idx (int or list): Index of the image in the table or a list of indexes.
|
257
|
-
limit (int): Number of results to return. Defaults to 25.
|
258
|
-
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
259
|
-
|
260
|
-
Returns:
|
261
|
-
(pandas.DataFrame): A dataframe containing the results.
|
262
|
-
|
263
|
-
Example:
|
264
|
-
```python
|
265
|
-
exp = Explorer()
|
266
|
-
exp.create_embeddings_table()
|
267
|
-
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
268
|
-
```
|
269
|
-
"""
|
270
|
-
assert return_type in {
|
271
|
-
"pandas",
|
272
|
-
"arrow",
|
273
|
-
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
274
|
-
img = self._check_imgs_or_idxs(img, idx)
|
275
|
-
similar = self.query(img, limit=limit)
|
276
|
-
|
277
|
-
if return_type == "arrow":
|
278
|
-
return similar
|
279
|
-
elif return_type == "pandas":
|
280
|
-
return similar.to_pandas()
|
281
|
-
|
282
|
-
def plot_similar(
|
283
|
-
self,
|
284
|
-
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
285
|
-
idx: Union[int, List[int]] = None,
|
286
|
-
limit: int = 25,
|
287
|
-
labels: bool = True,
|
288
|
-
) -> Image.Image:
|
289
|
-
"""
|
290
|
-
Plot the similar images. Accepts images or indexes.
|
291
|
-
|
292
|
-
Args:
|
293
|
-
img (str or list): Path to the image or a list of paths to the images.
|
294
|
-
idx (int or list): Index of the image in the table or a list of indexes.
|
295
|
-
labels (bool): Whether to plot the labels or not.
|
296
|
-
limit (int): Number of results to return. Defaults to 25.
|
297
|
-
|
298
|
-
Returns:
|
299
|
-
(PIL.Image): Image containing the plot.
|
300
|
-
|
301
|
-
Example:
|
302
|
-
```python
|
303
|
-
exp = Explorer()
|
304
|
-
exp.create_embeddings_table()
|
305
|
-
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
306
|
-
```
|
307
|
-
"""
|
308
|
-
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
309
|
-
if len(similar) == 0:
|
310
|
-
LOGGER.info("No results found.")
|
311
|
-
return None
|
312
|
-
img = plot_query_result(similar, plot_labels=labels)
|
313
|
-
return Image.fromarray(img)
|
314
|
-
|
315
|
-
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
316
|
-
"""
|
317
|
-
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
318
|
-
are max_dist or closer to the image in the embedding space at a given index.
|
319
|
-
|
320
|
-
Args:
|
321
|
-
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
322
|
-
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
323
|
-
vector search. Defaults: None.
|
324
|
-
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
325
|
-
|
326
|
-
Returns:
|
327
|
-
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
|
328
|
-
include indices of similar images and their respective distances.
|
329
|
-
|
330
|
-
Example:
|
331
|
-
```python
|
332
|
-
exp = Explorer()
|
333
|
-
exp.create_embeddings_table()
|
334
|
-
sim_idx = exp.similarity_index()
|
335
|
-
```
|
336
|
-
"""
|
337
|
-
if self.table is None:
|
338
|
-
raise ValueError("Table is not created. Please create the table first.")
|
339
|
-
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
340
|
-
if sim_idx_table_name in self.connection.table_names() and not force:
|
341
|
-
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
342
|
-
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
343
|
-
|
344
|
-
if top_k and not (1.0 >= top_k >= 0.0):
|
345
|
-
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
346
|
-
if max_dist < 0.0:
|
347
|
-
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
348
|
-
|
349
|
-
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
350
|
-
top_k = max(top_k, 1)
|
351
|
-
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
352
|
-
im_files = features["im_file"]
|
353
|
-
embeddings = features["vector"]
|
354
|
-
|
355
|
-
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
356
|
-
|
357
|
-
def _yield_sim_idx():
|
358
|
-
"""Generates a dataframe with similarity indices and distances for images."""
|
359
|
-
for i in tqdm(range(len(embeddings))):
|
360
|
-
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
361
|
-
yield [
|
362
|
-
{
|
363
|
-
"idx": i,
|
364
|
-
"im_file": im_files[i],
|
365
|
-
"count": len(sim_idx),
|
366
|
-
"sim_im_files": sim_idx["im_file"].tolist(),
|
367
|
-
}
|
368
|
-
]
|
369
|
-
|
370
|
-
sim_table.add(_yield_sim_idx())
|
371
|
-
self.sim_index = sim_table
|
372
|
-
return sim_table.to_pandas()
|
373
|
-
|
374
|
-
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
375
|
-
"""
|
376
|
-
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
377
|
-
max_dist or closer to the image in the embedding space at a given index.
|
378
|
-
|
379
|
-
Args:
|
380
|
-
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
381
|
-
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
|
382
|
-
running vector search. Defaults to 0.01.
|
383
|
-
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
384
|
-
|
385
|
-
Returns:
|
386
|
-
(PIL.Image): Image containing the plot.
|
387
|
-
|
388
|
-
Example:
|
389
|
-
```python
|
390
|
-
exp = Explorer()
|
391
|
-
exp.create_embeddings_table()
|
392
|
-
|
393
|
-
similarity_idx_plot = exp.plot_similarity_index()
|
394
|
-
similarity_idx_plot.show() # view image preview
|
395
|
-
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
|
396
|
-
```
|
397
|
-
"""
|
398
|
-
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
399
|
-
sim_count = sim_idx["count"].tolist()
|
400
|
-
sim_count = np.array(sim_count)
|
401
|
-
|
402
|
-
indices = np.arange(len(sim_count))
|
403
|
-
|
404
|
-
# Create the bar plot
|
405
|
-
plt.bar(indices, sim_count)
|
406
|
-
|
407
|
-
# Customize the plot (optional)
|
408
|
-
plt.xlabel("data idx")
|
409
|
-
plt.ylabel("Count")
|
410
|
-
plt.title("Similarity Count")
|
411
|
-
buffer = BytesIO()
|
412
|
-
plt.savefig(buffer, format="png")
|
413
|
-
buffer.seek(0)
|
414
|
-
|
415
|
-
# Use Pillow to open the image from the buffer
|
416
|
-
return Image.fromarray(np.array(Image.open(buffer)))
|
417
|
-
|
418
|
-
def _check_imgs_or_idxs(
|
419
|
-
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
420
|
-
) -> List[np.ndarray]:
|
421
|
-
if img is None and idx is None:
|
422
|
-
raise ValueError("Either img or idx must be provided.")
|
423
|
-
if img is not None and idx is not None:
|
424
|
-
raise ValueError("Only one of img or idx must be provided.")
|
425
|
-
if idx is not None:
|
426
|
-
idx = idx if isinstance(idx, list) else [idx]
|
427
|
-
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
428
|
-
|
429
|
-
return img if isinstance(img, list) else [img]
|
430
|
-
|
431
|
-
def ask_ai(self, query):
|
432
|
-
"""
|
433
|
-
Ask AI a question.
|
434
|
-
|
435
|
-
Args:
|
436
|
-
query (str): Question to ask.
|
437
|
-
|
438
|
-
Returns:
|
439
|
-
(pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
440
|
-
|
441
|
-
Example:
|
442
|
-
```python
|
443
|
-
exp = Explorer()
|
444
|
-
exp.create_embeddings_table()
|
445
|
-
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
|
446
|
-
```
|
447
|
-
"""
|
448
|
-
result = prompt_sql_query(query)
|
449
|
-
try:
|
450
|
-
df = self.sql_query(result)
|
451
|
-
except Exception as e:
|
452
|
-
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
453
|
-
LOGGER.error(e)
|
454
|
-
return None
|
455
|
-
return df
|
456
|
-
|
457
|
-
def visualize(self, result):
|
458
|
-
"""
|
459
|
-
Visualize the results of a query. TODO.
|
460
|
-
|
461
|
-
Args:
|
462
|
-
result (pyarrow.Table): Table containing the results of a query.
|
463
|
-
"""
|
464
|
-
pass
|
465
|
-
|
466
|
-
def generate_report(self, result):
|
467
|
-
"""
|
468
|
-
Generate a report of the dataset.
|
469
|
-
|
470
|
-
TODO
|
471
|
-
"""
|
472
|
-
pass
|
@@ -1 +0,0 @@
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|