vectordb-bench 0.0.1__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +30 -0
- vectordb_bench/__main__.py +39 -0
- vectordb_bench/backend/__init__.py +0 -0
- vectordb_bench/backend/assembler.py +57 -0
- vectordb_bench/backend/cases.py +124 -0
- vectordb_bench/backend/clients/__init__.py +57 -0
- vectordb_bench/backend/clients/api.py +179 -0
- vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
- vectordb_bench/backend/clients/milvus/config.py +123 -0
- vectordb_bench/backend/clients/milvus/milvus.py +182 -0
- vectordb_bench/backend/clients/pinecone/config.py +15 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
- vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
- vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
- vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
- vectordb_bench/backend/dataset.py +393 -0
- vectordb_bench/backend/result_collector.py +15 -0
- vectordb_bench/backend/runner/__init__.py +12 -0
- vectordb_bench/backend/runner/mp_runner.py +124 -0
- vectordb_bench/backend/runner/serial_runner.py +164 -0
- vectordb_bench/backend/task_runner.py +290 -0
- vectordb_bench/backend/utils.py +85 -0
- vectordb_bench/base.py +6 -0
- vectordb_bench/frontend/components/check_results/charts.py +175 -0
- vectordb_bench/frontend/components/check_results/data.py +86 -0
- vectordb_bench/frontend/components/check_results/filters.py +97 -0
- vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
- vectordb_bench/frontend/components/check_results/nav.py +21 -0
- vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
- vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
- vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
- vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
- vectordb_bench/frontend/const.py +391 -0
- vectordb_bench/frontend/pages/qps_with_price.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +59 -0
- vectordb_bench/frontend/utils.py +6 -0
- vectordb_bench/frontend/vdb_benchmark.py +42 -0
- vectordb_bench/interface.py +239 -0
- vectordb_bench/log_util.py +103 -0
- vectordb_bench/metric.py +53 -0
- vectordb_bench/models.py +234 -0
- vectordb_bench/results/result_20230609_standard.json +3228 -0
- vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
- vectordb_bench-0.0.1.dist-info/METADATA +226 -0
- vectordb_bench-0.0.1.dist-info/RECORD +56 -0
- vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
- vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
- vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,30 @@
|
|
1
|
+
import environs
|
2
|
+
import inspect
|
3
|
+
import pathlib
|
4
|
+
from . import log_util
|
5
|
+
|
6
|
+
|
7
|
+
env = environs.Env()
|
8
|
+
env.read_env(".env")
|
9
|
+
|
10
|
+
class config:
|
11
|
+
LOG_LEVEL = env.str("LOG_LEVEL", "INFO")
|
12
|
+
|
13
|
+
DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com/benchmark/")
|
14
|
+
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
|
15
|
+
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 5000)
|
16
|
+
|
17
|
+
DROP_OLD = env.bool("DROP_OLD", True)
|
18
|
+
USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
|
19
|
+
|
20
|
+
RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
|
21
|
+
CASE_TIMEOUT_IN_SECOND = 24 * 60 * 60
|
22
|
+
|
23
|
+
|
24
|
+
def display(self) -> str:
|
25
|
+
tmp = [i for i in inspect.getmembers(self)
|
26
|
+
if not inspect.ismethod(i[1]) and not i[0].startswith('_') \
|
27
|
+
]
|
28
|
+
return tmp
|
29
|
+
|
30
|
+
log_util.init(config.LOG_LEVEL)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import traceback
|
2
|
+
import logging
|
3
|
+
import subprocess
|
4
|
+
import os
|
5
|
+
from . import config
|
6
|
+
|
7
|
+
log = logging.getLogger("vectordb_bench")
|
8
|
+
|
9
|
+
|
10
|
+
def main():
|
11
|
+
log.info(f"all configs: {config().display()}")
|
12
|
+
run_streamlit()
|
13
|
+
|
14
|
+
|
15
|
+
def run_streamlit():
|
16
|
+
cmd = [
|
17
|
+
"streamlit",
|
18
|
+
"run",
|
19
|
+
f"{os.path.dirname(__file__)}/frontend/vdb_benchmark.py",
|
20
|
+
"--logger.level",
|
21
|
+
"info",
|
22
|
+
"--theme.base",
|
23
|
+
"light",
|
24
|
+
"--theme.primaryColor",
|
25
|
+
"#3670F2",
|
26
|
+
"--theme.secondaryBackgroundColor",
|
27
|
+
"#F0F2F6",
|
28
|
+
]
|
29
|
+
log.debug(f"cmd: {cmd}")
|
30
|
+
try:
|
31
|
+
subprocess.run(cmd, check=True)
|
32
|
+
except KeyboardInterrupt:
|
33
|
+
log.info("exit streamlit...")
|
34
|
+
except Exception as e:
|
35
|
+
log.warning(f"exit, err={e}\nstack trace={traceback.format_exc(chain=True)}")
|
36
|
+
|
37
|
+
|
38
|
+
if __name__ == "__main__":
|
39
|
+
main()
|
File without changes
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from .cases import type2case, CaseLabel
|
2
|
+
from .task_runner import CaseRunner, RunningStatus, TaskRunner
|
3
|
+
from ..models import TaskConfig
|
4
|
+
from ..backend.clients import EmptyDBCaseConfig
|
5
|
+
import logging
|
6
|
+
|
7
|
+
|
8
|
+
log = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class Assembler:
|
12
|
+
@classmethod
|
13
|
+
def assemble(cls, run_id , task: TaskConfig) -> CaseRunner:
|
14
|
+
c_cls = type2case.get(task.case_config.case_id)
|
15
|
+
|
16
|
+
c = c_cls()
|
17
|
+
if type(task.db_case_config) != EmptyDBCaseConfig:
|
18
|
+
task.db_case_config.metric_type = c.dataset.data.metric_type
|
19
|
+
|
20
|
+
runner = CaseRunner(
|
21
|
+
run_id=run_id,
|
22
|
+
config=task,
|
23
|
+
ca=c,
|
24
|
+
status=RunningStatus.PENDING,
|
25
|
+
)
|
26
|
+
|
27
|
+
return runner
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def assemble_all(cls, run_id: str, task_label: str, tasks: list[TaskConfig]) -> TaskRunner:
|
31
|
+
"""group by case type, db, and case dataset"""
|
32
|
+
runners = [cls.assemble(run_id, task) for task in tasks]
|
33
|
+
load_runners = [r for r in runners if r.ca.label == CaseLabel.Load]
|
34
|
+
perf_runners = [r for r in runners if r.ca.label == CaseLabel.Performance]
|
35
|
+
|
36
|
+
# group by db
|
37
|
+
db2runner = {}
|
38
|
+
for r in perf_runners:
|
39
|
+
db = r.config.db
|
40
|
+
if db not in db2runner:
|
41
|
+
db2runner[db] = []
|
42
|
+
db2runner[db].append(r)
|
43
|
+
|
44
|
+
# sort by dataset size
|
45
|
+
for k in db2runner.keys():
|
46
|
+
db2runner[k].sort(key=lambda x:x.ca.dataset.data.size)
|
47
|
+
|
48
|
+
all_runners = []
|
49
|
+
all_runners.extend(load_runners)
|
50
|
+
for v in db2runner.values():
|
51
|
+
all_runners.extend(v)
|
52
|
+
|
53
|
+
return TaskRunner(
|
54
|
+
run_id=run_id,
|
55
|
+
task_label=task_label,
|
56
|
+
case_runners=all_runners,
|
57
|
+
)
|
@@ -0,0 +1,124 @@
|
|
1
|
+
import logging
|
2
|
+
from enum import Enum, auto
|
3
|
+
|
4
|
+
from . import dataset as ds
|
5
|
+
from ..base import BaseModel
|
6
|
+
from ..models import CaseType
|
7
|
+
|
8
|
+
|
9
|
+
log = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class CaseLabel(Enum):
|
13
|
+
Load = auto()
|
14
|
+
Performance = auto()
|
15
|
+
|
16
|
+
|
17
|
+
class Case(BaseModel):
|
18
|
+
""" Undifined case
|
19
|
+
|
20
|
+
Fields:
|
21
|
+
case_id(CaseType): default 11 case type plus one custom cases.
|
22
|
+
label(CaseLabel): performance or load.
|
23
|
+
dataset(DataSet): dataset for this case runner.
|
24
|
+
filter_rate(float | None): one of 99% | 1% | None
|
25
|
+
filters(dict | None): filters for search
|
26
|
+
"""
|
27
|
+
|
28
|
+
case_id: CaseType
|
29
|
+
label: CaseLabel
|
30
|
+
dataset: ds.DataSet
|
31
|
+
|
32
|
+
filter_rate: float | None
|
33
|
+
|
34
|
+
@property
|
35
|
+
def filters(self) -> dict | None:
|
36
|
+
if self.filter_rate is not None:
|
37
|
+
ID = round(self.filter_rate * self.dataset.data.size)
|
38
|
+
return {
|
39
|
+
"metadata": f">={ID}",
|
40
|
+
"id": ID,
|
41
|
+
}
|
42
|
+
|
43
|
+
return None
|
44
|
+
|
45
|
+
|
46
|
+
class CapacityCase(Case, BaseModel):
|
47
|
+
label: CaseLabel = CaseLabel.Load
|
48
|
+
filter_rate: float | int | None = None
|
49
|
+
|
50
|
+
class PerformanceCase(Case, BaseModel):
|
51
|
+
label: CaseLabel = CaseLabel.Performance
|
52
|
+
filter_rate: float | int | None = None
|
53
|
+
|
54
|
+
class CapacityLDimCase(CapacityCase):
|
55
|
+
case_id: CaseType = CaseType.CapacityLDim
|
56
|
+
dataset: ds.DataSet = ds.get(ds.Name.GIST, ds.Label.SMALL)
|
57
|
+
|
58
|
+
class CapacitySDimCase(CapacityCase):
|
59
|
+
case_id: CaseType = CaseType.CapacitySDim
|
60
|
+
dataset: ds.DataSet = ds.get(ds.Name.SIFT, ds.Label.SMALL)
|
61
|
+
|
62
|
+
class PerformanceLZero(PerformanceCase):
|
63
|
+
case_id: CaseType = CaseType.PerformanceLZero
|
64
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
|
65
|
+
|
66
|
+
class PerformanceMZero(PerformanceCase):
|
67
|
+
case_id: CaseType = CaseType.PerformanceMZero
|
68
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
|
69
|
+
|
70
|
+
class PerformanceSZero(PerformanceCase):
|
71
|
+
case_id: CaseType = CaseType.PerformanceSZero
|
72
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.SMALL)
|
73
|
+
|
74
|
+
class PerformanceLLow(PerformanceCase):
|
75
|
+
case_id: CaseType = CaseType.PerformanceLLow
|
76
|
+
filter_rate: float | int | None = 0.01
|
77
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
|
78
|
+
|
79
|
+
class PerformanceMLow(PerformanceCase):
|
80
|
+
case_id: CaseType = CaseType.PerformanceMLow
|
81
|
+
filter_rate: float | int | None = 0.01
|
82
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
|
83
|
+
|
84
|
+
class PerformanceSLow(PerformanceCase):
|
85
|
+
case_id: CaseType = CaseType.PerformanceSLow
|
86
|
+
filter_rate: float | int | None = 0.01
|
87
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.SMALL)
|
88
|
+
|
89
|
+
class PerformanceLHigh(PerformanceCase):
|
90
|
+
case_id: CaseType = CaseType.PerformanceLHigh
|
91
|
+
filter_rate: float | int | None = 0.99
|
92
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
|
93
|
+
|
94
|
+
class PerformanceMHigh(PerformanceCase):
|
95
|
+
case_id: CaseType = CaseType.PerformanceMHigh
|
96
|
+
filter_rate: float | int | None = 0.99
|
97
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
|
98
|
+
|
99
|
+
class PerformanceSHigh(PerformanceCase):
|
100
|
+
case_id: CaseType = CaseType.PerformanceSLow
|
101
|
+
filter_rate: float | int | None = 0.99
|
102
|
+
dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.SMALL)
|
103
|
+
|
104
|
+
class Performance100M(PerformanceCase):
|
105
|
+
case_id: CaseType = CaseType.Performance100M
|
106
|
+
filter_rate: float | int | None = None
|
107
|
+
dataset: ds.DataSet = ds.get(ds.Name.LAION, ds.Label.LARGE)
|
108
|
+
|
109
|
+
type2case = {
|
110
|
+
CaseType.CapacityLDim: CapacityLDimCase,
|
111
|
+
CaseType.CapacitySDim: CapacitySDimCase,
|
112
|
+
|
113
|
+
CaseType.PerformanceLZero: PerformanceLZero,
|
114
|
+
CaseType.PerformanceMZero: PerformanceMZero,
|
115
|
+
CaseType.PerformanceSZero: PerformanceSZero,
|
116
|
+
|
117
|
+
CaseType.PerformanceLLow: PerformanceLLow,
|
118
|
+
CaseType.PerformanceMLow: PerformanceMLow,
|
119
|
+
CaseType.PerformanceSLow: PerformanceSLow,
|
120
|
+
CaseType.PerformanceLHigh: PerformanceLHigh,
|
121
|
+
CaseType.PerformanceMHigh: PerformanceMHigh,
|
122
|
+
CaseType.PerformanceSHigh: PerformanceSHigh,
|
123
|
+
CaseType.Performance100M: Performance100M,
|
124
|
+
}
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Type
|
3
|
+
from .api import (
|
4
|
+
VectorDB,
|
5
|
+
DBConfig,
|
6
|
+
DBCaseConfig,
|
7
|
+
EmptyDBCaseConfig,
|
8
|
+
IndexType,
|
9
|
+
MetricType,
|
10
|
+
)
|
11
|
+
|
12
|
+
from .milvus.milvus import Milvus
|
13
|
+
from .elastic_cloud.elastic_cloud import ElasticCloud
|
14
|
+
from .pinecone.pinecone import Pinecone
|
15
|
+
from .weaviate_cloud.weaviate_cloud import WeaviateCloud
|
16
|
+
from .qdrant_cloud.qdrant_cloud import QdrantCloud
|
17
|
+
from .zilliz_cloud.zilliz_cloud import ZillizCloud
|
18
|
+
|
19
|
+
|
20
|
+
class DB(Enum):
|
21
|
+
"""Database types
|
22
|
+
|
23
|
+
Examples:
|
24
|
+
>>> DB.Milvus
|
25
|
+
<DB.Milvus: 'Milvus'>
|
26
|
+
>>> DB.Milvus.value
|
27
|
+
"Milvus"
|
28
|
+
>>> DB.Milvus.name
|
29
|
+
"Milvus"
|
30
|
+
"""
|
31
|
+
|
32
|
+
Milvus = "Milvus"
|
33
|
+
ZillizCloud = "ZillizCloud"
|
34
|
+
Pinecone = "Pinecone"
|
35
|
+
ElasticCloud = "ElasticCloud"
|
36
|
+
QdrantCloud = "QdrantCloud"
|
37
|
+
WeaviateCloud = "WeaviateCloud"
|
38
|
+
|
39
|
+
|
40
|
+
@property
|
41
|
+
def init_cls(self) -> Type[VectorDB]:
|
42
|
+
return db2client.get(self)
|
43
|
+
|
44
|
+
|
45
|
+
db2client = {
|
46
|
+
DB.Milvus: Milvus,
|
47
|
+
DB.ZillizCloud: ZillizCloud,
|
48
|
+
DB.WeaviateCloud: WeaviateCloud,
|
49
|
+
DB.ElasticCloud: ElasticCloud,
|
50
|
+
DB.QdrantCloud: QdrantCloud,
|
51
|
+
DB.Pinecone: Pinecone,
|
52
|
+
}
|
53
|
+
|
54
|
+
|
55
|
+
__all__ = [
|
56
|
+
"DB", "VectorDB", "DBConfig", "DBCaseConfig", "IndexType", "MetricType", "EmptyDBCaseConfig",
|
57
|
+
]
|
@@ -0,0 +1,179 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Any, Type
|
4
|
+
from contextlib import contextmanager
|
5
|
+
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
|
9
|
+
class MetricType(str, Enum):
|
10
|
+
L2 = "L2"
|
11
|
+
COSINE = "COSINE"
|
12
|
+
IP = "IP"
|
13
|
+
|
14
|
+
|
15
|
+
class IndexType(str, Enum):
|
16
|
+
HNSW = "HNSW"
|
17
|
+
DISKANN = "DISKANN"
|
18
|
+
IVFFlat = "IVF_FLAT"
|
19
|
+
Flat = "FLAT"
|
20
|
+
AUTOINDEX = "AUTOINDEX"
|
21
|
+
ES_HNSW = "hnsw"
|
22
|
+
|
23
|
+
|
24
|
+
class DBConfig(ABC, BaseModel):
|
25
|
+
"""DBConfig contains the connection info of vector database
|
26
|
+
|
27
|
+
Args:
|
28
|
+
db_label(str): label to distinguish different types of DB of the same database.
|
29
|
+
|
30
|
+
MilvusConfig.db_label = 2c8g
|
31
|
+
MilvusConfig.db_label = 16c64g
|
32
|
+
ZillizCloudConfig.db_label = 1cu-perf
|
33
|
+
"""
|
34
|
+
|
35
|
+
db_label: str | None = None
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def to_dict(self) -> dict:
|
39
|
+
raise NotImplementedError
|
40
|
+
|
41
|
+
|
42
|
+
class DBCaseConfig(ABC):
|
43
|
+
"""Case specific vector database configs, usually uesed for index params like HNSW"""
|
44
|
+
@abstractmethod
|
45
|
+
def index_param(self) -> dict:
|
46
|
+
raise NotImplementedError
|
47
|
+
|
48
|
+
@abstractmethod
|
49
|
+
def search_param(self) -> dict:
|
50
|
+
raise NotImplementedError
|
51
|
+
|
52
|
+
|
53
|
+
class EmptyDBCaseConfig(BaseModel, DBCaseConfig):
|
54
|
+
"""EmptyDBCaseConfig will be used if the vector database has no case specific configs"""
|
55
|
+
null: str | None = None
|
56
|
+
def index_param(self) -> dict:
|
57
|
+
return {}
|
58
|
+
|
59
|
+
def search_param(self) -> dict:
|
60
|
+
return {}
|
61
|
+
|
62
|
+
|
63
|
+
class VectorDB(ABC):
|
64
|
+
"""Each VectorDB will be __init__ once for one case, the object will be copied into multiple processes.
|
65
|
+
|
66
|
+
In each process, the benchmark cases ensure VectorDB.init() calls before any other methods operations
|
67
|
+
|
68
|
+
insert_embeddings, search_embedding, and, ready_to_search will be timed for each call.
|
69
|
+
|
70
|
+
Examples:
|
71
|
+
>>> milvus = Milvus()
|
72
|
+
>>> with milvus.init():
|
73
|
+
>>> milvus.insert_embeddings()
|
74
|
+
>>> milvus.search_embedding()
|
75
|
+
"""
|
76
|
+
|
77
|
+
@abstractmethod
|
78
|
+
def __init__(
|
79
|
+
self,
|
80
|
+
dim: int,
|
81
|
+
db_config: dict,
|
82
|
+
db_case_config: DBCaseConfig | None,
|
83
|
+
collection_name: str,
|
84
|
+
drop_old: bool = False,
|
85
|
+
**kwargs
|
86
|
+
) -> None:
|
87
|
+
"""Initialize wrapper around the vector database client
|
88
|
+
|
89
|
+
Args:
|
90
|
+
dim(int): the dimension of the dataset
|
91
|
+
db_config(dict): configs to establish connections with the vector database
|
92
|
+
db_case_config(DBCaseConfig | None): case specific configs for indexing and searching
|
93
|
+
drop_old(bool): whether to drop the existing collection of the dataset.
|
94
|
+
"""
|
95
|
+
raise NotImplementedError
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
@abstractmethod
|
99
|
+
def config_cls(self) -> Type[DBConfig]:
|
100
|
+
raise NotImplementedError
|
101
|
+
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
@abstractmethod
|
105
|
+
def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
106
|
+
raise NotImplementedError
|
107
|
+
|
108
|
+
|
109
|
+
@abstractmethod
|
110
|
+
@contextmanager
|
111
|
+
def init(self) -> None:
|
112
|
+
""" create and destory connections to database.
|
113
|
+
|
114
|
+
Examples:
|
115
|
+
>>> with self.init():
|
116
|
+
>>> self.insert_embeddings()
|
117
|
+
"""
|
118
|
+
raise NotImplementedError
|
119
|
+
|
120
|
+
@abstractmethod
|
121
|
+
def insert_embeddings(
|
122
|
+
self,
|
123
|
+
embeddings: list[list[float]],
|
124
|
+
metadata: list[int],
|
125
|
+
kwargs: Any,
|
126
|
+
) -> int:
|
127
|
+
"""Insert the embeddings to the vector database. The default number of embeddings for
|
128
|
+
each insert_embeddings is 5000.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
embeddings(list[list[float]]): list of embedding to add to the vector database.
|
132
|
+
metadatas(list[int]): metadata associated with the embeddings, for filtering.
|
133
|
+
kwargs(Any): vector database specific parameters.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
int: inserted data count
|
137
|
+
"""
|
138
|
+
raise NotImplementedError
|
139
|
+
|
140
|
+
@abstractmethod
|
141
|
+
def search_embedding(
|
142
|
+
self,
|
143
|
+
query: list[float],
|
144
|
+
k: int = 100,
|
145
|
+
filters: dict | None = None,
|
146
|
+
) -> list[int]:
|
147
|
+
"""Get k most similar embeddings to query vector.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
query(list[float]): query embedding to look up documents similar to.
|
151
|
+
k(int): Number of most similar embeddings to return. Defaults to 100.
|
152
|
+
filters(dict, optional): filtering expression to filter the data while searching.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
list[int]: list of k most similar embeddings IDs to the query embedding.
|
156
|
+
"""
|
157
|
+
raise NotImplementedError
|
158
|
+
|
159
|
+
# TODO: remove
|
160
|
+
@abstractmethod
|
161
|
+
def ready_to_search(self):
|
162
|
+
"""ready_to_search will be called between insertion and search in performance cases.
|
163
|
+
|
164
|
+
Should be blocked until the vectorDB is ready to be tested on
|
165
|
+
heavy performance cases.
|
166
|
+
|
167
|
+
Time(insert the dataset) + Time(ready_to_search) will be recorded as "load_duration" metric
|
168
|
+
"""
|
169
|
+
raise NotImplementedError
|
170
|
+
|
171
|
+
# TODO: remove
|
172
|
+
@abstractmethod
|
173
|
+
def ready_to_load(self):
|
174
|
+
"""ready_to_load will be called before load in load cases.
|
175
|
+
|
176
|
+
Should be blocked until the vectorDB is ready to be tested on
|
177
|
+
heavy load cases.
|
178
|
+
"""
|
179
|
+
raise NotImplementedError
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from pydantic import SecretStr, BaseModel
|
3
|
+
|
4
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
5
|
+
|
6
|
+
|
7
|
+
class ElasticsearchConfig(DBConfig, BaseModel):
|
8
|
+
cloud_id: SecretStr
|
9
|
+
password: SecretStr | None = None
|
10
|
+
|
11
|
+
def to_dict(self) -> dict:
|
12
|
+
return {
|
13
|
+
"cloud_id": self.cloud_id.get_secret_value(),
|
14
|
+
"basic_auth": ("elastic", self.password.get_secret_value()),
|
15
|
+
}
|
16
|
+
|
17
|
+
|
18
|
+
class ESElementType(str, Enum):
|
19
|
+
float = "float" # 4 byte
|
20
|
+
byte = "byte" # 1 byte, -128 to 127
|
21
|
+
|
22
|
+
|
23
|
+
class ElasticsearchIndexConfig(BaseModel, DBCaseConfig):
|
24
|
+
element_type: ESElementType = ESElementType.float
|
25
|
+
index: IndexType = IndexType.ES_HNSW # ES only support 'hnsw'
|
26
|
+
|
27
|
+
metric_type: MetricType | None = None
|
28
|
+
efConstruction: int | None = None
|
29
|
+
M: int | None = None
|
30
|
+
num_candidates: int | None = None
|
31
|
+
|
32
|
+
def parse_metric(self) -> str:
|
33
|
+
if self.metric_type == MetricType.L2:
|
34
|
+
return "l2_norm"
|
35
|
+
elif self.metric_type == MetricType.IP:
|
36
|
+
return "dot_product"
|
37
|
+
return "cosine"
|
38
|
+
|
39
|
+
def index_param(self) -> dict:
|
40
|
+
params = {
|
41
|
+
"type": "dense_vector",
|
42
|
+
"index": True,
|
43
|
+
"element_type": self.element_type.value,
|
44
|
+
"similarity": self.parse_metric(),
|
45
|
+
"index_options": {
|
46
|
+
"type": self.index.value,
|
47
|
+
"m": self.M,
|
48
|
+
"ef_construction": self.efConstruction
|
49
|
+
}
|
50
|
+
}
|
51
|
+
return params
|
52
|
+
|
53
|
+
def search_param(self) -> dict:
|
54
|
+
return {
|
55
|
+
"num_candidates": self.num_candidates,
|
56
|
+
}
|
@@ -0,0 +1,152 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Iterable, Type
|
4
|
+
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
|
5
|
+
from .config import ElasticsearchIndexConfig, ElasticsearchConfig
|
6
|
+
from elasticsearch.helpers import bulk
|
7
|
+
|
8
|
+
|
9
|
+
for logger in ("elasticsearch", "elastic_transport"):
|
10
|
+
logging.getLogger(logger).setLevel(logging.WARNING)
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class ElasticCloud(VectorDB):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
dim: int,
|
18
|
+
db_config: dict,
|
19
|
+
db_case_config: ElasticsearchIndexConfig,
|
20
|
+
indice: str = "vdb_bench_indice", # must be lowercase
|
21
|
+
id_col_name: str = "id",
|
22
|
+
vector_col_name: str = "vector",
|
23
|
+
drop_old: bool = False,
|
24
|
+
):
|
25
|
+
self.dim = dim
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.indice = indice
|
29
|
+
self.id_col_name = id_col_name
|
30
|
+
self.vector_col_name = vector_col_name
|
31
|
+
|
32
|
+
from elasticsearch import Elasticsearch
|
33
|
+
|
34
|
+
client = Elasticsearch(**self.db_config)
|
35
|
+
|
36
|
+
if drop_old:
|
37
|
+
log.info(f"Elasticsearch client drop_old indices: {self.indice}")
|
38
|
+
is_existed_res = client.indices.exists(index=self.indice)
|
39
|
+
if is_existed_res.raw:
|
40
|
+
client.indices.delete(index=self.indice)
|
41
|
+
self._create_indice(client)
|
42
|
+
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def config_cls(cls) -> Type[DBConfig]:
|
46
|
+
return ElasticsearchConfig
|
47
|
+
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
51
|
+
return ElasticsearchIndexConfig
|
52
|
+
|
53
|
+
|
54
|
+
@contextmanager
|
55
|
+
def init(self) -> None:
|
56
|
+
"""connect to elasticsearch"""
|
57
|
+
from elasticsearch import Elasticsearch
|
58
|
+
self.client = Elasticsearch(**self.db_config, request_timeout=30)
|
59
|
+
|
60
|
+
yield
|
61
|
+
# self.client.transport.close()
|
62
|
+
self.client = None
|
63
|
+
del(self.client)
|
64
|
+
|
65
|
+
def _create_indice(self, client) -> None:
|
66
|
+
mappings = {
|
67
|
+
"properties": {
|
68
|
+
self.id_col_name: {"type": "integer"},
|
69
|
+
self.vector_col_name: {
|
70
|
+
"dims": self.dim,
|
71
|
+
**self.case_config.index_param(),
|
72
|
+
},
|
73
|
+
}
|
74
|
+
}
|
75
|
+
|
76
|
+
try:
|
77
|
+
client.indices.create(index=self.indice, mappings=mappings)
|
78
|
+
except Exception as e:
|
79
|
+
log.warning(f"Failed to create indice: {self.indice} error: {str(e)}")
|
80
|
+
raise e from None
|
81
|
+
|
82
|
+
def insert_embeddings(
|
83
|
+
self,
|
84
|
+
embeddings: Iterable[list[float]],
|
85
|
+
metadata: list[int],
|
86
|
+
) -> int:
|
87
|
+
"""Insert the embeddings to the elasticsearch."""
|
88
|
+
assert self.client is not None, "should self.init() first"
|
89
|
+
|
90
|
+
insert_data = [
|
91
|
+
{
|
92
|
+
"_index": self.indice,
|
93
|
+
"_source": {
|
94
|
+
self.id_col_name: metadata[i],
|
95
|
+
self.vector_col_name: embeddings[i],
|
96
|
+
},
|
97
|
+
}
|
98
|
+
for i in range(len(embeddings))
|
99
|
+
]
|
100
|
+
try:
|
101
|
+
bulk_insert_res = bulk(self.client, insert_data)
|
102
|
+
return bulk_insert_res[0]
|
103
|
+
except Exception as e:
|
104
|
+
log.warning(f"Failed to insert data: {self.indice} error: {str(e)}")
|
105
|
+
raise e from None
|
106
|
+
|
107
|
+
def search_embedding(
|
108
|
+
self,
|
109
|
+
query: list[float],
|
110
|
+
k: int = 100,
|
111
|
+
filters: dict | None = None,
|
112
|
+
) -> list[int]:
|
113
|
+
"""Get k most similar embeddings to query vector.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
query(list[float]): query embedding to look up documents similar to.
|
117
|
+
k(int): Number of most similar embeddings to return. Defaults to 100.
|
118
|
+
filters(dict, optional): filtering expression to filter the data while searching.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
122
|
+
"""
|
123
|
+
assert self.client is not None, "should self.init() first"
|
124
|
+
# is_existed_res = self.client.indices.exists(index=self.indice)
|
125
|
+
# assert is_existed_res.raw == True, "should self.init() first"
|
126
|
+
|
127
|
+
knn = {
|
128
|
+
"field": self.vector_col_name,
|
129
|
+
"k": k,
|
130
|
+
"num_candidates": self.case_config.num_candidates,
|
131
|
+
"filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}]
|
132
|
+
if filters
|
133
|
+
else [],
|
134
|
+
"query_vector": query,
|
135
|
+
}
|
136
|
+
size = k
|
137
|
+
try:
|
138
|
+
search_res = self.client.search(index=self.indice, knn=knn, size=size)
|
139
|
+
res = [d["_source"][self.id_col_name] for d in search_res["hits"]["hits"]]
|
140
|
+
|
141
|
+
return res
|
142
|
+
except Exception as e:
|
143
|
+
log.warning(f"Failed to search: {self.indice} error: {str(e)}")
|
144
|
+
raise e from None
|
145
|
+
|
146
|
+
def ready_to_search(self):
|
147
|
+
"""ready_to_search will be called between insertion and search in performance cases."""
|
148
|
+
pass
|
149
|
+
|
150
|
+
def ready_to_load(self):
|
151
|
+
"""ready_to_load will be called before load in load cases."""
|
152
|
+
pass
|