veadk-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/__init__.py +31 -0
- veadk/a2a/__init__.py +13 -0
- veadk/a2a/agent_card.py +45 -0
- veadk/a2a/remote_ve_agent.py +19 -0
- veadk/a2a/ve_a2a_server.py +77 -0
- veadk/a2a/ve_agent_executor.py +78 -0
- veadk/a2a/ve_task_store.py +37 -0
- veadk/agent.py +253 -0
- veadk/cli/__init__.py +13 -0
- veadk/cli/main.py +278 -0
- veadk/cli/services/agentpilot/__init__.py +17 -0
- veadk/cli/services/agentpilot/agentpilot.py +77 -0
- veadk/cli/services/veapig/__init__.py +17 -0
- veadk/cli/services/veapig/apig.py +224 -0
- veadk/cli/services/veapig/apig_utils.py +332 -0
- veadk/cli/services/vefaas/__init__.py +17 -0
- veadk/cli/services/vefaas/template/deploy.py +44 -0
- veadk/cli/services/vefaas/template/src/app.py +30 -0
- veadk/cli/services/vefaas/template/src/config.py +58 -0
- veadk/cli/services/vefaas/vefaas.py +346 -0
- veadk/cli/services/vefaas/vefaas_utils.py +408 -0
- veadk/cli/services/vetls/__init__.py +17 -0
- veadk/cli/services/vetls/vetls.py +87 -0
- veadk/cli/studio/__init__.py +13 -0
- veadk/cli/studio/agent_processor.py +247 -0
- veadk/cli/studio/fast_api.py +232 -0
- veadk/cli/studio/model.py +116 -0
- veadk/cloud/__init__.py +13 -0
- veadk/cloud/cloud_agent_engine.py +144 -0
- veadk/cloud/cloud_app.py +123 -0
- veadk/cloud/template/app.py +30 -0
- veadk/cloud/template/config.py +55 -0
- veadk/config.py +131 -0
- veadk/consts.py +17 -0
- veadk/database/__init__.py +17 -0
- veadk/database/base_database.py +45 -0
- veadk/database/database_factory.py +80 -0
- veadk/database/kv/__init__.py +13 -0
- veadk/database/kv/redis_database.py +109 -0
- veadk/database/local_database.py +43 -0
- veadk/database/relational/__init__.py +13 -0
- veadk/database/relational/mysql_database.py +114 -0
- veadk/database/vector/__init__.py +13 -0
- veadk/database/vector/opensearch_vector_database.py +205 -0
- veadk/database/vector/type.py +50 -0
- veadk/database/viking/__init__.py +13 -0
- veadk/database/viking/viking_database.py +378 -0
- veadk/database/viking/viking_memory_db.py +521 -0
- veadk/evaluation/__init__.py +17 -0
- veadk/evaluation/adk_evaluator/__init__.py +13 -0
- veadk/evaluation/adk_evaluator/adk_evaluator.py +291 -0
- veadk/evaluation/base_evaluator.py +242 -0
- veadk/evaluation/deepeval_evaluator/__init__.py +17 -0
- veadk/evaluation/deepeval_evaluator/deepeval_evaluator.py +223 -0
- veadk/evaluation/eval_set_file_loader.py +28 -0
- veadk/evaluation/eval_set_recorder.py +91 -0
- veadk/evaluation/utils/prometheus.py +142 -0
- veadk/knowledgebase/__init__.py +17 -0
- veadk/knowledgebase/knowledgebase.py +83 -0
- veadk/knowledgebase/knowledgebase_database_adapter.py +259 -0
- veadk/memory/__init__.py +13 -0
- veadk/memory/long_term_memory.py +119 -0
- veadk/memory/memory_database_adapter.py +235 -0
- veadk/memory/short_term_memory.py +124 -0
- veadk/memory/short_term_memory_processor.py +90 -0
- veadk/prompts/__init__.py +13 -0
- veadk/prompts/agent_default_prompt.py +30 -0
- veadk/prompts/prompt_evaluator.py +20 -0
- veadk/prompts/prompt_memory_processor.py +55 -0
- veadk/prompts/prompt_optimization.py +158 -0
- veadk/runner.py +252 -0
- veadk/tools/__init__.py +13 -0
- veadk/tools/builtin_tools/__init__.py +13 -0
- veadk/tools/builtin_tools/lark.py +67 -0
- veadk/tools/builtin_tools/las.py +23 -0
- veadk/tools/builtin_tools/vesearch.py +49 -0
- veadk/tools/builtin_tools/web_scraper.py +76 -0
- veadk/tools/builtin_tools/web_search.py +192 -0
- veadk/tools/demo_tools.py +58 -0
- veadk/tools/load_knowledgebase_tool.py +144 -0
- veadk/tools/sandbox/__init__.py +13 -0
- veadk/tools/sandbox/browser_sandbox.py +27 -0
- veadk/tools/sandbox/code_sandbox.py +30 -0
- veadk/tools/sandbox/computer_sandbox.py +27 -0
- veadk/tracing/__init__.py +13 -0
- veadk/tracing/base_tracer.py +172 -0
- veadk/tracing/telemetry/__init__.py +13 -0
- veadk/tracing/telemetry/exporters/__init__.py +13 -0
- veadk/tracing/telemetry/exporters/apiserver_exporter.py +60 -0
- veadk/tracing/telemetry/exporters/apmplus_exporter.py +101 -0
- veadk/tracing/telemetry/exporters/base_exporter.py +28 -0
- veadk/tracing/telemetry/exporters/cozeloop_exporter.py +69 -0
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +88 -0
- veadk/tracing/telemetry/exporters/tls_exporter.py +78 -0
- veadk/tracing/telemetry/metrics/__init__.py +13 -0
- veadk/tracing/telemetry/metrics/opentelemetry_metrics.py +73 -0
- veadk/tracing/telemetry/opentelemetry_tracer.py +167 -0
- veadk/types.py +23 -0
- veadk/utils/__init__.py +13 -0
- veadk/utils/logger.py +59 -0
- veadk/utils/misc.py +33 -0
- veadk/utils/patches.py +85 -0
- veadk/utils/volcengine_sign.py +199 -0
- veadk/version.py +15 -0
- veadk_python-0.1.0.dist-info/METADATA +124 -0
- veadk_python-0.1.0.dist-info/RECORD +110 -0
- veadk_python-0.1.0.dist-info/WHEEL +5 -0
- veadk_python-0.1.0.dist-info/entry_points.txt +2 -0
- veadk_python-0.1.0.dist-info/licenses/LICENSE +201 -0
- veadk_python-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import redis
|
|
20
|
+
from pydantic import BaseModel, Field, PrivateAttr
|
|
21
|
+
from typing_extensions import override
|
|
22
|
+
|
|
23
|
+
from veadk.config import getenv
|
|
24
|
+
from veadk.utils.logger import get_logger
|
|
25
|
+
|
|
26
|
+
from ..base_database import BaseDatabase
|
|
27
|
+
|
|
28
|
+
logger = get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RedisDatabaseConfig(BaseModel):
|
|
32
|
+
host: str = Field(
|
|
33
|
+
default=getenv("DATABASE_REDIS_HOST"),
|
|
34
|
+
description="Redis host",
|
|
35
|
+
)
|
|
36
|
+
port: int = Field(
|
|
37
|
+
default=getenv("DATABASE_REDIS_PORT"),
|
|
38
|
+
description="Redis port",
|
|
39
|
+
)
|
|
40
|
+
db: int = Field(
|
|
41
|
+
default=getenv("DATABASE_REDIS_DB"),
|
|
42
|
+
description="Redis db",
|
|
43
|
+
)
|
|
44
|
+
password: str = Field(
|
|
45
|
+
default=getenv("DATABASE_REDIS_PASSWORD"),
|
|
46
|
+
description="Redis password",
|
|
47
|
+
)
|
|
48
|
+
decode_responses: bool = Field(
|
|
49
|
+
default=True,
|
|
50
|
+
description="Redis decode responses",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RedisDatabase(BaseModel, BaseDatabase):
|
|
55
|
+
config: RedisDatabaseConfig = Field(default_factory=RedisDatabaseConfig)
|
|
56
|
+
_client: redis.Redis = PrivateAttr(default=None)
|
|
57
|
+
|
|
58
|
+
def model_post_init(self, context: Any, /) -> None:
|
|
59
|
+
try:
|
|
60
|
+
self._client = redis.StrictRedis(
|
|
61
|
+
host=self.config.host,
|
|
62
|
+
port=self.config.port,
|
|
63
|
+
db=self.config.db,
|
|
64
|
+
password=self.config.password,
|
|
65
|
+
decode_responses=self.config.decode_responses,
|
|
66
|
+
)
|
|
67
|
+
self._client.ping()
|
|
68
|
+
logger.info("Connected to Redis successfully.")
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"Failed to connect to Redis: {e}")
|
|
71
|
+
raise e
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
def add(self, key: str, value: str, **kwargs):
|
|
75
|
+
try:
|
|
76
|
+
self._client.rpush(key, value)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.error(f"Failed to add value to Redis list key `{key}`: {e}")
|
|
79
|
+
raise e
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def query(self, key: str, query: str = "", **kwargs) -> list[str]:
|
|
83
|
+
try:
|
|
84
|
+
result = self._client.lrange(key, 0, -1)
|
|
85
|
+
return result
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error(f"Failed to search from Redis list key '{key}': {e}")
|
|
88
|
+
raise e
|
|
89
|
+
|
|
90
|
+
@override
|
|
91
|
+
def delete(self, **kwargs):
|
|
92
|
+
"""Delete Redis list key based on app_name, user_id and session_id, or directly by key."""
|
|
93
|
+
key = kwargs.get("key")
|
|
94
|
+
if key is None:
|
|
95
|
+
app_name = kwargs.get("app_name")
|
|
96
|
+
user_id = kwargs.get("user_id")
|
|
97
|
+
session_id = kwargs.get("session_id")
|
|
98
|
+
key = f"{app_name}:{user_id}:{session_id}"
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
# For simple key deletion
|
|
102
|
+
result = self._client.delete(key)
|
|
103
|
+
if result > 0:
|
|
104
|
+
logger.info(f"Deleted key `{key}` from Redis.")
|
|
105
|
+
else:
|
|
106
|
+
logger.info(f"Key `{key}` not found in Redis. Skipping deletion.")
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error(f"Failed to delete key `{key}`: {e}")
|
|
109
|
+
raise e
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from .base_database import BaseDatabase
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LocalDataBase(BaseDatabase):
|
|
21
|
+
"""This database is only for basic demonstration.
|
|
22
|
+
It does not support the vector search function, and the `search` function will return all data.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, **kwargs):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.data = []
|
|
28
|
+
self._type = "local"
|
|
29
|
+
|
|
30
|
+
def add_texts(self, texts: list[str], **kwargs):
|
|
31
|
+
self.data.extend(texts)
|
|
32
|
+
|
|
33
|
+
def is_empty(self):
|
|
34
|
+
return len(self.data) == 0
|
|
35
|
+
|
|
36
|
+
def query(self, query: str, **kwargs: Any) -> list[str]:
|
|
37
|
+
return self.data
|
|
38
|
+
|
|
39
|
+
def delete(self, **kwargs: Any):
|
|
40
|
+
self.data = []
|
|
41
|
+
|
|
42
|
+
def add(self, texts: list[str], **kwargs: Any):
|
|
43
|
+
return self.add_texts(texts)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import pymysql
|
|
20
|
+
from pydantic import BaseModel, Field, PrivateAttr
|
|
21
|
+
from typing_extensions import override
|
|
22
|
+
|
|
23
|
+
from veadk.config import getenv
|
|
24
|
+
from veadk.utils.logger import get_logger
|
|
25
|
+
|
|
26
|
+
from ..base_database import BaseDatabase
|
|
27
|
+
|
|
28
|
+
logger = get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MysqlDatabaseConfig(BaseModel):
|
|
32
|
+
host: str = Field(
|
|
33
|
+
default=getenv("DATABASE_MYSQL_HOST"),
|
|
34
|
+
description="Mysql host",
|
|
35
|
+
)
|
|
36
|
+
user: str = Field(
|
|
37
|
+
default=getenv("DATABASE_MYSQL_USER"),
|
|
38
|
+
description="Mysql user",
|
|
39
|
+
)
|
|
40
|
+
password: str = Field(
|
|
41
|
+
default=getenv("DATABASE_MYSQL_PASSWORD"),
|
|
42
|
+
description="Mysql password",
|
|
43
|
+
)
|
|
44
|
+
database: str = Field(
|
|
45
|
+
default=getenv("DATABASE_MYSQL_DATABASE"),
|
|
46
|
+
description="Mysql database",
|
|
47
|
+
)
|
|
48
|
+
charset: str = Field(
|
|
49
|
+
default=getenv("DATABASE_MYSQL_CHARSET", "utf8mb4"),
|
|
50
|
+
description="Mysql charset",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MysqlDatabase(BaseModel, BaseDatabase):
|
|
55
|
+
config: MysqlDatabaseConfig = Field(default_factory=MysqlDatabaseConfig)
|
|
56
|
+
|
|
57
|
+
_connection: pymysql.Connection = PrivateAttr(default=None)
|
|
58
|
+
|
|
59
|
+
def model_post_init(self, context: Any, /) -> None:
|
|
60
|
+
self._connection = pymysql.connect(
|
|
61
|
+
host=self.config.host,
|
|
62
|
+
user=self.config.user,
|
|
63
|
+
password=self.config.password,
|
|
64
|
+
database=self.config.database,
|
|
65
|
+
charset=self.config.charset,
|
|
66
|
+
cursorclass=pymysql.cursors.DictCursor,
|
|
67
|
+
)
|
|
68
|
+
self._type = "mysql"
|
|
69
|
+
|
|
70
|
+
def table_exists(self, table: str) -> bool:
|
|
71
|
+
with self._connection.cursor() as cursor:
|
|
72
|
+
cursor.execute(
|
|
73
|
+
"SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s",
|
|
74
|
+
(self.config.database, table),
|
|
75
|
+
)
|
|
76
|
+
result = cursor.fetchone()
|
|
77
|
+
return result is not None
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def add(self, sql: str, params=None, **kwargs):
|
|
81
|
+
with self._connection.cursor() as cursor:
|
|
82
|
+
cursor.execute(sql, params)
|
|
83
|
+
self._connection.commit()
|
|
84
|
+
|
|
85
|
+
@override
|
|
86
|
+
def query(self, sql: str, params=None, **kwargs) -> list[str]:
|
|
87
|
+
with self._connection.cursor() as cursor:
|
|
88
|
+
cursor.execute(sql, params)
|
|
89
|
+
return cursor.fetchall()
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def delete(self, **kwargs):
|
|
93
|
+
table = kwargs.get("table")
|
|
94
|
+
if table is None:
|
|
95
|
+
app_name = kwargs.get("app_name", "default")
|
|
96
|
+
table = app_name
|
|
97
|
+
|
|
98
|
+
if not self.table_exists(table):
|
|
99
|
+
logger.warning(f"Table {table} does not exist. Skipping delete operation.")
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
with self._connection.cursor() as cursor:
|
|
104
|
+
# Drop the table directly
|
|
105
|
+
sql = f"DROP TABLE `{table}`"
|
|
106
|
+
cursor.execute(sql)
|
|
107
|
+
self._connection.commit()
|
|
108
|
+
logger.info(f"Dropped table {table}")
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.error(f"Failed to drop table {table}: {e}")
|
|
111
|
+
raise e
|
|
112
|
+
|
|
113
|
+
def is_empty(self):
|
|
114
|
+
pass
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from typing import Any, Literal, Optional
|
|
19
|
+
|
|
20
|
+
from opensearchpy import OpenSearch, Urllib3HttpConnection, helpers
|
|
21
|
+
from pydantic import BaseModel, Field, PrivateAttr
|
|
22
|
+
from typing_extensions import override
|
|
23
|
+
|
|
24
|
+
from veadk.config import getenv
|
|
25
|
+
from veadk.utils.logger import get_logger
|
|
26
|
+
|
|
27
|
+
from ..base_database import BaseDatabase
|
|
28
|
+
from .type import Embeddings
|
|
29
|
+
|
|
30
|
+
logger = get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OpenSearchVectorDatabaseConfig(BaseModel):
|
|
34
|
+
host: str = Field(
|
|
35
|
+
default=getenv("DATABASE_OPENSEARCH_HOST"),
|
|
36
|
+
description="OpenSearch host",
|
|
37
|
+
)
|
|
38
|
+
port: str | int = Field(
|
|
39
|
+
default=getenv("DATABASE_OPENSEARCH_PORT"),
|
|
40
|
+
description="OpenSearch port",
|
|
41
|
+
)
|
|
42
|
+
username: Optional[str] = Field(
|
|
43
|
+
default=getenv("DATABASE_OPENSEARCH_USERNAME"),
|
|
44
|
+
description="OpenSearch username",
|
|
45
|
+
)
|
|
46
|
+
password: Optional[str] = Field(
|
|
47
|
+
default=getenv("DATABASE_OPENSEARCH_PASSWORD"),
|
|
48
|
+
description="OpenSearch password",
|
|
49
|
+
)
|
|
50
|
+
secure: bool = Field(default=True, description="Whether enable SSL")
|
|
51
|
+
verify_certs: bool = Field(default=False, description="Whether verify SSL certs")
|
|
52
|
+
auth_method: Literal["basic", "aws_managed_iam"] = Field(
|
|
53
|
+
default="basic", description="OpenSearch auth method"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def to_opensearch_params(self) -> dict[str, Any]:
|
|
57
|
+
params = {
|
|
58
|
+
"hosts": [{"host": self.host, "port": int(self.port)}],
|
|
59
|
+
"use_ssl": self.secure,
|
|
60
|
+
"verify_certs": self.verify_certs,
|
|
61
|
+
"connection_class": Urllib3HttpConnection,
|
|
62
|
+
"pool_maxsize": 20,
|
|
63
|
+
}
|
|
64
|
+
ca_cert_path = os.getenv("OPENSEARCH_CA_CERT")
|
|
65
|
+
if self.verify_certs and ca_cert_path:
|
|
66
|
+
params["ca_certs"] = ca_cert_path
|
|
67
|
+
|
|
68
|
+
params["http_auth"] = (self.username, self.password)
|
|
69
|
+
|
|
70
|
+
return params
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class OpenSearchVectorDatabase(BaseModel, BaseDatabase):
|
|
74
|
+
config: OpenSearchVectorDatabaseConfig = Field(
|
|
75
|
+
default_factory=OpenSearchVectorDatabaseConfig
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
_embedding_client: Embeddings = PrivateAttr()
|
|
79
|
+
_opensearch_client: OpenSearch = PrivateAttr()
|
|
80
|
+
|
|
81
|
+
def model_post_init(self, context: Any, /) -> None:
|
|
82
|
+
self._embedding_client = Embeddings()
|
|
83
|
+
self._opensearch_client = OpenSearch(**self.config.to_opensearch_params())
|
|
84
|
+
|
|
85
|
+
self._type = "opensearch"
|
|
86
|
+
|
|
87
|
+
def _get_settings(self) -> dict:
|
|
88
|
+
settings = {"index": {"knn": True}}
|
|
89
|
+
return settings
|
|
90
|
+
|
|
91
|
+
def _get_mappings(self, dim: int = 2560) -> dict:
|
|
92
|
+
mappings = {
|
|
93
|
+
"properties": {
|
|
94
|
+
"page_content": {
|
|
95
|
+
"type": "text",
|
|
96
|
+
},
|
|
97
|
+
"vector": {
|
|
98
|
+
"type": "knn_vector",
|
|
99
|
+
"dimension": dim,
|
|
100
|
+
"method": {
|
|
101
|
+
"name": "hnsw",
|
|
102
|
+
"space_type": "l2",
|
|
103
|
+
"engine": "faiss",
|
|
104
|
+
"parameters": {"ef_construction": 64, "m": 8},
|
|
105
|
+
},
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
return mappings
|
|
110
|
+
|
|
111
|
+
def create_collection(
|
|
112
|
+
self,
|
|
113
|
+
collection_name: str,
|
|
114
|
+
embedding_dim: int,
|
|
115
|
+
):
|
|
116
|
+
if not self._opensearch_client.indices.exists(index=collection_name):
|
|
117
|
+
self._opensearch_client.indices.create(
|
|
118
|
+
index=collection_name,
|
|
119
|
+
body={
|
|
120
|
+
"mappings": self._get_mappings(dim=embedding_dim),
|
|
121
|
+
"settings": self._get_settings(),
|
|
122
|
+
},
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
logger.warning(f"Collection {collection_name} already exists.")
|
|
126
|
+
|
|
127
|
+
self._opensearch_client.indices.refresh(index=collection_name)
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
def _search_by_vector(
|
|
131
|
+
self, collection_name: str, query_vector: list[float], **kwargs: Any
|
|
132
|
+
) -> list[str]:
|
|
133
|
+
top_k = kwargs.get("top_k", 5)
|
|
134
|
+
query = {
|
|
135
|
+
"size": top_k,
|
|
136
|
+
"query": {"knn": {"vector": {"vector": query_vector, "k": top_k}}},
|
|
137
|
+
}
|
|
138
|
+
response = self._opensearch_client.search(index=collection_name, body=query)
|
|
139
|
+
|
|
140
|
+
result_list = []
|
|
141
|
+
for hit in response["hits"]["hits"]:
|
|
142
|
+
result_list.append(hit["_source"]["page_content"])
|
|
143
|
+
|
|
144
|
+
return result_list
|
|
145
|
+
|
|
146
|
+
def get_health(self):
|
|
147
|
+
response = self._opensearch_client.cat.health()
|
|
148
|
+
logger.info(response)
|
|
149
|
+
|
|
150
|
+
def add(self, texts: list[str], **kwargs):
|
|
151
|
+
collection_name = kwargs.get("collection_name")
|
|
152
|
+
assert collection_name is not None, "Collection name is required."
|
|
153
|
+
if not self._opensearch_client.indices.exists(index=collection_name):
|
|
154
|
+
self.create_collection(
|
|
155
|
+
embedding_dim=self._embedding_client.get_embedding_dim(),
|
|
156
|
+
collection_name=collection_name,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
actions = []
|
|
160
|
+
embeddings = self._embedding_client.embed_documents(texts)
|
|
161
|
+
for i in range(len(texts)):
|
|
162
|
+
action = {
|
|
163
|
+
"_op_type": "index",
|
|
164
|
+
"_index": collection_name,
|
|
165
|
+
"_source": {
|
|
166
|
+
"page_content": texts[i],
|
|
167
|
+
"vector": embeddings[i],
|
|
168
|
+
},
|
|
169
|
+
}
|
|
170
|
+
actions.append(action)
|
|
171
|
+
|
|
172
|
+
helpers.bulk(
|
|
173
|
+
client=self._opensearch_client,
|
|
174
|
+
actions=actions,
|
|
175
|
+
timeout=30,
|
|
176
|
+
max_retries=3,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self._opensearch_client.indices.refresh(index=collection_name)
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
@override
|
|
183
|
+
def query(self, query: str, **kwargs: Any) -> list[str]:
|
|
184
|
+
collection_name = kwargs.get("collection_name")
|
|
185
|
+
top_k = kwargs.get("top_k", 5)
|
|
186
|
+
assert collection_name is not None, "Collection name is required."
|
|
187
|
+
if not self._opensearch_client.indices.exists(index=collection_name):
|
|
188
|
+
logger.warning(
|
|
189
|
+
f"querying {query}, but collection {collection_name} does not exist. retun a empty list."
|
|
190
|
+
)
|
|
191
|
+
return []
|
|
192
|
+
query_vector = self._embedding_client.embed_query(query)
|
|
193
|
+
return self._search_by_vector(
|
|
194
|
+
collection_name=collection_name, query_vector=query_vector, top_k=top_k
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
@override
|
|
198
|
+
def delete(self, collection_name: str, **kwargs: Any):
|
|
199
|
+
if not self._opensearch_client.indices.exists(index=collection_name):
|
|
200
|
+
raise ValueError(f"Collection {collection_name} does not exist.")
|
|
201
|
+
self._opensearch_client.indices.delete(index=collection_name)
|
|
202
|
+
|
|
203
|
+
def is_empty(self, collection_name: str):
|
|
204
|
+
response = self._opensearch_client.count(index=collection_name)
|
|
205
|
+
return response["count"] == 0
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import requests
|
|
16
|
+
|
|
17
|
+
from veadk.config import getenv
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Embeddings:
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model: str = getenv("MODEL_EMBEDDING_NAME"),
|
|
24
|
+
api_base: str = getenv("MODEL_EMBEDDING_API_BASE"),
|
|
25
|
+
api_key: str = getenv("MODEL_EMBEDDING_API_KEY"),
|
|
26
|
+
dim: int = int(getenv("MODEL_EMBEDDING_DIM")),
|
|
27
|
+
):
|
|
28
|
+
self.model = model
|
|
29
|
+
self.url = api_base
|
|
30
|
+
self.api_key = api_key
|
|
31
|
+
self.dim = dim
|
|
32
|
+
|
|
33
|
+
self.headers = {
|
|
34
|
+
"Content-Type": "application/json",
|
|
35
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
39
|
+
MAX_CHARS = 4000
|
|
40
|
+
data = {"model": self.model, "input": [text[:MAX_CHARS] for text in texts]}
|
|
41
|
+
response = requests.post(self.url, headers=self.headers, json=data)
|
|
42
|
+
response.raise_for_status()
|
|
43
|
+
result = response.json()
|
|
44
|
+
return [item["embedding"] for item in result["data"]]
|
|
45
|
+
|
|
46
|
+
def embed_query(self, text: str) -> list[float]:
|
|
47
|
+
return self.embed_documents([text])[0]
|
|
48
|
+
|
|
49
|
+
def get_embedding_dim(self) -> int:
|
|
50
|
+
return self.dim
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|