vectordb-bench 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. vectordb_bench/backend/clients/__init__.py +33 -1
  2. vectordb_bench/backend/clients/api.py +1 -1
  3. vectordb_bench/backend/clients/chroma/chroma.py +2 -2
  4. vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
  5. vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
  6. vectordb_bench/backend/clients/clickhouse/config.py +60 -0
  7. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
  8. vectordb_bench/backend/clients/mariadb/cli.py +60 -45
  9. vectordb_bench/backend/clients/mariadb/config.py +11 -9
  10. vectordb_bench/backend/clients/mariadb/mariadb.py +52 -58
  11. vectordb_bench/backend/clients/milvus/cli.py +1 -19
  12. vectordb_bench/backend/clients/milvus/config.py +0 -1
  13. vectordb_bench/backend/clients/milvus/milvus.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/cli.py +1 -2
  15. vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
  16. vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
  17. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
  18. vectordb_bench/backend/clients/tidb/config.py +6 -9
  19. vectordb_bench/backend/clients/tidb/tidb.py +17 -18
  20. vectordb_bench/backend/clients/vespa/cli.py +47 -0
  21. vectordb_bench/backend/clients/vespa/config.py +51 -0
  22. vectordb_bench/backend/clients/vespa/util.py +15 -0
  23. vectordb_bench/backend/clients/vespa/vespa.py +249 -0
  24. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
  25. vectordb_bench/cli/cli.py +20 -17
  26. vectordb_bench/cli/vectordbbench.py +5 -1
  27. vectordb_bench/frontend/config/dbCaseConfigs.py +58 -7
  28. vectordb_bench/frontend/config/styles.py +2 -0
  29. vectordb_bench/models.py +5 -6
  30. {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +10 -2
  31. {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +35 -28
  32. {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
  33. {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
  34. {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
  35. {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,25 @@
1
- from ..api import VectorDB
2
-
3
1
  import logging
4
2
  from contextlib import contextmanager
5
- from typing import Any, Optional, Tuple
6
- from ..api import VectorDB
7
- from .config import MariaDBConfigDict, MariaDBIndexConfig
8
- import numpy as np
9
3
 
10
4
  import mariadb
5
+ import numpy as np
6
+
7
+ from ..api import VectorDB
8
+ from .config import MariaDBConfigDict, MariaDBIndexConfig
11
9
 
12
10
  log = logging.getLogger(__name__)
13
11
 
12
+
14
13
  class MariaDB(VectorDB):
15
14
  def __init__(
16
- self,
17
- dim: int,
18
- db_config: MariaDBConfigDict,
19
- db_case_config: MariaDBIndexConfig,
20
- collection_name: str = "vec_collection",
21
- drop_old: bool = False,
22
- **kwargs,
23
- ):
24
-
15
+ self,
16
+ dim: int,
17
+ db_config: MariaDBConfigDict,
18
+ db_case_config: MariaDBIndexConfig,
19
+ collection_name: str = "vec_collection",
20
+ drop_old: bool = False,
21
+ **kwargs,
22
+ ):
25
23
  self.name = "MariaDB"
26
24
  self.db_config = db_config
27
25
  self.case_config = db_case_config
@@ -31,7 +29,7 @@ class MariaDB(VectorDB):
31
29
 
32
30
  # construct basic units
33
31
  self.conn, self.cursor = self._create_connection(**self.db_config)
34
-
32
+
35
33
  if drop_old:
36
34
  self._drop_db()
37
35
  self._create_db_table(dim)
@@ -41,9 +39,8 @@ class MariaDB(VectorDB):
41
39
  self.cursor = None
42
40
  self.conn = None
43
41
 
44
-
45
42
  @staticmethod
46
- def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]:
43
+ def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]:
47
44
  conn = mariadb.connect(**kwargs)
48
45
  cursor = conn.cursor()
49
46
 
@@ -52,7 +49,6 @@ class MariaDB(VectorDB):
52
49
 
53
50
  return conn, cursor
54
51
 
55
-
56
52
  def _drop_db(self):
57
53
  assert self.conn is not None, "Connection is not initialized"
58
54
  assert self.cursor is not None, "Cursor is not initialized"
@@ -77,24 +73,23 @@ class MariaDB(VectorDB):
77
73
  log.info(f"{self.name} client create table : {self.table_name}")
78
74
  self.cursor.execute(f"USE {self.db_name}")
79
75
 
80
- self.cursor.execute(f"""
76
+ self.cursor.execute(
77
+ f"""
81
78
  CREATE TABLE {self.table_name} (
82
79
  id INT PRIMARY KEY,
83
80
  v VECTOR({self.dim}) NOT NULL
84
81
  ) ENGINE={index_param["storage_engine"]}
85
- """)
82
+ """
83
+ )
86
84
  self.cursor.execute("COMMIT")
87
85
 
88
86
  except Exception as e:
89
- log.warning(
90
- f"Failed to create table: {self.table_name} error: {e}"
91
- )
87
+ log.warning(f"Failed to create table: {self.table_name} error: {e}")
92
88
  raise e from None
93
89
 
94
-
95
90
  @contextmanager
96
- def init(self) -> None:
97
- """ create and destory connections to database.
91
+ def init(self):
92
+ """create and destory connections to database.
98
93
 
99
94
  Examples:
100
95
  >>> with self.init():
@@ -109,15 +104,21 @@ class MariaDB(VectorDB):
109
104
  self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
110
105
 
111
106
  if index_param["index_type"] == "HNSW":
112
- if index_param["max_cache_size"] != None:
113
- self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}")
114
- if search_param["ef_search"] != None:
115
- self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}")
107
+ if index_param["max_cache_size"] is not None:
108
+ self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param['max_cache_size']}")
109
+ if search_param["ef_search"] is not None:
110
+ self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
116
111
  self.cursor.execute("COMMIT")
117
112
 
118
- self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
119
- self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
120
- self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
113
+ self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
114
+ self.select_sql = (
115
+ f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
116
+ f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
117
+ )
118
+ self.select_sql_with_filter = (
119
+ f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
120
+ f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
121
+ )
121
122
 
122
123
  try:
123
124
  yield
@@ -126,7 +127,6 @@ class MariaDB(VectorDB):
126
127
  self.conn.close()
127
128
  self.cursor = None
128
129
  self.conn = None
129
-
130
130
 
131
131
  def ready_to_load(self) -> bool:
132
132
  pass
@@ -139,33 +139,31 @@ class MariaDB(VectorDB):
139
139
 
140
140
  try:
141
141
  index_options = f"DISTANCE={index_param['metric_type']}"
142
- if index_param["index_type"] == "HNSW" and index_param["M"] != None:
142
+ if index_param["index_type"] == "HNSW" and index_param["M"] is not None:
143
143
  index_options += f" M={index_param['M']}"
144
144
 
145
- self.cursor.execute(f"""
145
+ self.cursor.execute(
146
+ f"""
146
147
  ALTER TABLE {self.db_name}.{self.table_name}
147
148
  ADD VECTOR KEY v(v) {index_options}
148
- """)
149
+ """
150
+ )
149
151
  self.cursor.execute("COMMIT")
150
152
 
151
153
  except Exception as e:
152
- log.warning(
153
- f"Failed to create index: {self.table_name} error: {e}"
154
- )
154
+ log.warning(f"Failed to create index: {self.table_name} error: {e}")
155
155
  raise e from None
156
156
 
157
- pass
158
-
159
157
  @staticmethod
160
- def vector_to_hex(v):
161
- return np.array(v, 'float32').tobytes()
158
+ def vector_to_hex(v): # noqa: ANN001
159
+ return np.array(v, "float32").tobytes()
162
160
 
163
161
  def insert_embeddings(
164
162
  self,
165
163
  embeddings: list[list[float]],
166
164
  metadata: list[int],
167
- **kwargs: Any,
168
- ) -> Tuple[int, Optional[Exception]]:
165
+ **kwargs,
166
+ ) -> tuple[int, Exception]:
169
167
  """Insert embeddings into the database.
170
168
  Should call self.init() first.
171
169
  """
@@ -178,7 +176,7 @@ class MariaDB(VectorDB):
178
176
 
179
177
  batch_data = []
180
178
  for i, row in enumerate(metadata_arr):
181
- batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])));
179
+ batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])))
182
180
 
183
181
  self.cursor.executemany(self.insert_sql, batch_data)
184
182
  self.cursor.execute("COMMIT")
@@ -186,11 +184,8 @@ class MariaDB(VectorDB):
186
184
 
187
185
  return len(metadata), None
188
186
  except Exception as e:
189
- log.warning(
190
- f"Failed to insert data into Vector table ({self.table_name}), error: {e}"
191
- )
187
+ log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}")
192
188
  return 0, e
193
-
194
189
 
195
190
  def search_embedding(
196
191
  self,
@@ -198,17 +193,16 @@ class MariaDB(VectorDB):
198
193
  k: int = 100,
199
194
  filters: dict | None = None,
200
195
  timeout: int | None = None,
201
- **kwargs: Any,
202
- ) -> (list[int]):
196
+ **kwargs,
197
+ ) -> list[int]:
203
198
  assert self.conn is not None, "Connection is not initialized"
204
199
  assert self.cursor is not None, "Cursor is not initialized"
205
200
 
206
- search_param = self.case_config.search_param()
201
+ search_param = self.case_config.search_param() # noqa: F841
207
202
 
208
203
  if filters:
209
- self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k))
204
+ self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k))
210
205
  else:
211
206
  self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
212
207
 
213
- return [id for id, in self.cursor.fetchall()]
214
-
208
+ return [id for (id,) in self.cursor.fetchall()] # noqa: A001
@@ -194,25 +194,6 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
194
194
  **parameters,
195
195
  )
196
196
 
197
- @cli.command()
198
- @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
199
- def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
200
- from .config import GPUBruteForceConfig, MilvusConfig
201
-
202
- run(
203
- db=DBTYPE,
204
- db_config=MilvusConfig(
205
- db_label=parameters["db_label"],
206
- uri=SecretStr(parameters["uri"]),
207
- user=parameters["user_name"],
208
- password=SecretStr(parameters["password"]),
209
- ),
210
- db_case_config=GPUBruteForceConfig(
211
- metric_type=parameters["metric_type"],
212
- limit=parameters["limit"], # top-k for search
213
- ),
214
- **parameters,
215
- )
216
197
 
217
198
  class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
218
199
  metric_type: Annotated[
@@ -224,6 +205,7 @@ class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
224
205
  click.option("--limit", type=int, required=True, help="Top-k limit for search"),
225
206
  ]
226
207
 
208
+
227
209
  @cli.command()
228
210
  @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
229
211
  def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
@@ -215,7 +215,6 @@ class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
215
215
  }
216
216
 
217
217
 
218
-
219
218
  class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
220
219
  nlist: int = 1024
221
220
  m: int = 0
@@ -155,7 +155,7 @@ class Milvus(VectorDB):
155
155
  embeddings: Iterable[list[float]],
156
156
  metadata: list[int],
157
157
  **kwargs,
158
- ) -> (int, Exception):
158
+ ) -> tuple[int, Exception]:
159
159
  """Insert embeddings into Milvus. should call self.init() first"""
160
160
  # use the first insert_embeddings to init collection
161
161
  assert self.col is not None
@@ -18,8 +18,7 @@ from ....cli.cli import (
18
18
  )
19
19
 
20
20
 
21
- # ruff: noqa
22
- def set_default_quantized_fetch_limit(ctx: any, param: any, value: any):
21
+ def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): # noqa: ARG001
23
22
  if ctx.params.get("reranking") and value is None:
24
23
  # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
25
24
  # 100 is default value for quantized_fetch_limit for IVFFlat.
@@ -67,7 +67,7 @@ class Pinecone(VectorDB):
67
67
  embeddings: list[list[float]],
68
68
  metadata: list[int],
69
69
  **kwargs,
70
- ) -> (int, Exception):
70
+ ) -> tuple[int, Exception]:
71
71
  assert len(embeddings) == len(metadata)
72
72
  insert_count = 0
73
73
  try:
@@ -1,4 +1,4 @@
1
- from pydantic import BaseModel, SecretStr, validator
1
+ from pydantic import BaseModel, SecretStr
2
2
 
3
3
  from ..api import DBCaseConfig, DBConfig, MetricType
4
4
 
@@ -20,14 +20,6 @@ class QdrantConfig(DBConfig):
20
20
  "url": self.url.get_secret_value(),
21
21
  }
22
22
 
23
- @validator("*")
24
- def not_empty_field(cls, v: any, field: any):
25
- if field.name in ["api_key", "db_label"]:
26
- return v
27
- if isinstance(v, str | SecretStr) and len(v) == 0:
28
- raise ValueError("Empty string!")
29
- return v
30
-
31
23
 
32
24
  class QdrantIndexConfig(BaseModel, DBCaseConfig):
33
25
  metric_type: MetricType | None = None
@@ -111,7 +111,7 @@ class QdrantCloud(VectorDB):
111
111
  embeddings: list[list[float]],
112
112
  metadata: list[int],
113
113
  **kwargs,
114
- ) -> (int, Exception):
114
+ ) -> tuple[int, Exception]:
115
115
  """Insert embeddings into Milvus. should call self.init() first"""
116
116
  assert self.qdrant_client is not None
117
117
  try:
@@ -1,5 +1,6 @@
1
- from pydantic import SecretStr, BaseModel, validator
2
- from ..api import DBConfig, DBCaseConfig, MetricType
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, MetricType
3
4
 
4
5
 
5
6
  class TiDBConfig(DBConfig):
@@ -10,10 +11,6 @@ class TiDBConfig(DBConfig):
10
11
  db_name: str = "test"
11
12
  ssl: bool = False
12
13
 
13
- @validator("*")
14
- def not_empty_field(cls, v: any, field: any):
15
- return v
16
-
17
14
  def to_dict(self) -> dict:
18
15
  pwd_str = self.password.get_secret_value()
19
16
  return {
@@ -33,10 +30,10 @@ class TiDBIndexConfig(BaseModel, DBCaseConfig):
33
30
  def get_metric_fn(self) -> str:
34
31
  if self.metric_type == MetricType.L2:
35
32
  return "vec_l2_distance"
36
- elif self.metric_type == MetricType.COSINE:
33
+ if self.metric_type == MetricType.COSINE:
37
34
  return "vec_cosine_distance"
38
- else:
39
- raise ValueError(f"Unsupported metric type: {self.metric_type}")
35
+ msg = f"Unsupported metric type: {self.metric_type}"
36
+ raise ValueError(msg)
40
37
 
41
38
  def index_param(self) -> dict:
42
39
  return {
@@ -3,7 +3,7 @@ import io
3
3
  import logging
4
4
  import time
5
5
  from contextlib import contextmanager
6
- from typing import Any, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import pymysql
9
9
 
@@ -62,7 +62,7 @@ class TiDB(VectorDB):
62
62
  conn.commit()
63
63
  except Exception as e:
64
64
  log.warning("Failed to drop table: %s error: %s", self.table_name, e)
65
- raise e
65
+ raise
66
66
 
67
67
  def _create_table(self):
68
68
  try:
@@ -80,7 +80,7 @@ class TiDB(VectorDB):
80
80
  conn.commit()
81
81
  except Exception as e:
82
82
  log.warning("Failed to create table: %s error: %s", self.table_name, e)
83
- raise e
83
+ raise
84
84
 
85
85
  def ready_to_load(self) -> bool:
86
86
  pass
@@ -122,25 +122,25 @@ class TiDB(VectorDB):
122
122
  f"""
123
123
  SELECT PROGRESS FROM information_schema.tiflash_replica
124
124
  WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
125
- """
125
+ """ # noqa: S608
126
126
  )
127
127
  result = cursor.fetchone()
128
128
  return result[0]
129
129
  except Exception as e:
130
130
  log.warning("Failed to check TiFlash replica progress: %s", e)
131
- raise e
131
+ raise
132
132
 
133
133
  def _optimize_wait_tiflash_catch_up(self):
134
134
  try:
135
135
  with self._get_connection() as (conn, cursor):
136
136
  cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
137
137
  conn.commit()
138
- cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
138
+ cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608
139
139
  result = cursor.fetchone()
140
140
  return result[0]
141
141
  except Exception as e:
142
142
  log.warning("Failed to wait TiFlash to catch up: %s", e)
143
- raise e
143
+ raise
144
144
 
145
145
  def _optimize_compact_tiflash(self):
146
146
  try:
@@ -149,7 +149,7 @@ class TiDB(VectorDB):
149
149
  conn.commit()
150
150
  except Exception as e:
151
151
  log.warning("Failed to compact table: %s", e)
152
- raise e
152
+ raise
153
153
 
154
154
  def _optimize_get_tiflash_index_pending_rows(self):
155
155
  try:
@@ -160,13 +160,13 @@ class TiDB(VectorDB):
160
160
  SELECT SUM(ROWS_STABLE_NOT_INDEXED)
161
161
  FROM information_schema.tiflash_indexes
162
162
  WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
163
- """
163
+ """ # noqa: S608
164
164
  )
165
165
  result = cursor.fetchone()
166
166
  return result[0]
167
167
  except Exception as e:
168
168
  log.warning("Failed to read TiFlash index pending rows: %s", e)
169
- raise e
169
+ raise
170
170
 
171
171
  def _insert_embeddings_serial(
172
172
  self,
@@ -178,29 +178,28 @@ class TiDB(VectorDB):
178
178
  try:
179
179
  with self._get_connection() as (conn, cursor):
180
180
  buf = io.StringIO()
181
- buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ")
181
+ buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608
182
182
  for i in range(offset, offset + size):
183
183
  if i > offset:
184
184
  buf.write(",")
185
- buf.write(f'({metadata[i]}, "{str(embeddings[i])}")')
185
+ buf.write(f'({metadata[i]}, "{embeddings[i]!s}")')
186
186
  cursor.execute(buf.getvalue())
187
187
  conn.commit()
188
188
  except Exception as e:
189
189
  log.warning("Failed to insert data into table: %s", e)
190
- raise e
190
+ raise
191
191
 
192
192
  def insert_embeddings(
193
193
  self,
194
194
  embeddings: list[list[float]],
195
195
  metadata: list[int],
196
196
  **kwargs: Any,
197
- ) -> Tuple[int, Optional[Exception]]:
197
+ ) -> tuple[int, Exception]:
198
198
  workers = 10
199
199
  # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB)
200
200
  max_batch_size = 64 * 1024 * 1024 // 24 // self.dim
201
201
  batch_size = len(embeddings) // workers
202
- if batch_size > max_batch_size:
203
- batch_size = max_batch_size
202
+ batch_size = min(batch_size, max_batch_size)
204
203
  with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
205
204
  futures = []
206
205
  for i in range(0, len(embeddings), batch_size):
@@ -227,8 +226,8 @@ class TiDB(VectorDB):
227
226
  self.cursor.execute(
228
227
  f"""
229
228
  SELECT id FROM {self.table_name}
230
- ORDER BY {self.search_fn}(embedding, "{str(query)}") LIMIT {k};
231
- """
229
+ ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k};
230
+ """ # noqa: S608
232
231
  )
233
232
  result = self.cursor.fetchall()
234
233
  return [int(i[0]) for i in result]
@@ -0,0 +1,47 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from vectordb_bench.backend.clients import DB
7
+ from vectordb_bench.cli.cli import (
8
+ CommonTypedDict,
9
+ HNSWFlavor1,
10
+ cli,
11
+ click_parameter_decorators_from_typed_dict,
12
+ run,
13
+ )
14
+
15
+
16
+ class VespaTypedDict(CommonTypedDict, HNSWFlavor1):
17
+ uri: Annotated[
18
+ str,
19
+ click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"),
20
+ ]
21
+ port: Annotated[
22
+ int,
23
+ click.option("--port", "-p", type=int, help="connection port", default=8080),
24
+ ]
25
+ quantization: Annotated[
26
+ str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none")
27
+ ]
28
+
29
+
30
+ @cli.command()
31
+ @click_parameter_decorators_from_typed_dict(VespaTypedDict)
32
+ def Vespa(**params: Unpack[VespaTypedDict]):
33
+ from .config import VespaConfig, VespaHNSWConfig
34
+
35
+ case_params = {
36
+ "quantization_type": params["quantization"],
37
+ "M": params["m"],
38
+ "efConstruction": params["ef_construction"],
39
+ "ef": params["ef_search"],
40
+ }
41
+
42
+ run(
43
+ db=DB.Vespa,
44
+ db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]),
45
+ db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}),
46
+ **params,
47
+ )
@@ -0,0 +1,51 @@
1
+ from typing import Literal, TypeAlias
2
+
3
+ from pydantic import BaseModel, SecretStr
4
+
5
+ from ..api import DBCaseConfig, DBConfig, MetricType
6
+
7
+ VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"]
8
+
9
+ VespaQuantizationType: TypeAlias = Literal["none", "binary"]
10
+
11
+
12
+ class VespaConfig(DBConfig):
13
+ url: SecretStr = "http://127.0.0.1"
14
+ port: int = 8080
15
+
16
+ def to_dict(self):
17
+ return {
18
+ "url": self.url.get_secret_value(),
19
+ "port": self.port,
20
+ }
21
+
22
+
23
+ class VespaHNSWConfig(BaseModel, DBCaseConfig):
24
+ metric_type: MetricType = MetricType.COSINE
25
+ quantization_type: VespaQuantizationType = "none"
26
+ M: int = 16
27
+ efConstruction: int = 200
28
+ ef: int = 100
29
+
30
+ def index_param(self) -> dict:
31
+ return {
32
+ "distance_metric": self.parse_metric(self.metric_type),
33
+ "max_links_per_node": self.M,
34
+ "neighbors_to_explore_at_insert": self.efConstruction,
35
+ }
36
+
37
+ def search_param(self) -> dict:
38
+ return {}
39
+
40
+ def parse_metric(self, metric_type: MetricType) -> VespaMetric:
41
+ match metric_type:
42
+ case MetricType.COSINE:
43
+ return "angular"
44
+ case MetricType.L2:
45
+ return "euclidean"
46
+ case MetricType.DP | MetricType.IP:
47
+ return "dotproduct"
48
+ case MetricType.HAMMING:
49
+ return "hamming"
50
+ case _:
51
+ raise NotImplementedError
@@ -0,0 +1,15 @@
1
+ """Utility functions for supporting binary quantization
2
+
3
+ From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8
4
+ """
5
+
6
+ import numpy as np
7
+
8
+
9
+ def binarize_tensor(tensor: list[float]) -> list[int]:
10
+ """
11
+ Binarize a floating-point list by thresholding at zero
12
+ and packing the bits into bytes.
13
+ """
14
+ tensor = np.array(tensor)
15
+ return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist()