vectordb-bench 0.0.1__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. vectordb_bench/__init__.py +30 -0
  2. vectordb_bench/__main__.py +39 -0
  3. vectordb_bench/backend/__init__.py +0 -0
  4. vectordb_bench/backend/assembler.py +57 -0
  5. vectordb_bench/backend/cases.py +124 -0
  6. vectordb_bench/backend/clients/__init__.py +57 -0
  7. vectordb_bench/backend/clients/api.py +179 -0
  8. vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
  9. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
  10. vectordb_bench/backend/clients/milvus/config.py +123 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +182 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +15 -0
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
  20. vectordb_bench/backend/dataset.py +393 -0
  21. vectordb_bench/backend/result_collector.py +15 -0
  22. vectordb_bench/backend/runner/__init__.py +12 -0
  23. vectordb_bench/backend/runner/mp_runner.py +124 -0
  24. vectordb_bench/backend/runner/serial_runner.py +164 -0
  25. vectordb_bench/backend/task_runner.py +290 -0
  26. vectordb_bench/backend/utils.py +85 -0
  27. vectordb_bench/base.py +6 -0
  28. vectordb_bench/frontend/components/check_results/charts.py +175 -0
  29. vectordb_bench/frontend/components/check_results/data.py +86 -0
  30. vectordb_bench/frontend/components/check_results/filters.py +97 -0
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
  32. vectordb_bench/frontend/components/check_results/nav.py +21 -0
  33. vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
  34. vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
  35. vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
  36. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
  37. vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
  38. vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
  41. vectordb_bench/frontend/const.py +391 -0
  42. vectordb_bench/frontend/pages/qps_with_price.py +60 -0
  43. vectordb_bench/frontend/pages/run_test.py +59 -0
  44. vectordb_bench/frontend/utils.py +6 -0
  45. vectordb_bench/frontend/vdb_benchmark.py +42 -0
  46. vectordb_bench/interface.py +239 -0
  47. vectordb_bench/log_util.py +103 -0
  48. vectordb_bench/metric.py +53 -0
  49. vectordb_bench/models.py +234 -0
  50. vectordb_bench/results/result_20230609_standard.json +3228 -0
  51. vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
  52. vectordb_bench-0.0.1.dist-info/METADATA +226 -0
  53. vectordb_bench-0.0.1.dist-info/RECORD +56 -0
  54. vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
  55. vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
  56. vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
1
+ from pydantic import BaseModel, SecretStr
2
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
3
+
4
+
5
+ class MilvusConfig(DBConfig, BaseModel):
6
+ uri: SecretStr | None = "http://localhost:19530"
7
+
8
+ def to_dict(self) -> dict:
9
+ return {"uri": self.uri.get_secret_value()}
10
+
11
+
12
+
13
+ class MilvusIndexConfig(BaseModel):
14
+ """Base config for milvus"""
15
+
16
+ index: IndexType
17
+ metric_type: MetricType | None = None
18
+
19
+ def parse_metric(self) -> str:
20
+ if not self.metric_type:
21
+ return ""
22
+
23
+ if self.metric_type == MetricType.COSINE:
24
+ return MetricType.L2.value
25
+ return self.metric_type.value
26
+
27
+
28
+ class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig):
29
+ index: IndexType = IndexType.AUTOINDEX
30
+
31
+ def index_param(self) -> dict:
32
+ return {
33
+ "metric_type": self.parse_metric(),
34
+ "index_type": self.index.value,
35
+ "params": {},
36
+ }
37
+
38
+ def search_param(self) -> dict:
39
+ return {
40
+ "metric_type": self.parse_metric(),
41
+ }
42
+
43
+ class HNSWConfig(MilvusIndexConfig, DBCaseConfig):
44
+ M: int
45
+ efConstruction: int
46
+ ef: int | None = None
47
+ index: IndexType = IndexType.HNSW
48
+
49
+ def index_param(self) -> dict:
50
+ return {
51
+ "metric_type": self.parse_metric(),
52
+ "index_type": self.index.value,
53
+ "params": {"M": self.M, "efConstruction": self.efConstruction},
54
+ }
55
+
56
+ def search_param(self) -> dict:
57
+ return {
58
+ "metric_type": self.parse_metric(),
59
+ "params": {"ef": self.ef},
60
+ }
61
+
62
+
63
+ class DISKANNConfig(MilvusIndexConfig, DBCaseConfig):
64
+ search_list: int | None = None
65
+ index: IndexType = IndexType.DISKANN
66
+
67
+ def index_param(self) -> dict:
68
+ return {
69
+ "metric_type": self.parse_metric(),
70
+ "index_type": self.index.value,
71
+ "params": {},
72
+ }
73
+
74
+ def search_param(self) -> dict:
75
+ return {
76
+ "metric_type": self.parse_metric(),
77
+ "params": {"search_list": self.search_list},
78
+ }
79
+
80
+
81
+ class IVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
82
+ nlist: int
83
+ nprobe: int | None = None
84
+ index: IndexType = IndexType.IVFFlat
85
+
86
+ def index_param(self) -> dict:
87
+ return {
88
+ "metric_type": self.parse_metric(),
89
+ "index_type": self.index.value,
90
+ "params": {"nlist": self.nlist},
91
+ }
92
+
93
+ def search_param(self) -> dict:
94
+ return {
95
+ "metric_type": self.parse_metric(),
96
+ "params": {"nprobe": self.nprobe},
97
+ }
98
+
99
+
100
+ class FLATConfig(MilvusIndexConfig, DBCaseConfig):
101
+ index: IndexType = IndexType.Flat
102
+
103
+ def index_param(self) -> dict:
104
+ return {
105
+ "metric_type": self.parse_metric(),
106
+ "index_type": self.index.value,
107
+ "params": {},
108
+ }
109
+
110
+ def search_param(self) -> dict:
111
+ return {
112
+ "metric_type": self.parse_metric(),
113
+ "params": {},
114
+ }
115
+
116
+ _milvus_case_config = {
117
+ IndexType.AUTOINDEX: AutoIndexConfig,
118
+ IndexType.HNSW: HNSWConfig,
119
+ IndexType.DISKANN: DISKANNConfig,
120
+ IndexType.IVFFlat: IVFFlatConfig,
121
+ IndexType.Flat: FLATConfig,
122
+ }
123
+
@@ -0,0 +1,182 @@
1
+ """Wrapper around the Milvus vector database over VectorDB"""
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Any, Iterable, Type
6
+
7
+ from pymilvus import Collection, utility
8
+ from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
9
+
10
+ from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
11
+ from .config import MilvusConfig, _milvus_case_config
12
+
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class Milvus(VectorDB):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ db_config: dict,
22
+ db_case_config: DBCaseConfig,
23
+ collection_name: str = "VectorDBBenchCollection",
24
+ drop_old: bool = False,
25
+ name: str = "Milvus",
26
+ ):
27
+ """Initialize wrapper around the milvus vector database."""
28
+ self.name = name
29
+ self.db_config = db_config
30
+ self.case_config = db_case_config
31
+ self.collection_name = collection_name
32
+
33
+ self._primary_field = "pk"
34
+ self._scalar_field = "id"
35
+ self._vector_field = "vector"
36
+ self._index_name = "vector_idx"
37
+
38
+ from pymilvus import connections
39
+ connections.connect(**self.db_config, timeout=30)
40
+ if drop_old and utility.has_collection(self.collection_name):
41
+ log.info(f"{self.name} client drop_old collection: {self.collection_name}")
42
+ utility.drop_collection(self.collection_name)
43
+
44
+ if not utility.has_collection(self.collection_name):
45
+ fields = [
46
+ FieldSchema(self._primary_field, DataType.INT64, is_primary=True),
47
+ FieldSchema(self._scalar_field, DataType.INT64),
48
+ FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim)
49
+ ]
50
+
51
+ log.info(f"{self.name} create collection: {self.collection_name}")
52
+
53
+ # Create the collection
54
+ coll = Collection(
55
+ name=self.collection_name,
56
+ schema=CollectionSchema(fields),
57
+ consistency_level="Session",
58
+ )
59
+
60
+ # self._pre_load(coll)
61
+
62
+ connections.disconnect("default")
63
+
64
+ @classmethod
65
+ def config_cls(cls) -> Type[DBConfig]:
66
+ return MilvusConfig
67
+
68
+ @classmethod
69
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
70
+ return _milvus_case_config.get(index_type)
71
+
72
+
73
+ @contextmanager
74
+ def init(self) -> None:
75
+ """
76
+ Examples:
77
+ >>> with self.init():
78
+ >>> self.insert_embeddings()
79
+ >>> self.search_embedding()
80
+ """
81
+ from pymilvus import connections
82
+ self.col: Collection | None = None
83
+
84
+ connections.connect(**self.db_config, timeout=60)
85
+ # Grab the existing colection with connections
86
+ self.col = Collection(self.collection_name)
87
+
88
+ yield
89
+ connections.disconnect("default")
90
+
91
+ def _pre_load(self, coll: Collection):
92
+ if not coll.has_index(index_name=self._index_name):
93
+ log.info(f"{self.name} create index and load")
94
+ try:
95
+ coll.create_index(
96
+ self._vector_field,
97
+ self.case_config.index_param(),
98
+ index_name=self._index_name,
99
+ )
100
+
101
+ coll.load()
102
+ except Exception as e:
103
+ log.warning(f"{self.name} pre load error: {e}")
104
+ raise e from None
105
+
106
+ def _optimize(self):
107
+ log.info(f"{self.name} optimizing before search")
108
+ try:
109
+ self.col.flush()
110
+ self.col.compact()
111
+ self.col.wait_for_compaction_completed()
112
+
113
+ # wait for index done and load refresh
114
+ self.col.create_index(
115
+ self._vector_field,
116
+ self.case_config.index_param(),
117
+ index_name=self._index_name,
118
+ )
119
+ utility.wait_for_index_building_complete(self.collection_name)
120
+ self.col.load()
121
+ # self.col.load(_refresh=True)
122
+ # utility.wait_for_loading_complete(self.collection_name)
123
+ # import time; time.sleep(10)
124
+ except Exception as e:
125
+ log.warning(f"{self.name} optimize error: {e}")
126
+ raise e from None
127
+
128
+ def ready_to_load(self):
129
+ assert self.col, "Please call self.init() before"
130
+ self._pre_load(self.col)
131
+ pass
132
+
133
+ def ready_to_search(self):
134
+ assert self.col, "Please call self.init() before"
135
+ self._optimize()
136
+
137
+ def insert_embeddings(
138
+ self,
139
+ embeddings: Iterable[list[float]],
140
+ metadata: list[int],
141
+ **kwargs: Any,
142
+ ) -> int:
143
+ """Insert embeddings into Milvus. should call self.init() first"""
144
+ # use the first insert_embeddings to init collection
145
+ assert self.col is not None
146
+ insert_data = [
147
+ metadata,
148
+ metadata,
149
+ embeddings,
150
+ ]
151
+
152
+ try:
153
+ res = self.col.insert(insert_data, **kwargs)
154
+ return len(res.primary_keys)
155
+ except MilvusException as e:
156
+ log.warning("Failed to insert data")
157
+ raise e from None
158
+
159
+ def search_embedding(
160
+ self,
161
+ query: list[float],
162
+ k: int = 100,
163
+ filters: dict | None = None,
164
+ timeout: int | None = None,
165
+ ) -> list[int]:
166
+ """Perform a search on a query embedding and return results."""
167
+ assert self.col is not None
168
+
169
+ expr = f"{self._scalar_field} {filters.get('metadata')}" if filters else ""
170
+
171
+ # Perform the search.
172
+ res = self.col.search(
173
+ data=[query],
174
+ anns_field=self._vector_field,
175
+ param=self.case_config.search_param(),
176
+ limit=k,
177
+ expr=expr,
178
+ )
179
+
180
+ # Organize results.
181
+ ret = [result.id for result in res[0]]
182
+ return ret
@@ -0,0 +1,15 @@
1
+ from pydantic import BaseModel, SecretStr
2
+ from ..api import DBConfig
3
+
4
+
5
+ class PineconeConfig(DBConfig, BaseModel):
6
+ api_key: SecretStr | None = None
7
+ environment: SecretStr | None = None
8
+ index_name: str
9
+
10
+ def to_dict(self) -> dict:
11
+ return {
12
+ "api_key": self.api_key.get_secret_value(),
13
+ "environment": self.environment.get_secret_value(),
14
+ "index_name": self.index_name,
15
+ }
@@ -0,0 +1,113 @@
1
+ """Wrapper around the Pinecone vector database over VectorDB"""
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Any, Type
6
+
7
+ from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
+ from .config import PineconeConfig
9
+
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+ PINECONE_MAX_NUM_PER_BATCH = 1000
14
+ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
15
+
16
+ class Pinecone(VectorDB):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ db_config: dict,
21
+ db_case_config: DBCaseConfig,
22
+ drop_old: bool = False,
23
+ ):
24
+ """Initialize wrapper around the milvus vector database."""
25
+ self.index_name = db_config["index_name"]
26
+ self.api_key = db_config["api_key"]
27
+ self.environment = db_config["environment"]
28
+ self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
29
+ # Pincone will make connections with server while import
30
+ # so place the import here.
31
+ import pinecone
32
+ pinecone.init(
33
+ api_key=self.api_key, environment=self.environment)
34
+ if drop_old:
35
+ list_indexes = pinecone.list_indexes()
36
+ if self.index_name in list_indexes:
37
+ index = pinecone.Index(self.index_name)
38
+ index_dim = index.describe_index_stats()["dimension"]
39
+ if (index_dim != dim):
40
+ raise ValueError(
41
+ f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
42
+ log.info(
43
+ f"Pinecone client delete old index: {self.index_name}")
44
+ index.delete(delete_all=True)
45
+ index.close()
46
+ else:
47
+ raise ValueError(
48
+ f"Pinecone index {self.index_name} does not exist")
49
+
50
+ self._metadata_key = "meta"
51
+
52
+ @classmethod
53
+ def config_cls(cls) -> Type[DBConfig]:
54
+ return PineconeConfig
55
+
56
+ @classmethod
57
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
58
+ return EmptyDBCaseConfig
59
+
60
+ @contextmanager
61
+ def init(self) -> None:
62
+ import pinecone
63
+ pinecone.init(
64
+ api_key=self.api_key, environment=self.environment)
65
+ self.index = pinecone.Index(self.index_name)
66
+ yield
67
+ self.index.close()
68
+
69
+ def ready_to_load(self):
70
+ pass
71
+
72
+ def ready_to_search(self):
73
+ pass
74
+
75
+ def insert_embeddings(
76
+ self,
77
+ embeddings: list[list[float]],
78
+ metadata: list[int],
79
+ ) -> list[str]:
80
+ assert len(embeddings) == len(metadata)
81
+ for batch_start_offset in range(0, len(embeddings), self.batch_size):
82
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
83
+ insert_datas = []
84
+ for i in range(batch_start_offset, batch_end_offset):
85
+ insert_data = (str(metadata[i]), embeddings[i], {
86
+ self._metadata_key: metadata[i]})
87
+ insert_datas.append(insert_data)
88
+ self.index.upsert(insert_datas)
89
+ return len(embeddings)
90
+
91
+ def search_embedding(
92
+ self,
93
+ query: list[float],
94
+ k: int = 100,
95
+ filters: dict | None = None,
96
+ timeout: int | None = None,
97
+ **kwargs: Any,
98
+ ) -> list[tuple[int, float]]:
99
+ if filters is None:
100
+ pinecone_filters = {}
101
+ else:
102
+ pinecone_filters = {self._metadata_key: {"$gte": filters["id"]}}
103
+ try:
104
+ res = self.index.query(
105
+ top_k=k,
106
+ vector=query,
107
+ filter=pinecone_filters,
108
+ )['matches']
109
+ except Exception as e:
110
+ print(f"Error querying index: {e}")
111
+ raise e
112
+ id_res = [int(one_res['id']) for one_res in res]
113
+ return id_res
@@ -0,0 +1,16 @@
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBConfig
4
+
5
+
6
+ class QdrantConfig(DBConfig, BaseModel):
7
+ url: SecretStr | None = None
8
+ api_key: SecretStr | None = None
9
+ prefer_grpc: bool = True
10
+
11
+ def to_dict(self) -> dict:
12
+ return {
13
+ "url": self.url.get_secret_value(),
14
+ "api_key": self.api_key.get_secret_value(),
15
+ "prefer_grpc": self.prefer_grpc,
16
+ }
@@ -0,0 +1,169 @@
1
+ """Wrapper around the QdrantCloud vector database over VectorDB"""
2
+
3
+ import logging
4
+ import time
5
+ from contextlib import contextmanager
6
+ from typing import Any, Type
7
+
8
+ from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
9
+ from .config import QdrantConfig
10
+ from qdrant_client.http.models import (
11
+ CollectionStatus,
12
+ Distance,
13
+ VectorParams,
14
+ PayloadSchemaType,
15
+ Batch,
16
+ Filter,
17
+ FieldCondition,
18
+ Range,
19
+ )
20
+
21
+ from qdrant_client import QdrantClient
22
+
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ class QdrantCloud(VectorDB):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ db_config: dict,
32
+ db_case_config: DBCaseConfig,
33
+ collection_name: str = "QdrantCloudCollection",
34
+ drop_old: bool = False,
35
+ ):
36
+ """Initialize wrapper around the QdrantCloud vector database."""
37
+ self.db_config = db_config
38
+ self.case_config = db_case_config
39
+ self.collection_name = collection_name
40
+
41
+ self._primary_field = "pk"
42
+ self._vector_field = "vector"
43
+
44
+ tmp_client = QdrantClient(**self.db_config)
45
+ if drop_old:
46
+ log.info(f"QdrantCloud client drop_old collection: {self.collection_name}")
47
+ tmp_client.delete_collection(self.collection_name)
48
+
49
+ self._create_collection(dim, tmp_client)
50
+ tmp_client = None
51
+
52
+ @classmethod
53
+ def config_cls(cls) -> Type[DBConfig]:
54
+ return QdrantConfig
55
+
56
+ @classmethod
57
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
58
+ return EmptyDBCaseConfig
59
+
60
+ @contextmanager
61
+ def init(self) -> None:
62
+ """
63
+ Examples:
64
+ >>> with self.init():
65
+ >>> self.insert_embeddings()
66
+ >>> self.search_embedding()
67
+ """
68
+ self.qdrant_client = QdrantClient(**self.db_config)
69
+ yield
70
+ self.qdrant_client = None
71
+ del(self.qdrant_client)
72
+
73
+ def ready_to_load(self):
74
+ pass
75
+
76
+
77
+ def ready_to_search(self):
78
+ assert self.qdrant_client, "Please call self.init() before"
79
+ # wait for vectors to be fully indexed
80
+ SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
81
+ try:
82
+ while True:
83
+ info = self.qdrant_client.get_collection(self.collection_name)
84
+ time.sleep(SECONDS_WAITING_FOR_INDEXING_API_CALL)
85
+ if info.status != CollectionStatus.GREEN:
86
+ continue
87
+ if info.status == CollectionStatus.GREEN:
88
+ log.info(f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, Collection status: {info.indexed_vectors_count}")
89
+ return
90
+ except Exception as e:
91
+ log.warning(f"QdrantCloud ready to search error: {e}")
92
+ raise e from None
93
+
94
+ def _create_collection(self, dim, qdrant_client: int):
95
+ log.info(f"Create collection: {self.collection_name}")
96
+
97
+ try:
98
+ qdrant_client.create_collection(
99
+ collection_name=self.collection_name,
100
+ vectors_config=VectorParams(size=dim, distance=Distance.EUCLID)
101
+ )
102
+
103
+ qdrant_client.create_payload_index(
104
+ collection_name=self.collection_name,
105
+ field_name=self._primary_field,
106
+ field_schema=PayloadSchemaType.INTEGER,
107
+ )
108
+
109
+ except Exception as e:
110
+ if "already exists!" in str(e):
111
+ return
112
+ log.warning(f"Failed to create collection: {self.collection_name} error: {e}")
113
+ raise e from None
114
+
115
+ def insert_embeddings(
116
+ self,
117
+ embeddings: list[list[float]],
118
+ metadata: list[int],
119
+ **kwargs: Any,
120
+ ) -> list[str]:
121
+ """Insert embeddings into Milvus. should call self.init() first"""
122
+ assert self.qdrant_client is not None
123
+ try:
124
+ # TODO: counts
125
+ _ = self.qdrant_client.upsert(
126
+ collection_name=self.collection_name,
127
+ wait=True,
128
+ points=Batch(ids=metadata, payloads=[{self._primary_field: v} for v in metadata], vectors=embeddings)
129
+ )
130
+
131
+ return len(metadata)
132
+ except Exception as e:
133
+ log.info(f"Failed to insert data, {e}")
134
+ raise e from None
135
+
136
+ def search_embedding(
137
+ self,
138
+ query: list[float],
139
+ k: int = 100,
140
+ filters: dict | None = None,
141
+ timeout: int | None = None,
142
+ **kwargs: Any,
143
+ ) -> list[int]:
144
+ """Perform a search on a query embedding and return results with score.
145
+ Should call self.init() first.
146
+ """
147
+ assert self.qdrant_client is not None
148
+
149
+ f = None
150
+ if filters:
151
+ f = Filter(
152
+ must=[FieldCondition(
153
+ key = self._primary_field,
154
+ range = Range(
155
+ gt=filters.get('id'),
156
+ ),
157
+ )]
158
+ )
159
+
160
+ res = self.qdrant_client.search(
161
+ collection_name=self.collection_name,
162
+ query_vector=query,
163
+ limit=k,
164
+ query_filter=f,
165
+ # with_payload=True,
166
+ ),
167
+
168
+ ret = [result.id for result in res[0]]
169
+ return ret
@@ -0,0 +1,45 @@
1
+ from pydantic import BaseModel, SecretStr
2
+ import weaviate
3
+
4
+ from ..api import DBConfig, DBCaseConfig, MetricType
5
+
6
+
7
+ class WeaviateConfig(DBConfig, BaseModel):
8
+ url: SecretStr | None = None
9
+ api_key: SecretStr | None = None
10
+
11
+ def to_dict(self) -> dict:
12
+ return {
13
+ "url": self.url.get_secret_value(),
14
+ "auth_client_secret": weaviate.AuthApiKey(api_key=self.api_key.get_secret_value()),
15
+ }
16
+
17
+
18
+ class WeaviateIndexConfig(BaseModel, DBCaseConfig):
19
+ metric_type: MetricType | None = None
20
+ ef: int | None = -1
21
+ efConstruction: int | None = None
22
+ maxConnections: int | None = None
23
+
24
+ def parse_metric(self) -> str:
25
+ if self.metric_type == MetricType.L2:
26
+ return "l2-squared"
27
+ elif self.metric_type == MetricType.IP:
28
+ return "dot"
29
+ return "cosine"
30
+
31
+ def index_param(self) -> dict:
32
+ if self.maxConnections is not None and self.efConstruction is not None:
33
+ params = {
34
+ "distance": self.parse_metric(),
35
+ "maxConnections": self.maxConnections,
36
+ "efConstruction": self.efConstruction,
37
+ }
38
+ else:
39
+ params = {"distance": self.parse_metric()}
40
+ return params
41
+
42
+ def search_param(self) -> dict:
43
+ return {
44
+ "ef": self.ef,
45
+ }