vectordb-bench 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +19 -5
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +93 -27
- vectordb_bench/backend/clients/__init__.py +14 -0
- vectordb_bench/backend/clients/api.py +1 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/clients/milvus/cli.py +291 -0
- vectordb_bench/backend/clients/milvus/milvus.py +13 -6
- vectordb_bench/backend/clients/pgvector/cli.py +116 -0
- vectordb_bench/backend/clients/pgvector/config.py +1 -1
- vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
- vectordb_bench/backend/clients/redis/cli.py +74 -0
- vectordb_bench/backend/clients/test/cli.py +25 -0
- vectordb_bench/backend/clients/test/config.py +18 -0
- vectordb_bench/backend/clients/test/test.py +62 -0
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/backend/runner/mp_runner.py +14 -3
- vectordb_bench/backend/runner/serial_runner.py +7 -3
- vectordb_bench/backend/task_runner.py +76 -26
- vectordb_bench/cli/__init__.py +0 -0
- vectordb_bench/cli/cli.py +362 -0
- vectordb_bench/cli/vectordbbench.py +22 -0
- vectordb_bench/config-files/sample_config.yml +17 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +23 -20
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +79 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
- vectordb_bench/frontend/components/run_test/dbSelector.py +8 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +13 -5
- vectordb_bench/frontend/components/tables/data.py +44 -0
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +140 -32
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +65 -0
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +24 -0
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/interface.py +21 -25
- vectordb_bench/metric.py +23 -1
- vectordb_bench/models.py +45 -1
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +228 -14
- vectordb_bench-0.0.12.dist-info/RECORD +115 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +1 -0
- vectordb_bench-0.0.10.dist-info/RECORD +0 -88
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
vectordb_bench/__init__.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
-
import environs
|
2
1
|
import inspect
|
3
2
|
import pathlib
|
4
|
-
from . import log_util
|
5
3
|
|
4
|
+
import environs
|
5
|
+
|
6
|
+
from . import log_util
|
6
7
|
|
7
8
|
env = environs.Env()
|
8
|
-
env.read_env(".env")
|
9
|
+
env.read_env(".env", False)
|
10
|
+
|
9
11
|
|
10
12
|
class config:
|
11
13
|
ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/"
|
@@ -19,9 +21,21 @@ class config:
|
|
19
21
|
|
20
22
|
DROP_OLD = env.bool("DROP_OLD", True)
|
21
23
|
USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
|
22
|
-
NUM_CONCURRENCY = [1, 5, 10, 15, 20, 25, 30, 35]
|
23
24
|
|
24
|
-
|
25
|
+
NUM_CONCURRENCY = env.list("NUM_CONCURRENCY", [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100], subcast=int )
|
26
|
+
|
27
|
+
CONCURRENCY_DURATION = 30
|
28
|
+
|
29
|
+
RESULTS_LOCAL_DIR = env.path(
|
30
|
+
"RESULTS_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("results")
|
31
|
+
)
|
32
|
+
CONFIG_LOCAL_DIR = env.path(
|
33
|
+
"CONFIG_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("config-files")
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
K_DEFAULT = 100 # default return top k nearest neighbors during search
|
38
|
+
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
|
25
39
|
|
26
40
|
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
|
27
41
|
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
|
@@ -14,7 +14,7 @@ class Assembler:
|
|
14
14
|
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
|
15
15
|
c_cls = task.case_config.case_id.case_cls
|
16
16
|
|
17
|
-
c = c_cls()
|
17
|
+
c = c_cls(task.case_config.custom_case)
|
18
18
|
if type(task.db_case_config) != EmptyDBCaseConfig:
|
19
19
|
task.db_case_config.metric_type = c.dataset.data.metric_type
|
20
20
|
|
vectordb_bench/backend/cases.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1
1
|
import typing
|
2
2
|
import logging
|
3
3
|
from enum import Enum, auto
|
4
|
+
from typing import Type
|
4
5
|
|
5
6
|
from vectordb_bench import config
|
7
|
+
from vectordb_bench.backend.clients.api import MetricType
|
6
8
|
from vectordb_bench.base import BaseModel
|
9
|
+
from vectordb_bench.frontend.components.custom.getCustomConfig import (
|
10
|
+
CustomDatasetConfig,
|
11
|
+
)
|
7
12
|
|
8
|
-
from .dataset import Dataset, DatasetManager
|
13
|
+
from .dataset import CustomDataset, Dataset, DatasetManager
|
9
14
|
|
10
15
|
|
11
16
|
log = logging.getLogger(__name__)
|
12
17
|
|
13
|
-
Case = typing.TypeVar("Case")
|
14
|
-
|
15
18
|
|
16
19
|
class CaseType(Enum):
|
17
20
|
"""
|
@@ -42,24 +45,27 @@ class CaseType(Enum):
|
|
42
45
|
Performance1536D500K99P = 14
|
43
46
|
Performance1536D5M99P = 15
|
44
47
|
|
48
|
+
Performance1536D50K = 50
|
49
|
+
|
45
50
|
Custom = 100
|
51
|
+
PerformanceCustomDataset = 101
|
46
52
|
|
47
|
-
|
48
|
-
|
49
|
-
|
53
|
+
def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
|
54
|
+
if custom_configs is None:
|
55
|
+
return type2case.get(self)()
|
56
|
+
else:
|
57
|
+
return type2case.get(self)(**custom_configs)
|
50
58
|
|
51
|
-
|
52
|
-
|
53
|
-
c = self.case_cls
|
59
|
+
def case_name(self, custom_configs: dict | None = None) -> str:
|
60
|
+
c = self.case_cls(custom_configs)
|
54
61
|
if c is not None:
|
55
|
-
return c
|
62
|
+
return c.name
|
56
63
|
raise ValueError("Case unsupported")
|
57
64
|
|
58
|
-
|
59
|
-
|
60
|
-
c = self.case_cls
|
65
|
+
def case_description(self, custom_configs: dict | None = None) -> str:
|
66
|
+
c = self.case_cls(custom_configs)
|
61
67
|
if c is not None:
|
62
|
-
return c
|
68
|
+
return c.description
|
63
69
|
raise ValueError("Case unsupported")
|
64
70
|
|
65
71
|
|
@@ -69,7 +75,7 @@ class CaseLabel(Enum):
|
|
69
75
|
|
70
76
|
|
71
77
|
class Case(BaseModel):
|
72
|
-
"""
|
78
|
+
"""Undefined case
|
73
79
|
|
74
80
|
Fields:
|
75
81
|
case_id(CaseType): default 9 case type plus one custom cases.
|
@@ -86,9 +92,9 @@ class Case(BaseModel):
|
|
86
92
|
dataset: DatasetManager
|
87
93
|
|
88
94
|
load_timeout: float | int
|
89
|
-
optimize_timeout: float | int | None
|
95
|
+
optimize_timeout: float | int | None = None
|
90
96
|
|
91
|
-
filter_rate: float | None
|
97
|
+
filter_rate: float | None = None
|
92
98
|
|
93
99
|
@property
|
94
100
|
def filters(self) -> dict | None:
|
@@ -115,20 +121,23 @@ class PerformanceCase(Case, BaseModel):
|
|
115
121
|
load_timeout: float | int = config.LOAD_TIMEOUT_DEFAULT
|
116
122
|
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT
|
117
123
|
|
124
|
+
|
118
125
|
class CapacityDim960(CapacityCase):
|
119
126
|
case_id: CaseType = CaseType.CapacityDim960
|
120
127
|
dataset: DatasetManager = Dataset.GIST.manager(100_000)
|
121
128
|
name: str = "Capacity Test (960 Dim Repeated)"
|
122
|
-
description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension
|
123
|
-
Number of inserted vectors will be
|
129
|
+
description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension
|
130
|
+
vectors (GIST 100K vectors, <b>960 dimensions</b>) until it is fully loaded. Number of inserted vectors will be
|
131
|
+
reported."""
|
124
132
|
|
125
133
|
|
126
134
|
class CapacityDim128(CapacityCase):
|
127
135
|
case_id: CaseType = CaseType.CapacityDim128
|
128
136
|
dataset: DatasetManager = Dataset.SIFT.manager(500_000)
|
129
137
|
name: str = "Capacity Test (128 Dim Repeated)"
|
130
|
-
description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension
|
131
|
-
Number of inserted vectors will be
|
138
|
+
description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension
|
139
|
+
vectors (SIFT 100K vectors, <b>128 dimensions</b>) until it is fully loaded. Number of inserted vectors will be
|
140
|
+
reported."""
|
132
141
|
|
133
142
|
|
134
143
|
class Performance768D10M(PerformanceCase):
|
@@ -238,6 +247,7 @@ Results will show index building time, recall, and maximum QPS."""
|
|
238
247
|
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
|
239
248
|
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
|
240
249
|
|
250
|
+
|
241
251
|
class Performance1536D5M1P(PerformanceCase):
|
242
252
|
case_id: CaseType = CaseType.Performance1536D5M1P
|
243
253
|
filter_rate: float | int | None = 0.01
|
@@ -248,6 +258,7 @@ Results will show index building time, recall, and maximum QPS."""
|
|
248
258
|
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M
|
249
259
|
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
|
250
260
|
|
261
|
+
|
251
262
|
class Performance1536D500K99P(PerformanceCase):
|
252
263
|
case_id: CaseType = CaseType.Performance1536D500K99P
|
253
264
|
filter_rate: float | int | None = 0.99
|
@@ -258,6 +269,7 @@ Results will show index building time, recall, and maximum QPS."""
|
|
258
269
|
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
|
259
270
|
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
|
260
271
|
|
272
|
+
|
261
273
|
class Performance1536D5M99P(PerformanceCase):
|
262
274
|
case_id: CaseType = CaseType.Performance1536D5M99P
|
263
275
|
filter_rate: float | int | None = 0.99
|
@@ -269,26 +281,80 @@ Results will show index building time, recall, and maximum QPS."""
|
|
269
281
|
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
|
270
282
|
|
271
283
|
|
284
|
+
class Performance1536D50K(PerformanceCase):
|
285
|
+
case_id: CaseType = CaseType.Performance1536D50K
|
286
|
+
filter_rate: float | int | None = None
|
287
|
+
dataset: DatasetManager = Dataset.OPENAI.manager(50_000)
|
288
|
+
name: str = "Search Performance Test (50K Dataset, 1536 Dim)"
|
289
|
+
description: str = """This case tests the search performance of a vector database with a medium 50K dataset (<b>OpenAI 50K vectors</b>, 1536 dimensions), at varying parallel levels.
|
290
|
+
Results will show index building time, recall, and maximum QPS."""
|
291
|
+
load_timeout: float | int = 3600
|
292
|
+
optimize_timeout: float | int | None = 15 * 60
|
293
|
+
|
294
|
+
|
295
|
+
def metric_type_map(s: str) -> MetricType:
|
296
|
+
if s.lower() == "cosine":
|
297
|
+
return MetricType.COSINE
|
298
|
+
if s.lower() == "l2" or s.lower() == "euclidean":
|
299
|
+
return MetricType.L2
|
300
|
+
if s.lower() == "ip":
|
301
|
+
return MetricType.IP
|
302
|
+
err_msg = f"Not support metric_type: {s}"
|
303
|
+
log.error(err_msg)
|
304
|
+
raise RuntimeError(err_msg)
|
305
|
+
|
306
|
+
|
307
|
+
class PerformanceCustomDataset(PerformanceCase):
|
308
|
+
case_id: CaseType = CaseType.PerformanceCustomDataset
|
309
|
+
name: str = "Performance With Custom Dataset"
|
310
|
+
description: str = ""
|
311
|
+
dataset: DatasetManager
|
312
|
+
|
313
|
+
def __init__(
|
314
|
+
self,
|
315
|
+
name,
|
316
|
+
description,
|
317
|
+
load_timeout,
|
318
|
+
optimize_timeout,
|
319
|
+
dataset_config,
|
320
|
+
**kwargs,
|
321
|
+
):
|
322
|
+
dataset_config = CustomDatasetConfig(**dataset_config)
|
323
|
+
dataset = CustomDataset(
|
324
|
+
name=dataset_config.name,
|
325
|
+
size=dataset_config.size,
|
326
|
+
dim=dataset_config.dim,
|
327
|
+
metric_type=metric_type_map(dataset_config.metric_type),
|
328
|
+
use_shuffled=dataset_config.use_shuffled,
|
329
|
+
with_gt=dataset_config.with_gt,
|
330
|
+
dir=dataset_config.dir,
|
331
|
+
file_num=dataset_config.file_count,
|
332
|
+
)
|
333
|
+
super().__init__(
|
334
|
+
name=name,
|
335
|
+
description=description,
|
336
|
+
load_timeout=load_timeout,
|
337
|
+
optimize_timeout=optimize_timeout,
|
338
|
+
dataset=DatasetManager(data=dataset),
|
339
|
+
)
|
340
|
+
|
341
|
+
|
272
342
|
type2case = {
|
273
343
|
CaseType.CapacityDim960: CapacityDim960,
|
274
344
|
CaseType.CapacityDim128: CapacityDim128,
|
275
|
-
|
276
345
|
CaseType.Performance768D100M: Performance768D100M,
|
277
346
|
CaseType.Performance768D10M: Performance768D10M,
|
278
347
|
CaseType.Performance768D1M: Performance768D1M,
|
279
|
-
|
280
348
|
CaseType.Performance768D10M1P: Performance768D10M1P,
|
281
349
|
CaseType.Performance768D1M1P: Performance768D1M1P,
|
282
350
|
CaseType.Performance768D10M99P: Performance768D10M99P,
|
283
351
|
CaseType.Performance768D1M99P: Performance768D1M99P,
|
284
|
-
|
285
352
|
CaseType.Performance1536D500K: Performance1536D500K,
|
286
353
|
CaseType.Performance1536D5M: Performance1536D5M,
|
287
|
-
|
288
354
|
CaseType.Performance1536D500K1P: Performance1536D500K1P,
|
289
355
|
CaseType.Performance1536D5M1P: Performance1536D5M1P,
|
290
|
-
|
291
356
|
CaseType.Performance1536D500K99P: Performance1536D500K99P,
|
292
357
|
CaseType.Performance1536D5M99P: Performance1536D5M99P,
|
293
|
-
|
358
|
+
CaseType.Performance1536D50K: Performance1536D50K,
|
359
|
+
CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
|
294
360
|
}
|
@@ -32,6 +32,8 @@ class DB(Enum):
|
|
32
32
|
PgVectoRS = "PgVectoRS"
|
33
33
|
Redis = "Redis"
|
34
34
|
Chroma = "Chroma"
|
35
|
+
AWSOpenSearch = "OpenSearch"
|
36
|
+
Test = "test"
|
35
37
|
|
36
38
|
|
37
39
|
@property
|
@@ -77,6 +79,10 @@ class DB(Enum):
|
|
77
79
|
from .chroma.chroma import ChromaClient
|
78
80
|
return ChromaClient
|
79
81
|
|
82
|
+
if self == DB.AWSOpenSearch:
|
83
|
+
from .aws_opensearch.aws_opensearch import AWSOpenSearch
|
84
|
+
return AWSOpenSearch
|
85
|
+
|
80
86
|
@property
|
81
87
|
def config_cls(self) -> Type[DBConfig]:
|
82
88
|
"""Import while in use"""
|
@@ -120,6 +126,10 @@ class DB(Enum):
|
|
120
126
|
from .chroma.config import ChromaConfig
|
121
127
|
return ChromaConfig
|
122
128
|
|
129
|
+
if self == DB.AWSOpenSearch:
|
130
|
+
from .aws_opensearch.config import AWSOpenSearchConfig
|
131
|
+
return AWSOpenSearchConfig
|
132
|
+
|
123
133
|
def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
124
134
|
if self == DB.Milvus:
|
125
135
|
from .milvus.config import _milvus_case_config
|
@@ -149,6 +159,10 @@ class DB(Enum):
|
|
149
159
|
from .pgvecto_rs.config import _pgvecto_rs_case_config
|
150
160
|
return _pgvecto_rs_case_config.get(index_type)
|
151
161
|
|
162
|
+
if self == DB.AWSOpenSearch:
|
163
|
+
from .aws_opensearch.config import AWSOpenSearchIndexConfig
|
164
|
+
return AWSOpenSearchIndexConfig
|
165
|
+
|
152
166
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
153
167
|
return EmptyDBCaseConfig
|
154
168
|
|
@@ -47,7 +47,7 @@ class DBConfig(ABC, BaseModel):
|
|
47
47
|
def not_empty_field(cls, v, field):
|
48
48
|
if field.name == "db_label":
|
49
49
|
return v
|
50
|
-
if isinstance(v, (str, SecretStr))
|
50
|
+
if not v and isinstance(v, (str, SecretStr)):
|
51
51
|
raise ValueError("Empty string!")
|
52
52
|
return v
|
53
53
|
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
import time
|
4
|
+
from typing import Iterable, Type
|
5
|
+
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
|
6
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
7
|
+
from opensearchpy import OpenSearch
|
8
|
+
from opensearchpy.helpers import bulk
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class AWSOpenSearch(VectorDB):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
dim: int,
|
17
|
+
db_config: dict,
|
18
|
+
db_case_config: AWSOpenSearchIndexConfig,
|
19
|
+
index_name: str = "vdb_bench_index", # must be lowercase
|
20
|
+
id_col_name: str = "id",
|
21
|
+
vector_col_name: str = "embedding",
|
22
|
+
drop_old: bool = False,
|
23
|
+
**kwargs,
|
24
|
+
):
|
25
|
+
self.dim = dim
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.index_name = index_name
|
29
|
+
self.id_col_name = id_col_name
|
30
|
+
self.category_col_names = [
|
31
|
+
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
|
32
|
+
]
|
33
|
+
self.vector_col_name = vector_col_name
|
34
|
+
|
35
|
+
log.info(f"AWS_OpenSearch client config: {self.db_config}")
|
36
|
+
client = OpenSearch(**self.db_config)
|
37
|
+
if drop_old:
|
38
|
+
log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
|
39
|
+
is_existed = client.indices.exists(index=self.index_name)
|
40
|
+
if is_existed:
|
41
|
+
client.indices.delete(index=self.index_name)
|
42
|
+
self._create_index(client)
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def config_cls(cls) -> AWSOpenSearchConfig:
|
46
|
+
return AWSOpenSearchConfig
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def case_config_cls(
|
50
|
+
cls, index_type: IndexType | None = None
|
51
|
+
) -> AWSOpenSearchIndexConfig:
|
52
|
+
return AWSOpenSearchIndexConfig
|
53
|
+
|
54
|
+
def _create_index(self, client: OpenSearch):
|
55
|
+
settings = {
|
56
|
+
"index": {
|
57
|
+
"knn": True,
|
58
|
+
# "number_of_shards": 5,
|
59
|
+
# "refresh_interval": "600s",
|
60
|
+
}
|
61
|
+
}
|
62
|
+
mappings = {
|
63
|
+
"properties": {
|
64
|
+
self.id_col_name: {"type": "integer"},
|
65
|
+
**{
|
66
|
+
categoryCol: {"type": "keyword"}
|
67
|
+
for categoryCol in self.category_col_names
|
68
|
+
},
|
69
|
+
self.vector_col_name: {
|
70
|
+
"type": "knn_vector",
|
71
|
+
"dimension": self.dim,
|
72
|
+
"method": self.case_config.index_param(),
|
73
|
+
},
|
74
|
+
}
|
75
|
+
}
|
76
|
+
try:
|
77
|
+
client.indices.create(
|
78
|
+
index=self.index_name, body=dict(settings=settings, mappings=mappings)
|
79
|
+
)
|
80
|
+
except Exception as e:
|
81
|
+
log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
|
82
|
+
raise e from None
|
83
|
+
|
84
|
+
@contextmanager
|
85
|
+
def init(self) -> None:
|
86
|
+
"""connect to elasticsearch"""
|
87
|
+
self.client = OpenSearch(**self.db_config)
|
88
|
+
|
89
|
+
yield
|
90
|
+
# self.client.transport.close()
|
91
|
+
self.client = None
|
92
|
+
del self.client
|
93
|
+
|
94
|
+
def insert_embeddings(
|
95
|
+
self,
|
96
|
+
embeddings: Iterable[list[float]],
|
97
|
+
metadata: list[int],
|
98
|
+
**kwargs,
|
99
|
+
) -> tuple[int, Exception]:
|
100
|
+
"""Insert the embeddings to the elasticsearch."""
|
101
|
+
assert self.client is not None, "should self.init() first"
|
102
|
+
|
103
|
+
insert_data = []
|
104
|
+
for i in range(len(embeddings)):
|
105
|
+
insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
|
106
|
+
insert_data.append({self.vector_col_name: embeddings[i]})
|
107
|
+
try:
|
108
|
+
resp = self.client.bulk(insert_data)
|
109
|
+
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
|
110
|
+
resp = self.client.indices.stats(self.index_name)
|
111
|
+
log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
|
112
|
+
return (len(embeddings), None)
|
113
|
+
except Exception as e:
|
114
|
+
log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
|
115
|
+
time.sleep(10)
|
116
|
+
return self.insert_embeddings(embeddings, metadata)
|
117
|
+
|
118
|
+
def search_embedding(
|
119
|
+
self,
|
120
|
+
query: list[float],
|
121
|
+
k: int = 100,
|
122
|
+
filters: dict | None = None,
|
123
|
+
) -> list[int]:
|
124
|
+
"""Get k most similar embeddings to query vector.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
query(list[float]): query embedding to look up documents similar to.
|
128
|
+
k(int): Number of most similar embeddings to return. Defaults to 100.
|
129
|
+
filters(dict, optional): filtering expression to filter the data while searching.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
133
|
+
"""
|
134
|
+
assert self.client is not None, "should self.init() first"
|
135
|
+
|
136
|
+
body = {
|
137
|
+
"size": k,
|
138
|
+
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
|
139
|
+
}
|
140
|
+
try:
|
141
|
+
resp = self.client.search(index=self.index_name, body=body)
|
142
|
+
log.info(f'Search took: {resp["took"]}')
|
143
|
+
log.info(f'Search shards: {resp["_shards"]}')
|
144
|
+
log.info(f'Search hits total: {resp["hits"]["total"]}')
|
145
|
+
result = [int(d["_id"]) for d in resp["hits"]["hits"]]
|
146
|
+
# log.info(f'success! length={len(res)}')
|
147
|
+
|
148
|
+
return result
|
149
|
+
except Exception as e:
|
150
|
+
log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
|
151
|
+
raise e from None
|
152
|
+
|
153
|
+
def optimize(self):
|
154
|
+
"""optimize will be called between insertion and search in performance cases."""
|
155
|
+
pass
|
156
|
+
|
157
|
+
def ready_to_load(self):
|
158
|
+
"""ready_to_load will be called before load in load cases."""
|
159
|
+
pass
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from typing import Annotated, TypedDict, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
HNSWFlavor2,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from .. import DB
|
14
|
+
|
15
|
+
|
16
|
+
class AWSOpenSearchTypedDict(TypedDict):
|
17
|
+
host: Annotated[
|
18
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
+
]
|
20
|
+
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
21
|
+
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
22
|
+
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
23
|
+
|
24
|
+
|
25
|
+
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
|
26
|
+
...
|
27
|
+
|
28
|
+
|
29
|
+
@cli.command()
|
30
|
+
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
|
31
|
+
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
32
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
33
|
+
run(
|
34
|
+
db=DB.AWSOpenSearch,
|
35
|
+
db_config=AWSOpenSearchConfig(
|
36
|
+
host=parameters["host"],
|
37
|
+
port=parameters["port"],
|
38
|
+
user=parameters["user"],
|
39
|
+
password=SecretStr(parameters["password"]),
|
40
|
+
),
|
41
|
+
db_case_config=AWSOpenSearchIndexConfig(
|
42
|
+
),
|
43
|
+
**parameters,
|
44
|
+
)
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from pydantic import SecretStr, BaseModel
|
3
|
+
|
4
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
5
|
+
|
6
|
+
|
7
|
+
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
8
|
+
host: str = ""
|
9
|
+
port: int = 443
|
10
|
+
user: str = ""
|
11
|
+
password: SecretStr = ""
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
return {
|
15
|
+
"hosts": [{'host': self.host, 'port': self.port}],
|
16
|
+
"http_auth": (self.user, self.password.get_secret_value()),
|
17
|
+
"use_ssl": True,
|
18
|
+
"http_compress": True,
|
19
|
+
"verify_certs": True,
|
20
|
+
"ssl_assert_hostname": False,
|
21
|
+
"ssl_show_warn": False,
|
22
|
+
"timeout": 600,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
class AWSOS_Engine(Enum):
|
27
|
+
nmslib = "nmslib"
|
28
|
+
faiss = "faiss"
|
29
|
+
lucene = "Lucene"
|
30
|
+
|
31
|
+
|
32
|
+
class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
33
|
+
metric_type: MetricType = MetricType.L2
|
34
|
+
engine: AWSOS_Engine = AWSOS_Engine.nmslib
|
35
|
+
efConstruction: int = 360
|
36
|
+
M: int = 30
|
37
|
+
|
38
|
+
def parse_metric(self) -> str:
|
39
|
+
if self.metric_type == MetricType.IP:
|
40
|
+
return "innerproduct" # only support faiss / nmslib, not for Lucene.
|
41
|
+
elif self.metric_type == MetricType.COSINE:
|
42
|
+
return "cosinesimil"
|
43
|
+
return "l2"
|
44
|
+
|
45
|
+
def index_param(self) -> dict:
|
46
|
+
params = {
|
47
|
+
"name": "hnsw",
|
48
|
+
"space_type": self.parse_metric(),
|
49
|
+
"engine": self.engine.value,
|
50
|
+
"parameters": {
|
51
|
+
"ef_construction": self.efConstruction,
|
52
|
+
"m": self.M
|
53
|
+
}
|
54
|
+
}
|
55
|
+
return params
|
56
|
+
|
57
|
+
def search_param(self) -> dict:
|
58
|
+
return {}
|