veadk-python 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of veadk-python might be problematic. Click here for more details.

Files changed (75) hide show
  1. veadk/agent.py +3 -2
  2. veadk/auth/veauth/opensearch_veauth.py +75 -0
  3. veadk/auth/veauth/postgresql_veauth.py +75 -0
  4. veadk/cli/cli.py +3 -1
  5. veadk/cli/cli_eval.py +160 -0
  6. veadk/cli/cli_prompt.py +9 -2
  7. veadk/cli/cli_web.py +6 -1
  8. veadk/configs/database_configs.py +43 -0
  9. veadk/configs/model_configs.py +32 -0
  10. veadk/consts.py +11 -4
  11. veadk/evaluation/adk_evaluator/adk_evaluator.py +5 -2
  12. veadk/evaluation/base_evaluator.py +95 -68
  13. veadk/evaluation/deepeval_evaluator/deepeval_evaluator.py +23 -15
  14. veadk/evaluation/eval_set_recorder.py +2 -2
  15. veadk/integrations/ve_prompt_pilot/ve_prompt_pilot.py +9 -3
  16. veadk/integrations/ve_tls/utils.py +1 -2
  17. veadk/integrations/ve_tls/ve_tls.py +9 -5
  18. veadk/integrations/ve_tos/ve_tos.py +542 -68
  19. veadk/knowledgebase/backends/base_backend.py +59 -0
  20. veadk/knowledgebase/backends/in_memory_backend.py +82 -0
  21. veadk/knowledgebase/backends/opensearch_backend.py +136 -0
  22. veadk/knowledgebase/backends/redis_backend.py +144 -0
  23. veadk/knowledgebase/backends/utils.py +91 -0
  24. veadk/knowledgebase/backends/vikingdb_knowledge_backend.py +524 -0
  25. veadk/{database/__init__.py → knowledgebase/entry.py} +10 -2
  26. veadk/knowledgebase/knowledgebase.py +120 -139
  27. veadk/memory/__init__.py +22 -0
  28. veadk/memory/long_term_memory.py +124 -41
  29. veadk/{database/base_database.py → memory/long_term_memory_backends/base_backend.py} +10 -22
  30. veadk/memory/long_term_memory_backends/in_memory_backend.py +65 -0
  31. veadk/memory/long_term_memory_backends/mem0_backend.py +129 -0
  32. veadk/memory/long_term_memory_backends/opensearch_backend.py +120 -0
  33. veadk/memory/long_term_memory_backends/redis_backend.py +127 -0
  34. veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +148 -0
  35. veadk/memory/short_term_memory.py +80 -72
  36. veadk/memory/short_term_memory_backends/base_backend.py +31 -0
  37. veadk/memory/short_term_memory_backends/mysql_backend.py +41 -0
  38. veadk/memory/short_term_memory_backends/postgresql_backend.py +41 -0
  39. veadk/memory/short_term_memory_backends/sqlite_backend.py +48 -0
  40. veadk/runner.py +12 -19
  41. veadk/tools/builtin_tools/generate_image.py +355 -0
  42. veadk/tools/builtin_tools/image_edit.py +56 -16
  43. veadk/tools/builtin_tools/image_generate.py +51 -15
  44. veadk/tools/builtin_tools/video_generate.py +41 -41
  45. veadk/tools/builtin_tools/web_scraper.py +1 -1
  46. veadk/tools/builtin_tools/web_search.py +7 -7
  47. veadk/tools/load_knowledgebase_tool.py +2 -8
  48. veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +21 -3
  49. veadk/tracing/telemetry/exporters/apmplus_exporter.py +24 -6
  50. veadk/tracing/telemetry/exporters/cozeloop_exporter.py +2 -0
  51. veadk/tracing/telemetry/exporters/inmemory_exporter.py +22 -8
  52. veadk/tracing/telemetry/exporters/tls_exporter.py +2 -0
  53. veadk/tracing/telemetry/opentelemetry_tracer.py +13 -10
  54. veadk/tracing/telemetry/telemetry.py +66 -63
  55. veadk/utils/misc.py +15 -0
  56. veadk/version.py +1 -1
  57. {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/METADATA +28 -5
  58. {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/RECORD +65 -56
  59. veadk/database/database_adapter.py +0 -533
  60. veadk/database/database_factory.py +0 -80
  61. veadk/database/kv/redis_database.py +0 -159
  62. veadk/database/local_database.py +0 -62
  63. veadk/database/relational/mysql_database.py +0 -173
  64. veadk/database/vector/opensearch_vector_database.py +0 -263
  65. veadk/database/vector/type.py +0 -50
  66. veadk/database/viking/__init__.py +0 -13
  67. veadk/database/viking/viking_database.py +0 -638
  68. veadk/database/viking/viking_memory_db.py +0 -525
  69. /veadk/{database/kv → knowledgebase/backends}/__init__.py +0 -0
  70. /veadk/{database/relational → memory/long_term_memory_backends}/__init__.py +0 -0
  71. /veadk/{database/vector → memory/short_term_memory_backends}/__init__.py +0 -0
  72. {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/WHEEL +0 -0
  73. {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/entry_points.txt +0 -0
  74. {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/licenses/LICENSE +0 -0
  75. {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/top_level.txt +0 -0
@@ -1,159 +0,0 @@
1
- # Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from typing import Any
18
-
19
- import redis
20
- from pydantic import BaseModel, Field
21
- from typing_extensions import override
22
-
23
- from veadk.config import getenv
24
- from veadk.utils.logger import get_logger
25
-
26
- from ..base_database import BaseDatabase
27
-
28
- logger = get_logger(__name__)
29
-
30
-
31
- class RedisDatabaseConfig(BaseModel):
32
- host: str = Field(
33
- default_factory=lambda: getenv("DATABASE_REDIS_HOST"),
34
- description="Redis host",
35
- )
36
- port: int = Field(
37
- default_factory=lambda: int(getenv("DATABASE_REDIS_PORT")),
38
- description="Redis port",
39
- )
40
- db: int = Field(
41
- default_factory=lambda: int(getenv("DATABASE_REDIS_DB")),
42
- description="Redis db",
43
- )
44
- password: str = Field(
45
- default_factory=lambda: getenv("DATABASE_REDIS_PASSWORD"),
46
- description="Redis password",
47
- )
48
- decode_responses: bool = Field(
49
- default=True,
50
- description="Redis decode responses",
51
- )
52
-
53
-
54
- class RedisDatabase(BaseModel, BaseDatabase):
55
- config: RedisDatabaseConfig = Field(default_factory=RedisDatabaseConfig)
56
-
57
- def model_post_init(self, context: Any, /) -> None:
58
- try:
59
- self._client = redis.StrictRedis(
60
- host=self.config.host,
61
- port=self.config.port,
62
- db=self.config.db,
63
- password=self.config.password,
64
- decode_responses=self.config.decode_responses,
65
- )
66
-
67
- self._client.ping()
68
- logger.info("Connected to Redis successfully.")
69
- except Exception as e:
70
- logger.error(f"Failed to connect to Redis: {e}")
71
- raise e
72
-
73
- @override
74
- def add(self, key: str, value: str, **kwargs):
75
- try:
76
- self._client.rpush(key, value)
77
- except Exception as e:
78
- logger.error(f"Failed to add value to Redis list key `{key}`: {e}")
79
- raise e
80
-
81
- @override
82
- def query(self, key: str, query: str = "", **kwargs) -> list:
83
- try:
84
- result = self._client.lrange(key, 0, -1)
85
- return result # type: ignore
86
- except Exception as e:
87
- logger.error(f"Failed to search from Redis list key '{key}': {e}")
88
- raise e
89
-
90
- @override
91
- def delete(self, **kwargs):
92
- """Delete Redis list key based on app_name, user_id and session_id, or directly by key."""
93
- key = kwargs.get("key")
94
- if key is None:
95
- app_name = kwargs.get("app_name")
96
- user_id = kwargs.get("user_id")
97
- session_id = kwargs.get("session_id")
98
- key = f"{app_name}:{user_id}:{session_id}"
99
-
100
- try:
101
- # For simple key deletion
102
- # We use sync Redis client to delete the key
103
- # so the result will be `int`
104
- result = self._client.delete(key)
105
-
106
- if result > 0: # type: ignore
107
- logger.info(f"Deleted key `{key}` from Redis.")
108
- else:
109
- logger.info(f"Key `{key}` not found in Redis. Skipping deletion.")
110
- except Exception as e:
111
- logger.error(f"Failed to delete key `{key}`: {e}")
112
- raise e
113
-
114
- def delete_doc(self, key: str, id: str) -> bool:
115
- """Delete a specific document by ID from a Redis list.
116
-
117
- Args:
118
- key: The Redis key (list) to delete from
119
- id: The ID of the document to delete
120
-
121
- Returns:
122
- bool: True if deletion was successful, False otherwise
123
- """
124
- try:
125
- # Get all items in the list
126
- items = self._client.lrange(key, 0, -1)
127
-
128
- # Find the index of the item to delete
129
- for i, item in enumerate(items):
130
- # Assuming the item is stored as a JSON string with an 'id' field
131
- # If it's just the content, we'll use the list index as ID
132
- if str(i) == id:
133
- self._client.lrem(key, 1, item)
134
- return True
135
-
136
- logger.warning(f"Document with id {id} not found in key {key}")
137
- return False
138
- except Exception as e:
139
- logger.error(f"Failed to delete document with id {id} from key {key}: {e}")
140
- return False
141
-
142
- def list_docs(self, key: str) -> list[dict]:
143
- """List all documents in a Redis list.
144
-
145
- Args:
146
- key: The Redis key (list) to list documents from
147
-
148
- Returns:
149
- list[dict]: List of documents with id and content
150
- """
151
- try:
152
- items = self._client.lrange(key, 0, -1)
153
- return [
154
- {"id": str(i), "content": item, "metadata": {}}
155
- for i, item in enumerate(items)
156
- ]
157
- except Exception as e:
158
- logger.error(f"Failed to list documents from key {key}: {e}")
159
- return []
@@ -1,62 +0,0 @@
1
- # Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any
16
-
17
- from .base_database import BaseDatabase
18
-
19
-
20
- class LocalDataBase(BaseDatabase):
21
- """This database is only for basic demonstration.
22
- It does not support the vector search function, and the `search` function will return all data.
23
- """
24
-
25
- def __init__(self, **kwargs):
26
- super().__init__()
27
- self.data = {}
28
- self._type = "local"
29
- self._next_id = 0 # Used to generate unique IDs
30
-
31
- def add_texts(self, texts: list[str], **kwargs):
32
- for text in texts:
33
- self.data[str(self._next_id)] = text
34
- self._next_id += 1
35
-
36
- def is_empty(self):
37
- return len(self.data) == 0
38
-
39
- def query(self, query: str, **kwargs: Any) -> list[str]:
40
- return list(self.data.values())
41
-
42
- def delete(self, **kwargs: Any):
43
- self.data = {}
44
- return True
45
-
46
- def add(self, texts: list[str], **kwargs: Any):
47
- return self.add_texts(texts)
48
-
49
- def list_docs(self, **kwargs: Any) -> list[dict]:
50
- return [
51
- {"id": id, "content": content, "metadata": {}}
52
- for id, content in self.data.items()
53
- ]
54
-
55
- def delete_doc(self, id: str, **kwargs: Any):
56
- if id not in self.data:
57
- raise ValueError(f"id {id} not found")
58
- try:
59
- del self.data[id]
60
- return True
61
- except Exception:
62
- return False
@@ -1,173 +0,0 @@
1
- # Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from typing import Any
18
-
19
- import pymysql
20
- from pydantic import BaseModel, Field
21
- from typing_extensions import override
22
-
23
- from veadk.config import getenv
24
- from veadk.utils.logger import get_logger
25
-
26
- from ..base_database import BaseDatabase
27
-
28
- logger = get_logger(__name__)
29
-
30
-
31
- class MysqlDatabaseConfig(BaseModel):
32
- host: str = Field(
33
- default_factory=lambda: getenv("DATABASE_MYSQL_HOST"),
34
- description="Mysql host",
35
- )
36
- user: str = Field(
37
- default_factory=lambda: getenv("DATABASE_MYSQL_USER"),
38
- description="Mysql user",
39
- )
40
- password: str = Field(
41
- default_factory=lambda: getenv("DATABASE_MYSQL_PASSWORD"),
42
- description="Mysql password",
43
- )
44
- database: str = Field(
45
- default_factory=lambda: getenv("DATABASE_MYSQL_DATABASE"),
46
- description="Mysql database",
47
- )
48
- charset: str = Field(
49
- default_factory=lambda: getenv("DATABASE_MYSQL_CHARSET", "utf8mb4"),
50
- description="Mysql charset",
51
- )
52
-
53
-
54
- class MysqlDatabase(BaseModel, BaseDatabase):
55
- config: MysqlDatabaseConfig = Field(default_factory=MysqlDatabaseConfig)
56
-
57
- def model_post_init(self, context: Any, /) -> None:
58
- self._connection = pymysql.connect(
59
- host=self.config.host,
60
- user=self.config.user,
61
- password=self.config.password,
62
- database=self.config.database,
63
- charset=self.config.charset,
64
- cursorclass=pymysql.cursors.DictCursor,
65
- )
66
- self._connection.ping()
67
- logger.info("Connected to MySQL successfully.")
68
-
69
- self._type = "mysql"
70
-
71
- def table_exists(self, table: str) -> bool:
72
- with self._connection.cursor() as cursor:
73
- cursor.execute(
74
- "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s",
75
- (self.config.database, table),
76
- )
77
- result = cursor.fetchone()
78
- return result is not None
79
-
80
- @override
81
- def add(self, sql: str, params=None, **kwargs):
82
- with self._connection.cursor() as cursor:
83
- cursor.execute(sql, params)
84
- self._connection.commit()
85
-
86
- @override
87
- def query(self, sql: str, params=None, **kwargs) -> tuple[dict[str, Any], ...]:
88
- with self._connection.cursor() as cursor:
89
- cursor.execute(sql, params)
90
- return cursor.fetchall()
91
-
92
- @override
93
- def delete(self, **kwargs):
94
- table = kwargs.get("table")
95
- if table is None:
96
- app_name = kwargs.get("app_name", "default")
97
- table = app_name
98
-
99
- if not self.table_exists(table):
100
- logger.warning(f"Table {table} does not exist. Skipping delete operation.")
101
- return
102
-
103
- try:
104
- with self._connection.cursor() as cursor:
105
- # Drop the table directly
106
- sql = f"DROP TABLE `{table}`"
107
- cursor.execute(sql)
108
- self._connection.commit()
109
- logger.info(f"Dropped table {table}")
110
- except Exception as e:
111
- logger.error(f"Failed to drop table {table}: {e}")
112
- raise e
113
-
114
- def delete_doc(self, table: str, ids: list[int]) -> bool:
115
- """Delete documents by IDs from a MySQL table.
116
-
117
- Args:
118
- table: The table name to delete from
119
- ids: List of document IDs to delete
120
-
121
- Returns:
122
- bool: True if deletion was successful, False otherwise
123
- """
124
- if not self.table_exists(table):
125
- logger.warning(f"Table {table} does not exist. Skipping delete operation.")
126
- return False
127
-
128
- if not ids:
129
- return True # Nothing to delete
130
-
131
- try:
132
- with self._connection.cursor() as cursor:
133
- # Create placeholders for the IDs
134
- placeholders = ",".join(["%s"] * len(ids))
135
- sql = f"DELETE FROM `{table}` WHERE id IN ({placeholders})"
136
- cursor.execute(sql, ids)
137
- self._connection.commit()
138
- logger.info(f"Deleted {cursor.rowcount} documents from table {table}")
139
- return True
140
- except Exception as e:
141
- logger.error(f"Failed to delete documents from table {table}: {e}")
142
- return False
143
-
144
- def list_docs(self, table: str, offset: int = 0, limit: int = 100) -> list[dict]:
145
- """List documents from a MySQL table.
146
-
147
- Args:
148
- table: The table name to list documents from
149
- offset: Offset for pagination
150
- limit: Limit for pagination
151
-
152
- Returns:
153
- list[dict]: List of documents with id and content
154
- """
155
- if not self.table_exists(table):
156
- logger.warning(f"Table {table} does not exist. Returning empty list.")
157
- return []
158
-
159
- try:
160
- with self._connection.cursor() as cursor:
161
- sql = f"SELECT id, data FROM `{table}` ORDER BY created_at DESC LIMIT %s OFFSET %s"
162
- cursor.execute(sql, (limit, offset))
163
- results = cursor.fetchall()
164
- return [
165
- {"id": str(row["id"]), "content": row["data"], "metadata": {}}
166
- for row in results
167
- ]
168
- except Exception as e:
169
- logger.error(f"Failed to list documents from table {table}: {e}")
170
- return []
171
-
172
- def is_empty(self):
173
- pass
@@ -1,263 +0,0 @@
1
- # Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import os
18
- from typing import Any, Literal, Optional
19
-
20
- from opensearchpy import OpenSearch, Urllib3HttpConnection, helpers
21
- from pydantic import BaseModel, Field, PrivateAttr
22
- from typing_extensions import override
23
-
24
- from veadk.config import getenv
25
- from veadk.utils.logger import get_logger
26
-
27
- from ..base_database import BaseDatabase
28
- from .type import Embeddings
29
-
30
- logger = get_logger(__name__)
31
-
32
-
33
- class OpenSearchVectorDatabaseConfig(BaseModel):
34
- host: str = Field(
35
- default_factory=lambda: getenv("DATABASE_OPENSEARCH_HOST"),
36
- description="OpenSearch host",
37
- )
38
-
39
- port: str | int = Field(
40
- default_factory=lambda: getenv("DATABASE_OPENSEARCH_PORT"),
41
- description="OpenSearch port",
42
- )
43
-
44
- username: Optional[str] = Field(
45
- default_factory=lambda: getenv("DATABASE_OPENSEARCH_USERNAME"),
46
- description="OpenSearch username",
47
- )
48
-
49
- password: Optional[str] = Field(
50
- default_factory=lambda: getenv("DATABASE_OPENSEARCH_PASSWORD"),
51
- description="OpenSearch password",
52
- )
53
-
54
- secure: bool = Field(default=True, description="Whether enable SSL")
55
-
56
- verify_certs: bool = Field(default=False, description="Whether verify SSL certs")
57
-
58
- auth_method: Literal["basic", "aws_managed_iam"] = Field(
59
- default="basic", description="OpenSearch auth method"
60
- )
61
-
62
- def to_opensearch_params(self) -> dict[str, Any]:
63
- params = {
64
- "hosts": [{"host": self.host, "port": int(self.port)}],
65
- "use_ssl": self.secure,
66
- "verify_certs": self.verify_certs,
67
- "connection_class": Urllib3HttpConnection,
68
- "pool_maxsize": 20,
69
- }
70
- ca_cert_path = os.getenv("OPENSEARCH_CA_CERT")
71
- if self.verify_certs and ca_cert_path:
72
- params["ca_certs"] = ca_cert_path
73
-
74
- params["http_auth"] = (self.username, self.password)
75
-
76
- return params
77
-
78
-
79
- class OpenSearchVectorDatabase(BaseModel, BaseDatabase):
80
- config: OpenSearchVectorDatabaseConfig = Field(
81
- default_factory=OpenSearchVectorDatabaseConfig
82
- )
83
-
84
- _embedding_client: Embeddings = PrivateAttr()
85
- _opensearch_client: OpenSearch = PrivateAttr()
86
-
87
- def model_post_init(self, context: Any, /) -> None:
88
- self._embedding_client = Embeddings()
89
- self._opensearch_client = OpenSearch(**self.config.to_opensearch_params())
90
-
91
- self._type = "opensearch"
92
-
93
- def _get_settings(self) -> dict:
94
- settings = {"index": {"knn": True}}
95
- return settings
96
-
97
- def _get_mappings(self, dim: int = 2560) -> dict:
98
- mappings = {
99
- "properties": {
100
- "page_content": {
101
- "type": "text",
102
- },
103
- "vector": {
104
- "type": "knn_vector",
105
- "dimension": dim,
106
- "method": {
107
- "name": "hnsw",
108
- "space_type": "l2",
109
- "engine": "faiss",
110
- "parameters": {"ef_construction": 64, "m": 8},
111
- },
112
- },
113
- }
114
- }
115
- return mappings
116
-
117
- def create_collection(
118
- self,
119
- collection_name: str,
120
- embedding_dim: int,
121
- ):
122
- if not self._opensearch_client.indices.exists(index=collection_name):
123
- self._opensearch_client.indices.create(
124
- index=collection_name,
125
- body={
126
- "mappings": self._get_mappings(dim=embedding_dim),
127
- "settings": self._get_settings(),
128
- },
129
- )
130
- else:
131
- logger.warning(f"Collection {collection_name} already exists.")
132
-
133
- self._opensearch_client.indices.refresh(index=collection_name)
134
- return
135
-
136
- def _search_by_vector(
137
- self, collection_name: str, query_vector: list[float], **kwargs: Any
138
- ) -> list[str]:
139
- top_k = kwargs.get("top_k", 5)
140
- query = {
141
- "size": top_k,
142
- "query": {"knn": {"vector": {"vector": query_vector, "k": top_k}}},
143
- }
144
- response = self._opensearch_client.search(index=collection_name, body=query)
145
-
146
- result_list = []
147
- for hit in response["hits"]["hits"]:
148
- result_list.append(hit["_source"]["page_content"])
149
-
150
- return result_list
151
-
152
- def get_health(self):
153
- response = self._opensearch_client.cat.health()
154
- logger.info(response)
155
-
156
- def add(self, texts: list[str], **kwargs):
157
- collection_name = kwargs.get("collection_name")
158
- assert collection_name is not None, "Collection name is required."
159
- if not self._opensearch_client.indices.exists(index=collection_name):
160
- self.create_collection(
161
- embedding_dim=self._embedding_client.get_embedding_dim(),
162
- collection_name=collection_name,
163
- )
164
-
165
- actions = []
166
- embeddings = self._embedding_client.embed_documents(texts)
167
- for i in range(len(texts)):
168
- action = {
169
- "_op_type": "index",
170
- "_index": collection_name,
171
- "_source": {
172
- "page_content": texts[i],
173
- "vector": embeddings[i],
174
- },
175
- }
176
- actions.append(action)
177
-
178
- helpers.bulk(
179
- client=self._opensearch_client,
180
- actions=actions,
181
- timeout=30,
182
- max_retries=3,
183
- )
184
-
185
- self._opensearch_client.indices.refresh(index=collection_name)
186
- return
187
-
188
- @override
189
- def query(self, query: str, **kwargs: Any) -> list[str]:
190
- collection_name = kwargs.get("collection_name")
191
- top_k = kwargs.get("top_k", 5)
192
- assert collection_name is not None, "Collection name is required."
193
- if not self._opensearch_client.indices.exists(index=collection_name):
194
- logger.warning(
195
- f"querying {query}, but collection {collection_name} does not exist. return a empty list."
196
- )
197
- return []
198
- query_vector = self._embedding_client.embed_query(query)
199
- return self._search_by_vector(
200
- collection_name=collection_name, query_vector=query_vector, top_k=top_k
201
- )
202
-
203
- @override
204
- def delete(self, collection_name: str, **kwargs: Any):
205
- """drop index"""
206
- if not self._opensearch_client.indices.exists(index=collection_name):
207
- raise ValueError(f"Collection {collection_name} does not exist.")
208
- self._opensearch_client.indices.delete(index=collection_name)
209
-
210
- def is_empty(self, collection_name: str) -> bool:
211
- response = self._opensearch_client.count(index=collection_name)
212
- return response["count"] == 0
213
-
214
- def collection_exists(self, collection_name: str) -> bool:
215
- return self._opensearch_client.indices.exists(index=collection_name)
216
-
217
- def list_all_collection(self) -> list:
218
- """List all index name of OpenSearch."""
219
- response = self._opensearch_client.indices.get_alias()
220
- return list(response.keys())
221
-
222
- def list_docs(
223
- self, collection_name: str, offset: int = 0, limit: int = 10000
224
- ) -> list[dict]:
225
- """Match all docs in one index of OpenSearch"""
226
- if not self.collection_exists(collection_name):
227
- logger.warning(
228
- f"Get all docs, but collection {collection_name} does not exist. return a empty list."
229
- )
230
- return []
231
-
232
- query = {"size": limit, "from": offset, "query": {"match_all": {}}}
233
- response = self._opensearch_client.search(index=collection_name, body=query)
234
- return [
235
- {
236
- "id": hit["_id"],
237
- "content": hit["_source"]["page_content"],
238
- "metadata": {},
239
- }
240
- for hit in response["hits"]["hits"]
241
- ]
242
-
243
- def delete_by_query(self, collection_name: str, query: str) -> Any:
244
- """Delete docs by query in one index of OpenSearch"""
245
- if not self.collection_exists(collection_name):
246
- raise ValueError(f"Collection {collection_name} does not exist.")
247
-
248
- query_payload = {"query": {"match": {"page_content": query}}}
249
- response = self._opensearch_client.delete_by_query(
250
- index=collection_name, body=query_payload
251
- )
252
-
253
- self._opensearch_client.indices.refresh(index=collection_name)
254
- return response
255
-
256
- def delete_by_id(self, collection_name: str, id: str):
257
- """Delete docs by id in index of OpenSearch"""
258
- if not self.collection_exists(collection_name):
259
- raise ValueError(f"Collection {collection_name} does not exist.")
260
-
261
- response = self._opensearch_client.delete(index=collection_name, id=id)
262
- self._opensearch_client.indices.refresh(index=collection_name)
263
- return response