vectordb-bench 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. vectordb_bench/__init__.py +19 -5
  2. vectordb_bench/backend/assembler.py +1 -1
  3. vectordb_bench/backend/cases.py +93 -27
  4. vectordb_bench/backend/clients/__init__.py +14 -0
  5. vectordb_bench/backend/clients/api.py +1 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
  9. vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
  10. vectordb_bench/backend/clients/milvus/cli.py +291 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +13 -6
  12. vectordb_bench/backend/clients/pgvector/cli.py +116 -0
  13. vectordb_bench/backend/clients/pgvector/config.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
  15. vectordb_bench/backend/clients/redis/cli.py +74 -0
  16. vectordb_bench/backend/clients/test/cli.py +25 -0
  17. vectordb_bench/backend/clients/test/config.py +18 -0
  18. vectordb_bench/backend/clients/test/test.py +62 -0
  19. vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
  20. vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
  21. vectordb_bench/backend/dataset.py +27 -5
  22. vectordb_bench/backend/runner/mp_runner.py +14 -3
  23. vectordb_bench/backend/runner/serial_runner.py +7 -3
  24. vectordb_bench/backend/task_runner.py +76 -26
  25. vectordb_bench/cli/__init__.py +0 -0
  26. vectordb_bench/cli/cli.py +362 -0
  27. vectordb_bench/cli/vectordbbench.py +22 -0
  28. vectordb_bench/config-files/sample_config.yml +17 -0
  29. vectordb_bench/custom/custom_case.json +18 -0
  30. vectordb_bench/frontend/components/check_results/charts.py +6 -6
  31. vectordb_bench/frontend/components/check_results/data.py +23 -20
  32. vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
  33. vectordb_bench/frontend/components/check_results/filters.py +20 -13
  34. vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
  35. vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
  36. vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
  37. vectordb_bench/frontend/components/concurrent/charts.py +79 -0
  38. vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
  39. vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
  40. vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
  41. vectordb_bench/frontend/components/custom/initStyle.py +15 -0
  42. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  43. vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
  44. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
  45. vectordb_bench/frontend/components/run_test/dbSelector.py +8 -14
  46. vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
  47. vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
  48. vectordb_bench/frontend/components/run_test/submitTask.py +13 -5
  49. vectordb_bench/frontend/components/tables/data.py +44 -0
  50. vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +140 -32
  51. vectordb_bench/frontend/{const → config}/styles.py +2 -0
  52. vectordb_bench/frontend/pages/concurrent.py +65 -0
  53. vectordb_bench/frontend/pages/custom.py +64 -0
  54. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
  55. vectordb_bench/frontend/pages/run_test.py +4 -0
  56. vectordb_bench/frontend/pages/tables.py +24 -0
  57. vectordb_bench/frontend/utils.py +17 -1
  58. vectordb_bench/frontend/vdb_benchmark.py +3 -3
  59. vectordb_bench/interface.py +21 -25
  60. vectordb_bench/metric.py +23 -1
  61. vectordb_bench/models.py +45 -1
  62. vectordb_bench/results/getLeaderboardData.py +1 -1
  63. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +228 -14
  64. vectordb_bench-0.0.12.dist-info/RECORD +115 -0
  65. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
  66. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +1 -0
  67. vectordb_bench-0.0.10.dist-info/RECORD +0 -88
  68. /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
  69. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
  70. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
- import environs
2
1
  import inspect
3
2
  import pathlib
4
- from . import log_util
5
3
 
4
+ import environs
5
+
6
+ from . import log_util
6
7
 
7
8
  env = environs.Env()
8
- env.read_env(".env")
9
+ env.read_env(".env", False)
10
+
9
11
 
10
12
  class config:
11
13
  ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/"
@@ -19,9 +21,21 @@ class config:
19
21
 
20
22
  DROP_OLD = env.bool("DROP_OLD", True)
21
23
  USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
22
- NUM_CONCURRENCY = [1, 5, 10, 15, 20, 25, 30, 35]
23
24
 
24
- RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
25
+ NUM_CONCURRENCY = env.list("NUM_CONCURRENCY", [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100], subcast=int )
26
+
27
+ CONCURRENCY_DURATION = 30
28
+
29
+ RESULTS_LOCAL_DIR = env.path(
30
+ "RESULTS_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("results")
31
+ )
32
+ CONFIG_LOCAL_DIR = env.path(
33
+ "CONFIG_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("config-files")
34
+ )
35
+
36
+
37
+ K_DEFAULT = 100 # default return top k nearest neighbors during search
38
+ CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
25
39
 
26
40
  CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
27
41
  LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
@@ -14,7 +14,7 @@ class Assembler:
14
14
  def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
15
15
  c_cls = task.case_config.case_id.case_cls
16
16
 
17
- c = c_cls()
17
+ c = c_cls(task.case_config.custom_case)
18
18
  if type(task.db_case_config) != EmptyDBCaseConfig:
19
19
  task.db_case_config.metric_type = c.dataset.data.metric_type
20
20
 
@@ -1,17 +1,20 @@
1
1
  import typing
2
2
  import logging
3
3
  from enum import Enum, auto
4
+ from typing import Type
4
5
 
5
6
  from vectordb_bench import config
7
+ from vectordb_bench.backend.clients.api import MetricType
6
8
  from vectordb_bench.base import BaseModel
9
+ from vectordb_bench.frontend.components.custom.getCustomConfig import (
10
+ CustomDatasetConfig,
11
+ )
7
12
 
8
- from .dataset import Dataset, DatasetManager
13
+ from .dataset import CustomDataset, Dataset, DatasetManager
9
14
 
10
15
 
11
16
  log = logging.getLogger(__name__)
12
17
 
13
- Case = typing.TypeVar("Case")
14
-
15
18
 
16
19
  class CaseType(Enum):
17
20
  """
@@ -42,24 +45,27 @@ class CaseType(Enum):
42
45
  Performance1536D500K99P = 14
43
46
  Performance1536D5M99P = 15
44
47
 
48
+ Performance1536D50K = 50
49
+
45
50
  Custom = 100
51
+ PerformanceCustomDataset = 101
46
52
 
47
- @property
48
- def case_cls(self, custom_configs: dict | None = None) -> Case:
49
- return type2case.get(self)
53
+ def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
54
+ if custom_configs is None:
55
+ return type2case.get(self)()
56
+ else:
57
+ return type2case.get(self)(**custom_configs)
50
58
 
51
- @property
52
- def case_name(self) -> str:
53
- c = self.case_cls
59
+ def case_name(self, custom_configs: dict | None = None) -> str:
60
+ c = self.case_cls(custom_configs)
54
61
  if c is not None:
55
- return c().name
62
+ return c.name
56
63
  raise ValueError("Case unsupported")
57
64
 
58
- @property
59
- def case_description(self) -> str:
60
- c = self.case_cls
65
+ def case_description(self, custom_configs: dict | None = None) -> str:
66
+ c = self.case_cls(custom_configs)
61
67
  if c is not None:
62
- return c().description
68
+ return c.description
63
69
  raise ValueError("Case unsupported")
64
70
 
65
71
 
@@ -69,7 +75,7 @@ class CaseLabel(Enum):
69
75
 
70
76
 
71
77
  class Case(BaseModel):
72
- """Undifined case
78
+ """Undefined case
73
79
 
74
80
  Fields:
75
81
  case_id(CaseType): default 9 case type plus one custom cases.
@@ -86,9 +92,9 @@ class Case(BaseModel):
86
92
  dataset: DatasetManager
87
93
 
88
94
  load_timeout: float | int
89
- optimize_timeout: float | int | None
95
+ optimize_timeout: float | int | None = None
90
96
 
91
- filter_rate: float | None
97
+ filter_rate: float | None = None
92
98
 
93
99
  @property
94
100
  def filters(self) -> dict | None:
@@ -115,20 +121,23 @@ class PerformanceCase(Case, BaseModel):
115
121
  load_timeout: float | int = config.LOAD_TIMEOUT_DEFAULT
116
122
  optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT
117
123
 
124
+
118
125
  class CapacityDim960(CapacityCase):
119
126
  case_id: CaseType = CaseType.CapacityDim960
120
127
  dataset: DatasetManager = Dataset.GIST.manager(100_000)
121
128
  name: str = "Capacity Test (960 Dim Repeated)"
122
- description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension vectors (GIST 100K vectors, <b>960 dimensions</b>) until it is fully loaded.
123
- Number of inserted vectors will be reported."""
129
+ description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension
130
+ vectors (GIST 100K vectors, <b>960 dimensions</b>) until it is fully loaded. Number of inserted vectors will be
131
+ reported."""
124
132
 
125
133
 
126
134
  class CapacityDim128(CapacityCase):
127
135
  case_id: CaseType = CaseType.CapacityDim128
128
136
  dataset: DatasetManager = Dataset.SIFT.manager(500_000)
129
137
  name: str = "Capacity Test (128 Dim Repeated)"
130
- description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension vectors (SIFT 100K vectors, <b>128 dimensions</b>) until it is fully loaded.
131
- Number of inserted vectors will be reported."""
138
+ description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension
139
+ vectors (SIFT 100K vectors, <b>128 dimensions</b>) until it is fully loaded. Number of inserted vectors will be
140
+ reported."""
132
141
 
133
142
 
134
143
  class Performance768D10M(PerformanceCase):
@@ -238,6 +247,7 @@ Results will show index building time, recall, and maximum QPS."""
238
247
  load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
239
248
  optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
240
249
 
250
+
241
251
  class Performance1536D5M1P(PerformanceCase):
242
252
  case_id: CaseType = CaseType.Performance1536D5M1P
243
253
  filter_rate: float | int | None = 0.01
@@ -248,6 +258,7 @@ Results will show index building time, recall, and maximum QPS."""
248
258
  load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M
249
259
  optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
250
260
 
261
+
251
262
  class Performance1536D500K99P(PerformanceCase):
252
263
  case_id: CaseType = CaseType.Performance1536D500K99P
253
264
  filter_rate: float | int | None = 0.99
@@ -258,6 +269,7 @@ Results will show index building time, recall, and maximum QPS."""
258
269
  load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
259
270
  optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
260
271
 
272
+
261
273
  class Performance1536D5M99P(PerformanceCase):
262
274
  case_id: CaseType = CaseType.Performance1536D5M99P
263
275
  filter_rate: float | int | None = 0.99
@@ -269,26 +281,80 @@ Results will show index building time, recall, and maximum QPS."""
269
281
  optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
270
282
 
271
283
 
284
+ class Performance1536D50K(PerformanceCase):
285
+ case_id: CaseType = CaseType.Performance1536D50K
286
+ filter_rate: float | int | None = None
287
+ dataset: DatasetManager = Dataset.OPENAI.manager(50_000)
288
+ name: str = "Search Performance Test (50K Dataset, 1536 Dim)"
289
+ description: str = """This case tests the search performance of a vector database with a medium 50K dataset (<b>OpenAI 50K vectors</b>, 1536 dimensions), at varying parallel levels.
290
+ Results will show index building time, recall, and maximum QPS."""
291
+ load_timeout: float | int = 3600
292
+ optimize_timeout: float | int | None = 15 * 60
293
+
294
+
295
+ def metric_type_map(s: str) -> MetricType:
296
+ if s.lower() == "cosine":
297
+ return MetricType.COSINE
298
+ if s.lower() == "l2" or s.lower() == "euclidean":
299
+ return MetricType.L2
300
+ if s.lower() == "ip":
301
+ return MetricType.IP
302
+ err_msg = f"Not support metric_type: {s}"
303
+ log.error(err_msg)
304
+ raise RuntimeError(err_msg)
305
+
306
+
307
+ class PerformanceCustomDataset(PerformanceCase):
308
+ case_id: CaseType = CaseType.PerformanceCustomDataset
309
+ name: str = "Performance With Custom Dataset"
310
+ description: str = ""
311
+ dataset: DatasetManager
312
+
313
+ def __init__(
314
+ self,
315
+ name,
316
+ description,
317
+ load_timeout,
318
+ optimize_timeout,
319
+ dataset_config,
320
+ **kwargs,
321
+ ):
322
+ dataset_config = CustomDatasetConfig(**dataset_config)
323
+ dataset = CustomDataset(
324
+ name=dataset_config.name,
325
+ size=dataset_config.size,
326
+ dim=dataset_config.dim,
327
+ metric_type=metric_type_map(dataset_config.metric_type),
328
+ use_shuffled=dataset_config.use_shuffled,
329
+ with_gt=dataset_config.with_gt,
330
+ dir=dataset_config.dir,
331
+ file_num=dataset_config.file_count,
332
+ )
333
+ super().__init__(
334
+ name=name,
335
+ description=description,
336
+ load_timeout=load_timeout,
337
+ optimize_timeout=optimize_timeout,
338
+ dataset=DatasetManager(data=dataset),
339
+ )
340
+
341
+
272
342
  type2case = {
273
343
  CaseType.CapacityDim960: CapacityDim960,
274
344
  CaseType.CapacityDim128: CapacityDim128,
275
-
276
345
  CaseType.Performance768D100M: Performance768D100M,
277
346
  CaseType.Performance768D10M: Performance768D10M,
278
347
  CaseType.Performance768D1M: Performance768D1M,
279
-
280
348
  CaseType.Performance768D10M1P: Performance768D10M1P,
281
349
  CaseType.Performance768D1M1P: Performance768D1M1P,
282
350
  CaseType.Performance768D10M99P: Performance768D10M99P,
283
351
  CaseType.Performance768D1M99P: Performance768D1M99P,
284
-
285
352
  CaseType.Performance1536D500K: Performance1536D500K,
286
353
  CaseType.Performance1536D5M: Performance1536D5M,
287
-
288
354
  CaseType.Performance1536D500K1P: Performance1536D500K1P,
289
355
  CaseType.Performance1536D5M1P: Performance1536D5M1P,
290
-
291
356
  CaseType.Performance1536D500K99P: Performance1536D500K99P,
292
357
  CaseType.Performance1536D5M99P: Performance1536D5M99P,
293
-
358
+ CaseType.Performance1536D50K: Performance1536D50K,
359
+ CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
294
360
  }
@@ -32,6 +32,8 @@ class DB(Enum):
32
32
  PgVectoRS = "PgVectoRS"
33
33
  Redis = "Redis"
34
34
  Chroma = "Chroma"
35
+ AWSOpenSearch = "OpenSearch"
36
+ Test = "test"
35
37
 
36
38
 
37
39
  @property
@@ -77,6 +79,10 @@ class DB(Enum):
77
79
  from .chroma.chroma import ChromaClient
78
80
  return ChromaClient
79
81
 
82
+ if self == DB.AWSOpenSearch:
83
+ from .aws_opensearch.aws_opensearch import AWSOpenSearch
84
+ return AWSOpenSearch
85
+
80
86
  @property
81
87
  def config_cls(self) -> Type[DBConfig]:
82
88
  """Import while in use"""
@@ -120,6 +126,10 @@ class DB(Enum):
120
126
  from .chroma.config import ChromaConfig
121
127
  return ChromaConfig
122
128
 
129
+ if self == DB.AWSOpenSearch:
130
+ from .aws_opensearch.config import AWSOpenSearchConfig
131
+ return AWSOpenSearchConfig
132
+
123
133
  def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
124
134
  if self == DB.Milvus:
125
135
  from .milvus.config import _milvus_case_config
@@ -149,6 +159,10 @@ class DB(Enum):
149
159
  from .pgvecto_rs.config import _pgvecto_rs_case_config
150
160
  return _pgvecto_rs_case_config.get(index_type)
151
161
 
162
+ if self == DB.AWSOpenSearch:
163
+ from .aws_opensearch.config import AWSOpenSearchIndexConfig
164
+ return AWSOpenSearchIndexConfig
165
+
152
166
  # DB.Pinecone, DB.Chroma, DB.Redis
153
167
  return EmptyDBCaseConfig
154
168
 
@@ -47,7 +47,7 @@ class DBConfig(ABC, BaseModel):
47
47
  def not_empty_field(cls, v, field):
48
48
  if field.name == "db_label":
49
49
  return v
50
- if isinstance(v, (str, SecretStr)) and len(v) == 0:
50
+ if not v and isinstance(v, (str, SecretStr)):
51
51
  raise ValueError("Empty string!")
52
52
  return v
53
53
 
@@ -0,0 +1,159 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ import time
4
+ from typing import Iterable, Type
5
+ from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
6
+ from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
7
+ from opensearchpy import OpenSearch
8
+ from opensearchpy.helpers import bulk
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class AWSOpenSearch(VectorDB):
14
+ def __init__(
15
+ self,
16
+ dim: int,
17
+ db_config: dict,
18
+ db_case_config: AWSOpenSearchIndexConfig,
19
+ index_name: str = "vdb_bench_index", # must be lowercase
20
+ id_col_name: str = "id",
21
+ vector_col_name: str = "embedding",
22
+ drop_old: bool = False,
23
+ **kwargs,
24
+ ):
25
+ self.dim = dim
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.index_name = index_name
29
+ self.id_col_name = id_col_name
30
+ self.category_col_names = [
31
+ f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
32
+ ]
33
+ self.vector_col_name = vector_col_name
34
+
35
+ log.info(f"AWS_OpenSearch client config: {self.db_config}")
36
+ client = OpenSearch(**self.db_config)
37
+ if drop_old:
38
+ log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
39
+ is_existed = client.indices.exists(index=self.index_name)
40
+ if is_existed:
41
+ client.indices.delete(index=self.index_name)
42
+ self._create_index(client)
43
+
44
+ @classmethod
45
+ def config_cls(cls) -> AWSOpenSearchConfig:
46
+ return AWSOpenSearchConfig
47
+
48
+ @classmethod
49
+ def case_config_cls(
50
+ cls, index_type: IndexType | None = None
51
+ ) -> AWSOpenSearchIndexConfig:
52
+ return AWSOpenSearchIndexConfig
53
+
54
+ def _create_index(self, client: OpenSearch):
55
+ settings = {
56
+ "index": {
57
+ "knn": True,
58
+ # "number_of_shards": 5,
59
+ # "refresh_interval": "600s",
60
+ }
61
+ }
62
+ mappings = {
63
+ "properties": {
64
+ self.id_col_name: {"type": "integer"},
65
+ **{
66
+ categoryCol: {"type": "keyword"}
67
+ for categoryCol in self.category_col_names
68
+ },
69
+ self.vector_col_name: {
70
+ "type": "knn_vector",
71
+ "dimension": self.dim,
72
+ "method": self.case_config.index_param(),
73
+ },
74
+ }
75
+ }
76
+ try:
77
+ client.indices.create(
78
+ index=self.index_name, body=dict(settings=settings, mappings=mappings)
79
+ )
80
+ except Exception as e:
81
+ log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
82
+ raise e from None
83
+
84
+ @contextmanager
85
+ def init(self) -> None:
86
+ """connect to elasticsearch"""
87
+ self.client = OpenSearch(**self.db_config)
88
+
89
+ yield
90
+ # self.client.transport.close()
91
+ self.client = None
92
+ del self.client
93
+
94
+ def insert_embeddings(
95
+ self,
96
+ embeddings: Iterable[list[float]],
97
+ metadata: list[int],
98
+ **kwargs,
99
+ ) -> tuple[int, Exception]:
100
+ """Insert the embeddings to the elasticsearch."""
101
+ assert self.client is not None, "should self.init() first"
102
+
103
+ insert_data = []
104
+ for i in range(len(embeddings)):
105
+ insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
106
+ insert_data.append({self.vector_col_name: embeddings[i]})
107
+ try:
108
+ resp = self.client.bulk(insert_data)
109
+ log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
110
+ resp = self.client.indices.stats(self.index_name)
111
+ log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
112
+ return (len(embeddings), None)
113
+ except Exception as e:
114
+ log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
115
+ time.sleep(10)
116
+ return self.insert_embeddings(embeddings, metadata)
117
+
118
+ def search_embedding(
119
+ self,
120
+ query: list[float],
121
+ k: int = 100,
122
+ filters: dict | None = None,
123
+ ) -> list[int]:
124
+ """Get k most similar embeddings to query vector.
125
+
126
+ Args:
127
+ query(list[float]): query embedding to look up documents similar to.
128
+ k(int): Number of most similar embeddings to return. Defaults to 100.
129
+ filters(dict, optional): filtering expression to filter the data while searching.
130
+
131
+ Returns:
132
+ list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
133
+ """
134
+ assert self.client is not None, "should self.init() first"
135
+
136
+ body = {
137
+ "size": k,
138
+ "query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
139
+ }
140
+ try:
141
+ resp = self.client.search(index=self.index_name, body=body)
142
+ log.info(f'Search took: {resp["took"]}')
143
+ log.info(f'Search shards: {resp["_shards"]}')
144
+ log.info(f'Search hits total: {resp["hits"]["total"]}')
145
+ result = [int(d["_id"]) for d in resp["hits"]["hits"]]
146
+ # log.info(f'success! length={len(res)}')
147
+
148
+ return result
149
+ except Exception as e:
150
+ log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
151
+ raise e from None
152
+
153
+ def optimize(self):
154
+ """optimize will be called between insertion and search in performance cases."""
155
+ pass
156
+
157
+ def ready_to_load(self):
158
+ """ready_to_load will be called before load in load cases."""
159
+ pass
@@ -0,0 +1,44 @@
1
+ from typing import Annotated, TypedDict, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from ....cli.cli import (
7
+ CommonTypedDict,
8
+ HNSWFlavor2,
9
+ cli,
10
+ click_parameter_decorators_from_typed_dict,
11
+ run,
12
+ )
13
+ from .. import DB
14
+
15
+
16
+ class AWSOpenSearchTypedDict(TypedDict):
17
+ host: Annotated[
18
+ str, click.option("--host", type=str, help="Db host", required=True)
19
+ ]
20
+ port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
21
+ user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
22
+ password: Annotated[str, click.option("--password", type=str, help="Db password")]
23
+
24
+
25
+ class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
26
+ ...
27
+
28
+
29
+ @cli.command()
30
+ @click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
31
+ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
32
+ from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
33
+ run(
34
+ db=DB.AWSOpenSearch,
35
+ db_config=AWSOpenSearchConfig(
36
+ host=parameters["host"],
37
+ port=parameters["port"],
38
+ user=parameters["user"],
39
+ password=SecretStr(parameters["password"]),
40
+ ),
41
+ db_case_config=AWSOpenSearchIndexConfig(
42
+ ),
43
+ **parameters,
44
+ )
@@ -0,0 +1,58 @@
1
+ from enum import Enum
2
+ from pydantic import SecretStr, BaseModel
3
+
4
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
5
+
6
+
7
+ class AWSOpenSearchConfig(DBConfig, BaseModel):
8
+ host: str = ""
9
+ port: int = 443
10
+ user: str = ""
11
+ password: SecretStr = ""
12
+
13
+ def to_dict(self) -> dict:
14
+ return {
15
+ "hosts": [{'host': self.host, 'port': self.port}],
16
+ "http_auth": (self.user, self.password.get_secret_value()),
17
+ "use_ssl": True,
18
+ "http_compress": True,
19
+ "verify_certs": True,
20
+ "ssl_assert_hostname": False,
21
+ "ssl_show_warn": False,
22
+ "timeout": 600,
23
+ }
24
+
25
+
26
+ class AWSOS_Engine(Enum):
27
+ nmslib = "nmslib"
28
+ faiss = "faiss"
29
+ lucene = "Lucene"
30
+
31
+
32
+ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
33
+ metric_type: MetricType = MetricType.L2
34
+ engine: AWSOS_Engine = AWSOS_Engine.nmslib
35
+ efConstruction: int = 360
36
+ M: int = 30
37
+
38
+ def parse_metric(self) -> str:
39
+ if self.metric_type == MetricType.IP:
40
+ return "innerproduct" # only support faiss / nmslib, not for Lucene.
41
+ elif self.metric_type == MetricType.COSINE:
42
+ return "cosinesimil"
43
+ return "l2"
44
+
45
+ def index_param(self) -> dict:
46
+ params = {
47
+ "name": "hnsw",
48
+ "space_type": self.parse_metric(),
49
+ "engine": self.engine.value,
50
+ "parameters": {
51
+ "ef_construction": self.efConstruction,
52
+ "m": self.M
53
+ }
54
+ }
55
+ return params
56
+
57
+ def search_param(self) -> dict:
58
+ return {}