veadk-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/__init__.py +31 -0
- veadk/a2a/__init__.py +13 -0
- veadk/a2a/agent_card.py +45 -0
- veadk/a2a/remote_ve_agent.py +19 -0
- veadk/a2a/ve_a2a_server.py +77 -0
- veadk/a2a/ve_agent_executor.py +78 -0
- veadk/a2a/ve_task_store.py +37 -0
- veadk/agent.py +253 -0
- veadk/cli/__init__.py +13 -0
- veadk/cli/main.py +278 -0
- veadk/cli/services/agentpilot/__init__.py +17 -0
- veadk/cli/services/agentpilot/agentpilot.py +77 -0
- veadk/cli/services/veapig/__init__.py +17 -0
- veadk/cli/services/veapig/apig.py +224 -0
- veadk/cli/services/veapig/apig_utils.py +332 -0
- veadk/cli/services/vefaas/__init__.py +17 -0
- veadk/cli/services/vefaas/template/deploy.py +44 -0
- veadk/cli/services/vefaas/template/src/app.py +30 -0
- veadk/cli/services/vefaas/template/src/config.py +58 -0
- veadk/cli/services/vefaas/vefaas.py +346 -0
- veadk/cli/services/vefaas/vefaas_utils.py +408 -0
- veadk/cli/services/vetls/__init__.py +17 -0
- veadk/cli/services/vetls/vetls.py +87 -0
- veadk/cli/studio/__init__.py +13 -0
- veadk/cli/studio/agent_processor.py +247 -0
- veadk/cli/studio/fast_api.py +232 -0
- veadk/cli/studio/model.py +116 -0
- veadk/cloud/__init__.py +13 -0
- veadk/cloud/cloud_agent_engine.py +144 -0
- veadk/cloud/cloud_app.py +123 -0
- veadk/cloud/template/app.py +30 -0
- veadk/cloud/template/config.py +55 -0
- veadk/config.py +131 -0
- veadk/consts.py +17 -0
- veadk/database/__init__.py +17 -0
- veadk/database/base_database.py +45 -0
- veadk/database/database_factory.py +80 -0
- veadk/database/kv/__init__.py +13 -0
- veadk/database/kv/redis_database.py +109 -0
- veadk/database/local_database.py +43 -0
- veadk/database/relational/__init__.py +13 -0
- veadk/database/relational/mysql_database.py +114 -0
- veadk/database/vector/__init__.py +13 -0
- veadk/database/vector/opensearch_vector_database.py +205 -0
- veadk/database/vector/type.py +50 -0
- veadk/database/viking/__init__.py +13 -0
- veadk/database/viking/viking_database.py +378 -0
- veadk/database/viking/viking_memory_db.py +521 -0
- veadk/evaluation/__init__.py +17 -0
- veadk/evaluation/adk_evaluator/__init__.py +13 -0
- veadk/evaluation/adk_evaluator/adk_evaluator.py +291 -0
- veadk/evaluation/base_evaluator.py +242 -0
- veadk/evaluation/deepeval_evaluator/__init__.py +17 -0
- veadk/evaluation/deepeval_evaluator/deepeval_evaluator.py +223 -0
- veadk/evaluation/eval_set_file_loader.py +28 -0
- veadk/evaluation/eval_set_recorder.py +91 -0
- veadk/evaluation/utils/prometheus.py +142 -0
- veadk/knowledgebase/__init__.py +17 -0
- veadk/knowledgebase/knowledgebase.py +83 -0
- veadk/knowledgebase/knowledgebase_database_adapter.py +259 -0
- veadk/memory/__init__.py +13 -0
- veadk/memory/long_term_memory.py +119 -0
- veadk/memory/memory_database_adapter.py +235 -0
- veadk/memory/short_term_memory.py +124 -0
- veadk/memory/short_term_memory_processor.py +90 -0
- veadk/prompts/__init__.py +13 -0
- veadk/prompts/agent_default_prompt.py +30 -0
- veadk/prompts/prompt_evaluator.py +20 -0
- veadk/prompts/prompt_memory_processor.py +55 -0
- veadk/prompts/prompt_optimization.py +158 -0
- veadk/runner.py +252 -0
- veadk/tools/__init__.py +13 -0
- veadk/tools/builtin_tools/__init__.py +13 -0
- veadk/tools/builtin_tools/lark.py +67 -0
- veadk/tools/builtin_tools/las.py +23 -0
- veadk/tools/builtin_tools/vesearch.py +49 -0
- veadk/tools/builtin_tools/web_scraper.py +76 -0
- veadk/tools/builtin_tools/web_search.py +192 -0
- veadk/tools/demo_tools.py +58 -0
- veadk/tools/load_knowledgebase_tool.py +144 -0
- veadk/tools/sandbox/__init__.py +13 -0
- veadk/tools/sandbox/browser_sandbox.py +27 -0
- veadk/tools/sandbox/code_sandbox.py +30 -0
- veadk/tools/sandbox/computer_sandbox.py +27 -0
- veadk/tracing/__init__.py +13 -0
- veadk/tracing/base_tracer.py +172 -0
- veadk/tracing/telemetry/__init__.py +13 -0
- veadk/tracing/telemetry/exporters/__init__.py +13 -0
- veadk/tracing/telemetry/exporters/apiserver_exporter.py +60 -0
- veadk/tracing/telemetry/exporters/apmplus_exporter.py +101 -0
- veadk/tracing/telemetry/exporters/base_exporter.py +28 -0
- veadk/tracing/telemetry/exporters/cozeloop_exporter.py +69 -0
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +88 -0
- veadk/tracing/telemetry/exporters/tls_exporter.py +78 -0
- veadk/tracing/telemetry/metrics/__init__.py +13 -0
- veadk/tracing/telemetry/metrics/opentelemetry_metrics.py +73 -0
- veadk/tracing/telemetry/opentelemetry_tracer.py +167 -0
- veadk/types.py +23 -0
- veadk/utils/__init__.py +13 -0
- veadk/utils/logger.py +59 -0
- veadk/utils/misc.py +33 -0
- veadk/utils/patches.py +85 -0
- veadk/utils/volcengine_sign.py +199 -0
- veadk/version.py +15 -0
- veadk_python-0.1.0.dist-info/METADATA +124 -0
- veadk_python-0.1.0.dist-info/RECORD +110 -0
- veadk_python-0.1.0.dist-info/WHEEL +5 -0
- veadk_python-0.1.0.dist-info/entry_points.txt +2 -0
- veadk_python-0.1.0.dist-info/licenses/LICENSE +201 -0
- veadk_python-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import io
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
import uuid
|
|
19
|
+
from typing import Any, BinaryIO, Literal, Optional, TextIO
|
|
20
|
+
|
|
21
|
+
import requests
|
|
22
|
+
import tos
|
|
23
|
+
from pydantic import BaseModel, Field
|
|
24
|
+
from volcengine.auth.SignerV4 import SignerV4
|
|
25
|
+
from volcengine.base.Request import Request
|
|
26
|
+
from volcengine.Credentials import Credentials
|
|
27
|
+
|
|
28
|
+
from veadk.config import getenv
|
|
29
|
+
from veadk.database.base_database import BaseDatabase
|
|
30
|
+
from veadk.utils.logger import get_logger
|
|
31
|
+
|
|
32
|
+
logger = get_logger(__name__)
|
|
33
|
+
|
|
34
|
+
# knowledge base domain
|
|
35
|
+
g_knowledge_base_domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
|
|
36
|
+
# paths
|
|
37
|
+
create_collection_path = "/api/knowledge/collection/create"
|
|
38
|
+
search_knowledge_path = "/api/knowledge/collection/search_knowledge"
|
|
39
|
+
list_collections_path = "/api/knowledge/collection/list"
|
|
40
|
+
get_collections_path = "/api/knowledge/collection/info"
|
|
41
|
+
doc_add_path = "/api/knowledge/doc/add"
|
|
42
|
+
doc_info_path = "/api/knowledge/doc/info"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class VolcengineTOSConfig(BaseModel):
|
|
46
|
+
endpoint: Optional[str] = Field(
|
|
47
|
+
default=getenv("DATABASE_TOS_ENDPOINT", "tos-cn-beijing.volces.com"),
|
|
48
|
+
description="VikingDB TOS endpoint",
|
|
49
|
+
)
|
|
50
|
+
region: Optional[str] = Field(
|
|
51
|
+
default=getenv("DATABASE_TOS_REGION", "cn-beijing"),
|
|
52
|
+
description="VikingDB TOS region",
|
|
53
|
+
)
|
|
54
|
+
bucket: Optional[str] = Field(
|
|
55
|
+
default=getenv("DATABASE_TOS_BUCKET"),
|
|
56
|
+
description="VikingDB TOS bucket",
|
|
57
|
+
)
|
|
58
|
+
base_key: Optional[str] = Field(
|
|
59
|
+
default="veadk",
|
|
60
|
+
description="VikingDB TOS base key",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class VikingDatabaseConfig(BaseModel):
|
|
65
|
+
volcengine_ak: Optional[str] = Field(
|
|
66
|
+
default=getenv("VOLCENGINE_ACCESS_KEY"),
|
|
67
|
+
description="VikingDB access key",
|
|
68
|
+
)
|
|
69
|
+
volcengine_sk: Optional[str] = Field(
|
|
70
|
+
default=getenv("VOLCENGINE_SECRET_KEY"),
|
|
71
|
+
description="VikingDB secret key",
|
|
72
|
+
)
|
|
73
|
+
project: Optional[str] = Field(
|
|
74
|
+
default=getenv("DATABASE_VIKING_PROJECT"),
|
|
75
|
+
description="VikingDB project name",
|
|
76
|
+
)
|
|
77
|
+
region: Optional[str] = Field(
|
|
78
|
+
default=getenv("DATABASE_VIKING_REGION"),
|
|
79
|
+
description="VikingDB region",
|
|
80
|
+
)
|
|
81
|
+
tos: Optional[VolcengineTOSConfig] = Field(
|
|
82
|
+
default_factory=VolcengineTOSConfig,
|
|
83
|
+
description="VikingDB TOS configuration",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def prepare_request(
|
|
88
|
+
method, path, config: VikingDatabaseConfig, params=None, data=None, doseq=0
|
|
89
|
+
):
|
|
90
|
+
ak = config.volcengine_ak
|
|
91
|
+
sk = config.volcengine_sk
|
|
92
|
+
|
|
93
|
+
if params:
|
|
94
|
+
for key in params:
|
|
95
|
+
if (
|
|
96
|
+
type(params[key]) is int
|
|
97
|
+
or type(params[key]) is float
|
|
98
|
+
or type(params[key]) is bool
|
|
99
|
+
):
|
|
100
|
+
params[key] = str(params[key])
|
|
101
|
+
elif type(params[key]) is list:
|
|
102
|
+
if not doseq:
|
|
103
|
+
params[key] = ",".join(params[key])
|
|
104
|
+
r = Request()
|
|
105
|
+
r.set_shema("https")
|
|
106
|
+
r.set_method(method)
|
|
107
|
+
r.set_connection_timeout(10)
|
|
108
|
+
r.set_socket_timeout(10)
|
|
109
|
+
mheaders = {
|
|
110
|
+
"Accept": "application/json",
|
|
111
|
+
"Content-Type": "application/json",
|
|
112
|
+
}
|
|
113
|
+
r.set_headers(mheaders)
|
|
114
|
+
if params:
|
|
115
|
+
r.set_query(params)
|
|
116
|
+
r.set_path(path)
|
|
117
|
+
if data is not None:
|
|
118
|
+
r.set_body(json.dumps(data))
|
|
119
|
+
credentials = Credentials(ak, sk, "air", config.region)
|
|
120
|
+
SignerV4.sign(r, credentials)
|
|
121
|
+
return r
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class VikingDatabase(BaseModel, BaseDatabase):
|
|
125
|
+
config: VikingDatabaseConfig = Field(
|
|
126
|
+
default_factory=VikingDatabaseConfig,
|
|
127
|
+
description="VikingDB configuration",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _upload_to_tos(
|
|
131
|
+
self,
|
|
132
|
+
data: str | list[str] | TextIO | BinaryIO | bytes,
|
|
133
|
+
**kwargs: Any,
|
|
134
|
+
):
|
|
135
|
+
file_ext = kwargs.get(
|
|
136
|
+
"file_ext", ".pdf"
|
|
137
|
+
) # when bytes data, file_ext is required
|
|
138
|
+
ak = self.config.volcengine_ak
|
|
139
|
+
sk = self.config.volcengine_sk
|
|
140
|
+
tos_bucket = self.config.tos.bucket
|
|
141
|
+
tos_endpoint = self.config.tos.endpoint
|
|
142
|
+
tos_region = self.config.tos.region
|
|
143
|
+
tos_key = self.config.tos.base_key
|
|
144
|
+
client = tos.TosClientV2(ak, sk, tos_endpoint, tos_region, max_connections=1024)
|
|
145
|
+
if isinstance(data, str) and os.path.isfile(data): # Process file path
|
|
146
|
+
file_ext = os.path.splitext(data)[1]
|
|
147
|
+
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"
|
|
148
|
+
with open(data, "rb") as f:
|
|
149
|
+
upload_data = f.read()
|
|
150
|
+
|
|
151
|
+
elif isinstance(
|
|
152
|
+
data,
|
|
153
|
+
(io.TextIOWrapper, io.BufferedReader), # file type: TextIO | BinaryIO
|
|
154
|
+
): # Process file stream
|
|
155
|
+
# Try to get the file extension from the file name, and use the default value if there is none
|
|
156
|
+
file_ext = ".unknown"
|
|
157
|
+
if hasattr(data, "name"):
|
|
158
|
+
_, file_ext = os.path.splitext(data.name)
|
|
159
|
+
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"
|
|
160
|
+
if isinstance(data, TextIO):
|
|
161
|
+
# Encode the text stream content into bytes
|
|
162
|
+
upload_data = data.read().encode("utf-8")
|
|
163
|
+
else:
|
|
164
|
+
# Read the content of the binary stream
|
|
165
|
+
upload_data = data.read()
|
|
166
|
+
|
|
167
|
+
elif isinstance(data, str): # Process ordinary strings
|
|
168
|
+
new_key = f"{tos_key}/{str(uuid.uuid4())}.txt"
|
|
169
|
+
upload_data = data.encode("utf-8") # Encode as byte type
|
|
170
|
+
|
|
171
|
+
elif isinstance(data, list): # Process list of strings
|
|
172
|
+
new_key = f"{tos_key}/{str(uuid.uuid4())}.txt"
|
|
173
|
+
# Join the strings in the list with newlines and encode as byte type
|
|
174
|
+
upload_data = "\n".join(data).encode("utf-8")
|
|
175
|
+
|
|
176
|
+
elif isinstance(data, bytes): # Process bytes data
|
|
177
|
+
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"
|
|
178
|
+
upload_data = data
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
182
|
+
|
|
183
|
+
resp = client.put_object(tos_bucket, new_key, content=upload_data)
|
|
184
|
+
tos_url = f"{tos_bucket}/{new_key}"
|
|
185
|
+
|
|
186
|
+
return resp.resp.status, tos_url
|
|
187
|
+
|
|
188
|
+
def _add_doc(self, collection_name: str, tos_url: str, doc_id: str, **kwargs: Any):
|
|
189
|
+
request_params = {
|
|
190
|
+
"collection_name": collection_name,
|
|
191
|
+
"project": self.config.project,
|
|
192
|
+
"add_type": "tos",
|
|
193
|
+
"doc_id": doc_id,
|
|
194
|
+
"tos_path": tos_url,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
doc_add_req = prepare_request(
|
|
198
|
+
method="POST", path=doc_add_path, config=self.config, data=request_params
|
|
199
|
+
)
|
|
200
|
+
rsp = requests.request(
|
|
201
|
+
method=doc_add_req.method,
|
|
202
|
+
url="https://{}{}".format(g_knowledge_base_domain, doc_add_req.path),
|
|
203
|
+
headers=doc_add_req.headers,
|
|
204
|
+
data=doc_add_req.body,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
result = rsp.json()
|
|
208
|
+
if result["code"] != 0:
|
|
209
|
+
logger.error(f"Error in add_doc: {result['message']}")
|
|
210
|
+
return {"error": result["message"]}
|
|
211
|
+
|
|
212
|
+
doc_add_data = result["data"]
|
|
213
|
+
if not doc_add_data:
|
|
214
|
+
raise ValueError(f"doc {doc_id} has no data.")
|
|
215
|
+
|
|
216
|
+
return doc_id
|
|
217
|
+
|
|
218
|
+
def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any):
|
|
219
|
+
"""
|
|
220
|
+
Args:
|
|
221
|
+
data: str, file path or file stream: Both file or file.read() are acceptable.
|
|
222
|
+
**kwargs: collection_name(required)
|
|
223
|
+
Returns:
|
|
224
|
+
{
|
|
225
|
+
"tos_url": "tos://<bucket>/<key>",
|
|
226
|
+
"doc_id": "<doc_id>",
|
|
227
|
+
}
|
|
228
|
+
"""
|
|
229
|
+
collection_name = kwargs.get("collection_name")
|
|
230
|
+
assert collection_name is not None, "collection_name is required"
|
|
231
|
+
|
|
232
|
+
status, tos_url = self._upload_to_tos(data=data, **kwargs)
|
|
233
|
+
if status != 200:
|
|
234
|
+
raise ValueError(f"Error in upload_to_tos: {status}")
|
|
235
|
+
doc_id = self._add_doc(
|
|
236
|
+
collection_name=collection_name,
|
|
237
|
+
tos_url=tos_url,
|
|
238
|
+
doc_id=str(uuid.uuid4()),
|
|
239
|
+
)
|
|
240
|
+
return {
|
|
241
|
+
"tos_url": f"tos://{tos_url}",
|
|
242
|
+
"doc_id": doc_id,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
def delete(self, **kwargs: Any):
|
|
246
|
+
# collection_name = kwargs.get("collection_name")
|
|
247
|
+
# todo: delete vikingdb
|
|
248
|
+
...
|
|
249
|
+
|
|
250
|
+
def query(self, query: str, **kwargs: Any) -> list[str]:
|
|
251
|
+
"""
|
|
252
|
+
Args:
|
|
253
|
+
query: query text
|
|
254
|
+
**kwargs: collection_name(required), top_k(optional, default 5)
|
|
255
|
+
|
|
256
|
+
Returns: list of str, the search result
|
|
257
|
+
"""
|
|
258
|
+
collection_name = kwargs.get("collection_name")
|
|
259
|
+
assert collection_name is not None, "collection_name is required"
|
|
260
|
+
request_params = {
|
|
261
|
+
"query": query,
|
|
262
|
+
"limit": int(kwargs.get("top_k", 5)),
|
|
263
|
+
"name": collection_name,
|
|
264
|
+
"project": self.config.project,
|
|
265
|
+
}
|
|
266
|
+
search_req = prepare_request(
|
|
267
|
+
method="POST",
|
|
268
|
+
path=search_knowledge_path,
|
|
269
|
+
config=self.config,
|
|
270
|
+
data=request_params,
|
|
271
|
+
)
|
|
272
|
+
resp = requests.request(
|
|
273
|
+
method=search_req.method,
|
|
274
|
+
url="https://{}{}".format(g_knowledge_base_domain, search_req.path),
|
|
275
|
+
headers=search_req.headers,
|
|
276
|
+
data=search_req.body,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
result = resp.json()
|
|
280
|
+
if result["code"] != 0:
|
|
281
|
+
logger.error(f"Error in search_knowledge: {result['message']}")
|
|
282
|
+
raise ValueError(f"Error in search_knowledge: {result['message']}")
|
|
283
|
+
|
|
284
|
+
if not result["data"]["result_list"]:
|
|
285
|
+
raise ValueError(f"No results found for collection {collection_name}")
|
|
286
|
+
|
|
287
|
+
chunks = result["data"]["result_list"]
|
|
288
|
+
|
|
289
|
+
search_result = []
|
|
290
|
+
|
|
291
|
+
for chunk in chunks:
|
|
292
|
+
search_result.append(chunk["content"])
|
|
293
|
+
|
|
294
|
+
return search_result
|
|
295
|
+
|
|
296
|
+
def create_collection(
|
|
297
|
+
self,
|
|
298
|
+
collection_name: str,
|
|
299
|
+
description: str = "",
|
|
300
|
+
version: Literal[2, 4] = 4,
|
|
301
|
+
data_type: Literal[
|
|
302
|
+
"unstructured_data", "structured_data"
|
|
303
|
+
] = "unstructured_data",
|
|
304
|
+
chunking_strategy: Literal["custom_balance", "custom"] = "custom_balance",
|
|
305
|
+
chunk_length: int = 500,
|
|
306
|
+
merge_small_chunks: bool = True,
|
|
307
|
+
):
|
|
308
|
+
request_params = {
|
|
309
|
+
"name": collection_name,
|
|
310
|
+
"project": self.config.project,
|
|
311
|
+
"description": description,
|
|
312
|
+
"version": version,
|
|
313
|
+
"data_type": data_type,
|
|
314
|
+
"preprocessing": {
|
|
315
|
+
"chunking_strategy": chunking_strategy,
|
|
316
|
+
"chunk_length": chunk_length,
|
|
317
|
+
"merge_small_chunks": merge_small_chunks,
|
|
318
|
+
},
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
create_collection_req = prepare_request(
|
|
322
|
+
method="POST",
|
|
323
|
+
path=create_collection_path,
|
|
324
|
+
config=self.config,
|
|
325
|
+
data=request_params,
|
|
326
|
+
)
|
|
327
|
+
resp = requests.request(
|
|
328
|
+
method=create_collection_req.method,
|
|
329
|
+
url="https://{}{}".format(
|
|
330
|
+
g_knowledge_base_domain, create_collection_req.path
|
|
331
|
+
),
|
|
332
|
+
headers=create_collection_req.headers,
|
|
333
|
+
data=create_collection_req.body,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
result = resp.json()
|
|
337
|
+
if result["code"] != 0:
|
|
338
|
+
logger.error(f"Error in create_collection: {result['message']}")
|
|
339
|
+
raise ValueError(f"Error in create_collection: {result['message']}")
|
|
340
|
+
return result
|
|
341
|
+
|
|
342
|
+
def collection_exists(self, collection_name: str) -> bool:
|
|
343
|
+
request_params = {
|
|
344
|
+
"project": self.config.project,
|
|
345
|
+
}
|
|
346
|
+
list_collections_req = prepare_request(
|
|
347
|
+
method="POST",
|
|
348
|
+
path=list_collections_path,
|
|
349
|
+
config=self.config,
|
|
350
|
+
data=request_params,
|
|
351
|
+
)
|
|
352
|
+
resp = requests.request(
|
|
353
|
+
method=list_collections_req.method,
|
|
354
|
+
url="https://{}{}".format(
|
|
355
|
+
g_knowledge_base_domain, list_collections_req.path
|
|
356
|
+
),
|
|
357
|
+
headers=list_collections_req.headers,
|
|
358
|
+
data=list_collections_req.body,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
result = resp.json()
|
|
362
|
+
if result["code"] != 0:
|
|
363
|
+
logger.error(f"Error in list_collections: {result['message']}")
|
|
364
|
+
raise ValueError(f"Error in list_collections: {result['message']}")
|
|
365
|
+
|
|
366
|
+
collections = result["data"]["collection_list"]
|
|
367
|
+
if not collections:
|
|
368
|
+
raise ValueError(f"No collections found in project {self.config.project}.")
|
|
369
|
+
|
|
370
|
+
collection_list = set()
|
|
371
|
+
|
|
372
|
+
for collection in collections:
|
|
373
|
+
collection_list.add(collection["collection_name"])
|
|
374
|
+
# check the collection exist or not
|
|
375
|
+
if collection_name in collection_list:
|
|
376
|
+
return True
|
|
377
|
+
else:
|
|
378
|
+
return False
|