vectordb-bench 0.0.1__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. vectordb_bench/__init__.py +30 -0
  2. vectordb_bench/__main__.py +39 -0
  3. vectordb_bench/backend/__init__.py +0 -0
  4. vectordb_bench/backend/assembler.py +57 -0
  5. vectordb_bench/backend/cases.py +124 -0
  6. vectordb_bench/backend/clients/__init__.py +57 -0
  7. vectordb_bench/backend/clients/api.py +179 -0
  8. vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
  9. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
  10. vectordb_bench/backend/clients/milvus/config.py +123 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +182 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +15 -0
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
  20. vectordb_bench/backend/dataset.py +393 -0
  21. vectordb_bench/backend/result_collector.py +15 -0
  22. vectordb_bench/backend/runner/__init__.py +12 -0
  23. vectordb_bench/backend/runner/mp_runner.py +124 -0
  24. vectordb_bench/backend/runner/serial_runner.py +164 -0
  25. vectordb_bench/backend/task_runner.py +290 -0
  26. vectordb_bench/backend/utils.py +85 -0
  27. vectordb_bench/base.py +6 -0
  28. vectordb_bench/frontend/components/check_results/charts.py +175 -0
  29. vectordb_bench/frontend/components/check_results/data.py +86 -0
  30. vectordb_bench/frontend/components/check_results/filters.py +97 -0
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
  32. vectordb_bench/frontend/components/check_results/nav.py +21 -0
  33. vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
  34. vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
  35. vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
  36. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
  37. vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
  38. vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
  41. vectordb_bench/frontend/const.py +391 -0
  42. vectordb_bench/frontend/pages/qps_with_price.py +60 -0
  43. vectordb_bench/frontend/pages/run_test.py +59 -0
  44. vectordb_bench/frontend/utils.py +6 -0
  45. vectordb_bench/frontend/vdb_benchmark.py +42 -0
  46. vectordb_bench/interface.py +239 -0
  47. vectordb_bench/log_util.py +103 -0
  48. vectordb_bench/metric.py +53 -0
  49. vectordb_bench/models.py +234 -0
  50. vectordb_bench/results/result_20230609_standard.json +3228 -0
  51. vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
  52. vectordb_bench-0.0.1.dist-info/METADATA +226 -0
  53. vectordb_bench-0.0.1.dist-info/RECORD +56 -0
  54. vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
  55. vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
  56. vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,30 @@
1
+ import environs
2
+ import inspect
3
+ import pathlib
4
+ from . import log_util
5
+
6
+
7
+ env = environs.Env()
8
+ env.read_env(".env")
9
+
10
+ class config:
11
+ LOG_LEVEL = env.str("LOG_LEVEL", "INFO")
12
+
13
+ DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com/benchmark/")
14
+ DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
15
+ NUM_PER_BATCH = env.int("NUM_PER_BATCH", 5000)
16
+
17
+ DROP_OLD = env.bool("DROP_OLD", True)
18
+ USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
19
+
20
+ RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
21
+ CASE_TIMEOUT_IN_SECOND = 24 * 60 * 60
22
+
23
+
24
+ def display(self) -> str:
25
+ tmp = [i for i in inspect.getmembers(self)
26
+ if not inspect.ismethod(i[1]) and not i[0].startswith('_') \
27
+ ]
28
+ return tmp
29
+
30
+ log_util.init(config.LOG_LEVEL)
@@ -0,0 +1,39 @@
1
+ import traceback
2
+ import logging
3
+ import subprocess
4
+ import os
5
+ from . import config
6
+
7
+ log = logging.getLogger("vectordb_bench")
8
+
9
+
10
+ def main():
11
+ log.info(f"all configs: {config().display()}")
12
+ run_streamlit()
13
+
14
+
15
+ def run_streamlit():
16
+ cmd = [
17
+ "streamlit",
18
+ "run",
19
+ f"{os.path.dirname(__file__)}/frontend/vdb_benchmark.py",
20
+ "--logger.level",
21
+ "info",
22
+ "--theme.base",
23
+ "light",
24
+ "--theme.primaryColor",
25
+ "#3670F2",
26
+ "--theme.secondaryBackgroundColor",
27
+ "#F0F2F6",
28
+ ]
29
+ log.debug(f"cmd: {cmd}")
30
+ try:
31
+ subprocess.run(cmd, check=True)
32
+ except KeyboardInterrupt:
33
+ log.info("exit streamlit...")
34
+ except Exception as e:
35
+ log.warning(f"exit, err={e}\nstack trace={traceback.format_exc(chain=True)}")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
File without changes
@@ -0,0 +1,57 @@
1
+ from .cases import type2case, CaseLabel
2
+ from .task_runner import CaseRunner, RunningStatus, TaskRunner
3
+ from ..models import TaskConfig
4
+ from ..backend.clients import EmptyDBCaseConfig
5
+ import logging
6
+
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class Assembler:
12
+ @classmethod
13
+ def assemble(cls, run_id , task: TaskConfig) -> CaseRunner:
14
+ c_cls = type2case.get(task.case_config.case_id)
15
+
16
+ c = c_cls()
17
+ if type(task.db_case_config) != EmptyDBCaseConfig:
18
+ task.db_case_config.metric_type = c.dataset.data.metric_type
19
+
20
+ runner = CaseRunner(
21
+ run_id=run_id,
22
+ config=task,
23
+ ca=c,
24
+ status=RunningStatus.PENDING,
25
+ )
26
+
27
+ return runner
28
+
29
+ @classmethod
30
+ def assemble_all(cls, run_id: str, task_label: str, tasks: list[TaskConfig]) -> TaskRunner:
31
+ """group by case type, db, and case dataset"""
32
+ runners = [cls.assemble(run_id, task) for task in tasks]
33
+ load_runners = [r for r in runners if r.ca.label == CaseLabel.Load]
34
+ perf_runners = [r for r in runners if r.ca.label == CaseLabel.Performance]
35
+
36
+ # group by db
37
+ db2runner = {}
38
+ for r in perf_runners:
39
+ db = r.config.db
40
+ if db not in db2runner:
41
+ db2runner[db] = []
42
+ db2runner[db].append(r)
43
+
44
+ # sort by dataset size
45
+ for k in db2runner.keys():
46
+ db2runner[k].sort(key=lambda x:x.ca.dataset.data.size)
47
+
48
+ all_runners = []
49
+ all_runners.extend(load_runners)
50
+ for v in db2runner.values():
51
+ all_runners.extend(v)
52
+
53
+ return TaskRunner(
54
+ run_id=run_id,
55
+ task_label=task_label,
56
+ case_runners=all_runners,
57
+ )
@@ -0,0 +1,124 @@
1
+ import logging
2
+ from enum import Enum, auto
3
+
4
+ from . import dataset as ds
5
+ from ..base import BaseModel
6
+ from ..models import CaseType
7
+
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class CaseLabel(Enum):
13
+ Load = auto()
14
+ Performance = auto()
15
+
16
+
17
+ class Case(BaseModel):
18
+ """ Undifined case
19
+
20
+ Fields:
21
+ case_id(CaseType): default 11 case type plus one custom cases.
22
+ label(CaseLabel): performance or load.
23
+ dataset(DataSet): dataset for this case runner.
24
+ filter_rate(float | None): one of 99% | 1% | None
25
+ filters(dict | None): filters for search
26
+ """
27
+
28
+ case_id: CaseType
29
+ label: CaseLabel
30
+ dataset: ds.DataSet
31
+
32
+ filter_rate: float | None
33
+
34
+ @property
35
+ def filters(self) -> dict | None:
36
+ if self.filter_rate is not None:
37
+ ID = round(self.filter_rate * self.dataset.data.size)
38
+ return {
39
+ "metadata": f">={ID}",
40
+ "id": ID,
41
+ }
42
+
43
+ return None
44
+
45
+
46
+ class CapacityCase(Case, BaseModel):
47
+ label: CaseLabel = CaseLabel.Load
48
+ filter_rate: float | int | None = None
49
+
50
+ class PerformanceCase(Case, BaseModel):
51
+ label: CaseLabel = CaseLabel.Performance
52
+ filter_rate: float | int | None = None
53
+
54
+ class CapacityLDimCase(CapacityCase):
55
+ case_id: CaseType = CaseType.CapacityLDim
56
+ dataset: ds.DataSet = ds.get(ds.Name.GIST, ds.Label.SMALL)
57
+
58
+ class CapacitySDimCase(CapacityCase):
59
+ case_id: CaseType = CaseType.CapacitySDim
60
+ dataset: ds.DataSet = ds.get(ds.Name.SIFT, ds.Label.SMALL)
61
+
62
+ class PerformanceLZero(PerformanceCase):
63
+ case_id: CaseType = CaseType.PerformanceLZero
64
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
65
+
66
+ class PerformanceMZero(PerformanceCase):
67
+ case_id: CaseType = CaseType.PerformanceMZero
68
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
69
+
70
+ class PerformanceSZero(PerformanceCase):
71
+ case_id: CaseType = CaseType.PerformanceSZero
72
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.SMALL)
73
+
74
+ class PerformanceLLow(PerformanceCase):
75
+ case_id: CaseType = CaseType.PerformanceLLow
76
+ filter_rate: float | int | None = 0.01
77
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
78
+
79
+ class PerformanceMLow(PerformanceCase):
80
+ case_id: CaseType = CaseType.PerformanceMLow
81
+ filter_rate: float | int | None = 0.01
82
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
83
+
84
+ class PerformanceSLow(PerformanceCase):
85
+ case_id: CaseType = CaseType.PerformanceSLow
86
+ filter_rate: float | int | None = 0.01
87
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.SMALL)
88
+
89
+ class PerformanceLHigh(PerformanceCase):
90
+ case_id: CaseType = CaseType.PerformanceLHigh
91
+ filter_rate: float | int | None = 0.99
92
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
93
+
94
+ class PerformanceMHigh(PerformanceCase):
95
+ case_id: CaseType = CaseType.PerformanceMHigh
96
+ filter_rate: float | int | None = 0.99
97
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
98
+
99
+ class PerformanceSHigh(PerformanceCase):
100
+ case_id: CaseType = CaseType.PerformanceSLow
101
+ filter_rate: float | int | None = 0.99
102
+ dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.SMALL)
103
+
104
+ class Performance100M(PerformanceCase):
105
+ case_id: CaseType = CaseType.Performance100M
106
+ filter_rate: float | int | None = None
107
+ dataset: ds.DataSet = ds.get(ds.Name.LAION, ds.Label.LARGE)
108
+
109
+ type2case = {
110
+ CaseType.CapacityLDim: CapacityLDimCase,
111
+ CaseType.CapacitySDim: CapacitySDimCase,
112
+
113
+ CaseType.PerformanceLZero: PerformanceLZero,
114
+ CaseType.PerformanceMZero: PerformanceMZero,
115
+ CaseType.PerformanceSZero: PerformanceSZero,
116
+
117
+ CaseType.PerformanceLLow: PerformanceLLow,
118
+ CaseType.PerformanceMLow: PerformanceMLow,
119
+ CaseType.PerformanceSLow: PerformanceSLow,
120
+ CaseType.PerformanceLHigh: PerformanceLHigh,
121
+ CaseType.PerformanceMHigh: PerformanceMHigh,
122
+ CaseType.PerformanceSHigh: PerformanceSHigh,
123
+ CaseType.Performance100M: Performance100M,
124
+ }
@@ -0,0 +1,57 @@
1
+ from enum import Enum
2
+ from typing import Type
3
+ from .api import (
4
+ VectorDB,
5
+ DBConfig,
6
+ DBCaseConfig,
7
+ EmptyDBCaseConfig,
8
+ IndexType,
9
+ MetricType,
10
+ )
11
+
12
+ from .milvus.milvus import Milvus
13
+ from .elastic_cloud.elastic_cloud import ElasticCloud
14
+ from .pinecone.pinecone import Pinecone
15
+ from .weaviate_cloud.weaviate_cloud import WeaviateCloud
16
+ from .qdrant_cloud.qdrant_cloud import QdrantCloud
17
+ from .zilliz_cloud.zilliz_cloud import ZillizCloud
18
+
19
+
20
+ class DB(Enum):
21
+ """Database types
22
+
23
+ Examples:
24
+ >>> DB.Milvus
25
+ <DB.Milvus: 'Milvus'>
26
+ >>> DB.Milvus.value
27
+ "Milvus"
28
+ >>> DB.Milvus.name
29
+ "Milvus"
30
+ """
31
+
32
+ Milvus = "Milvus"
33
+ ZillizCloud = "ZillizCloud"
34
+ Pinecone = "Pinecone"
35
+ ElasticCloud = "ElasticCloud"
36
+ QdrantCloud = "QdrantCloud"
37
+ WeaviateCloud = "WeaviateCloud"
38
+
39
+
40
+ @property
41
+ def init_cls(self) -> Type[VectorDB]:
42
+ return db2client.get(self)
43
+
44
+
45
+ db2client = {
46
+ DB.Milvus: Milvus,
47
+ DB.ZillizCloud: ZillizCloud,
48
+ DB.WeaviateCloud: WeaviateCloud,
49
+ DB.ElasticCloud: ElasticCloud,
50
+ DB.QdrantCloud: QdrantCloud,
51
+ DB.Pinecone: Pinecone,
52
+ }
53
+
54
+
55
+ __all__ = [
56
+ "DB", "VectorDB", "DBConfig", "DBCaseConfig", "IndexType", "MetricType", "EmptyDBCaseConfig",
57
+ ]
@@ -0,0 +1,179 @@
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from typing import Any, Type
4
+ from contextlib import contextmanager
5
+
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class MetricType(str, Enum):
10
+ L2 = "L2"
11
+ COSINE = "COSINE"
12
+ IP = "IP"
13
+
14
+
15
+ class IndexType(str, Enum):
16
+ HNSW = "HNSW"
17
+ DISKANN = "DISKANN"
18
+ IVFFlat = "IVF_FLAT"
19
+ Flat = "FLAT"
20
+ AUTOINDEX = "AUTOINDEX"
21
+ ES_HNSW = "hnsw"
22
+
23
+
24
+ class DBConfig(ABC, BaseModel):
25
+ """DBConfig contains the connection info of vector database
26
+
27
+ Args:
28
+ db_label(str): label to distinguish different types of DB of the same database.
29
+
30
+ MilvusConfig.db_label = 2c8g
31
+ MilvusConfig.db_label = 16c64g
32
+ ZillizCloudConfig.db_label = 1cu-perf
33
+ """
34
+
35
+ db_label: str | None = None
36
+
37
+ @abstractmethod
38
+ def to_dict(self) -> dict:
39
+ raise NotImplementedError
40
+
41
+
42
+ class DBCaseConfig(ABC):
43
+ """Case specific vector database configs, usually uesed for index params like HNSW"""
44
+ @abstractmethod
45
+ def index_param(self) -> dict:
46
+ raise NotImplementedError
47
+
48
+ @abstractmethod
49
+ def search_param(self) -> dict:
50
+ raise NotImplementedError
51
+
52
+
53
+ class EmptyDBCaseConfig(BaseModel, DBCaseConfig):
54
+ """EmptyDBCaseConfig will be used if the vector database has no case specific configs"""
55
+ null: str | None = None
56
+ def index_param(self) -> dict:
57
+ return {}
58
+
59
+ def search_param(self) -> dict:
60
+ return {}
61
+
62
+
63
+ class VectorDB(ABC):
64
+ """Each VectorDB will be __init__ once for one case, the object will be copied into multiple processes.
65
+
66
+ In each process, the benchmark cases ensure VectorDB.init() calls before any other methods operations
67
+
68
+ insert_embeddings, search_embedding, and, ready_to_search will be timed for each call.
69
+
70
+ Examples:
71
+ >>> milvus = Milvus()
72
+ >>> with milvus.init():
73
+ >>> milvus.insert_embeddings()
74
+ >>> milvus.search_embedding()
75
+ """
76
+
77
+ @abstractmethod
78
+ def __init__(
79
+ self,
80
+ dim: int,
81
+ db_config: dict,
82
+ db_case_config: DBCaseConfig | None,
83
+ collection_name: str,
84
+ drop_old: bool = False,
85
+ **kwargs
86
+ ) -> None:
87
+ """Initialize wrapper around the vector database client
88
+
89
+ Args:
90
+ dim(int): the dimension of the dataset
91
+ db_config(dict): configs to establish connections with the vector database
92
+ db_case_config(DBCaseConfig | None): case specific configs for indexing and searching
93
+ drop_old(bool): whether to drop the existing collection of the dataset.
94
+ """
95
+ raise NotImplementedError
96
+
97
+ @classmethod
98
+ @abstractmethod
99
+ def config_cls(self) -> Type[DBConfig]:
100
+ raise NotImplementedError
101
+
102
+
103
+ @classmethod
104
+ @abstractmethod
105
+ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
106
+ raise NotImplementedError
107
+
108
+
109
+ @abstractmethod
110
+ @contextmanager
111
+ def init(self) -> None:
112
+ """ create and destory connections to database.
113
+
114
+ Examples:
115
+ >>> with self.init():
116
+ >>> self.insert_embeddings()
117
+ """
118
+ raise NotImplementedError
119
+
120
+ @abstractmethod
121
+ def insert_embeddings(
122
+ self,
123
+ embeddings: list[list[float]],
124
+ metadata: list[int],
125
+ kwargs: Any,
126
+ ) -> int:
127
+ """Insert the embeddings to the vector database. The default number of embeddings for
128
+ each insert_embeddings is 5000.
129
+
130
+ Args:
131
+ embeddings(list[list[float]]): list of embedding to add to the vector database.
132
+ metadatas(list[int]): metadata associated with the embeddings, for filtering.
133
+ kwargs(Any): vector database specific parameters.
134
+
135
+ Returns:
136
+ int: inserted data count
137
+ """
138
+ raise NotImplementedError
139
+
140
+ @abstractmethod
141
+ def search_embedding(
142
+ self,
143
+ query: list[float],
144
+ k: int = 100,
145
+ filters: dict | None = None,
146
+ ) -> list[int]:
147
+ """Get k most similar embeddings to query vector.
148
+
149
+ Args:
150
+ query(list[float]): query embedding to look up documents similar to.
151
+ k(int): Number of most similar embeddings to return. Defaults to 100.
152
+ filters(dict, optional): filtering expression to filter the data while searching.
153
+
154
+ Returns:
155
+ list[int]: list of k most similar embeddings IDs to the query embedding.
156
+ """
157
+ raise NotImplementedError
158
+
159
+ # TODO: remove
160
+ @abstractmethod
161
+ def ready_to_search(self):
162
+ """ready_to_search will be called between insertion and search in performance cases.
163
+
164
+ Should be blocked until the vectorDB is ready to be tested on
165
+ heavy performance cases.
166
+
167
+ Time(insert the dataset) + Time(ready_to_search) will be recorded as "load_duration" metric
168
+ """
169
+ raise NotImplementedError
170
+
171
+ # TODO: remove
172
+ @abstractmethod
173
+ def ready_to_load(self):
174
+ """ready_to_load will be called before load in load cases.
175
+
176
+ Should be blocked until the vectorDB is ready to be tested on
177
+ heavy load cases.
178
+ """
179
+ raise NotImplementedError
@@ -0,0 +1,56 @@
1
+ from enum import Enum
2
+ from pydantic import SecretStr, BaseModel
3
+
4
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
5
+
6
+
7
+ class ElasticsearchConfig(DBConfig, BaseModel):
8
+ cloud_id: SecretStr
9
+ password: SecretStr | None = None
10
+
11
+ def to_dict(self) -> dict:
12
+ return {
13
+ "cloud_id": self.cloud_id.get_secret_value(),
14
+ "basic_auth": ("elastic", self.password.get_secret_value()),
15
+ }
16
+
17
+
18
+ class ESElementType(str, Enum):
19
+ float = "float" # 4 byte
20
+ byte = "byte" # 1 byte, -128 to 127
21
+
22
+
23
+ class ElasticsearchIndexConfig(BaseModel, DBCaseConfig):
24
+ element_type: ESElementType = ESElementType.float
25
+ index: IndexType = IndexType.ES_HNSW # ES only support 'hnsw'
26
+
27
+ metric_type: MetricType | None = None
28
+ efConstruction: int | None = None
29
+ M: int | None = None
30
+ num_candidates: int | None = None
31
+
32
+ def parse_metric(self) -> str:
33
+ if self.metric_type == MetricType.L2:
34
+ return "l2_norm"
35
+ elif self.metric_type == MetricType.IP:
36
+ return "dot_product"
37
+ return "cosine"
38
+
39
+ def index_param(self) -> dict:
40
+ params = {
41
+ "type": "dense_vector",
42
+ "index": True,
43
+ "element_type": self.element_type.value,
44
+ "similarity": self.parse_metric(),
45
+ "index_options": {
46
+ "type": self.index.value,
47
+ "m": self.M,
48
+ "ef_construction": self.efConstruction
49
+ }
50
+ }
51
+ return params
52
+
53
+ def search_param(self) -> dict:
54
+ return {
55
+ "num_candidates": self.num_candidates,
56
+ }
@@ -0,0 +1,152 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Iterable, Type
4
+ from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
5
+ from .config import ElasticsearchIndexConfig, ElasticsearchConfig
6
+ from elasticsearch.helpers import bulk
7
+
8
+
9
+ for logger in ("elasticsearch", "elastic_transport"):
10
+ logging.getLogger(logger).setLevel(logging.WARNING)
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ class ElasticCloud(VectorDB):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ db_config: dict,
19
+ db_case_config: ElasticsearchIndexConfig,
20
+ indice: str = "vdb_bench_indice", # must be lowercase
21
+ id_col_name: str = "id",
22
+ vector_col_name: str = "vector",
23
+ drop_old: bool = False,
24
+ ):
25
+ self.dim = dim
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.indice = indice
29
+ self.id_col_name = id_col_name
30
+ self.vector_col_name = vector_col_name
31
+
32
+ from elasticsearch import Elasticsearch
33
+
34
+ client = Elasticsearch(**self.db_config)
35
+
36
+ if drop_old:
37
+ log.info(f"Elasticsearch client drop_old indices: {self.indice}")
38
+ is_existed_res = client.indices.exists(index=self.indice)
39
+ if is_existed_res.raw:
40
+ client.indices.delete(index=self.indice)
41
+ self._create_indice(client)
42
+
43
+
44
+ @classmethod
45
+ def config_cls(cls) -> Type[DBConfig]:
46
+ return ElasticsearchConfig
47
+
48
+
49
+ @classmethod
50
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
51
+ return ElasticsearchIndexConfig
52
+
53
+
54
+ @contextmanager
55
+ def init(self) -> None:
56
+ """connect to elasticsearch"""
57
+ from elasticsearch import Elasticsearch
58
+ self.client = Elasticsearch(**self.db_config, request_timeout=30)
59
+
60
+ yield
61
+ # self.client.transport.close()
62
+ self.client = None
63
+ del(self.client)
64
+
65
+ def _create_indice(self, client) -> None:
66
+ mappings = {
67
+ "properties": {
68
+ self.id_col_name: {"type": "integer"},
69
+ self.vector_col_name: {
70
+ "dims": self.dim,
71
+ **self.case_config.index_param(),
72
+ },
73
+ }
74
+ }
75
+
76
+ try:
77
+ client.indices.create(index=self.indice, mappings=mappings)
78
+ except Exception as e:
79
+ log.warning(f"Failed to create indice: {self.indice} error: {str(e)}")
80
+ raise e from None
81
+
82
+ def insert_embeddings(
83
+ self,
84
+ embeddings: Iterable[list[float]],
85
+ metadata: list[int],
86
+ ) -> int:
87
+ """Insert the embeddings to the elasticsearch."""
88
+ assert self.client is not None, "should self.init() first"
89
+
90
+ insert_data = [
91
+ {
92
+ "_index": self.indice,
93
+ "_source": {
94
+ self.id_col_name: metadata[i],
95
+ self.vector_col_name: embeddings[i],
96
+ },
97
+ }
98
+ for i in range(len(embeddings))
99
+ ]
100
+ try:
101
+ bulk_insert_res = bulk(self.client, insert_data)
102
+ return bulk_insert_res[0]
103
+ except Exception as e:
104
+ log.warning(f"Failed to insert data: {self.indice} error: {str(e)}")
105
+ raise e from None
106
+
107
+ def search_embedding(
108
+ self,
109
+ query: list[float],
110
+ k: int = 100,
111
+ filters: dict | None = None,
112
+ ) -> list[int]:
113
+ """Get k most similar embeddings to query vector.
114
+
115
+ Args:
116
+ query(list[float]): query embedding to look up documents similar to.
117
+ k(int): Number of most similar embeddings to return. Defaults to 100.
118
+ filters(dict, optional): filtering expression to filter the data while searching.
119
+
120
+ Returns:
121
+ list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
122
+ """
123
+ assert self.client is not None, "should self.init() first"
124
+ # is_existed_res = self.client.indices.exists(index=self.indice)
125
+ # assert is_existed_res.raw == True, "should self.init() first"
126
+
127
+ knn = {
128
+ "field": self.vector_col_name,
129
+ "k": k,
130
+ "num_candidates": self.case_config.num_candidates,
131
+ "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}]
132
+ if filters
133
+ else [],
134
+ "query_vector": query,
135
+ }
136
+ size = k
137
+ try:
138
+ search_res = self.client.search(index=self.indice, knn=knn, size=size)
139
+ res = [d["_source"][self.id_col_name] for d in search_res["hits"]["hits"]]
140
+
141
+ return res
142
+ except Exception as e:
143
+ log.warning(f"Failed to search: {self.indice} error: {str(e)}")
144
+ raise e from None
145
+
146
+ def ready_to_search(self):
147
+ """ready_to_search will be called between insertion and search in performance cases."""
148
+ pass
149
+
150
+ def ready_to_load(self):
151
+ """ready_to_load will be called before load in load cases."""
152
+ pass