vectordb-bench 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. vectordb_bench/__init__.py +14 -3
  2. vectordb_bench/backend/assembler.py +2 -2
  3. vectordb_bench/backend/cases.py +146 -57
  4. vectordb_bench/backend/clients/__init__.py +6 -1
  5. vectordb_bench/backend/clients/api.py +23 -11
  6. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  7. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +11 -9
  8. vectordb_bench/backend/clients/milvus/config.py +2 -3
  9. vectordb_bench/backend/clients/milvus/milvus.py +32 -19
  10. vectordb_bench/backend/clients/pgvector/config.py +49 -0
  11. vectordb_bench/backend/clients/pgvector/pgvector.py +171 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +3 -3
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +19 -13
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +23 -6
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +12 -13
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +3 -3
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +9 -8
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +5 -4
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +3 -1
  20. vectordb_bench/backend/dataset.py +100 -162
  21. vectordb_bench/backend/result_collector.py +2 -2
  22. vectordb_bench/backend/runner/mp_runner.py +29 -13
  23. vectordb_bench/backend/runner/serial_runner.py +98 -36
  24. vectordb_bench/backend/task_runner.py +43 -48
  25. vectordb_bench/frontend/components/check_results/charts.py +10 -21
  26. vectordb_bench/frontend/components/check_results/data.py +31 -15
  27. vectordb_bench/frontend/components/check_results/expanderStyle.py +37 -0
  28. vectordb_bench/frontend/components/check_results/filters.py +61 -33
  29. vectordb_bench/frontend/components/check_results/footer.py +8 -0
  30. vectordb_bench/frontend/components/check_results/headerIcon.py +8 -4
  31. vectordb_bench/frontend/components/check_results/nav.py +7 -6
  32. vectordb_bench/frontend/components/check_results/priceTable.py +3 -2
  33. vectordb_bench/frontend/components/check_results/stPageConfig.py +18 -0
  34. vectordb_bench/frontend/components/get_results/saveAsImage.py +50 -0
  35. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  36. vectordb_bench/frontend/components/run_test/caseSelector.py +19 -16
  37. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +20 -7
  38. vectordb_bench/frontend/components/run_test/dbSelector.py +5 -5
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +4 -6
  40. vectordb_bench/frontend/components/run_test/submitTask.py +16 -10
  41. vectordb_bench/frontend/const/dbCaseConfigs.py +291 -0
  42. vectordb_bench/frontend/const/dbPrices.py +6 -0
  43. vectordb_bench/frontend/const/styles.py +58 -0
  44. vectordb_bench/frontend/pages/{qps_with_price.py → quries_per_dollar.py} +24 -17
  45. vectordb_bench/frontend/pages/run_test.py +17 -11
  46. vectordb_bench/frontend/vdb_benchmark.py +19 -12
  47. vectordb_bench/metric.py +19 -10
  48. vectordb_bench/models.py +14 -40
  49. vectordb_bench/results/dbPrices.json +32 -0
  50. vectordb_bench/results/getLeaderboardData.py +52 -0
  51. vectordb_bench/results/leaderboard.json +1 -0
  52. vectordb_bench/results/{result_20230609_standard.json → result_20230705_standard.json} +1910 -897
  53. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/METADATA +107 -27
  54. vectordb_bench-0.0.3.dist-info/RECORD +67 -0
  55. vectordb_bench/frontend/const.py +0 -391
  56. vectordb_bench-0.0.1.dist-info/RECORD +0 -56
  57. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/LICENSE +0 -0
  58. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/WHEEL +0 -0
  59. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/entry_points.txt +0 -0
  60. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ from pydantic import BaseModel, SecretStr
2
+ from ..api import DBConfig, DBCaseConfig, MetricType
3
+
4
+ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
5
+
6
+ class PgVectorConfig(DBConfig):
7
+ user_name: SecretStr = "postgres"
8
+ password: SecretStr
9
+ url: SecretStr
10
+ db_name: str
11
+
12
+ def to_dict(self) -> dict:
13
+ user_str = self.user_name.get_secret_value()
14
+ pwd_str = self.password.get_secret_value()
15
+ url_str = self.url.get_secret_value()
16
+ return {
17
+ "url" : POSTGRE_URL_PLACEHOLDER%(user_str, pwd_str, url_str, self.db_name)
18
+ }
19
+
20
+ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
21
+ metric_type: MetricType | None = None
22
+ lists: int | None = 1000
23
+ probes: int | None = 10
24
+
25
+ def parse_metric(self) -> str:
26
+ if self.metric_type == MetricType.L2:
27
+ return "vector_l2_ops"
28
+ elif self.metric_type == MetricType.IP:
29
+ return "vector_ip_ops"
30
+ return "vector_cosine_ops"
31
+
32
+ def parse_metric_fun_str(self) -> str:
33
+ if self.metric_type == MetricType.L2:
34
+ return "l2_distance"
35
+ elif self.metric_type == MetricType.IP:
36
+ return "max_inner_product"
37
+ return "cosine_distance"
38
+
39
+ def index_param(self) -> dict:
40
+ return {
41
+ "lists" : self.lists,
42
+ "metric" : self.parse_metric()
43
+ }
44
+
45
+ def search_param(self) -> dict:
46
+ return {
47
+ "probes" : self.probes,
48
+ "metric_fun" : self.parse_metric_fun_str()
49
+ }
@@ -0,0 +1,171 @@
1
+ """Wrapper around the Pgvector vector database over VectorDB"""
2
+
3
+ import logging
4
+ import time
5
+ from contextlib import contextmanager
6
+ from typing import Any, Type
7
+ from functools import wraps
8
+
9
+ from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
10
+ from pgvector.sqlalchemy import Vector
11
+ from .config import PgVectorConfig, PgVectorIndexConfig
12
+ from sqlalchemy import (
13
+ MetaData,
14
+ create_engine,
15
+ insert,
16
+ select,
17
+ Index,
18
+ Table,
19
+ text,
20
+ Column,
21
+ Float,
22
+ Integer
23
+ )
24
+ from sqlalchemy.orm import (
25
+ declarative_base,
26
+ mapped_column,
27
+ Session
28
+ )
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+ class PgVector(VectorDB):
33
+ """ Use SQLAlchemy instructions"""
34
+ def __init__(
35
+ self,
36
+ dim: int,
37
+ db_config: dict,
38
+ db_case_config: DBCaseConfig,
39
+ collection_name: str = "PgVectorCollection",
40
+ drop_old: bool = False,
41
+ **kwargs,
42
+ ):
43
+ self.db_config = db_config
44
+ self.case_config = db_case_config
45
+ self.table_name = collection_name
46
+ self.dim = dim
47
+
48
+ self._index_name = "pqvector_index"
49
+ self._primary_field = "id"
50
+ self._vector_field = "embedding"
51
+
52
+ # construct basic units
53
+ pg_engine = create_engine(**self.db_config)
54
+ Base = declarative_base()
55
+ pq_metadata = Base.metadata
56
+ pq_metadata.reflect(pg_engine)
57
+
58
+ # create vector extension
59
+ with pg_engine.connect() as conn:
60
+ conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
61
+ conn.commit()
62
+
63
+ self.pg_table = self._get_table_schema(pq_metadata)
64
+ if drop_old and self.table_name in pq_metadata.tables:
65
+ log.info(f"Pgvector client drop table : {self.table_name}")
66
+ # self.pg_table.drop(pg_engine, checkfirst=True)
67
+ pq_metadata.drop_all(pg_engine)
68
+ self._create_table(dim, pg_engine)
69
+
70
+
71
+ @classmethod
72
+ def config_cls(cls) -> Type[DBConfig]:
73
+ return PgVectorConfig
74
+
75
+ @classmethod
76
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
77
+ return PgVectorIndexConfig
78
+
79
+ @contextmanager
80
+ def init(self) -> None:
81
+ """
82
+ Examples:
83
+ >>> with self.init():
84
+ >>> self.insert_embeddings()
85
+ >>> self.search_embedding()
86
+ """
87
+ self.pg_engine = create_engine(**self.db_config)
88
+
89
+ Base = declarative_base()
90
+ pq_metadata = Base.metadata
91
+ pq_metadata.reflect(self.pg_engine)
92
+ self.pg_session = Session(self.pg_engine)
93
+ self.pg_table = self._get_table_schema(pq_metadata)
94
+ yield
95
+ self.pg_session = None
96
+ self.pg_engine = None
97
+ del (self.pg_session)
98
+ del (self.pg_engine)
99
+
100
+ def ready_to_load(self):
101
+ pass
102
+
103
+ def optimize(self):
104
+ pass
105
+
106
+ def ready_to_search(self):
107
+ pass
108
+
109
+ def _get_table_schema(self, pq_metadata):
110
+ return Table(
111
+ self.table_name,
112
+ pq_metadata,
113
+ Column(self._primary_field, Integer, primary_key=True),
114
+ Column(self._vector_field, Vector(self.dim)),
115
+ extend_existing=True
116
+ )
117
+
118
+ def _create_index(self, pg_engine):
119
+ index_param = self.case_config.index_param()
120
+ index = Index(self._index_name, self.pg_table.c.embedding,
121
+ postgresql_using='ivfflat',
122
+ postgresql_with={'lists': index_param["lists"]},
123
+ postgresql_ops={'embedding': index_param["metric"]}
124
+ )
125
+ index.drop(pg_engine, checkfirst = True)
126
+ index.create(pg_engine)
127
+
128
+ def _create_table(self, dim, pg_engine : int):
129
+ try:
130
+ # create table
131
+ self.pg_table.create(bind = pg_engine, checkfirst = True)
132
+ # create vec index
133
+ self._create_index(pg_engine)
134
+ except Exception as e:
135
+ log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
136
+ raise e from None
137
+
138
+ def insert_embeddings(
139
+ self,
140
+ embeddings: list[list[float]],
141
+ metadata: list[int],
142
+ **kwargs: Any,
143
+ ) -> (int, Exception):
144
+ try:
145
+ items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(len(metadata))]
146
+ self.pg_session.execute(insert(self.pg_table), items)
147
+ self.pg_session.commit()
148
+ return len(metadata), None
149
+ except Exception as e:
150
+ log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
151
+ return 0, e
152
+
153
+ def search_embedding(
154
+ self,
155
+ query: list[float],
156
+ k: int = 100,
157
+ filters: dict | None = None,
158
+ timeout: int | None = None,
159
+ ) -> list[int]:
160
+ assert self.pg_table is not None
161
+ search_param =self.case_config.search_param()
162
+ with self.pg_engine.connect() as conn:
163
+ conn.execute(text(f'SET ivfflat.probes = {search_param["probes"]}'))
164
+ conn.commit()
165
+ op_fun = getattr(self.pg_table.c.embedding, search_param["metric_fun"])
166
+ if filters:
167
+ res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).filter(self.pg_table.c.id > filters.get('id')).limit(k))
168
+ else:
169
+ res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).limit(k))
170
+ return list(res)
171
+
@@ -2,9 +2,9 @@ from pydantic import BaseModel, SecretStr
2
2
  from ..api import DBConfig
3
3
 
4
4
 
5
- class PineconeConfig(DBConfig, BaseModel):
6
- api_key: SecretStr | None = None
7
- environment: SecretStr | None = None
5
+ class PineconeConfig(DBConfig):
6
+ api_key: SecretStr
7
+ environment: SecretStr
8
8
  index_name: str
9
9
 
10
10
  def to_dict(self) -> dict:
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
  from contextlib import contextmanager
5
- from typing import Any, Type
5
+ from typing import Type
6
6
 
7
7
  from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
8
  from .config import PineconeConfig
@@ -20,6 +20,7 @@ class Pinecone(VectorDB):
20
20
  db_config: dict,
21
21
  db_case_config: DBCaseConfig,
22
22
  drop_old: bool = False,
23
+ **kwargs,
23
24
  ):
24
25
  """Initialize wrapper around the milvus vector database."""
25
26
  self.index_name = db_config["index_name"]
@@ -69,24 +70,30 @@ class Pinecone(VectorDB):
69
70
  def ready_to_load(self):
70
71
  pass
71
72
 
72
- def ready_to_search(self):
73
+ def optimize(self):
73
74
  pass
74
75
 
75
76
  def insert_embeddings(
76
77
  self,
77
78
  embeddings: list[list[float]],
78
79
  metadata: list[int],
79
- ) -> list[str]:
80
+ **kwargs,
81
+ ) -> (int, Exception):
80
82
  assert len(embeddings) == len(metadata)
81
- for batch_start_offset in range(0, len(embeddings), self.batch_size):
82
- batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
83
- insert_datas = []
84
- for i in range(batch_start_offset, batch_end_offset):
85
- insert_data = (str(metadata[i]), embeddings[i], {
86
- self._metadata_key: metadata[i]})
87
- insert_datas.append(insert_data)
88
- self.index.upsert(insert_datas)
89
- return len(embeddings)
83
+ insert_count = 0
84
+ try:
85
+ for batch_start_offset in range(0, len(embeddings), self.batch_size):
86
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
87
+ insert_datas = []
88
+ for i in range(batch_start_offset, batch_end_offset):
89
+ insert_data = (str(metadata[i]), embeddings[i], {
90
+ self._metadata_key: metadata[i]})
91
+ insert_datas.append(insert_data)
92
+ self.index.upsert(insert_datas)
93
+ insert_count += batch_end_offset - batch_start_offset
94
+ except Exception as e:
95
+ return (insert_count, e)
96
+ return (len(embeddings), None)
90
97
 
91
98
  def search_embedding(
92
99
  self,
@@ -94,7 +101,6 @@ class Pinecone(VectorDB):
94
101
  k: int = 100,
95
102
  filters: dict | None = None,
96
103
  timeout: int | None = None,
97
- **kwargs: Any,
98
104
  ) -> list[tuple[int, float]]:
99
105
  if filters is None:
100
106
  pinecone_filters = {}
@@ -1,16 +1,33 @@
1
1
  from pydantic import BaseModel, SecretStr
2
2
 
3
- from ..api import DBConfig
3
+ from ..api import DBConfig, DBCaseConfig, MetricType
4
+ from qdrant_client.models import Distance
4
5
 
5
6
 
6
- class QdrantConfig(DBConfig, BaseModel):
7
- url: SecretStr | None = None
8
- api_key: SecretStr | None = None
9
- prefer_grpc: bool = True
7
+ class QdrantConfig(DBConfig):
8
+ url: SecretStr
9
+ api_key: SecretStr
10
10
 
11
11
  def to_dict(self) -> dict:
12
12
  return {
13
13
  "url": self.url.get_secret_value(),
14
14
  "api_key": self.api_key.get_secret_value(),
15
- "prefer_grpc": self.prefer_grpc,
15
+ "prefer_grpc": True,
16
16
  }
17
+
18
+ class QdrantIndexConfig(BaseModel, DBCaseConfig):
19
+ metric_type: MetricType | None = None
20
+
21
+ def parse_metric(self) -> str:
22
+ if self.metric_type == MetricType.L2:
23
+ return Distance.EUCLID
24
+ elif self.metric_type == MetricType.IP:
25
+ return Distance.DOT
26
+ return Distance.COSINE
27
+
28
+ def index_param(self) -> dict:
29
+ params = {"distance": self.parse_metric()}
30
+ return params
31
+
32
+ def search_param(self) -> dict:
33
+ return {}
@@ -3,13 +3,12 @@
3
3
  import logging
4
4
  import time
5
5
  from contextlib import contextmanager
6
- from typing import Any, Type
6
+ from typing import Type
7
7
 
8
- from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
9
- from .config import QdrantConfig
8
+ from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
9
+ from .config import QdrantConfig, QdrantIndexConfig
10
10
  from qdrant_client.http.models import (
11
11
  CollectionStatus,
12
- Distance,
13
12
  VectorParams,
14
13
  PayloadSchemaType,
15
14
  Batch,
@@ -32,6 +31,7 @@ class QdrantCloud(VectorDB):
32
31
  db_case_config: DBCaseConfig,
33
32
  collection_name: str = "QdrantCloudCollection",
34
33
  drop_old: bool = False,
34
+ **kwargs,
35
35
  ):
36
36
  """Initialize wrapper around the QdrantCloud vector database."""
37
37
  self.db_config = db_config
@@ -55,7 +55,7 @@ class QdrantCloud(VectorDB):
55
55
 
56
56
  @classmethod
57
57
  def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
58
- return EmptyDBCaseConfig
58
+ return QdrantIndexConfig
59
59
 
60
60
  @contextmanager
61
61
  def init(self) -> None:
@@ -74,7 +74,7 @@ class QdrantCloud(VectorDB):
74
74
  pass
75
75
 
76
76
 
77
- def ready_to_search(self):
77
+ def optimize(self):
78
78
  assert self.qdrant_client, "Please call self.init() before"
79
79
  # wait for vectors to be fully indexed
80
80
  SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
@@ -97,7 +97,7 @@ class QdrantCloud(VectorDB):
97
97
  try:
98
98
  qdrant_client.create_collection(
99
99
  collection_name=self.collection_name,
100
- vectors_config=VectorParams(size=dim, distance=Distance.EUCLID)
100
+ vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"])
101
101
  )
102
102
 
103
103
  qdrant_client.create_payload_index(
@@ -116,8 +116,8 @@ class QdrantCloud(VectorDB):
116
116
  self,
117
117
  embeddings: list[list[float]],
118
118
  metadata: list[int],
119
- **kwargs: Any,
120
- ) -> list[str]:
119
+ **kwargs,
120
+ ) -> (int, Exception):
121
121
  """Insert embeddings into Milvus. should call self.init() first"""
122
122
  assert self.qdrant_client is not None
123
123
  try:
@@ -127,11 +127,11 @@ class QdrantCloud(VectorDB):
127
127
  wait=True,
128
128
  points=Batch(ids=metadata, payloads=[{self._primary_field: v} for v in metadata], vectors=embeddings)
129
129
  )
130
-
131
- return len(metadata)
132
130
  except Exception as e:
133
131
  log.info(f"Failed to insert data, {e}")
134
- raise e from None
132
+ return 0, e
133
+ else:
134
+ return len(metadata), None
135
135
 
136
136
  def search_embedding(
137
137
  self,
@@ -139,7 +139,6 @@ class QdrantCloud(VectorDB):
139
139
  k: int = 100,
140
140
  filters: dict | None = None,
141
141
  timeout: int | None = None,
142
- **kwargs: Any,
143
142
  ) -> list[int]:
144
143
  """Perform a search on a query embedding and return results with score.
145
144
  Should call self.init() first.
@@ -4,9 +4,9 @@ import weaviate
4
4
  from ..api import DBConfig, DBCaseConfig, MetricType
5
5
 
6
6
 
7
- class WeaviateConfig(DBConfig, BaseModel):
8
- url: SecretStr | None = None
9
- api_key: SecretStr | None = None
7
+ class WeaviateConfig(DBConfig):
8
+ url: SecretStr
9
+ api_key: SecretStr
10
10
 
11
11
  def to_dict(self) -> dict:
12
12
  return {
@@ -1,7 +1,7 @@
1
1
  """Wrapper around the Weaviate vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- from typing import Any, Iterable, Type
4
+ from typing import Iterable, Type
5
5
  from contextlib import contextmanager
6
6
 
7
7
  from weaviate.exceptions import WeaviateBaseError
@@ -21,6 +21,7 @@ class WeaviateCloud(VectorDB):
21
21
  db_case_config: DBCaseConfig,
22
22
  collection_name: str = "VectorDBBenchCollection",
23
23
  drop_old: bool = False,
24
+ **kwargs,
24
25
  ):
25
26
  """Initialize wrapper around the weaviate vector database."""
26
27
  self.db_config = db_config
@@ -70,7 +71,7 @@ class WeaviateCloud(VectorDB):
70
71
  """Should call insert first, do nothing"""
71
72
  pass
72
73
 
73
- def ready_to_search(self):
74
+ def optimize(self):
74
75
  assert self.client.schema.exists(self.collection_name)
75
76
  self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } )
76
77
 
@@ -98,11 +99,11 @@ class WeaviateCloud(VectorDB):
98
99
  self,
99
100
  embeddings: Iterable[list[float]],
100
101
  metadata: list[int],
101
- **kwargs: Any,
102
- ) -> int:
102
+ **kwargs,
103
+ ) -> (int, Exception):
103
104
  """Insert embeddings into Weaviate"""
104
105
  assert self.client.schema.exists(self.collection_name)
105
-
106
+ insert_count = 0
106
107
  try:
107
108
  with self.client.batch as batch:
108
109
  batch.batch_size = len(metadata)
@@ -114,10 +115,11 @@ class WeaviateCloud(VectorDB):
114
115
  class_name=self.collection_name,
115
116
  vector=embeddings[i]
116
117
  ))
117
- return len(res)
118
+ insert_count += 1
119
+ return (len(res), None)
118
120
  except WeaviateBaseError as e:
119
121
  log.warning(f"Failed to insert data, error: {str(e)}")
120
- raise e from None
122
+ return (insert_count, e)
121
123
 
122
124
  def search_embedding(
123
125
  self,
@@ -125,7 +127,6 @@ class WeaviateCloud(VectorDB):
125
127
  k: int = 100,
126
128
  filters: dict | None = None,
127
129
  timeout: int | None = None,
128
- **kwargs: Any,
129
130
  ) -> list[int]:
130
131
  """Perform a search on a query embedding and return results with distance.
131
132
  Should call self.init() first.
@@ -1,12 +1,13 @@
1
- from pydantic import BaseModel, SecretStr
1
+ from pydantic import SecretStr
2
+
2
3
  from ..api import DBCaseConfig, DBConfig
3
4
  from ..milvus.config import MilvusIndexConfig, IndexType
4
5
 
5
6
 
6
- class ZillizCloudConfig(DBConfig, BaseModel):
7
- uri: SecretStr | None = None
7
+ class ZillizCloudConfig(DBConfig):
8
+ uri: SecretStr
8
9
  user: str
9
- password: SecretStr | None = None
10
+ password: SecretStr
10
11
 
11
12
  def to_dict(self) -> dict:
12
13
  return {
@@ -14,7 +14,8 @@ class ZillizCloud(Milvus):
14
14
  db_case_config: DBCaseConfig,
15
15
  collection_name: str = "ZillizCloudVectorDBBench",
16
16
  drop_old: bool = False,
17
- name: str = "ZillizCloud"
17
+ name: str = "ZillizCloud",
18
+ **kwargs,
18
19
  ):
19
20
  super().__init__(
20
21
  dim=dim,
@@ -23,6 +24,7 @@ class ZillizCloud(Milvus):
23
24
  collection_name=collection_name,
24
25
  drop_old=drop_old,
25
26
  name=name,
27
+ **kwargs,
26
28
  )
27
29
 
28
30
  @classmethod