vectordb-bench 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. vectordb_bench/__init__.py +1 -0
  2. vectordb_bench/backend/assembler.py +1 -1
  3. vectordb_bench/backend/cases.py +64 -18
  4. vectordb_bench/backend/clients/__init__.py +35 -0
  5. vectordb_bench/backend/clients/api.py +21 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
  9. vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
  10. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  11. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  12. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  13. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  14. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  15. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  16. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  17. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
  18. vectordb_bench/backend/dataset.py +27 -5
  19. vectordb_bench/cli/vectordbbench.py +7 -0
  20. vectordb_bench/custom/custom_case.json +18 -0
  21. vectordb_bench/frontend/components/check_results/charts.py +6 -6
  22. vectordb_bench/frontend/components/check_results/data.py +18 -11
  23. vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
  24. vectordb_bench/frontend/components/check_results/filters.py +20 -13
  25. vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
  26. vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
  27. vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
  28. vectordb_bench/frontend/components/concurrent/charts.py +26 -29
  29. vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
  30. vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
  31. vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
  32. vectordb_bench/frontend/components/custom/initStyle.py +15 -0
  33. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  34. vectordb_bench/frontend/components/run_test/caseSelector.py +50 -28
  35. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -19
  36. vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
  37. vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
  38. vectordb_bench/frontend/components/run_test/initStyle.py +16 -0
  39. vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
  40. vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +311 -40
  41. vectordb_bench/frontend/{const → config}/styles.py +2 -0
  42. vectordb_bench/frontend/pages/concurrent.py +11 -18
  43. vectordb_bench/frontend/pages/custom.py +64 -0
  44. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
  45. vectordb_bench/frontend/pages/run_test.py +4 -0
  46. vectordb_bench/frontend/pages/tables.py +2 -2
  47. vectordb_bench/frontend/utils.py +17 -1
  48. vectordb_bench/frontend/vdb_benchmark.py +3 -3
  49. vectordb_bench/models.py +26 -10
  50. vectordb_bench/results/getLeaderboardData.py +1 -1
  51. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +46 -15
  52. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +57 -40
  53. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
  54. /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
  55. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
  56. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
  57. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ class config:
35
35
 
36
36
 
37
37
  K_DEFAULT = 100 # default return top k nearest neighbors during search
38
+ CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
38
39
 
39
40
  CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
40
41
  LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
@@ -14,7 +14,7 @@ class Assembler:
14
14
  def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
15
15
  c_cls = task.case_config.case_id.case_cls
16
16
 
17
- c = c_cls()
17
+ c = c_cls(task.case_config.custom_case)
18
18
  if type(task.db_case_config) != EmptyDBCaseConfig:
19
19
  task.db_case_config.metric_type = c.dataset.data.metric_type
20
20
 
@@ -4,9 +4,13 @@ from enum import Enum, auto
4
4
  from typing import Type
5
5
 
6
6
  from vectordb_bench import config
7
+ from vectordb_bench.backend.clients.api import MetricType
7
8
  from vectordb_bench.base import BaseModel
9
+ from vectordb_bench.frontend.components.custom.getCustomConfig import (
10
+ CustomDatasetConfig,
11
+ )
8
12
 
9
- from .dataset import Dataset, DatasetManager
13
+ from .dataset import CustomDataset, Dataset, DatasetManager
10
14
 
11
15
 
12
16
  log = logging.getLogger(__name__)
@@ -44,25 +48,24 @@ class CaseType(Enum):
44
48
  Performance1536D50K = 50
45
49
 
46
50
  Custom = 100
51
+ PerformanceCustomDataset = 101
47
52
 
48
- @property
49
53
  def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
50
- if self not in type2case:
51
- raise NotImplementedError(f"Case {self} has not implemented. You can add it manually to vectordb_bench.backend.cases.type2case or define a custom_configs['custom_cls']")
52
- return type2case[self]
54
+ if custom_configs is None:
55
+ return type2case.get(self)()
56
+ else:
57
+ return type2case.get(self)(**custom_configs)
53
58
 
54
- @property
55
- def case_name(self) -> str:
56
- c = self.case_cls
59
+ def case_name(self, custom_configs: dict | None = None) -> str:
60
+ c = self.case_cls(custom_configs)
57
61
  if c is not None:
58
- return c().name
62
+ return c.name
59
63
  raise ValueError("Case unsupported")
60
64
 
61
- @property
62
- def case_description(self) -> str:
63
- c = self.case_cls
65
+ def case_description(self, custom_configs: dict | None = None) -> str:
66
+ c = self.case_cls(custom_configs)
64
67
  if c is not None:
65
- return c().description
68
+ return c.description
66
69
  raise ValueError("Case unsupported")
67
70
 
68
71
 
@@ -289,26 +292,69 @@ Results will show index building time, recall, and maximum QPS."""
289
292
  optimize_timeout: float | int | None = 15 * 60
290
293
 
291
294
 
295
+ def metric_type_map(s: str) -> MetricType:
296
+ if s.lower() == "cosine":
297
+ return MetricType.COSINE
298
+ if s.lower() == "l2" or s.lower() == "euclidean":
299
+ return MetricType.L2
300
+ if s.lower() == "ip":
301
+ return MetricType.IP
302
+ err_msg = f"Not support metric_type: {s}"
303
+ log.error(err_msg)
304
+ raise RuntimeError(err_msg)
305
+
306
+
307
+ class PerformanceCustomDataset(PerformanceCase):
308
+ case_id: CaseType = CaseType.PerformanceCustomDataset
309
+ name: str = "Performance With Custom Dataset"
310
+ description: str = ""
311
+ dataset: DatasetManager
312
+
313
+ def __init__(
314
+ self,
315
+ name,
316
+ description,
317
+ load_timeout,
318
+ optimize_timeout,
319
+ dataset_config,
320
+ **kwargs,
321
+ ):
322
+ dataset_config = CustomDatasetConfig(**dataset_config)
323
+ dataset = CustomDataset(
324
+ name=dataset_config.name,
325
+ size=dataset_config.size,
326
+ dim=dataset_config.dim,
327
+ metric_type=metric_type_map(dataset_config.metric_type),
328
+ use_shuffled=dataset_config.use_shuffled,
329
+ with_gt=dataset_config.with_gt,
330
+ dir=dataset_config.dir,
331
+ file_num=dataset_config.file_count,
332
+ )
333
+ super().__init__(
334
+ name=name,
335
+ description=description,
336
+ load_timeout=load_timeout,
337
+ optimize_timeout=optimize_timeout,
338
+ dataset=DatasetManager(data=dataset),
339
+ )
340
+
341
+
292
342
  type2case = {
293
343
  CaseType.CapacityDim960: CapacityDim960,
294
344
  CaseType.CapacityDim128: CapacityDim128,
295
-
296
345
  CaseType.Performance768D100M: Performance768D100M,
297
346
  CaseType.Performance768D10M: Performance768D10M,
298
347
  CaseType.Performance768D1M: Performance768D1M,
299
-
300
348
  CaseType.Performance768D10M1P: Performance768D10M1P,
301
349
  CaseType.Performance768D1M1P: Performance768D1M1P,
302
350
  CaseType.Performance768D10M99P: Performance768D10M99P,
303
351
  CaseType.Performance768D1M99P: Performance768D1M99P,
304
-
305
352
  CaseType.Performance1536D500K: Performance1536D500K,
306
353
  CaseType.Performance1536D5M: Performance1536D5M,
307
-
308
354
  CaseType.Performance1536D500K1P: Performance1536D500K1P,
309
355
  CaseType.Performance1536D5M1P: Performance1536D5M1P,
310
-
311
356
  CaseType.Performance1536D500K99P: Performance1536D500K99P,
312
357
  CaseType.Performance1536D5M99P: Performance1536D5M99P,
313
358
  CaseType.Performance1536D50K: Performance1536D50K,
359
+ CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
314
360
  }
@@ -30,8 +30,11 @@ class DB(Enum):
30
30
  WeaviateCloud = "WeaviateCloud"
31
31
  PgVector = "PgVector"
32
32
  PgVectoRS = "PgVectoRS"
33
+ PgVectorScale = "PgVectorScale"
33
34
  Redis = "Redis"
35
+ MemoryDB = "MemoryDB"
34
36
  Chroma = "Chroma"
37
+ AWSOpenSearch = "OpenSearch"
35
38
  Test = "test"
36
39
 
37
40
 
@@ -69,15 +72,27 @@ class DB(Enum):
69
72
  if self == DB.PgVectoRS:
70
73
  from .pgvecto_rs.pgvecto_rs import PgVectoRS
71
74
  return PgVectoRS
75
+
76
+ if self == DB.PgVectorScale:
77
+ from .pgvectorscale.pgvectorscale import PgVectorScale
78
+ return PgVectorScale
72
79
 
73
80
  if self == DB.Redis:
74
81
  from .redis.redis import Redis
75
82
  return Redis
83
+
84
+ if self == DB.MemoryDB:
85
+ from .memorydb.memorydb import MemoryDB
86
+ return MemoryDB
76
87
 
77
88
  if self == DB.Chroma:
78
89
  from .chroma.chroma import ChromaClient
79
90
  return ChromaClient
80
91
 
92
+ if self == DB.AWSOpenSearch:
93
+ from .aws_opensearch.aws_opensearch import AWSOpenSearch
94
+ return AWSOpenSearch
95
+
81
96
  @property
82
97
  def config_cls(self) -> Type[DBConfig]:
83
98
  """Import while in use"""
@@ -113,14 +128,26 @@ class DB(Enum):
113
128
  from .pgvecto_rs.config import PgVectoRSConfig
114
129
  return PgVectoRSConfig
115
130
 
131
+ if self == DB.PgVectorScale:
132
+ from .pgvectorscale.config import PgVectorScaleConfig
133
+ return PgVectorScaleConfig
134
+
116
135
  if self == DB.Redis:
117
136
  from .redis.config import RedisConfig
118
137
  return RedisConfig
138
+
139
+ if self == DB.MemoryDB:
140
+ from .memorydb.config import MemoryDBConfig
141
+ return MemoryDBConfig
119
142
 
120
143
  if self == DB.Chroma:
121
144
  from .chroma.config import ChromaConfig
122
145
  return ChromaConfig
123
146
 
147
+ if self == DB.AWSOpenSearch:
148
+ from .aws_opensearch.config import AWSOpenSearchConfig
149
+ return AWSOpenSearchConfig
150
+
124
151
  def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
125
152
  if self == DB.Milvus:
126
153
  from .milvus.config import _milvus_case_config
@@ -150,6 +177,14 @@ class DB(Enum):
150
177
  from .pgvecto_rs.config import _pgvecto_rs_case_config
151
178
  return _pgvecto_rs_case_config.get(index_type)
152
179
 
180
+ if self == DB.AWSOpenSearch:
181
+ from .aws_opensearch.config import AWSOpenSearchIndexConfig
182
+ return AWSOpenSearchIndexConfig
183
+
184
+ if self == DB.PgVectorScale:
185
+ from .pgvectorscale.config import _pgvectorscale_case_config
186
+ return _pgvectorscale_case_config.get(index_type)
187
+
153
188
  # DB.Pinecone, DB.Chroma, DB.Redis
154
189
  return EmptyDBCaseConfig
155
190
 
@@ -15,6 +15,7 @@ class MetricType(str, Enum):
15
15
  class IndexType(str, Enum):
16
16
  HNSW = "HNSW"
17
17
  DISKANN = "DISKANN"
18
+ STREAMING_DISKANN = "DISKANN"
18
19
  IVFFlat = "IVF_FLAT"
19
20
  IVFSQ8 = "IVF_SQ8"
20
21
  Flat = "FLAT"
@@ -38,6 +39,22 @@ class DBConfig(ABC, BaseModel):
38
39
  """
39
40
 
40
41
  db_label: str = ""
42
+ version: str = ""
43
+ note: str = ""
44
+
45
+ @staticmethod
46
+ def common_short_configs() -> list[str]:
47
+ """
48
+ short input, such as `db_label`, `version`
49
+ """
50
+ return ["version", "db_label"]
51
+
52
+ @staticmethod
53
+ def common_long_configs() -> list[str]:
54
+ """
55
+ long input, such as `note`
56
+ """
57
+ return ["note"]
41
58
 
42
59
  @abstractmethod
43
60
  def to_dict(self) -> dict:
@@ -45,7 +62,10 @@ class DBConfig(ABC, BaseModel):
45
62
 
46
63
  @validator("*")
47
64
  def not_empty_field(cls, v, field):
48
- if field.name == "db_label":
65
+ if (
66
+ field.name in cls.common_short_configs()
67
+ or field.name in cls.common_long_configs()
68
+ ):
49
69
  return v
50
70
  if not v and isinstance(v, (str, SecretStr)):
51
71
  raise ValueError("Empty string!")
@@ -0,0 +1,159 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ import time
4
+ from typing import Iterable, Type
5
+ from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
6
+ from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
7
+ from opensearchpy import OpenSearch
8
+ from opensearchpy.helpers import bulk
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class AWSOpenSearch(VectorDB):
14
+ def __init__(
15
+ self,
16
+ dim: int,
17
+ db_config: dict,
18
+ db_case_config: AWSOpenSearchIndexConfig,
19
+ index_name: str = "vdb_bench_index", # must be lowercase
20
+ id_col_name: str = "id",
21
+ vector_col_name: str = "embedding",
22
+ drop_old: bool = False,
23
+ **kwargs,
24
+ ):
25
+ self.dim = dim
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.index_name = index_name
29
+ self.id_col_name = id_col_name
30
+ self.category_col_names = [
31
+ f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
32
+ ]
33
+ self.vector_col_name = vector_col_name
34
+
35
+ log.info(f"AWS_OpenSearch client config: {self.db_config}")
36
+ client = OpenSearch(**self.db_config)
37
+ if drop_old:
38
+ log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
39
+ is_existed = client.indices.exists(index=self.index_name)
40
+ if is_existed:
41
+ client.indices.delete(index=self.index_name)
42
+ self._create_index(client)
43
+
44
+ @classmethod
45
+ def config_cls(cls) -> AWSOpenSearchConfig:
46
+ return AWSOpenSearchConfig
47
+
48
+ @classmethod
49
+ def case_config_cls(
50
+ cls, index_type: IndexType | None = None
51
+ ) -> AWSOpenSearchIndexConfig:
52
+ return AWSOpenSearchIndexConfig
53
+
54
+ def _create_index(self, client: OpenSearch):
55
+ settings = {
56
+ "index": {
57
+ "knn": True,
58
+ # "number_of_shards": 5,
59
+ # "refresh_interval": "600s",
60
+ }
61
+ }
62
+ mappings = {
63
+ "properties": {
64
+ self.id_col_name: {"type": "integer"},
65
+ **{
66
+ categoryCol: {"type": "keyword"}
67
+ for categoryCol in self.category_col_names
68
+ },
69
+ self.vector_col_name: {
70
+ "type": "knn_vector",
71
+ "dimension": self.dim,
72
+ "method": self.case_config.index_param(),
73
+ },
74
+ }
75
+ }
76
+ try:
77
+ client.indices.create(
78
+ index=self.index_name, body=dict(settings=settings, mappings=mappings)
79
+ )
80
+ except Exception as e:
81
+ log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
82
+ raise e from None
83
+
84
+ @contextmanager
85
+ def init(self) -> None:
86
+ """connect to elasticsearch"""
87
+ self.client = OpenSearch(**self.db_config)
88
+
89
+ yield
90
+ # self.client.transport.close()
91
+ self.client = None
92
+ del self.client
93
+
94
+ def insert_embeddings(
95
+ self,
96
+ embeddings: Iterable[list[float]],
97
+ metadata: list[int],
98
+ **kwargs,
99
+ ) -> tuple[int, Exception]:
100
+ """Insert the embeddings to the elasticsearch."""
101
+ assert self.client is not None, "should self.init() first"
102
+
103
+ insert_data = []
104
+ for i in range(len(embeddings)):
105
+ insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
106
+ insert_data.append({self.vector_col_name: embeddings[i]})
107
+ try:
108
+ resp = self.client.bulk(insert_data)
109
+ log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
110
+ resp = self.client.indices.stats(self.index_name)
111
+ log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
112
+ return (len(embeddings), None)
113
+ except Exception as e:
114
+ log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
115
+ time.sleep(10)
116
+ return self.insert_embeddings(embeddings, metadata)
117
+
118
+ def search_embedding(
119
+ self,
120
+ query: list[float],
121
+ k: int = 100,
122
+ filters: dict | None = None,
123
+ ) -> list[int]:
124
+ """Get k most similar embeddings to query vector.
125
+
126
+ Args:
127
+ query(list[float]): query embedding to look up documents similar to.
128
+ k(int): Number of most similar embeddings to return. Defaults to 100.
129
+ filters(dict, optional): filtering expression to filter the data while searching.
130
+
131
+ Returns:
132
+ list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
133
+ """
134
+ assert self.client is not None, "should self.init() first"
135
+
136
+ body = {
137
+ "size": k,
138
+ "query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
139
+ }
140
+ try:
141
+ resp = self.client.search(index=self.index_name, body=body)
142
+ log.info(f'Search took: {resp["took"]}')
143
+ log.info(f'Search shards: {resp["_shards"]}')
144
+ log.info(f'Search hits total: {resp["hits"]["total"]}')
145
+ result = [int(d["_id"]) for d in resp["hits"]["hits"]]
146
+ # log.info(f'success! length={len(res)}')
147
+
148
+ return result
149
+ except Exception as e:
150
+ log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
151
+ raise e from None
152
+
153
+ def optimize(self):
154
+ """optimize will be called between insertion and search in performance cases."""
155
+ pass
156
+
157
+ def ready_to_load(self):
158
+ """ready_to_load will be called before load in load cases."""
159
+ pass
@@ -0,0 +1,44 @@
1
+ from typing import Annotated, TypedDict, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from ....cli.cli import (
7
+ CommonTypedDict,
8
+ HNSWFlavor2,
9
+ cli,
10
+ click_parameter_decorators_from_typed_dict,
11
+ run,
12
+ )
13
+ from .. import DB
14
+
15
+
16
+ class AWSOpenSearchTypedDict(TypedDict):
17
+ host: Annotated[
18
+ str, click.option("--host", type=str, help="Db host", required=True)
19
+ ]
20
+ port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
21
+ user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
22
+ password: Annotated[str, click.option("--password", type=str, help="Db password")]
23
+
24
+
25
+ class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
26
+ ...
27
+
28
+
29
+ @cli.command()
30
+ @click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
31
+ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
32
+ from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
33
+ run(
34
+ db=DB.AWSOpenSearch,
35
+ db_config=AWSOpenSearchConfig(
36
+ host=parameters["host"],
37
+ port=parameters["port"],
38
+ user=parameters["user"],
39
+ password=SecretStr(parameters["password"]),
40
+ ),
41
+ db_case_config=AWSOpenSearchIndexConfig(
42
+ ),
43
+ **parameters,
44
+ )
@@ -0,0 +1,58 @@
1
+ from enum import Enum
2
+ from pydantic import SecretStr, BaseModel
3
+
4
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
5
+
6
+
7
+ class AWSOpenSearchConfig(DBConfig, BaseModel):
8
+ host: str = ""
9
+ port: int = 443
10
+ user: str = ""
11
+ password: SecretStr = ""
12
+
13
+ def to_dict(self) -> dict:
14
+ return {
15
+ "hosts": [{'host': self.host, 'port': self.port}],
16
+ "http_auth": (self.user, self.password.get_secret_value()),
17
+ "use_ssl": True,
18
+ "http_compress": True,
19
+ "verify_certs": True,
20
+ "ssl_assert_hostname": False,
21
+ "ssl_show_warn": False,
22
+ "timeout": 600,
23
+ }
24
+
25
+
26
+ class AWSOS_Engine(Enum):
27
+ nmslib = "nmslib"
28
+ faiss = "faiss"
29
+ lucene = "Lucene"
30
+
31
+
32
+ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
33
+ metric_type: MetricType = MetricType.L2
34
+ engine: AWSOS_Engine = AWSOS_Engine.nmslib
35
+ efConstruction: int = 360
36
+ M: int = 30
37
+
38
+ def parse_metric(self) -> str:
39
+ if self.metric_type == MetricType.IP:
40
+ return "innerproduct" # only support faiss / nmslib, not for Lucene.
41
+ elif self.metric_type == MetricType.COSINE:
42
+ return "cosinesimil"
43
+ return "l2"
44
+
45
+ def index_param(self) -> dict:
46
+ params = {
47
+ "name": "hnsw",
48
+ "space_type": self.parse_metric(),
49
+ "engine": self.engine.value,
50
+ "parameters": {
51
+ "ef_construction": self.efConstruction,
52
+ "m": self.M
53
+ }
54
+ }
55
+ return params
56
+
57
+ def search_param(self) -> dict:
58
+ return {}
@@ -0,0 +1,125 @@
1
+ import time, random
2
+ from opensearchpy import OpenSearch
3
+ from opensearch_dsl import Search, Document, Text, Keyword
4
+
5
+ _HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
6
+ _PORT = 443
7
+ _AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.
8
+
9
+ _INDEX_NAME = 'my-dsl-index'
10
+ _BATCH = 100
11
+ _ROWS = 100
12
+ _DIM = 128
13
+ _TOPK = 10
14
+
15
+
16
+ def create_client():
17
+ client = OpenSearch(
18
+ hosts=[{'host': _HOST, 'port': _PORT}],
19
+ http_compress=True, # enables gzip compression for request bodies
20
+ http_auth=_AUTH,
21
+ use_ssl=True,
22
+ verify_certs=True,
23
+ ssl_assert_hostname=False,
24
+ ssl_show_warn=False,
25
+ )
26
+ return client
27
+
28
+
29
+ def create_index(client, index_name):
30
+ settings = {
31
+ "index": {
32
+ "knn": True,
33
+ "number_of_shards": 1,
34
+ "refresh_interval": "5s",
35
+ }
36
+ }
37
+ mappings = {
38
+ "properties": {
39
+ "embedding": {
40
+ "type": "knn_vector",
41
+ "dimension": _DIM,
42
+ "method": {
43
+ "engine": "nmslib",
44
+ "name": "hnsw",
45
+ "space_type": "l2",
46
+ "parameters": {
47
+ "ef_construction": 128,
48
+ "m": 24,
49
+ }
50
+ }
51
+ }
52
+ }
53
+ }
54
+
55
+ response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
56
+ print('\nCreating index:')
57
+ print(response)
58
+
59
+
60
+ def delete_index(client, index_name):
61
+ response = client.indices.delete(index=index_name)
62
+ print('\nDeleting index:')
63
+ print(response)
64
+
65
+
66
+ def bulk_insert(client, index_name):
67
+ # Perform bulk operations
68
+ ids = [i for i in range(_ROWS)]
69
+ vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
70
+
71
+ docs = []
72
+ for i in range(0, _ROWS, _BATCH):
73
+ docs.clear()
74
+ for j in range(0, _BATCH):
75
+ docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
76
+ docs.append({"embedding": vec[i+j]})
77
+ response = client.bulk(docs)
78
+ print('\nAdding documents:', len(response['items']), response['errors'])
79
+ response = client.indices.stats(index_name)
80
+ print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])
81
+
82
+
83
+ def search(client, index_name):
84
+ # Search for the document.
85
+ search_body = {
86
+ "size": _TOPK,
87
+ "query": {
88
+ "knn": {
89
+ "embedding": {
90
+ "vector": [random.random() for _ in range(_DIM)],
91
+ "k": _TOPK,
92
+ }
93
+ }
94
+ }
95
+ }
96
+ while True:
97
+ response = client.search(index=index_name, body=search_body)
98
+ print(f'\nSearch took: {response["took"]}')
99
+ print(f'\nSearch shards: {response["_shards"]}')
100
+ print(f'\nSearch hits total: {response["hits"]["total"]}')
101
+ result = response["hits"]["hits"]
102
+ if len(result) != 0:
103
+ print('\nSearch results:')
104
+ for hit in response["hits"]["hits"]:
105
+ print(hit["_id"], hit["_score"])
106
+ break
107
+ else:
108
+ print('\nSearch not ready, sleep 1s')
109
+ time.sleep(1)
110
+
111
+
112
+ def main():
113
+ client = create_client()
114
+ try:
115
+ create_index(client, _INDEX_NAME)
116
+ bulk_insert(client, _INDEX_NAME)
117
+ search(client, _INDEX_NAME)
118
+ delete_index(client, _INDEX_NAME)
119
+ except Exception as e:
120
+ print(e)
121
+ delete_index(client, _INDEX_NAME)
122
+
123
+
124
+ if __name__ == '__main__':
125
+ main()