vectordb-bench 1.0.4__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. vectordb_bench/__init__.py +1 -0
  2. vectordb_bench/backend/cases.py +45 -1
  3. vectordb_bench/backend/clients/__init__.py +47 -0
  4. vectordb_bench/backend/clients/api.py +2 -0
  5. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +104 -40
  6. vectordb_bench/backend/clients/aws_opensearch/cli.py +52 -15
  7. vectordb_bench/backend/clients/aws_opensearch/config.py +27 -7
  8. vectordb_bench/backend/clients/hologres/cli.py +50 -0
  9. vectordb_bench/backend/clients/hologres/config.py +121 -0
  10. vectordb_bench/backend/clients/hologres/hologres.py +365 -0
  11. vectordb_bench/backend/clients/lancedb/lancedb.py +1 -0
  12. vectordb_bench/backend/clients/milvus/cli.py +29 -9
  13. vectordb_bench/backend/clients/milvus/config.py +2 -0
  14. vectordb_bench/backend/clients/milvus/milvus.py +1 -1
  15. vectordb_bench/backend/clients/oceanbase/cli.py +1 -0
  16. vectordb_bench/backend/clients/oceanbase/config.py +3 -1
  17. vectordb_bench/backend/clients/oceanbase/oceanbase.py +20 -4
  18. vectordb_bench/backend/clients/oss_opensearch/cli.py +155 -0
  19. vectordb_bench/backend/clients/oss_opensearch/config.py +157 -0
  20. vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py +582 -0
  21. vectordb_bench/backend/clients/oss_opensearch/run.py +166 -0
  22. vectordb_bench/backend/clients/pgdiskann/cli.py +45 -0
  23. vectordb_bench/backend/clients/pgdiskann/config.py +16 -0
  24. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +94 -26
  25. vectordb_bench/backend/clients/s3_vectors/config.py +41 -0
  26. vectordb_bench/backend/clients/s3_vectors/s3_vectors.py +171 -0
  27. vectordb_bench/backend/clients/tidb/cli.py +0 -4
  28. vectordb_bench/backend/clients/tidb/config.py +22 -2
  29. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -1
  30. vectordb_bench/backend/clients/zilliz_cloud/config.py +4 -1
  31. vectordb_bench/backend/dataset.py +70 -0
  32. vectordb_bench/backend/filter.py +17 -0
  33. vectordb_bench/backend/runner/mp_runner.py +4 -0
  34. vectordb_bench/backend/runner/rate_runner.py +23 -11
  35. vectordb_bench/backend/runner/read_write_runner.py +10 -9
  36. vectordb_bench/backend/runner/serial_runner.py +23 -7
  37. vectordb_bench/backend/task_runner.py +5 -4
  38. vectordb_bench/cli/cli.py +36 -0
  39. vectordb_bench/cli/vectordbbench.py +4 -0
  40. vectordb_bench/fig/custom_case_run_test.png +0 -0
  41. vectordb_bench/fig/custom_dataset.png +0 -0
  42. vectordb_bench/fig/homepage/bar-chart.png +0 -0
  43. vectordb_bench/fig/homepage/concurrent.png +0 -0
  44. vectordb_bench/fig/homepage/custom.png +0 -0
  45. vectordb_bench/fig/homepage/label_filter.png +0 -0
  46. vectordb_bench/fig/homepage/qp$.png +0 -0
  47. vectordb_bench/fig/homepage/run_test.png +0 -0
  48. vectordb_bench/fig/homepage/streaming.png +0 -0
  49. vectordb_bench/fig/homepage/table.png +0 -0
  50. vectordb_bench/fig/run_test_select_case.png +0 -0
  51. vectordb_bench/fig/run_test_select_db.png +0 -0
  52. vectordb_bench/fig/run_test_submit.png +0 -0
  53. vectordb_bench/frontend/components/check_results/filters.py +1 -4
  54. vectordb_bench/frontend/components/check_results/nav.py +2 -1
  55. vectordb_bench/frontend/components/concurrent/charts.py +5 -0
  56. vectordb_bench/frontend/components/int_filter/charts.py +60 -0
  57. vectordb_bench/frontend/components/streaming/data.py +7 -0
  58. vectordb_bench/frontend/components/welcome/welcomePrams.py +42 -4
  59. vectordb_bench/frontend/config/dbCaseConfigs.py +142 -16
  60. vectordb_bench/frontend/config/styles.py +4 -0
  61. vectordb_bench/frontend/pages/concurrent.py +1 -1
  62. vectordb_bench/frontend/pages/custom.py +1 -1
  63. vectordb_bench/frontend/pages/int_filter.py +56 -0
  64. vectordb_bench/frontend/pages/streaming.py +16 -3
  65. vectordb_bench/interface.py +5 -1
  66. vectordb_bench/metric.py +7 -0
  67. vectordb_bench/models.py +39 -4
  68. vectordb_bench/results/S3Vectors/result_20250722_standard_s3vectors.json +2509 -0
  69. vectordb_bench/results/getLeaderboardDataV2.py +23 -2
  70. vectordb_bench/results/leaderboard_v2.json +200 -0
  71. vectordb_bench/results/leaderboard_v2_streaming.json +128 -0
  72. {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/METADATA +40 -8
  73. {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/RECORD +77 -51
  74. {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/WHEEL +0 -0
  75. {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/entry_points.txt +0 -0
  76. {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/licenses/LICENSE +0 -0
  77. {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
1
+ import logging
2
+ import random
3
+ import time
4
+
5
+ from opensearchpy import OpenSearch
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+ _HOST = "xxxxxx.us-west-2.es.amazonaws.com"
10
+ _PORT = 443
11
+ _AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
12
+
13
+ _INDEX_NAME = "my-dsl-index"
14
+ _BATCH = 100
15
+ _ROWS = 100
16
+ _DIM = 128
17
+ _TOPK = 10
18
+
19
+
20
+ def create_client():
21
+ return OpenSearch(
22
+ hosts=[{"host": _HOST, "port": _PORT}],
23
+ http_compress=True, # enables gzip compression for request bodies
24
+ http_auth=_AUTH,
25
+ use_ssl=True,
26
+ verify_certs=True,
27
+ ssl_assert_hostname=False,
28
+ ssl_show_warn=False,
29
+ )
30
+
31
+
32
+ def create_index(client: OpenSearch, index_name: str):
33
+ settings = {
34
+ "index": {
35
+ "knn": True,
36
+ "number_of_shards": 1,
37
+ "refresh_interval": "5s",
38
+ },
39
+ }
40
+ mappings = {
41
+ "properties": {
42
+ "embedding": {
43
+ "type": "knn_vector",
44
+ "dimension": _DIM,
45
+ "method": {
46
+ "engine": "faiss",
47
+ "name": "hnsw",
48
+ "space_type": "l2",
49
+ "parameters": {
50
+ "ef_construction": 256,
51
+ "m": 16,
52
+ },
53
+ },
54
+ },
55
+ },
56
+ }
57
+
58
+ response = client.indices.create(
59
+ index=index_name,
60
+ body={"settings": settings, "mappings": mappings},
61
+ )
62
+ log.info("\nCreating index:")
63
+ log.info(response)
64
+
65
+
66
+ def delete_index(client: OpenSearch, index_name: str):
67
+ response = client.indices.delete(index=index_name)
68
+ log.info("\nDeleting index:")
69
+ log.info(response)
70
+
71
+
72
+ def bulk_insert(client: OpenSearch, index_name: str):
73
+ # Perform bulk operations
74
+ ids = list(range(_ROWS))
75
+ vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
76
+
77
+ docs = []
78
+ for i in range(0, _ROWS, _BATCH):
79
+ docs.clear()
80
+ for j in range(_BATCH):
81
+ docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
82
+ docs.append({"embedding": vec[i + j]})
83
+ response = client.bulk(docs)
84
+ log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
85
+ response = client.indices.stats(index_name)
86
+ log.info(
87
+ f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
88
+ )
89
+
90
+
91
+ def search(client: OpenSearch, index_name: str):
92
+ # Search for the document.
93
+ search_body = {
94
+ "size": _TOPK,
95
+ "query": {
96
+ "knn": {
97
+ "embedding": {
98
+ "vector": [random.random() for _ in range(_DIM)],
99
+ "k": _TOPK,
100
+ },
101
+ },
102
+ },
103
+ }
104
+ while True:
105
+ response = client.search(index=index_name, body=search_body)
106
+ log.info(f'\nSearch took: {response["took"]}')
107
+ log.info(f'\nSearch shards: {response["_shards"]}')
108
+ log.info(f'\nSearch hits total: {response["hits"]["total"]}')
109
+ result = response["hits"]["hits"]
110
+ if len(result) != 0:
111
+ log.info("\nSearch results:")
112
+ for hit in response["hits"]["hits"]:
113
+ log.info(hit["_id"], hit["_score"])
114
+ break
115
+ log.info("\nSearch not ready, sleep 1s")
116
+ time.sleep(1)
117
+
118
+
119
+ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
120
+ WAITINT_FOR_REFRESH_SEC = 30
121
+
122
+
123
+ def optimize_index(client: OpenSearch, index_name: str):
124
+ log.info(f"Starting force merge for index {index_name}")
125
+ force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
126
+ force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
127
+ while True:
128
+ time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
129
+ task_status = client.tasks.get(task_id=force_merge_task_id)
130
+ if task_status["completed"]:
131
+ break
132
+ log.info(f"Completed force merge for index {index_name}")
133
+
134
+
135
+ def refresh_index(client: OpenSearch, index_name: str):
136
+ log.info(f"Starting refresh for index {index_name}")
137
+ while True:
138
+ try:
139
+ log.info("Starting the Refresh Index..")
140
+ client.indices.refresh(index=index_name)
141
+ break
142
+ except Exception as e:
143
+ log.info(
144
+ f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
145
+ )
146
+ time.sleep(WAITINT_FOR_REFRESH_SEC)
147
+ continue
148
+ log.info(f"Completed refresh for index {index_name}")
149
+
150
+
151
+ def main():
152
+ client = create_client()
153
+ try:
154
+ create_index(client, _INDEX_NAME)
155
+ bulk_insert(client, _INDEX_NAME)
156
+ optimize_index(client, _INDEX_NAME)
157
+ refresh_index(client, _INDEX_NAME)
158
+ search(client, _INDEX_NAME)
159
+ delete_index(client, _INDEX_NAME)
160
+ except Exception as e:
161
+ log.info(e)
162
+ delete_index(client, _INDEX_NAME)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
@@ -5,6 +5,7 @@ import click
5
5
  from pydantic import SecretStr
6
6
 
7
7
  from vectordb_bench.backend.clients import DB
8
+ from vectordb_bench.backend.clients.api import MetricType
8
9
 
9
10
  from ....cli.cli import (
10
11
  CommonTypedDict,
@@ -48,6 +49,15 @@ class PgDiskAnnTypedDict(CommonTypedDict):
48
49
  help="PgDiskAnn l_value_ib",
49
50
  ),
50
51
  ]
52
+ pq_param_num_chunks: Annotated[
53
+ int,
54
+ click.option(
55
+ "--pq-param-num-chunks",
56
+ type=int,
57
+ help="PgDiskAnn pq_param_num_chunks",
58
+ required=False,
59
+ ),
60
+ ]
51
61
  l_value_is: Annotated[
52
62
  float,
53
63
  click.option(
@@ -56,6 +66,37 @@ class PgDiskAnnTypedDict(CommonTypedDict):
56
66
  help="PgDiskAnn l_value_is",
57
67
  ),
58
68
  ]
69
+ reranking: Annotated[
70
+ bool | None,
71
+ click.option(
72
+ "--reranking/--skip-reranking",
73
+ type=bool,
74
+ help="Enable reranking for PQ search",
75
+ default=False,
76
+ ),
77
+ ]
78
+ reranking_metric: Annotated[
79
+ str | None,
80
+ click.option(
81
+ "--reranking-metric",
82
+ type=click.Choice(
83
+ [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD", "DP"]],
84
+ ),
85
+ help="Distance metric for reranking",
86
+ default="COSINE",
87
+ show_default=True,
88
+ required=False,
89
+ ),
90
+ ]
91
+ quantized_fetch_limit: Annotated[
92
+ int | None,
93
+ click.option(
94
+ "--quantized-fetch-limit",
95
+ type=int,
96
+ help="Limit of inner query in case of reranking",
97
+ required=False,
98
+ ),
99
+ ]
59
100
  maintenance_work_mem: Annotated[
60
101
  str | None,
61
102
  click.option(
@@ -98,7 +139,11 @@ def PgDiskAnn(
98
139
  db_case_config=PgDiskANNImplConfig(
99
140
  max_neighbors=parameters["max_neighbors"],
100
141
  l_value_ib=parameters["l_value_ib"],
142
+ pq_param_num_chunks=parameters["pq_param_num_chunks"],
101
143
  l_value_is=parameters["l_value_is"],
144
+ reranking=parameters["reranking"],
145
+ reranking_metric=parameters["reranking_metric"],
146
+ quantized_fetch_limit=parameters["quantized_fetch_limit"],
102
147
  max_parallel_workers=parameters["max_parallel_workers"],
103
148
  maintenance_work_mem=parameters["maintenance_work_mem"],
104
149
  ),
@@ -60,6 +60,13 @@ class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
60
60
  return "<#>"
61
61
  return "<=>"
62
62
 
63
+ def parse_reranking_metric_fun_op(self) -> LiteralString:
64
+ if self.reranking_metric == MetricType.L2:
65
+ return "<->"
66
+ if self.reranking_metric == MetricType.IP:
67
+ return "<#>"
68
+ return "<=>"
69
+
63
70
  def parse_metric_fun_str(self) -> str:
64
71
  if self.metric_type == MetricType.L2:
65
72
  return "l2_distance"
@@ -115,7 +122,11 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
115
122
  index: IndexType = IndexType.DISKANN
116
123
  max_neighbors: int | None
117
124
  l_value_ib: int | None
125
+ pq_param_num_chunks: int | None
118
126
  l_value_is: float | None
127
+ reranking: bool | None = None
128
+ reranking_metric: str | None = None
129
+ quantized_fetch_limit: int | None = None
119
130
  maintenance_work_mem: str | None = None
120
131
  max_parallel_workers: int | None = None
121
132
 
@@ -126,6 +137,8 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
126
137
  "options": {
127
138
  "max_neighbors": self.max_neighbors,
128
139
  "l_value_ib": self.l_value_ib,
140
+ "pq_param_num_chunks": self.pq_param_num_chunks,
141
+ "product_quantized": str(self.reranking),
129
142
  },
130
143
  "maintenance_work_mem": self.maintenance_work_mem,
131
144
  "max_parallel_workers": self.max_parallel_workers,
@@ -135,6 +148,9 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
135
148
  return {
136
149
  "metric": self.parse_metric(),
137
150
  "metric_fun_op": self.parse_metric_fun_op(),
151
+ "reranking": self.reranking,
152
+ "reranking_metric_fun_op": self.parse_reranking_metric_fun_op(),
153
+ "quantized_fetch_limit": self.quantized_fetch_limit,
138
154
  }
139
155
 
140
156
  def session_param(self) -> dict:
@@ -90,38 +90,83 @@ class PgDiskANN(VectorDB):
90
90
  def init(self) -> Generator[None, None, None]:
91
91
  self.conn, self.cursor = self._create_connection(**self.db_config)
92
92
 
93
- # index configuration may have commands defined that we should set during each client session
94
93
  session_options: dict[str, Any] = self.case_config.session_param()
95
94
 
96
95
  if len(session_options) > 0:
97
96
  for setting_name, setting_val in session_options.items():
98
- command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
99
- setting_name=sql.Identifier(setting_name),
100
- setting_val=sql.Identifier(str(setting_val)),
97
+ command = sql.SQL("SET {setting_name} = {setting_val};").format(
98
+ setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val)
101
99
  )
102
100
  log.debug(command.as_string(self.cursor))
103
101
  self.cursor.execute(command)
104
102
  self.conn.commit()
105
103
 
106
- self._filtered_search = sql.Composed(
107
- [
108
- sql.SQL(
109
- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
110
- ).format(table_name=sql.Identifier(self.table_name)),
111
- sql.SQL(self.case_config.search_param()["metric_fun_op"]),
112
- sql.SQL(" %s::vector LIMIT %s::int"),
113
- ],
114
- )
104
+ search_params = self.case_config.search_param()
105
+
106
+ if search_params.get("reranking"):
107
+ # Reranking-enabled queries
108
+ self._filtered_search = sql.SQL(
109
+ """
110
+ SELECT i.id
111
+ FROM (
112
+ SELECT id, embedding
113
+ FROM public.{table_name}
114
+ WHERE id >= %s
115
+ ORDER BY embedding {metric_fun_op} %s::vector
116
+ LIMIT {quantized_fetch_limit}::int
117
+ ) i
118
+ ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
119
+ LIMIT %s::int
120
+ """
121
+ ).format(
122
+ table_name=sql.Identifier(self.table_name),
123
+ metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
124
+ reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
125
+ quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
126
+ )
115
127
 
116
- self._unfiltered_search = sql.Composed(
117
- [
118
- sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
119
- sql.Identifier(self.table_name),
120
- ),
121
- sql.SQL(self.case_config.search_param()["metric_fun_op"]),
122
- sql.SQL(" %s::vector LIMIT %s::int"),
123
- ],
124
- )
128
+ self._unfiltered_search = sql.SQL(
129
+ """
130
+ SELECT i.id
131
+ FROM (
132
+ SELECT id, embedding
133
+ FROM public.{table_name}
134
+ ORDER BY embedding {metric_fun_op} %s::vector
135
+ LIMIT {quantized_fetch_limit}::int
136
+ ) i
137
+ ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
138
+ LIMIT %s::int
139
+ """
140
+ ).format(
141
+ table_name=sql.Identifier(self.table_name),
142
+ metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
143
+ reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
144
+ quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
145
+ )
146
+
147
+ else:
148
+ self._filtered_search = sql.Composed(
149
+ [
150
+ sql.SQL(
151
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
152
+ ).format(table_name=sql.Identifier(self.table_name)),
153
+ sql.SQL(search_params["metric_fun_op"]),
154
+ sql.SQL(" %s::vector LIMIT %s::int"),
155
+ ]
156
+ )
157
+
158
+ self._unfiltered_search = sql.Composed(
159
+ [
160
+ sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
161
+ table_name=sql.Identifier(self.table_name)
162
+ ),
163
+ sql.SQL(search_params["metric_fun_op"]),
164
+ sql.SQL(" %s::vector LIMIT %s::int"),
165
+ ]
166
+ )
167
+
168
+ log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}")
169
+ log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}")
125
170
 
126
171
  try:
127
172
  yield
@@ -234,7 +279,7 @@ class PgDiskANN(VectorDB):
234
279
  options.append(
235
280
  sql.SQL("{option_name} = {val}").format(
236
281
  option_name=sql.Identifier(option_name),
237
- val=sql.Identifier(str(option_val)),
282
+ val=sql.Literal(option_val),
238
283
  ),
239
284
  )
240
285
 
@@ -314,16 +359,39 @@ class PgDiskANN(VectorDB):
314
359
  assert self.conn is not None, "Connection is not initialized"
315
360
  assert self.cursor is not None, "Cursor is not initialized"
316
361
 
362
+ search_params = self.case_config.search_param()
363
+ is_reranking = search_params.get("reranking", False)
364
+
317
365
  q = np.asarray(query)
318
366
  if filters:
319
367
  gt = filters.get("id")
368
+ if is_reranking:
369
+ result = self.cursor.execute(
370
+ self._filtered_search,
371
+ (gt, q, q, k),
372
+ prepare=True,
373
+ binary=True,
374
+ )
375
+ else:
376
+ result = self.cursor.execute(
377
+ self._filtered_search,
378
+ (gt, q, k),
379
+ prepare=True,
380
+ binary=True,
381
+ )
382
+ elif is_reranking:
320
383
  result = self.cursor.execute(
321
- self._filtered_search,
322
- (gt, q, k),
384
+ self._unfiltered_search,
385
+ (q, q, k),
323
386
  prepare=True,
324
387
  binary=True,
325
388
  )
326
389
  else:
327
- result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
390
+ result = self.cursor.execute(
391
+ self._unfiltered_search,
392
+ (q, k),
393
+ prepare=True,
394
+ binary=True,
395
+ )
328
396
 
329
397
  return [int(i[0]) for i in result.fetchall()]
@@ -0,0 +1,41 @@
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, MetricType
4
+
5
+
6
+ class S3VectorsConfig(DBConfig):
7
+ region_name: str = "us-west-2"
8
+ access_key_id: SecretStr
9
+ secret_access_key: SecretStr
10
+ bucket_name: str
11
+ index_name: str = "vdbbench-index"
12
+
13
+ def to_dict(self) -> dict:
14
+ return {
15
+ "region_name": self.region_name,
16
+ "access_key_id": self.access_key_id.get_secret_value() if self.access_key_id else "",
17
+ "secret_access_key": self.secret_access_key.get_secret_value() if self.secret_access_key else "",
18
+ "bucket_name": self.bucket_name,
19
+ "index_name": self.index_name,
20
+ }
21
+
22
+
23
+ class S3VectorsIndexConfig(DBCaseConfig, BaseModel):
24
+ """Base config for s3-vectors"""
25
+
26
+ metric_type: MetricType | None = None
27
+ data_type: str = "float32"
28
+
29
+ def parse_metric(self) -> str:
30
+ if self.metric_type == MetricType.COSINE:
31
+ return "cosine"
32
+ if self.metric_type == MetricType.L2:
33
+ return "euclidean"
34
+ msg = f"Unsupported metric type: {self.metric_type}"
35
+ raise ValueError(msg)
36
+
37
+ def index_param(self) -> dict:
38
+ return {}
39
+
40
+ def search_param(self) -> dict:
41
+ return {}
@@ -0,0 +1,171 @@
1
+ """Wrapper around the Milvus vector database over VectorDB"""
2
+
3
+ import logging
4
+ from collections.abc import Iterable
5
+ from contextlib import contextmanager
6
+
7
+ import boto3
8
+
9
+ from vectordb_bench.backend.filter import Filter, FilterOp
10
+
11
+ from ..api import VectorDB
12
+ from .config import S3VectorsIndexConfig
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class S3Vectors(VectorDB):
18
+ supported_filter_types: list[FilterOp] = [
19
+ FilterOp.NonFilter,
20
+ FilterOp.NumGE,
21
+ FilterOp.StrEqual,
22
+ ]
23
+
24
+ def __init__(
25
+ self,
26
+ dim: int,
27
+ db_config: dict,
28
+ db_case_config: S3VectorsIndexConfig,
29
+ drop_old: bool = False,
30
+ with_scalar_labels: bool = False,
31
+ **kwargs,
32
+ ):
33
+ """Initialize wrapper around the s3-vectors client."""
34
+ self.db_config = db_config
35
+ self.case_config = db_case_config
36
+ self.with_scalar_labels = with_scalar_labels
37
+
38
+ self.batch_size = 500
39
+
40
+ self._scalar_id_field = "id"
41
+ self._scalar_label_field = "label"
42
+ self._vector_field = "vector"
43
+
44
+ self.region_name = self.db_config.get("region_name")
45
+ self.access_key_id = self.db_config.get("access_key_id")
46
+ self.secret_access_key = self.db_config.get("secret_access_key")
47
+ self.bucket_name = self.db_config.get("bucket_name")
48
+ self.index_name = self.db_config.get("index_name")
49
+
50
+ client = boto3.client(
51
+ service_name="s3vectors",
52
+ region_name=self.region_name,
53
+ aws_access_key_id=self.access_key_id,
54
+ aws_secret_access_key=self.secret_access_key,
55
+ )
56
+
57
+ if drop_old:
58
+ # delete old index if exists
59
+ response = client.list_indexes(vectorBucketName=self.bucket_name)
60
+ index_names = [index["indexName"] for index in response["indexes"]]
61
+ if self.index_name in index_names:
62
+ log.info(f"drop old index: {self.index_name}")
63
+ client.delete_index(vectorBucketName=self.bucket_name, indexName=self.index_name)
64
+
65
+ # create the index
66
+ client.create_index(
67
+ vectorBucketName=self.bucket_name,
68
+ indexName=self.index_name,
69
+ dataType=self.case_config.data_type,
70
+ dimension=dim,
71
+ distanceMetric=self.case_config.parse_metric(),
72
+ )
73
+
74
+ client.close()
75
+
76
+ @contextmanager
77
+ def init(self):
78
+ """
79
+ Examples:
80
+ >>> with self.init():
81
+ >>> self.insert_embeddings()
82
+ >>> self.search_embedding()
83
+ """
84
+ self.client = boto3.client(
85
+ service_name="s3vectors",
86
+ region_name=self.region_name,
87
+ aws_access_key_id=self.access_key_id,
88
+ aws_secret_access_key=self.secret_access_key,
89
+ )
90
+
91
+ yield
92
+ self.client.close()
93
+
94
+ def optimize(self, **kwargs):
95
+ return
96
+
97
+ def need_normalize_cosine(self) -> bool:
98
+ """Wheather this database need to normalize dataset to support COSINE"""
99
+ return False
100
+
101
+ def insert_embeddings(
102
+ self,
103
+ embeddings: Iterable[list[float]],
104
+ metadata: list[int],
105
+ labels_data: list[str] | None = None,
106
+ **kwargs,
107
+ ) -> tuple[int, Exception]:
108
+ """Insert embeddings into s3-vectors. should call self.init() first"""
109
+ # use the first insert_embeddings to init collection
110
+ assert self.client is not None
111
+ assert len(embeddings) == len(metadata)
112
+ insert_count = 0
113
+ try:
114
+ for batch_start_offset in range(0, len(embeddings), self.batch_size):
115
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
116
+ insert_data = [
117
+ {
118
+ "key": str(metadata[i]),
119
+ "data": {self.case_config.data_type: embeddings[i]},
120
+ "metadata": (
121
+ {self._scalar_label_field: labels_data[i], self._scalar_id_field: metadata[i]}
122
+ if self.with_scalar_labels
123
+ else {self._scalar_id_field: metadata[i]}
124
+ ),
125
+ }
126
+ for i in range(batch_start_offset, batch_end_offset)
127
+ ]
128
+ self.client.put_vectors(
129
+ vectorBucketName=self.bucket_name,
130
+ indexName=self.index_name,
131
+ vectors=insert_data,
132
+ )
133
+ insert_count += len(insert_data)
134
+ except Exception as e:
135
+ log.info(f"Failed to insert data: {e}")
136
+ return insert_count, e
137
+ return insert_count, None
138
+
139
+ def prepare_filter(self, filters: Filter):
140
+ if filters.type == FilterOp.NonFilter:
141
+ self.filter = None
142
+ elif filters.type == FilterOp.NumGE:
143
+ self.filter = {self._scalar_id_field: {"$gte": filters.int_value}}
144
+ elif filters.type == FilterOp.StrEqual:
145
+ self.filter = {self._scalar_label_field: filters.label_value}
146
+ else:
147
+ msg = f"Not support Filter for S3Vectors - {filters}"
148
+ raise ValueError(msg)
149
+
150
+ def search_embedding(
151
+ self,
152
+ query: list[float],
153
+ k: int = 100,
154
+ timeout: int | None = None,
155
+ ) -> list[int]:
156
+ """Perform a search on a query embedding and return results."""
157
+ assert self.client is not None
158
+
159
+ # Perform the search.
160
+ res = self.client.query_vectors(
161
+ vectorBucketName=self.bucket_name,
162
+ indexName=self.index_name,
163
+ queryVector={"float32": query},
164
+ topK=k,
165
+ filter=self.filter,
166
+ returnDistance=False,
167
+ returnMetadata=False,
168
+ )
169
+
170
+ # Organize results.
171
+ return [int(result["key"]) for result in res["vectors"]]
@@ -17,7 +17,6 @@ class TiDBTypedDict(CommonTypedDict):
17
17
  help="Username",
18
18
  default="root",
19
19
  show_default=True,
20
- required=True,
21
20
  ),
22
21
  ]
23
22
  password: Annotated[
@@ -37,7 +36,6 @@ class TiDBTypedDict(CommonTypedDict):
37
36
  type=str,
38
37
  default="127.0.0.1",
39
38
  show_default=True,
40
- required=True,
41
39
  help="Db host",
42
40
  ),
43
41
  ]
@@ -48,7 +46,6 @@ class TiDBTypedDict(CommonTypedDict):
48
46
  type=int,
49
47
  default=4000,
50
48
  show_default=True,
51
- required=True,
52
49
  help="Db Port",
53
50
  ),
54
51
  ]
@@ -59,7 +56,6 @@ class TiDBTypedDict(CommonTypedDict):
59
56
  type=str,
60
57
  default="test",
61
58
  show_default=True,
62
- required=True,
63
59
  help="Db name",
64
60
  ),
65
61
  ]