vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +56 -46
  5. vectordb_bench/backend/clients/__init__.py +101 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +52 -35
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +8 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +62 -80
  26. vectordb_bench/backend/clients/milvus/config.py +31 -7
  27. vectordb_bench/backend/clients/milvus/milvus.py +23 -26
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +51 -23
  62. vectordb_bench/backend/runner/read_write_runner.py +140 -46
  63. vectordb_bench/backend/runner/serial_runner.py +99 -50
  64. vectordb_bench/backend/runner/util.py +4 -19
  65. vectordb_bench/backend/task_runner.py +95 -74
  66. vectordb_bench/backend/utils.py +17 -9
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.18.dist-info/RECORD +0 -131
  103. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
1
1
  from abc import ABC, abstractmethod
2
- from enum import Enum
3
- from typing import Any, Type
4
2
  from contextlib import contextmanager
3
+ from enum import Enum
5
4
 
6
- from pydantic import BaseModel, validator, SecretStr
5
+ from pydantic import BaseModel, SecretStr, validator
7
6
 
8
7
 
9
8
  class MetricType(str, Enum):
@@ -65,13 +64,10 @@ class DBConfig(ABC, BaseModel):
65
64
  raise NotImplementedError
66
65
 
67
66
  @validator("*")
68
- def not_empty_field(cls, v, field):
69
- if (
70
- field.name in cls.common_short_configs()
71
- or field.name in cls.common_long_configs()
72
- ):
67
+ def not_empty_field(cls, v: any, field: any):
68
+ if field.name in cls.common_short_configs() or field.name in cls.common_long_configs():
73
69
  return v
74
- if not v and isinstance(v, (str, SecretStr)):
70
+ if not v and isinstance(v, str | SecretStr):
75
71
  raise ValueError("Empty string!")
76
72
  return v
77
73
 
@@ -204,6 +200,9 @@ class VectorDB(ABC):
204
200
  """
205
201
  raise NotImplementedError
206
202
 
203
+ def optimize_with_size(self, data_size: int):
204
+ self.optimize()
205
+
207
206
  # TODO: remove
208
207
  @abstractmethod
209
208
  def ready_to_load(self):
@@ -1,14 +1,18 @@
1
1
  import logging
2
- from contextlib import contextmanager
3
2
  import time
4
- from typing import Iterable, Type
5
- from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
6
- from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
3
+ from collections.abc import Iterable
4
+ from contextlib import contextmanager
5
+
7
6
  from opensearchpy import OpenSearch
8
- from opensearchpy.helpers import bulk
7
+
8
+ from ..api import IndexType, VectorDB
9
+ from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
9
10
 
10
11
  log = logging.getLogger(__name__)
11
12
 
13
+ WAITING_FOR_REFRESH_SEC = 30
14
+ WAITING_FOR_FORCE_MERGE_SEC = 30
15
+
12
16
 
13
17
  class AWSOpenSearch(VectorDB):
14
18
  def __init__(
@@ -17,7 +21,7 @@ class AWSOpenSearch(VectorDB):
17
21
  db_config: dict,
18
22
  db_case_config: AWSOpenSearchIndexConfig,
19
23
  index_name: str = "vdb_bench_index", # must be lowercase
20
- id_col_name: str = "id",
24
+ id_col_name: str = "_id",
21
25
  vector_col_name: str = "embedding",
22
26
  drop_old: bool = False,
23
27
  **kwargs,
@@ -27,9 +31,7 @@ class AWSOpenSearch(VectorDB):
27
31
  self.case_config = db_case_config
28
32
  self.index_name = index_name
29
33
  self.id_col_name = id_col_name
30
- self.category_col_names = [
31
- f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
32
- ]
34
+ self.category_col_names = [f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]]
33
35
  self.vector_col_name = vector_col_name
34
36
 
35
37
  log.info(f"AWS_OpenSearch client config: {self.db_config}")
@@ -46,39 +48,32 @@ class AWSOpenSearch(VectorDB):
46
48
  return AWSOpenSearchConfig
47
49
 
48
50
  @classmethod
49
- def case_config_cls(
50
- cls, index_type: IndexType | None = None
51
- ) -> AWSOpenSearchIndexConfig:
51
+ def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig:
52
52
  return AWSOpenSearchIndexConfig
53
53
 
54
54
  def _create_index(self, client: OpenSearch):
55
55
  settings = {
56
56
  "index": {
57
57
  "knn": True,
58
- # "number_of_shards": 5,
59
- # "refresh_interval": "600s",
60
- }
58
+ },
61
59
  }
62
60
  mappings = {
63
61
  "properties": {
64
- self.id_col_name: {"type": "integer"},
65
- **{
66
- categoryCol: {"type": "keyword"}
67
- for categoryCol in self.category_col_names
68
- },
62
+ **{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names},
69
63
  self.vector_col_name: {
70
64
  "type": "knn_vector",
71
65
  "dimension": self.dim,
72
66
  "method": self.case_config.index_param(),
73
67
  },
74
- }
68
+ },
75
69
  }
76
70
  try:
77
71
  client.indices.create(
78
- index=self.index_name, body=dict(settings=settings, mappings=mappings)
72
+ index=self.index_name,
73
+ body={"settings": settings, "mappings": mappings},
79
74
  )
80
75
  except Exception as e:
81
- log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
76
+ log.warning(f"Failed to create index: {self.index_name} error: {e!s}")
82
77
  raise e from None
83
78
 
84
79
  @contextmanager
@@ -87,7 +82,6 @@ class AWSOpenSearch(VectorDB):
87
82
  self.client = OpenSearch(**self.db_config)
88
83
 
89
84
  yield
90
- # self.client.transport.close()
91
85
  self.client = None
92
86
  del self.client
93
87
 
@@ -102,16 +96,20 @@ class AWSOpenSearch(VectorDB):
102
96
 
103
97
  insert_data = []
104
98
  for i in range(len(embeddings)):
105
- insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
99
+ insert_data.append(
100
+ {"index": {"_index": self.index_name, self.id_col_name: metadata[i]}},
101
+ )
106
102
  insert_data.append({self.vector_col_name: embeddings[i]})
107
103
  try:
108
104
  resp = self.client.bulk(insert_data)
109
105
  log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
110
106
  resp = self.client.indices.stats(self.index_name)
111
- log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
107
+ log.info(
108
+ f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}",
109
+ )
112
110
  return (len(embeddings), None)
113
111
  except Exception as e:
114
- log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
112
+ log.warning(f"Failed to insert data: {self.index_name} error: {e!s}")
115
113
  time.sleep(10)
116
114
  return self.insert_embeddings(embeddings, metadata)
117
115
 
@@ -136,20 +134,23 @@ class AWSOpenSearch(VectorDB):
136
134
  body = {
137
135
  "size": k,
138
136
  "query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
139
- **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {})
137
+ **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}),
140
138
  }
141
139
  try:
142
- resp = self.client.search(index=self.index_name, body=body,size=k,_source=False,docvalue_fields=[self.id_col_name],stored_fields="_none_",filter_path=[f"hits.hits.fields.{self.id_col_name}"],)
140
+ resp = self.client.search(
141
+ index=self.index_name,
142
+ body=body,
143
+ size=k,
144
+ _source=False,
145
+ docvalue_fields=[self.id_col_name],
146
+ stored_fields="_none_",
147
+ )
143
148
  log.info(f'Search took: {resp["took"]}')
144
149
  log.info(f'Search shards: {resp["_shards"]}')
145
150
  log.info(f'Search hits total: {resp["hits"]["total"]}')
146
- result = [h["fields"][self.id_col_name][0] for h in resp["hits"]["hits"]]
147
- #result = [int(d["_id"]) for d in resp["hits"]["hits"]]
148
- # log.info(f'success! length={len(res)}')
149
-
150
- return result
151
+ return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
151
152
  except Exception as e:
152
- log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
153
+ log.warning(f"Failed to search: {self.index_name} error: {e!s}")
153
154
  raise e from None
154
155
 
155
156
  def optimize(self):
@@ -164,37 +165,35 @@ class AWSOpenSearch(VectorDB):
164
165
 
165
166
  def _refresh_index(self):
166
167
  log.debug(f"Starting refresh for index {self.index_name}")
167
- SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
168
168
  while True:
169
169
  try:
170
- log.info(f"Starting the Refresh Index..")
170
+ log.info("Starting the Refresh Index..")
171
171
  self.client.indices.refresh(index=self.index_name)
172
172
  break
173
173
  except Exception as e:
174
174
  log.info(
175
- f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
176
- time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
175
+ f"Refresh errored out. Sleeping for {WAITING_FOR_REFRESH_SEC} sec and then Retrying : {e}",
176
+ )
177
+ time.sleep(WAITING_FOR_REFRESH_SEC)
177
178
  continue
178
179
  log.debug(f"Completed refresh for index {self.index_name}")
179
180
 
180
181
  def _do_force_merge(self):
181
182
  log.debug(f"Starting force merge for index {self.index_name}")
182
- force_merge_endpoint = f'/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
183
- force_merge_task_id = self.client.transport.perform_request('POST', force_merge_endpoint)['task']
184
- SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
183
+ force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
184
+ force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
185
185
  while True:
186
- time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
186
+ time.sleep(WAITING_FOR_FORCE_MERGE_SEC)
187
187
  task_status = self.client.tasks.get(task_id=force_merge_task_id)
188
- if task_status['completed']:
188
+ if task_status["completed"]:
189
189
  break
190
190
  log.debug(f"Completed force merge for index {self.index_name}")
191
191
 
192
192
  def _load_graphs_to_memory(self):
193
193
  if self.case_config.engine != AWSOS_Engine.lucene:
194
194
  log.info("Calling warmup API to load graphs into memory")
195
- warmup_endpoint = f'/_plugins/_knn/warmup/{self.index_name}'
196
- self.client.transport.perform_request('GET', warmup_endpoint)
195
+ warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}"
196
+ self.client.transport.perform_request("GET", warmup_endpoint)
197
197
 
198
198
  def ready_to_load(self):
199
199
  """ready_to_load will be called before load in load cases."""
200
- pass
@@ -14,22 +14,20 @@ from .. import DB
14
14
 
15
15
 
16
16
  class AWSOpenSearchTypedDict(TypedDict):
17
- host: Annotated[
18
- str, click.option("--host", type=str, help="Db host", required=True)
19
- ]
17
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
20
18
  port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
21
19
  user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
22
20
  password: Annotated[str, click.option("--password", type=str, help="Db password")]
23
21
 
24
22
 
25
- class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
26
- ...
23
+ class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
27
24
 
28
25
 
29
26
  @cli.command()
30
27
  @click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
31
28
  def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
32
29
  from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
30
+
33
31
  run(
34
32
  db=DB.AWSOpenSearch,
35
33
  db_config=AWSOpenSearchConfig(
@@ -38,7 +36,6 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
38
36
  user=parameters["user"],
39
37
  password=SecretStr(parameters["password"]),
40
38
  ),
41
- db_case_config=AWSOpenSearchIndexConfig(
42
- ),
39
+ db_case_config=AWSOpenSearchIndexConfig(),
43
40
  **parameters,
44
41
  )
@@ -1,10 +1,13 @@
1
1
  import logging
2
2
  from enum import Enum
3
- from pydantic import SecretStr, BaseModel
4
3
 
5
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
4
+ from pydantic import BaseModel, SecretStr
5
+
6
+ from ..api import DBCaseConfig, DBConfig, MetricType
6
7
 
7
8
  log = logging.getLogger(__name__)
9
+
10
+
8
11
  class AWSOpenSearchConfig(DBConfig, BaseModel):
9
12
  host: str = ""
10
13
  port: int = 443
@@ -13,7 +16,7 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
13
16
 
14
17
  def to_dict(self) -> dict:
15
18
  return {
16
- "hosts": [{'host': self.host, 'port': self.port}],
19
+ "hosts": [{"host": self.host, "port": self.port}],
17
20
  "http_auth": (self.user, self.password.get_secret_value()),
18
21
  "use_ssl": True,
19
22
  "http_compress": True,
@@ -40,25 +43,26 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
40
43
  def parse_metric(self) -> str:
41
44
  if self.metric_type == MetricType.IP:
42
45
  return "innerproduct"
43
- elif self.metric_type == MetricType.COSINE:
46
+ if self.metric_type == MetricType.COSINE:
44
47
  if self.engine == AWSOS_Engine.faiss:
45
- log.info(f"Using metric type as innerproduct because faiss doesn't support cosine as metric type for Opensearch")
48
+ log.info(
49
+ "Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
50
+ )
46
51
  return "innerproduct"
47
52
  return "cosinesimil"
48
53
  return "l2"
49
54
 
50
55
  def index_param(self) -> dict:
51
- params = {
56
+ return {
52
57
  "name": "hnsw",
53
58
  "space_type": self.parse_metric(),
54
59
  "engine": self.engine.value,
55
60
  "parameters": {
56
61
  "ef_construction": self.efConstruction,
57
62
  "m": self.M,
58
- "ef_search": self.efSearch
59
- }
63
+ "ef_search": self.efSearch,
64
+ },
60
65
  }
61
- return params
62
66
 
63
67
  def search_param(self) -> dict:
64
68
  return {}
@@ -1,12 +1,16 @@
1
- import time, random
1
+ import logging
2
+ import random
3
+ import time
4
+
2
5
  from opensearchpy import OpenSearch
3
- from opensearch_dsl import Search, Document, Text, Keyword
4
6
 
5
- _HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
7
+ log = logging.getLogger(__name__)
8
+
9
+ _HOST = "xxxxxx.us-west-2.es.amazonaws.com"
6
10
  _PORT = 443
7
- _AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.
11
+ _AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
8
12
 
9
- _INDEX_NAME = 'my-dsl-index'
13
+ _INDEX_NAME = "my-dsl-index"
10
14
  _BATCH = 100
11
15
  _ROWS = 100
12
16
  _DIM = 128
@@ -14,25 +18,24 @@ _TOPK = 10
14
18
 
15
19
 
16
20
  def create_client():
17
- client = OpenSearch(
18
- hosts=[{'host': _HOST, 'port': _PORT}],
19
- http_compress=True, # enables gzip compression for request bodies
21
+ return OpenSearch(
22
+ hosts=[{"host": _HOST, "port": _PORT}],
23
+ http_compress=True, # enables gzip compression for request bodies
20
24
  http_auth=_AUTH,
21
25
  use_ssl=True,
22
26
  verify_certs=True,
23
27
  ssl_assert_hostname=False,
24
28
  ssl_show_warn=False,
25
29
  )
26
- return client
27
30
 
28
31
 
29
- def create_index(client, index_name):
32
+ def create_index(client: OpenSearch, index_name: str):
30
33
  settings = {
31
34
  "index": {
32
35
  "knn": True,
33
36
  "number_of_shards": 1,
34
37
  "refresh_interval": "5s",
35
- }
38
+ },
36
39
  }
37
40
  mappings = {
38
41
  "properties": {
@@ -46,41 +49,46 @@ def create_index(client, index_name):
46
49
  "parameters": {
47
50
  "ef_construction": 256,
48
51
  "m": 16,
49
- }
50
- }
51
- }
52
- }
52
+ },
53
+ },
54
+ },
55
+ },
53
56
  }
54
57
 
55
- response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
56
- print('\nCreating index:')
57
- print(response)
58
+ response = client.indices.create(
59
+ index=index_name,
60
+ body={"settings": settings, "mappings": mappings},
61
+ )
62
+ log.info("\nCreating index:")
63
+ log.info(response)
58
64
 
59
65
 
60
- def delete_index(client, index_name):
66
+ def delete_index(client: OpenSearch, index_name: str):
61
67
  response = client.indices.delete(index=index_name)
62
- print('\nDeleting index:')
63
- print(response)
68
+ log.info("\nDeleting index:")
69
+ log.info(response)
64
70
 
65
71
 
66
- def bulk_insert(client, index_name):
72
+ def bulk_insert(client: OpenSearch, index_name: str):
67
73
  # Perform bulk operations
68
- ids = [i for i in range(_ROWS)]
74
+ ids = list(range(_ROWS))
69
75
  vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
70
76
 
71
77
  docs = []
72
78
  for i in range(0, _ROWS, _BATCH):
73
79
  docs.clear()
74
- for j in range(0, _BATCH):
75
- docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
76
- docs.append({"embedding": vec[i+j]})
80
+ for j in range(_BATCH):
81
+ docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
82
+ docs.append({"embedding": vec[i + j]})
77
83
  response = client.bulk(docs)
78
- print('\nAdding documents:', len(response['items']), response['errors'])
84
+ log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
79
85
  response = client.indices.stats(index_name)
80
- print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])
86
+ log.info(
87
+ f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
88
+ )
81
89
 
82
90
 
83
- def search(client, index_name):
91
+ def search(client: OpenSearch, index_name: str):
84
92
  # Search for the document.
85
93
  search_body = {
86
94
  "size": _TOPK,
@@ -89,53 +97,55 @@ def search(client, index_name):
89
97
  "embedding": {
90
98
  "vector": [random.random() for _ in range(_DIM)],
91
99
  "k": _TOPK,
92
- }
93
- }
94
- }
100
+ },
101
+ },
102
+ },
95
103
  }
96
104
  while True:
97
105
  response = client.search(index=index_name, body=search_body)
98
- print(f'\nSearch took: {response["took"]}')
99
- print(f'\nSearch shards: {response["_shards"]}')
100
- print(f'\nSearch hits total: {response["hits"]["total"]}')
106
+ log.info(f'\nSearch took: {response["took"]}')
107
+ log.info(f'\nSearch shards: {response["_shards"]}')
108
+ log.info(f'\nSearch hits total: {response["hits"]["total"]}')
101
109
  result = response["hits"]["hits"]
102
110
  if len(result) != 0:
103
- print('\nSearch results:')
111
+ log.info("\nSearch results:")
104
112
  for hit in response["hits"]["hits"]:
105
- print(hit["_id"], hit["_score"])
113
+ log.info(hit["_id"], hit["_score"])
106
114
  break
107
- else:
108
- print('\nSearch not ready, sleep 1s')
109
- time.sleep(1)
110
-
111
- def optimize_index(client, index_name):
112
- print(f"Starting force merge for index {index_name}")
113
- force_merge_endpoint = f'/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
114
- force_merge_task_id = client.transport.perform_request('POST', force_merge_endpoint)['task']
115
- SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
115
+ log.info("\nSearch not ready, sleep 1s")
116
+ time.sleep(1)
117
+
118
+
119
+ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
120
+ WAITINT_FOR_REFRESH_SEC = 30
121
+
122
+
123
+ def optimize_index(client: OpenSearch, index_name: str):
124
+ log.info(f"Starting force merge for index {index_name}")
125
+ force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
126
+ force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
116
127
  while True:
117
128
  time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
118
129
  task_status = client.tasks.get(task_id=force_merge_task_id)
119
- if task_status['completed']:
130
+ if task_status["completed"]:
120
131
  break
121
- print(f"Completed force merge for index {index_name}")
132
+ log.info(f"Completed force merge for index {index_name}")
122
133
 
123
134
 
124
- def refresh_index(client, index_name):
125
- print(f"Starting refresh for index {index_name}")
126
- SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
135
+ def refresh_index(client: OpenSearch, index_name: str):
136
+ log.info(f"Starting refresh for index {index_name}")
127
137
  while True:
128
138
  try:
129
- print(f"Starting the Refresh Index..")
139
+ log.info("Starting the Refresh Index..")
130
140
  client.indices.refresh(index=index_name)
131
141
  break
132
142
  except Exception as e:
133
- print(
134
- f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
135
- time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
143
+ log.info(
144
+ f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
145
+ )
146
+ time.sleep(WAITINT_FOR_REFRESH_SEC)
136
147
  continue
137
- print(f"Completed refresh for index {index_name}")
138
-
148
+ log.info(f"Completed refresh for index {index_name}")
139
149
 
140
150
 
141
151
  def main():
@@ -148,9 +158,9 @@ def main():
148
158
  search(client, _INDEX_NAME)
149
159
  delete_index(client, _INDEX_NAME)
150
160
  except Exception as e:
151
- print(e)
161
+ log.info(e)
152
162
  delete_index(client, _INDEX_NAME)
153
163
 
154
164
 
155
- if __name__ == '__main__':
165
+ if __name__ == "__main__":
156
166
  main()
@@ -1,55 +1,55 @@
1
- import chromadb
2
- import logging
1
+ import logging
3
2
  from contextlib import contextmanager
4
3
  from typing import Any
5
- from ..api import VectorDB, DBCaseConfig
4
+
5
+ import chromadb
6
+
7
+ from ..api import DBCaseConfig, VectorDB
6
8
 
7
9
  log = logging.getLogger(__name__)
10
+
11
+
8
12
  class ChromaClient(VectorDB):
9
- """Chroma client for VectorDB.
13
+ """Chroma client for VectorDB.
10
14
  To set up Chroma in docker, see https://docs.trychroma.com/usage-guide
11
15
  or the instructions in tests/test_chroma.py
12
16
 
13
17
  To change to running in process, modify the HttpClient() in __init__() and init().
14
- """
18
+ """
15
19
 
16
20
  def __init__(
17
- self,
18
- dim: int,
19
- db_config: dict,
20
- db_case_config: DBCaseConfig,
21
- drop_old: bool = False,
22
-
23
- **kwargs
24
- ):
25
-
21
+ self,
22
+ dim: int,
23
+ db_config: dict,
24
+ db_case_config: DBCaseConfig,
25
+ drop_old: bool = False,
26
+ **kwargs,
27
+ ):
26
28
  self.db_config = db_config
27
29
  self.case_config = db_case_config
28
- self.collection_name = 'example2'
30
+ self.collection_name = "example2"
29
31
 
30
- client = chromadb.HttpClient(host=self.db_config["host"],
31
- port=self.db_config["port"])
32
+ client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
32
33
  assert client.heartbeat() is not None
33
34
  if drop_old:
34
35
  try:
35
- client.reset() # Reset the database
36
- except:
36
+ client.reset() # Reset the database
37
+ except Exception:
37
38
  drop_old = False
38
39
  log.info(f"Chroma client drop_old collection: {self.collection_name}")
39
40
 
40
41
  @contextmanager
41
42
  def init(self) -> None:
42
- """ create and destory connections to database.
43
+ """create and destory connections to database.
43
44
 
44
45
  Examples:
45
46
  >>> with self.init():
46
47
  >>> self.insert_embeddings()
47
48
  """
48
- #create connection
49
- self.client = chromadb.HttpClient(host=self.db_config["host"],
50
- port=self.db_config["port"])
51
-
52
- self.collection = self.client.get_or_create_collection('example2')
49
+ # create connection
50
+ self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
51
+
52
+ self.collection = self.client.get_or_create_collection("example2")
53
53
  yield
54
54
  self.client = None
55
55
  self.collection = None
@@ -79,12 +79,12 @@ class ChromaClient(VectorDB):
79
79
  Returns:
80
80
  (int, Exception): number of embeddings inserted and exception if any
81
81
  """
82
- ids=[str(i) for i in metadata]
83
- metadata = [{"id": int(i)} for i in metadata]
82
+ ids = [str(i) for i in metadata]
83
+ metadata = [{"id": int(i)} for i in metadata]
84
84
  if len(embeddings) > 0:
85
85
  self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
86
86
  return len(embeddings), None
87
-
87
+
88
88
  def search_embedding(
89
89
  self,
90
90
  query: list[float],
@@ -100,17 +100,19 @@ class ChromaClient(VectorDB):
100
100
  kwargs: other arguments
101
101
 
102
102
  Returns:
103
- Dict {ids: list[list[int]],
104
- embedding: list[list[float]]
103
+ Dict {ids: list[list[int]],
104
+ embedding: list[list[float]]
105
105
  distance: list[list[float]]}
106
106
  """
107
107
  if filters:
108
108
  # assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
109
109
  id_value = filters.get("id")
110
- results = self.collection.query(query_embeddings=query, n_results=k,
111
- where={"id": {"$gt": id_value}})
112
- #return list of id's in results
113
- return [int(i) for i in results.get('ids')[0]]
110
+ results = self.collection.query(
111
+ query_embeddings=query,
112
+ n_results=k,
113
+ where={"id": {"$gt": id_value}},
114
+ )
115
+ # return list of id's in results
116
+ return [int(i) for i in results.get("ids")[0]]
114
117
  results = self.collection.query(query_embeddings=query, n_results=k)
115
- return [int(i) for i in results.get('ids')[0]]
116
-
118
+ return [int(i) for i in results.get("ids")[0]]